Source code for bamt.utils.GraphUtils

from typing import Dict, List, Tuple, Type

import networkx as nx
from pandas import DataFrame

from bamt.log import logger_preprocessor
from bamt.nodes.base import BaseNode


[docs] def nodes_types(data: DataFrame) -> Dict[str, str]: """ Function to define the type of the node disc - discrete node cont - continuous Args: data: input dataset Returns: dict: output dictionary where 'key' - node name and 'value' - node type """ column_type = dict() for c in data.columns.to_list(): disc = ["str", "O", "b", "categorical", "object", "bool"] disc_numerical = ["int32", "int64"] cont = ["float32", "float64"] if data[c].dtype.name in disc: column_type[c] = "disc" elif data[c].dtype.name in cont: column_type[c] = "cont" elif data[c].dtype.name in disc_numerical: column_type[c] = "disc_num" else: logger_preprocessor.error(f"Unsupported data type. Dtype: {data[c].dtypes}") return column_type
[docs] def nodes_signs(nodes_types: dict, data: DataFrame) -> Dict[str, str]: """Function to define sign of the node neg - if node has negative values pos - if node has only positive values Args: data (pd.DataFrame): input dataset Returns: dict: output dictionary where 'key' - node name and 'value' - sign of data """ if list(nodes_types.keys()) != data.columns.to_list(): logger_preprocessor.error("Nodes_types dictionary is not full.") return columns_sign = dict() for c in data.columns.to_list(): if nodes_types[c] == "cont": if (data[c] < 0).any(): columns_sign[c] = "neg" else: columns_sign[c] = "pos" return columns_sign
[docs] def get_descriptor(data) -> Dict[str, Dict[str, str]]: return {"types": nodes_types(data), "signs": nodes_signs(nodes_types(data), data)}
[docs] def toporder(nodes: List[Type[BaseNode]], edges: List[Tuple]) -> List[List[str]]: """ Function for topological sorting """ G = nx.DiGraph() G.add_nodes_from([node.name for node in nodes]) G.add_edges_from(edges) return list(nx.topological_sort(G))
[docs] class GraphAnalyzer(object): """ Object to analyze DAG. """ def __init__(self, bn): self.bn = bn def _isolate_structure(self, nodes): isolated_edges = [] for edge in self.bn.edges: if edge[0] in nodes and edge[1] in nodes: isolated_edges.append(edge) return isolated_edges
[docs] def markov_blanket(self, node_name: str): node = self.bn[node_name] parents = node.cont_parents + node.disc_parents children = node.children fremd_eltern = [] for child in node.children: all_parents = self.bn[child].cont_parents + self.bn[child].disc_parents if all_parents == [node_name]: continue else: new = all_parents fremd_eltern.extend(new) nodes = parents + children + fremd_eltern + [node_name] edges = self._isolate_structure(nodes) return {"nodes": list(set(nodes)), "edges": edges}
def _collect_height(self, node_name, height): nodes = [] node = self.bn[node_name] if height <= 0: return [] if height == 1: return node.disc_parents + node.cont_parents for parent in node.cont_parents + node.disc_parents: nodes.append(parent) nodes.extend(self._collect_height(parent, height=height - 1)) return nodes def _collect_depth(self, node_name, depth): nodes = [] node = self.bn[node_name] if depth <= 0: return [] if depth == 1: return node.children for child in node.children: nodes.append(child) nodes.extend(self._collect_depth(child, depth=depth - 1)) return nodes
[docs] def find_family(self, *args): node_name, height, depth, with_nodes = args if not with_nodes: with_nodes = [] else: with_nodes = list(with_nodes) nodes = ( self._collect_depth(node_name, depth) + self._collect_height(node_name, height) + [node_name] ) nodes = list(set(nodes + with_nodes)) return {"nodes": nodes, "edges": self._isolate_structure(nodes + with_nodes)}