Source code for bamt.nodes.logit_node

import random
from typing import Optional, List, Union
from bamt.result_models.node_result import LogitNodeResult
import numpy as np
from pandas import DataFrame
from sklearn import linear_model

from .base import BaseNode
from .schema import LogitParams


[docs] class LogitNode(BaseNode): """ Main class for logit node """ def __init__(self, name, classifier: Optional[object] = None): super(LogitNode, self).__init__(name) if classifier is None: classifier = linear_model.LogisticRegression( solver="newton-cg", max_iter=100 ) self.classifier = classifier self.type = "Logit" + f" ({type(self.classifier).__name__})"
[docs] def fit_parameters(self, data: DataFrame, **kwargs) -> LogitParams: parents = self.cont_parents + self.disc_parents self.classifier.fit(X=data[parents].values, y=data[self.name].values, **kwargs) return { "classes": list(self.classifier.classes_), "classifier_obj": self.classifier, "classifier": type(self.classifier).__name__, "serialization": None, }
[docs] def get_dist(self, node_info, pvals): if len(node_info["classes"]) > 1: model = node_info["classifier_obj"] if type(self).__name__ == "CompositeDiscreteNode": pvals = [int(item) if isinstance(item, str) else item for item in pvals] probs = model.predict_proba(np.array(pvals).reshape(1, -1))[0] else: probs = np.array([1.0]) return LogitNodeResult(probs=probs, values=node_info["classes"])
[docs] def choose(self, node_info: LogitParams, pvals: List[Union[float]]) -> str: """ Return value from Logit node params: node_info: nodes info from distributions pvals: parent values """ rindex = 0 distribution, _ = self.get_dist(node_info, pvals).get() if len(node_info["classes"]) > 1: rand = random.random() lbound = 0 ubound = 0 for interval in range(len(node_info["classes"])): ubound += distribution[interval] if lbound <= rand < ubound: rindex = interval break else: lbound = ubound return str(node_info["classes"][rindex]) else: return str(node_info["classes"][0])
[docs] @staticmethod def predict(node_info: LogitParams, pvals: List[Union[float]]) -> str: """ Return prediction from Logit node params: node_info: nodes info from distributions pvals: parent values """ if len(node_info["classes"]) > 1: model = node_info["classifier_obj"] pred = model.predict(np.array(pvals).reshape(1, -1))[0] return str(pred) else: return str(node_info["classes"][0])