Source code for bamt.nodes.base

import pickle
from typing import Union


[docs] class BaseNode(object): """ Base class for nodes. """ def __init__(self, name: str): """ :param name: name for node (taken from column name) type: node type disc_parents: list with discrete parents cont_parents: list with continuous parents children: node's children """ self.name = name self.type = "abstract" self.disc_parents = [] self.cont_parents = [] self.children = [] def __repr__(self): return f"{self.name}" def __eq__(self, other): if not isinstance(other, BaseNode): # don't attempt to compare against unrelated types return NotImplemented return ( self.name == other.name and self.type == other.type and self.disc_parents == other.disc_parents and self.cont_parents == other.cont_parents and self.children == other.children )
[docs] @staticmethod def choose_serialization(model) -> Union[str, Exception]: try: ex_b = pickle.dumps(model, protocol=4) model_ser = ex_b.decode("latin1").replace("'", '"') if type(model).__name__ == "CatBoostRegressor": a = model_ser.encode("latin1") else: a = model_ser.replace('"', "'").encode("latin1") classifier_body = pickle.loads(a) return "pickle" except Exception as ex: return ex
[docs] @staticmethod def get_dist(node_info, pvals): pass