Source code for bamt.nodes.conditional_logit_node

import itertools
import random
from typing import Optional, List, Union, Dict
from bamt.result_models.node_result import ConditionalLogitNodeResult
import numpy as np
from pandas import DataFrame
from sklearn import linear_model
from sklearn.base import clone

from .base import BaseNode
from .schema import LogitParams


[docs] class ConditionalLogitNode(BaseNode): """ Main class for Conditional Logit Node """ def __init__(self, name: str, classifier: Optional[object] = None): super(ConditionalLogitNode, self).__init__(name) if classifier is None: classifier = linear_model.LogisticRegression( solver="newton-cg", max_iter=100 ) self.classifier = classifier self.type = "ConditionalLogit" + f" ({type(self.classifier).__name__})"
[docs] def fit_parameters(self, data: DataFrame) -> Dict[str, Dict[str, LogitParams]]: """ Train params on data Return: {"hybcprob": {<combination of outputs from discrete parents> : LogitParams}} """ hycprob = dict() values = [] combinations = [] for d_p in self.disc_parents: values.append(np.unique(data[d_p].values)) for xs in itertools.product(*values): combinations.append(list(xs)) for comb in combinations: mask = np.full(len(data), True) for col, val in zip(self.disc_parents, comb): mask = (mask) & (data[col] == val) new_data = data[mask] # mean_base = [np.nan] classes = [np.nan] key_comb = [str(x) for x in comb] if new_data.shape[0] != 0: model = clone(self.classifier) values = set(new_data[self.name]) if len(values) > 1: model.fit( X=new_data[self.cont_parents].values, y=new_data[self.name].values, ) classes = list(model.classes_) hycprob[str(key_comb)] = { "classes": classes, "classifier_obj": model, "classifier": type(self.classifier).__name__, "serialization": None, } else: classes = list(values) hycprob[str(key_comb)] = { "classes": classes, "classifier": type(self.classifier).__name__, "classifier_obj": None, "serialization": None, } else: hycprob[str(key_comb)] = { "classes": list(classes), "classifier": type(self.classifier).__name__, "classifier_obj": None, "serialization": None, } return {"hybcprob": hycprob}
[docs] @staticmethod def get_dist(node_info, pvals, **kwargs): dispvals = [] lgpvals = [] for pval in pvals: if isinstance(pval, str): dispvals.append(pval) else: lgpvals.append(pval) lgdistribution = node_info["hybcprob"][str(dispvals)] if any(parent_value == "nan" for parent_value in dispvals): return ConditionalLogitNodeResult(probs=(np.nan, np.nan), values=lgdistribution["classes"]) # JOBLIB if len(lgdistribution["classes"]) > 1: model = lgdistribution["classifier_obj"] distribution = model.predict_proba(np.array(lgpvals).reshape(1, -1))[0] if not kwargs.get("inner", False): return ConditionalLogitNodeResult(probs=distribution, values=lgdistribution["classes"]) else: return distribution, lgdistribution else: if not kwargs.get("inner", False): return ConditionalLogitNodeResult(probs=np.array([1.0]), values=lgdistribution["classes"]) else: return np.array([1.0]), lgdistribution
[docs] def choose( self, node_info: Dict[str, Dict[str, LogitParams]], pvals: List[Union[str, float]], ) -> str: """ Return value from ConditionalLogit node params: node_info: nodes info from distributions pvals: parent values """ distribution, lgdistribution = self.get_dist(node_info, pvals, inner=True) # JOBLIB if len(lgdistribution["classes"]) > 1: rand = random.random() rindex = 0 lbound = 0 ubound = 0 for interval in range(len(lgdistribution["classes"])): ubound += distribution[interval] if lbound <= rand < ubound: rindex = interval break else: lbound = ubound return str(lgdistribution["classes"][rindex]) else: return str(lgdistribution["classes"][0])
[docs] @staticmethod def predict( node_info: Dict[str, Dict[str, LogitParams]], pvals: List[Union[str, float]] ) -> str: """ Return value from ConditionalLogit node params: node_info: nodes info from distributions pvals: parent values """ dispvals = [] lgpvals = [] for pval in pvals: if isinstance(pval, str): dispvals.append(pval) else: lgpvals.append(pval) lgdistribution = node_info["hybcprob"][str(dispvals)] # JOBLIB if len(lgdistribution["classes"]) > 1: model = lgdistribution["classifier_obj"] pred = model.predict(np.array(lgpvals).reshape(1, -1))[0] return str(pred) else: return str(lgdistribution["classes"][0])