From e3b216338000aabfdda007cb6a410d1092a4d998 Mon Sep 17 00:00:00 2001 From: Daniel Nowak <63320957+danielnowakassis@users.noreply.github.com> Date: Fri, 6 Sep 2024 16:47:33 -0300 Subject: [PATCH] Adding Local Adaptive Streaming Tree (#1610) * add first last with out-of-date docs * update LAST to detect change in the data distribution + iter_arff none class Docs are also updated * update docs * Update hoeffding_adaptive_tree_classifier.py * changes after tests * solving inheritance and small fixes * tests + current_merit method * Update river/tree/hoeffding_adaptive_tree_classifier.py * update docs * change docs * change last * Update docs/releases/unreleased.md * Update river/tree/hoeffding_adaptive_tree_classifier.py * Update river/tree/last_classifier.py * add disclamer * Update river/tree/last_classifier.py --------- Co-authored-by: Saulo Martiello Mastelini --- docs/releases/unreleased.md | 6 + river/stream/iter_arff.py | 7 +- river/tree/__init__.py | 2 + river/tree/last_classifier.py | 363 ++++++++++++++++++ river/tree/nodes/last_nodes.py | 206 ++++++++++ river/tree/split_criterion/base.py | 14 + .../split_criterion/gini_split_criterion.py | 3 + .../hellinger_distance_criterion.py | 3 + .../info_gain_split_criterion.py | 3 + ...ster_variance_reduction_split_criterion.py | 3 + .../variance_ratio_split_criterion.py | 3 + .../variance_reduction_split_criterion.py | 3 + 12 files changed, 614 insertions(+), 2 deletions(-) create mode 100644 river/tree/last_classifier.py create mode 100644 river/tree/nodes/last_nodes.py diff --git a/docs/releases/unreleased.md b/docs/releases/unreleased.md index dfb7dd4b19..e0536f435b 100644 --- a/docs/releases/unreleased.md +++ b/docs/releases/unreleased.md @@ -17,3 +17,9 @@ ## tree - Instead of letting trees grow indefinitely, setting the `max_depth` parameter to `None` will stop the trees from growing when they reach the system recursion limit. + +-Added `tree.LASTClassifier` (Local Adaptive Streaming Tree Classifier). + +## stream + +- `stream.iter_arff` now supports blank values (treated as missing values). \ No newline at end of file diff --git a/river/stream/iter_arff.py b/river/stream/iter_arff.py index 4464e3f74a..f9eec46ded 100644 --- a/river/stream/iter_arff.py +++ b/river/stream/iter_arff.py @@ -176,7 +176,7 @@ def iter_arff( x = { name: cast(val) if cast else val for name, cast, val in zip(names, casts, r.rstrip().split(",")) - if val != "?" + if val != "?" and val != "" } # Handle target @@ -185,7 +185,10 @@ def iter_arff( if isinstance(target, list): y = {name: x.pop(name, 0) for name in target} else: - y = x.pop(target) if target else None + try: + y = x.pop(target) if target else None + except KeyError: + y = None yield x, y diff --git a/river/tree/__init__.py b/river/tree/__init__.py index 9e8b887217..ad53a1ecf8 100755 --- a/river/tree/__init__.py +++ b/river/tree/__init__.py @@ -59,6 +59,7 @@ from .hoeffding_tree_classifier import HoeffdingTreeClassifier from .hoeffding_tree_regressor import HoeffdingTreeRegressor from .isoup_tree_regressor import iSOUPTreeRegressor +from .last_classifier import LASTClassifier from .stochastic_gradient_tree import SGTClassifier, SGTRegressor __all__ = [ @@ -70,6 +71,7 @@ "HoeffdingTreeRegressor", "HoeffdingAdaptiveTreeRegressor", "iSOUPTreeRegressor", + "LASTClassifier", "SGTClassifier", "SGTRegressor", ] diff --git a/river/tree/last_classifier.py b/river/tree/last_classifier.py new file mode 100644 index 0000000000..391ba83668 --- /dev/null +++ b/river/tree/last_classifier.py @@ -0,0 +1,363 @@ +from __future__ import annotations + +from river import base, drift + +from .hoeffding_tree_classifier import HoeffdingTreeClassifier +from .nodes.branch import DTBranch +from .nodes.last_nodes import ( + LeafMajorityClassWithDetector, + LeafNaiveBayesAdaptiveWithDetector, + LeafNaiveBayesWithDetector, +) +from .nodes.leaf import HTLeaf +from .split_criterion import GiniSplitCriterion, HellingerDistanceCriterion, InfoGainSplitCriterion +from .splitter import Splitter + + +class LASTClassifier(HoeffdingTreeClassifier, base.Classifier): + """Local Adaptive Streaming Tree Classifier. + + Local Adaptive Streaming Tree [^1] (LAST) is an incremental decision tree with + adaptive splitting mechanisms. LAST maintains a change detector at each leaf and splits + this node if a change is detected in the error or the leaf`s data distribution. + + LAST is still not suitable for use as a base classifier in ensembles due to the change detectors. + The authors in [^1] are working on a version of LAST that overcomes this limitation. + + Parameters + ---------- + max_depth + The maximum depth a tree can reach. If `None`, the tree will grow until the system recursion limit. + split_criterion + Split criterion to use.
+ - 'gini' - Gini
+ - 'info_gain' - Information Gain
+ - 'hellinger' - Helinger Distance
+ leaf_prediction + Prediction mechanism used at leafs.
+ - 'mc' - Majority Class
+ - 'nb' - Naive Bayes
+ - 'nba' - Naive Bayes Adaptive
+ change_detector + Change detector that will be created at each leaf of the tree. + track_error + If True, the change detector will have binary inputs for error predictions, + otherwise the input will be the split criteria. + nb_threshold + Number of instances a leaf should observe before allowing Naive Bayes. + nominal_attributes + List of Nominal attributes identifiers. If empty, then assume that all numeric + attributes should be treated as continuous. + splitter + The Splitter or Attribute Observer (AO) used to monitor the class statistics of numeric + features and perform splits. Splitters are available in the `tree.splitter` module. + Different splitters are available for classification and regression tasks. Classification + and regression splitters can be distinguished by their property `is_target_class`. + This is an advanced option. Special care must be taken when choosing different splitters. + By default, `tree.splitter.GaussianSplitter` is used if `splitter` is `None`. + binary_split + If True, only allow binary splits. + min_branch_fraction + The minimum percentage of observed data required for branches resulting from split + candidates. To validate a split candidate, at least two resulting branches must have + a percentage of samples greater than `min_branch_fraction`. This criterion prevents + unnecessary splits when the majority of instances are concentrated in a single branch. + max_share_to_split + Only perform a split in a leaf if the proportion of elements in the majority class is + smaller than this parameter value. This parameter avoids performing splits when most + of the data belongs to a single class. + max_size + The max size of the tree, in Megabytes (MB). + memory_estimate_period + Interval (number of processed instances) between memory consumption checks. + stop_mem_management + If True, stop growing as soon as memory limit is hit. + remove_poor_attrs + If True, disable poor attributes to reduce memory usage. + merit_preprune + If True, enable merit-based tree pre-pruning. + + References + ---------- + + [^1]: Daniel Nowak Assis, Jean Paul Barddal, and Fabrício Enembreck. + Just Change on Change: Adaptive Splitting Time for Decision Trees in + Data Stream Classification . In Proceedings of ACM SAC Conference (SAC’24). + + Examples + -------- + + >>> from river.datasets import synth + >>> from river import evaluate + >>> from river import metrics + >>> from river import tree + + >>> gen = synth.ConceptDriftStream(stream=synth.SEA(seed=42, variant=0), + ... drift_stream=synth.SEA(seed=42, variant=1), + ... seed=1, position=1500, width=50) + >>> dataset = iter(gen.take(3000)) + + >>> model = tree.LASTClassifier() + + >>> metric = metrics.Accuracy() + + >>> evaluate.progressive_val_score(dataset, model, metric) + Accuracy: 91.10% + + """ + + def __init__( + self, + max_depth: int | None = None, + split_criterion: str = "info_gain", + leaf_prediction: str = "nba", + change_detector: base.DriftDetector | None = None, + track_error: bool = True, + nb_threshold: int = 0, + nominal_attributes: list | None = None, + splitter: Splitter | None = None, + binary_split: bool = False, + min_branch_fraction: float = 0.01, + max_share_to_split: float = 0.99, + max_size: float = 100.0, + memory_estimate_period: int = 1000000, + stop_mem_management: bool = False, + remove_poor_attrs: bool = False, + merit_preprune: bool = True, + ): + super().__init__( + grace_period=1, # no usage + max_depth=max_depth, + split_criterion=split_criterion, + delta=1.0, # no usage + tau=1, # no usage + leaf_prediction=leaf_prediction, + nb_threshold=nb_threshold, + binary_split=binary_split, + max_size=max_size, + memory_estimate_period=memory_estimate_period, + stop_mem_management=stop_mem_management, + remove_poor_attrs=remove_poor_attrs, + merit_preprune=merit_preprune, + nominal_attributes=nominal_attributes, + splitter=splitter, + min_branch_fraction=min_branch_fraction, + max_share_to_split=max_share_to_split, + ) + self.change_detector = change_detector if change_detector is not None else drift.ADWIN() + self.track_error = track_error + + # To keep track of the observed classes + self.classes: set = set() + + @property + def _mutable_attributes(self): + return {} + + def _new_leaf(self, initial_stats=None, parent=None): + if initial_stats is None: + initial_stats = {} + if parent is None: + depth = 0 + else: + depth = parent.depth + 1 + + if self._leaf_prediction == self._MAJORITY_CLASS: + return LeafMajorityClassWithDetector( + initial_stats, + depth, + self.splitter, + self.change_detector.clone(), + split_criterion=self._new_split_criterion() if not self.track_error else None, + ) + elif self._leaf_prediction == self._NAIVE_BAYES: + return LeafNaiveBayesWithDetector( + initial_stats, + depth, + self.splitter, + self.change_detector.clone(), + split_criterion=self._new_split_criterion() if not self.track_error else None, + ) + else: # Naives Bayes Adaptive (default) + return LeafNaiveBayesAdaptiveWithDetector( + initial_stats, + depth, + self.splitter, + self.change_detector.clone(), + split_criterion=self._new_split_criterion() if not self.track_error else None, + ) + + def _new_split_criterion(self): + if self._split_criterion == self._GINI_SPLIT: + split_criterion = GiniSplitCriterion(self.min_branch_fraction) + elif self._split_criterion == self._INFO_GAIN_SPLIT: + split_criterion = InfoGainSplitCriterion(self.min_branch_fraction) + elif self._split_criterion == self._HELLINGER_SPLIT: + if not self.track_error: + raise ValueError( + "The Heillinger distance cannot estimate the purity of a single distribution.\ + Use another split criterion or set track_error to True" + ) + split_criterion = HellingerDistanceCriterion(self.min_branch_fraction) + else: + split_criterion = InfoGainSplitCriterion(self.min_branch_fraction) + + return split_criterion + + def _attempt_to_split(self, leaf: HTLeaf, parent: DTBranch, parent_branch: int, **kwargs): + """Attempt to split a leaf. + + If the samples seen so far are not from the same class then: + + 1. Find split candidates and select the top 1. + 2. If the top1 is greater than zero: + 3.1 Replace the leaf node by a split node (branch node). + 3.2 Add a new leaf node on each branch of the new split node. + 3.3 Update tree's metrics + + Optional: Disable poor attributes. Depends on the tree's configuration. + + Parameters + ---------- + leaf + The leaf to evaluate. + parent + The leaf's parent. + parent_branch + Parent leaf's branch index. + kwargs + Other parameters passed to the new branch. + """ + if not leaf.observed_class_distribution_is_pure(): # type: ignore + split_criterion = self._new_split_criterion() + + best_split_suggestions = leaf.best_split_suggestions(split_criterion, self) + should_split = False + if len(best_split_suggestions) < 2: + should_split = len(best_split_suggestions) > 0 + else: + best_suggestion = max(best_split_suggestions) + should_split = best_suggestion.merit > 0.0 + if self.remove_poor_attrs: + poor_atts = set() + # Add any poor attribute to set + for suggestion in best_split_suggestions: + poor_atts.add(suggestion.feature) + for poor_att in poor_atts: + leaf.disable_attribute(poor_att) + if should_split: + split_decision = max(best_split_suggestions) + if split_decision.feature is None: + # Pre-pruning - null wins + leaf.deactivate() + self._n_inactive_leaves += 1 + self._n_active_leaves -= 1 + else: + branch = self._branch_selector( + split_decision.numerical_feature, split_decision.multiway_split + ) + leaves = tuple( + self._new_leaf(initial_stats, parent=leaf) + for initial_stats in split_decision.children_stats # type: ignore + ) + + new_split = split_decision.assemble( + branch, leaf.stats, leaf.depth, *leaves, **kwargs + ) + + self._n_active_leaves -= 1 + self._n_active_leaves += len(leaves) + if parent is None: + self._root = new_split + else: + parent.children[parent_branch] = new_split + + # Manage memory + self._enforce_size_limit() + + def learn_one(self, x, y, *, w=1.0): + """Train the model on instance x and corresponding target y. + + Parameters + ---------- + x + Instance attributes. + y + Class label for sample x. + w + Sample weight. + + Notes + ----- + Training tasks: + + * If the tree is empty, create a leaf node as the root. + * If the tree is already initialized, find the corresponding leaf for + the instance and update the leaf node statistics. + * Update the leaf change detector with (1 if the tree misclassified the instance, + or 0 if it correctly classified) or the data distribution purity + * If growth is allowed then attempt + to split. + """ + + # Updates the set of observed classes + self.classes.add(y) + + self._train_weight_seen_by_model += w + + if self._root is None: + self._root = self._new_leaf() + self._n_active_leaves = 1 + + p_node = None + node = None + if isinstance(self._root, DTBranch): + path = iter(self._root.walk(x, until_leaf=False)) + while True: + aux = next(path, None) + if aux is None: + break + p_node = node + node = aux + else: + node = self._root + + if isinstance(node, HTLeaf): + node.learn_one(x, y, w=w, tree=self) + if self._growth_allowed and node.is_active(): + if node.depth >= self.max_depth: # Max depth reached + node.deactivate() + self._n_active_leaves -= 1 + self._n_inactive_leaves += 1 + else: + weight_seen = node.total_weight + # check if the change detector triggered a change + if node.change_detector.drift_detected: + p_branch = p_node.branch_no(x) if isinstance(p_node, DTBranch) else None + self._attempt_to_split(node, p_node, p_branch) + node.last_split_attempt_at = weight_seen + else: + while True: + # Split node encountered a previously unseen categorical value (in a multi-way + # test), so there is no branch to sort the instance to + if node.max_branches() == -1 and node.feature in x: + # Create a new branch to the new categorical value + leaf = self._new_leaf(parent=node) + node.add_child(x[node.feature], leaf) + self._n_active_leaves += 1 + node = leaf + # The split feature is missing in the instance. Hence, we pass the new example + # to the most traversed path in the current subtree + else: + _, node = node.most_common_path() + # And we keep trying to reach a leaf + if isinstance(node, DTBranch): + node = node.traverse(x, until_leaf=False) + # Once a leaf is reached, the traversal can stop + if isinstance(node, HTLeaf): + break + # Learn from the sample + node.learn_one(x, y, w=w, tree=self) + + if self._train_weight_seen_by_model % self.memory_estimate_period == 0: + self._estimate_model_size() diff --git a/river/tree/nodes/last_nodes.py b/river/tree/nodes/last_nodes.py new file mode 100644 index 0000000000..f130fe17ac --- /dev/null +++ b/river/tree/nodes/last_nodes.py @@ -0,0 +1,206 @@ +from __future__ import annotations + +from ..utils import do_naive_bayes_prediction +from .htc_nodes import LeafMajorityClass + + +class LeafMajorityClassWithDetector(LeafMajorityClass): + """Leaf that always predicts the majority class. + + Parameters + ---------- + stats + Initial class observations. + depth + The depth of the node. + splitter + The numeric attribute observer algorithm used to monitor target statistics + and perform split attempts. + change_detector + Change detector that monitors the leaf error rate or class distribution and + determines when the leaf will split. + split_criterion + Split criterion used in the tree for updating the change detector if it + monitors the class distribution. + kwargs + Other parameters passed to the learning node. + """ + + def __init__(self, stats, depth, splitter, change_detector, split_criterion=None, **kwargs): + super().__init__(stats, depth, splitter, **kwargs) + self.change_detector = change_detector + # change this in future PR's by acessing the tree parameter in the leaf + self.split_criterion = ( + split_criterion # if None, the change detector will have binary inputs + ) + + def learn_one(self, x, y, *, w=1, tree=None): + self.update_stats(y, w) + if self.is_active(): + if self.split_criterion is None: + mc_pred = self.prediction(x) + detector_input = max(mc_pred, key=mc_pred.get) != y + self.change_detector.update(detector_input) + else: + detector_input = self.split_criterion.current_merit(self.stats) + self.change_detector.update(detector_input) + self.update_splitters(x, y, w, tree.nominal_attributes) + + +class LeafNaiveBayesWithDetector(LeafMajorityClassWithDetector): + """Leaf that uses Naive Bayes models. + + Parameters + ---------- + stats + Initial class observations. + depth + The depth of the node. + splitter + The numeric attribute observer algorithm used to monitor target statistics + and perform split attempts. + change_detector + Change detector that monitors the leaf error rate or class distribution and + determines when the leaf will split. + split_criterion + Split criterion used in the tree for updating the change detector if it + monitors the class distribution. + kwargs + Other parameters passed to the learning node. + """ + + def __init__(self, stats, depth, splitter, change_detector, split_criterion=None, **kwargs): + super().__init__(stats, depth, splitter, change_detector, split_criterion, **kwargs) + + def learn_one(self, x, y, *, w=1, tree=None): + self.update_stats(y, w) + if self.is_active(): + if self.split_criterion is None: + nb_pred = self.prediction(x) + detector_input = max(nb_pred, key=nb_pred.get) == y + self.change_detector.update(detector_input) + else: + detector_input = self.split_criterion.current_merit(self.stats) + self.change_detector.update(detector_input) + self.update_splitters(x, y, w, tree.nominal_attributes) + + def prediction(self, x, *, tree=None): + if self.is_active() and self.total_weight >= tree.nb_threshold: + return do_naive_bayes_prediction(x, self.stats, self.splitters) + else: + return super().prediction(x) + + def disable_attribute(self, att_index): + """Disable an attribute observer. + + Disabled in Nodes using Naive Bayes, since poor attributes are used in + Naive Bayes calculation. + + Parameters + ---------- + att_index + Attribute index. + """ + pass + + +class LeafNaiveBayesAdaptiveWithDetector(LeafMajorityClassWithDetector): + """Learning node that uses Adaptive Naive Bayes models. + + Parameters + ---------- + stats + Initial class observations. + depth + The depth of the node. + splitter + The numeric attribute observer algorithm used to monitor target statistics + and perform split attempts. + change_detector + Change detector that monitors the leaf error rate or class distribution and + determines when the leaf will split. + split_criterion + Split criterion used in the tree for updating the change detector if it + monitors the class distribution. + kwargs + Other parameters passed to the learning node. + """ + + def __init__(self, stats, depth, splitter, change_detector, split_criterion=None, **kwargs): + super().__init__(stats, depth, splitter, change_detector, split_criterion, **kwargs) + self._mc_correct_weight = 0.0 + self._nb_correct_weight = 0.0 + + def learn_one(self, x, y, *, w=1.0, tree=None): + """Update the node with the provided instance. + + Parameters + ---------- + x + Instance attributes for updating the node. + y + Instance class. + w + The instance's weight. + tree + The tree to update. + + """ + detector_input_mc = 1 + detector_input_nb = 1 + if self.is_active(): + mc_pred = super().prediction(x) + # Empty node (assume the majority class will be the best option) or majority + # class prediction is correct + if len(self.stats) == 0 or max(mc_pred, key=mc_pred.get) == y: + self._mc_correct_weight += w + detector_input_mc = 0 + nb_pred = do_naive_bayes_prediction(x, self.stats, self.splitters) + if len(nb_pred) > 0 and max(nb_pred, key=nb_pred.get) == y: + self._nb_correct_weight += w + detector_input_nb = 0 + + self.update_stats(y, w) + if self.is_active(): + if self.split_criterion is None: + if self._nb_correct_weight >= self._mc_correct_weight: + self.change_detector.update(detector_input_nb) + else: + self.change_detector.update(detector_input_mc) + else: + detector_input = self.split_criterion.current_merit(self.stats) + self.change_detector.update(detector_input) + self.update_splitters(x, y, w, tree.nominal_attributes) + + def prediction(self, x, *, tree=None): + """Get the probabilities per class for a given instance. + + Parameters + ---------- + x + Instance attributes. + tree + LAST Tree. + + Returns + ------- + Class votes for the given instance. + + """ + if self.is_active() and self._nb_correct_weight >= self._mc_correct_weight: + return do_naive_bayes_prediction(x, self.stats, self.splitters) + else: + return super().prediction(x) + + def disable_attribute(self, att_index): + """Disable an attribute observer. + + Disabled in Nodes using Naive Bayes, since poor attributes are used in + Naive Bayes calculation. + + Parameters + ---------- + att_index + Attribute index. + """ + pass diff --git a/river/tree/split_criterion/base.py b/river/tree/split_criterion/base.py index d329bf6d58..84da388b4b 100644 --- a/river/tree/split_criterion/base.py +++ b/river/tree/split_criterion/base.py @@ -32,6 +32,20 @@ def merit_of_split(self, pre_split_dist, post_split_dist): Value of the merit of splitting """ + @abc.abstractmethod + def current_merit(self, dist): + """Compute the merit of the distribution. + + Parameters + ---------- + dist + The data distribution. + + Returns + ------- + Value of merit of the distribution according to the splitting criterion + """ + @staticmethod @abc.abstractmethod def range_of_merit(pre_split_dist): diff --git a/river/tree/split_criterion/gini_split_criterion.py b/river/tree/split_criterion/gini_split_criterion.py index 9b4c6e3187..1a71c1b651 100644 --- a/river/tree/split_criterion/gini_split_criterion.py +++ b/river/tree/split_criterion/gini_split_criterion.py @@ -28,6 +28,9 @@ def merit_of_split(self, pre_split_dist, post_split_dist): ) return 1.0 - gini + def current_merit(self, dist): + return self.compute_gini(dist, sum(dist.values())) + @staticmethod def compute_gini(dist, dist_sum_of_weights): gini = 1.0 diff --git a/river/tree/split_criterion/hellinger_distance_criterion.py b/river/tree/split_criterion/hellinger_distance_criterion.py index 5ad379b6a9..236564f9bd 100644 --- a/river/tree/split_criterion/hellinger_distance_criterion.py +++ b/river/tree/split_criterion/hellinger_distance_criterion.py @@ -28,6 +28,9 @@ def merit_of_split(self, pre_split_dist, post_split_dist): return -math.inf return self.compute_hellinger(post_split_dist) + def current_merit(self, dist): + raise ValueError("The Heillinger distance is for 2 or more sets of data.") + @staticmethod def compute_hellinger(dist): try: diff --git a/river/tree/split_criterion/info_gain_split_criterion.py b/river/tree/split_criterion/info_gain_split_criterion.py index 0863e6ac84..b112ff829d 100644 --- a/river/tree/split_criterion/info_gain_split_criterion.py +++ b/river/tree/split_criterion/info_gain_split_criterion.py @@ -39,6 +39,9 @@ def compute_entropy(self, dist): elif isinstance(dist, list): return self._compute_entropy_list(dist) + def current_merit(self, dist): + return self.compute_entropy(dist) + @staticmethod def _compute_entropy_dict(dist): entropy = 0.0 diff --git a/river/tree/split_criterion/intra_cluster_variance_reduction_split_criterion.py b/river/tree/split_criterion/intra_cluster_variance_reduction_split_criterion.py index 1436e817e2..0a1af74773 100644 --- a/river/tree/split_criterion/intra_cluster_variance_reduction_split_criterion.py +++ b/river/tree/split_criterion/intra_cluster_variance_reduction_split_criterion.py @@ -27,6 +27,9 @@ def merit_of_split(self, pre_split_dist, post_split_dist): icvr -= n_i / n * self.compute_var(dist) return icvr + def current_merit(self, dist): + return self.compute_var(dist) + @staticmethod def compute_var(dist): icvr = [vr.get() for vr in dist.values()] diff --git a/river/tree/split_criterion/variance_ratio_split_criterion.py b/river/tree/split_criterion/variance_ratio_split_criterion.py index c51df25a14..dfdff8ea55 100644 --- a/river/tree/split_criterion/variance_ratio_split_criterion.py +++ b/river/tree/split_criterion/variance_ratio_split_criterion.py @@ -34,6 +34,9 @@ def merit_of_split(self, pre_split_dist, post_split_dist): vr -= (n_i / n) * (self.compute_var(post_split_dist[i]) / var) return vr + def current_merit(self, dist): + return self.compute_var(dist) + @staticmethod def compute_var(dist): return dist.get() diff --git a/river/tree/split_criterion/variance_reduction_split_criterion.py b/river/tree/split_criterion/variance_reduction_split_criterion.py index f52cfa7bd3..147ead573b 100644 --- a/river/tree/split_criterion/variance_reduction_split_criterion.py +++ b/river/tree/split_criterion/variance_reduction_split_criterion.py @@ -35,6 +35,9 @@ def merit_of_split(self, pre_split_dist, post_split_dist): vr -= n_i / n * self.compute_var(post_split_dist[i]) return vr + def current_merit(self, dist): + return self.compute_var(dist) + @staticmethod def compute_var(dist): return dist.get()