Source code for bamt.nodes.conditional_logit_node

import itertools
import pickle
import random
from typing import Optional, List, Union, Dict

import joblib
import numpy as np
from pandas import DataFrame
from sklearn import linear_model

from bamt.log import logger_nodes
from .base import BaseNode
from .schema import LogitParams


[docs] class ConditionalLogitNode(BaseNode): """ Main class for Conditional Logit Node """ def __init__(self, name: str, classifier: Optional[object] = None): super(ConditionalLogitNode, self).__init__(name) if classifier is None: classifier = linear_model.LogisticRegression( multi_class="multinomial", solver="newton-cg", max_iter=100 ) self.classifier = classifier self.type = "ConditionalLogit" + f" ({type(self.classifier).__name__})"
[docs] def fit_parameters(self, data: DataFrame) -> Dict[str, Dict[str, LogitParams]]: """ Train params on data Return: {"hybcprob": {<combination of outputs from discrete parents> : LogitParams}} """ hycprob = dict() values = [] combinations = [] for d_p in self.disc_parents: values.append(np.unique(data[d_p].values)) for xs in itertools.product(*values): combinations.append(list(xs)) for comb in combinations: mask = np.full(len(data), True) for col, val in zip(self.disc_parents, comb): mask = (mask) & (data[col] == val) new_data = data[mask] # mean_base = [np.nan] classes = [np.nan] key_comb = [str(x) for x in comb] if new_data.shape[0] != 0: model = self.classifier values = set(new_data[self.name]) if len(values) > 1: model.fit( X=new_data[self.cont_parents].values, y=new_data[self.name].values, ) classes = list(model.classes_) serialization = self.choose_serialization(model) if serialization == "pickle": ex_b = pickle.dumps(self.classifier, protocol=4) model_ser = ex_b.decode("latin1") # model_ser = pickle.dumps(self.classifier, protocol=4) hycprob[str(key_comb)] = { "classes": classes, "classifier_obj": model_ser, "classifier": type(self.classifier).__name__, "serialization": "pickle", } else: logger_nodes.warning( f"{self.name} {comb}::Pickle failed. BAMT will use Joblib. | " + str(serialization.args[0]) ) path = self.get_path_joblib( node_name=self.name.replace(" ", "_"), specific=comb ) joblib.dump(model, path, compress=True, protocol=4) hycprob[str(key_comb)] = { "classes": classes, "classifier_obj": path, "classifier": type(self.classifier).__name__, "serialization": "joblib", } else: classes = list(values) hycprob[str(key_comb)] = { "classes": classes, "classifier": type(self.classifier).__name__, "classifier_obj": None, "serialization": None, } else: hycprob[str(key_comb)] = { "classes": list(classes), "classifier": type(self.classifier).__name__, "classifier_obj": None, "serialization": None, } return {"hybcprob": hycprob}
[docs] @staticmethod def get_dist(node_info, pvals, **kwargs): dispvals = [] lgpvals = [] for pval in pvals: if isinstance(pval, str): dispvals.append(pval) else: lgpvals.append(pval) if any(parent_value == "nan" for parent_value in dispvals): return np.nan lgdistribution = node_info["hybcprob"][str(dispvals)] # JOBLIB if len(lgdistribution["classes"]) > 1: if lgdistribution["serialization"] == "joblib": model = joblib.load(lgdistribution["classifier_obj"]) else: # str_model = lgdistribution["classifier_obj"].decode('latin1').replace('\'', '\"') bytes_model = lgdistribution["classifier_obj"].encode("latin1") model = pickle.loads(bytes_model) distribution = model.predict_proba(np.array(lgpvals).reshape(1, -1))[0] if not kwargs.get("inner", False): return distribution else: return distribution, lgdistribution else: if not kwargs.get("inner", False): return np.array([1.0]) else: return np.array([1.0]), lgdistribution
[docs] def choose( self, node_info: Dict[str, Dict[str, LogitParams]], pvals: List[Union[str, float]], ) -> str: """ Return value from ConditionalLogit node params: node_info: nodes info from distributions pvals: parent values """ distribution, lgdistribution = self.get_dist(node_info, pvals, inner=True) # JOBLIB if len(lgdistribution["classes"]) > 1: rand = random.random() rindex = 0 lbound = 0 ubound = 0 for interval in range(len(lgdistribution["classes"])): ubound += distribution[interval] if lbound <= rand < ubound: rindex = interval break else: lbound = ubound return str(lgdistribution["classes"][rindex]) else: return str(lgdistribution["classes"][0])
[docs] @staticmethod def predict( node_info: Dict[str, Dict[str, LogitParams]], pvals: List[Union[str, float]] ) -> str: """ Return value from ConditionalLogit node params: node_info: nodes info from distributions pvals: parent values """ dispvals = [] lgpvals = [] for pval in pvals: if isinstance(pval, str): dispvals.append(pval) else: lgpvals.append(pval) lgdistribution = node_info["hybcprob"][str(dispvals)] # JOBLIB if len(lgdistribution["classes"]) > 1: if lgdistribution["serialization"] == "joblib": model = joblib.load(lgdistribution["classifier_obj"]) else: # str_model = lgdistribution["classifier_obj"].decode('latin1').replace('\'', '\"') bytes_model = lgdistribution["classifier_obj"].encode("latin1") model = pickle.loads(bytes_model) pred = model.predict(np.array(lgpvals).reshape(1, -1))[0] return str(pred) else: return str(lgdistribution["classes"][0])