diff --git a/logix/analysis/influence_function.py b/logix/analysis/influence_function.py index 45b0a4b3..5d89a29f 100644 --- a/logix/analysis/influence_function.py +++ b/logix/analysis/influence_function.py @@ -74,7 +74,8 @@ def compute_influence( tgt_log: Tuple[str, Dict[str, Dict[str, torch.Tensor]]], mode: Optional[str] = "dot", precondition: Optional[bool] = True, - precondition_hessian: Optional[str] = "auto", + hessian: Optional[str] = "auto", + influence_groups: Optional[List[str]] = None, damping: Optional[float] = None, ): """ @@ -92,12 +93,18 @@ def compute_influence( result = {} if precondition: src_log = self.precondition( - src_log=src_log, damping=damping, hessian=precondition_hessian + src_log=src_log, damping=damping, hessian=hessian ) src_ids, src = src_log tgt_ids, tgt = tgt_log + + # Initialize influence total_influence = 0 + if influence_groups is not None: + total_influence = {"total": 0} + for influence_group in influence_groups: + total_influence[influence_group] = 0 # Compute influence scores. By default, we should compute the basic influence # scores, which is essentially the inner product between the source and target @@ -113,27 +120,58 @@ def compute_influence( else: synchronize_device(src, tgt) for module_name in src.keys(): - total_influence += cross_dot_product( + module_influence = cross_dot_product( src[module_name]["grad"], tgt[module_name]["grad"] ) + if influence_groups is None: + total_influence += module_influence + else: + total_influence["total"] += module_influence + in_groups = [ig for ig in influence_groups if ig in module_name] + for group in in_groups: + total_influence[group] += module_influence if mode == "cosine": tgt_norm = self.compute_self_influence( - tgt_log, precondition=True, damping=damping + tgt_log, + precondition=True, + hessian=hessian, + influence_groups=influence_groups, + damping=damping, ) - total_influence /= torch.sqrt(tgt_norm.unsqueeze(0)) + if influence_groups is None: + total_influence /= torch.sqrt(tgt_norm.unsqueeze(0)) + else: + for key in total_influence.keys(): + total_influence[key] /= tgt_norm[key] elif mode == "l2": tgt_norm = self.compute_self_influence( - tgt_log, precondition=True, damping=damping + tgt_log, + precondition=True, + hessian=hessian, + influence_groups=influence_groups, + damping=damping, ) - total_influence = 2 * total_influence - tgt_norm.unsqueeze(0) + if influence_groups is None: + total_influence -= 0.5 * tgt_norm.unsqueeze(0) + else: + for key in total_influence.keys(): + total_influence[key] -= 0.5 * tgt_norm[key].unsqueeze(0) - assert total_influence.shape[0] == len(src_ids) - assert total_influence.shape[1] == len(tgt_ids) + # Move influence scores to CPU to save memory + if influence_groups is None: + assert total_influence.shape[0] == len(src_ids) + assert total_influence.shape[1] == len(tgt_ids) + total_influence = total_influence.cpu() + else: + for key, value in total_influence.items(): + assert value.shape[0] == len(src_ids) + assert value.shape[1] == len(tgt_ids) + total_influence[key] = value.cpu() result["src_ids"] = src_ids result["tgt_ids"] = tgt_ids - result["influence"] = total_influence.cpu() + result["influence"] = total_influence return result @@ -142,6 +180,8 @@ def compute_self_influence( self, src_log: Tuple[str, Dict[str, Dict[str, torch.Tensor]]], precondition: Optional[bool] = True, + hessian: Optional[str] = "auto", + influence_groups: Optional[List[str]] = None, damping: Optional[float] = None, ): """ @@ -160,18 +200,44 @@ def compute_self_influence( src = unflatten_log( log=src, path=self._state.get_state("model_module")["path"] ) - tgt = self.precondition(src_log, damping)[1] if precondition else src - # Compute self-influence scores + tgt = src + if precondition: + tgt = self.precondition(src_log, hessian=hessian, damping=damping)[1] + + # Initialize influence total_influence = 0 + if influence_groups is not None: + total_influence = {"total": 0} + for influence_group in influence_groups: + total_influence[influence_group] = 0 + + # Compute self-influence scores for module_name in src.keys(): src_module = src[module_name]["grad"] tgt_module = tgt[module_name]["grad"] if tgt is not None else src_module - module_influence = reduce(src_module * tgt_module, "n a b -> n", "sum") - total_influence += module_influence.reshape(-1) + module_influence = reduce( + src_module * tgt_module, "n a b -> n", "sum" + ).reshape(-1) + if influence_groups is None: + total_influence += module_influence + else: + total_influence["total"] += module_influence + in_groups = [ig for ig in influence_groups if ig in module_name] + for group in in_groups: + total_influence[group] += module_influence + + # Move influence scores to CPU to save memory + if influence_groups is not None: + assert len(total_influence) == len(src_ids) + total_influence = total_influence.cpu() + else: + for key, value in total_influence.items(): + assert len(value) == len(src_ids) + total_influence[key] = value.cpu() result["src_ids"] = src_ids - result["influence"] = total_influence.cpu() + result["influence"] = total_influence return result @@ -181,6 +247,8 @@ def compute_influence_all( loader: torch.utils.data.DataLoader, mode: Optional[str] = "dot", precondition: Optional[bool] = True, + hessian: Optional[str] = "auto", + influence_groups: Optional[List[str]] = None, damping: Optional[float] = None, ): """ @@ -195,21 +263,34 @@ def compute_influence_all( damping (Optional[float], optional): Damping parameter for preconditioning. Defaults to None. """ if precondition: - src_log = self.precondition(src_log, damping) + src_log = self.precondition(src_log, hessian=hessian, damping=damping) result_all = None for tgt_log in tqdm(loader, desc="Compute IF"): result = self.compute_influence( - src_log, tgt_log, mode=mode, precondition=False, damping=damping + src_log, + tgt_log, + mode=mode, + precondition=False, + hessian=hessian, + influence_groups=influence_groups, + damping=damping, ) # Merge results if result_all is None: result_all = result - else: - result_all["tgt_ids"].extend(result["tgt_ids"]) + continue + result_all["tgt_ids"].extend(result["tgt_ids"]) + if influence_groups is None: result_all["influence"] = torch.cat( [result_all["influence"], result["influence"]], dim=1 ) + else: + for key in result_all["influence"].keys(): + result_all["influence"][key] = torch.cat( + [result_all["influence"][key], result["influence"][key]], + dim=1, + ) return result_all