import random
from itertools import product
from typing import Type, Dict, Union, List
from bamt.result_models.node_result import DiscreteNodeResult
import numpy as np
from pandas import DataFrame, crosstab
from .base import BaseNode
from .schema import DiscreteParams
[docs]
class DiscreteNode(BaseNode):
"""
Main class of Discrete Node
"""
def __init__(self, name):
super(DiscreteNode, self).__init__(name)
self.type = "Discrete"
[docs]
def fit_parameters(self, data: DataFrame, num_workers: int = 1):
"""
Train params for Discrete Node
data: DataFrame to train on
num_workers: number of Parallel Workers
Method returns probas dict with the following format {[<combinations>: value]}
and vals, list of appeared values in combinations
"""
def worker(node: Type[BaseNode]) -> DiscreteParams:
parents = node.disc_parents + node.cont_parents
dist = data[node.name].value_counts(normalize=True).sort_index()
vals = [str(i) for i in dist.index.to_list()]
if not parents:
cprob = dist.to_list()
else:
cprob = {
str([str(i) for i in comb]): [1 / len(vals) for _ in vals]
for comb in product(*[data[p].unique() for p in parents])
}
conditional_dist = crosstab(
data[node.name].to_list(),
[data[p] for p in parents],
normalize="columns",
).T
tight_form = conditional_dist.to_dict("tight")
for comb, probs in zip(tight_form["index"], tight_form["data"]):
if len(parents) > 1:
cprob[str([str(i) for i in comb])] = probs
else:
cprob[f"['{comb}']"] = probs
return {"cprob": cprob, "vals": vals}
# pool = ThreadPoolExecutor(num_workers)
# future = pool.submit(worker, self)
result = worker(self)
return result
[docs]
@staticmethod
def get_dist(node_info, pvals):
if pvals:
probs = node_info["cprob"][str(pvals)]
else:
probs = node_info["cprob"]
return DiscreteNodeResult(probs=probs, values=node_info["vals"])
[docs]
def choose(self, node_info: Dict[str, Union[float, str]], pvals: List[str]) -> str:
"""
Return value from discrete node
params:
node_info: nodes info from distributions
pvals: parent values
"""
vals = node_info["vals"]
probs = self.get_dist(node_info, pvals).get()[0]
cumulative_dist = np.cumsum(probs)
rand = np.random.random()
rindex = np.searchsorted(cumulative_dist, rand)
return vals[rindex]
[docs]
@staticmethod
def predict(node_info: Dict[str, Union[float, str]], pvals: List[str]) -> str:
"""function for prediction based on evidence values in discrete node
Args:
node_info (Dict[str, Union[float, str]]): parameters of node
pvals (List[str]): values in parents nodes
Returns:
str: prediction
"""
vals = node_info["vals"]
disct = []
if not pvals:
dist = node_info["cprob"]
else:
# noinspection PyTypeChecker
dist = node_info["cprob"][str(pvals)]
max_value = max(dist)
indices = [index for index, value in enumerate(dist) if value == max_value]
max_ind = 0
if len(indices) == 1:
max_ind = indices[0]
else:
max_ind = random.choice(indices)
return vals[max_ind]