import pickle
from typing import Union
[docs]
class BaseNode(object):
"""
Base class for nodes.
"""
def __init__(self, name: str):
"""
:param name: name for node (taken from column name)
type: node type
disc_parents: list with discrete parents
cont_parents: list with continuous parents
children: node's children
"""
self.name = name
self.type = "abstract"
self.disc_parents = []
self.cont_parents = []
self.children = []
def __repr__(self):
return f"{self.name}"
def __eq__(self, other):
if not isinstance(other, BaseNode):
# don't attempt to compare against unrelated types
return NotImplemented
return (
self.name == other.name
and self.type == other.type
and self.disc_parents == other.disc_parents
and self.cont_parents == other.cont_parents
and self.children == other.children
)
[docs]
@staticmethod
def choose_serialization(model) -> Union[str, Exception]:
try:
ex_b = pickle.dumps(model, protocol=4)
model_ser = ex_b.decode("latin1").replace("'", '"')
if type(model).__name__ == "CatBoostRegressor":
a = model_ser.encode("latin1")
else:
a = model_ser.replace('"', "'").encode("latin1")
classifier_body = pickle.loads(a)
return "pickle"
except Exception as ex:
return ex
[docs]
@staticmethod
def get_dist(node_info, pvals):
pass