Source code for bamt.nodes.base

import os
import pickle
from typing import Union

from bamt.config import config

STORAGE = config.get(
    "NODES", "models_storage", fallback="models_storage is not defined"
)


[docs] class BaseNode(object): """ Base class for nodes. """ def __init__(self, name: str): """ :param name: name for node (taken from column name) type: node type disc_parents: list with discrete parents cont_parents: list with continuous parents children: node's children """ self.name = name self.type = "abstract" self.disc_parents = [] self.cont_parents = [] self.children = [] def __repr__(self): return f"{self.name}" def __eq__(self, other): if not isinstance(other, BaseNode): # don't attempt to compare against unrelated types return NotImplemented return ( self.name == other.name and self.type == other.type and self.disc_parents == other.disc_parents and self.cont_parents == other.cont_parents and self.children == other.children )
[docs] @staticmethod def choose_serialization(model) -> Union[str, Exception]: try: ex_b = pickle.dumps(model, protocol=4) model_ser = ex_b.decode("latin1").replace("'", '"') if type(model).__name__ == "CatBoostRegressor": a = model_ser.encode("latin1") else: a = model_ser.replace('"', "'").encode("latin1") classifier_body = pickle.loads(a) return "pickle" except Exception as ex: return ex
[docs] @staticmethod def get_path_joblib(node_name: str, specific: str = "") -> str: """ Args: node_name: name of node specific: more specific unique name for node. For example, combination. Return: Path to save a joblib file. """ if not isinstance(specific, str): specific = str(specific) index = str(int(os.listdir(STORAGE)[-1])) path_to_check = os.path.join(STORAGE, index, f"{node_name.replace(' ', '_')}") if not os.path.isdir(path_to_check): os.makedirs(os.path.join(STORAGE, index, f"{node_name.replace(' ', '_')}")) path = os.path.abspath( os.path.join(path_to_check, f"{specific}.joblib.compressed") ) return path
[docs] @staticmethod def get_dist(node_info, pvals): pass