Skip to content

Commit

Permalink
Support module-wise influence computations (#99)
Browse files Browse the repository at this point in the history
* WIP

* WIP

* WIP

* fix bug

* minor clean

* support module-wise influence computations
  • Loading branch information
sangkeun00 authored Apr 22, 2024
1 parent f24712b commit 4410152
Showing 1 changed file with 100 additions and 19 deletions.
119 changes: 100 additions & 19 deletions logix/analysis/influence_function.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,8 @@ def compute_influence(
tgt_log: Tuple[str, Dict[str, Dict[str, torch.Tensor]]],
mode: Optional[str] = "dot",
precondition: Optional[bool] = True,
precondition_hessian: Optional[str] = "auto",
hessian: Optional[str] = "auto",
influence_groups: Optional[List[str]] = None,
damping: Optional[float] = None,
):
"""
Expand All @@ -92,12 +93,18 @@ def compute_influence(
result = {}
if precondition:
src_log = self.precondition(
src_log=src_log, damping=damping, hessian=precondition_hessian
src_log=src_log, damping=damping, hessian=hessian
)

src_ids, src = src_log
tgt_ids, tgt = tgt_log

# Initialize influence
total_influence = 0
if influence_groups is not None:
total_influence = {"total": 0}
for influence_group in influence_groups:
total_influence[influence_group] = 0

# Compute influence scores. By default, we should compute the basic influence
# scores, which is essentially the inner product between the source and target
Expand All @@ -113,27 +120,58 @@ def compute_influence(
else:
synchronize_device(src, tgt)
for module_name in src.keys():
total_influence += cross_dot_product(
module_influence = cross_dot_product(
src[module_name]["grad"], tgt[module_name]["grad"]
)
if influence_groups is None:
total_influence += module_influence
else:
total_influence["total"] += module_influence
in_groups = [ig for ig in influence_groups if ig in module_name]
for group in in_groups:
total_influence[group] += module_influence

if mode == "cosine":
tgt_norm = self.compute_self_influence(
tgt_log, precondition=True, damping=damping
tgt_log,
precondition=True,
hessian=hessian,
influence_groups=influence_groups,
damping=damping,
)
total_influence /= torch.sqrt(tgt_norm.unsqueeze(0))
if influence_groups is None:
total_influence /= torch.sqrt(tgt_norm.unsqueeze(0))
else:
for key in total_influence.keys():
total_influence[key] /= tgt_norm[key]
elif mode == "l2":
tgt_norm = self.compute_self_influence(
tgt_log, precondition=True, damping=damping
tgt_log,
precondition=True,
hessian=hessian,
influence_groups=influence_groups,
damping=damping,
)
total_influence = 2 * total_influence - tgt_norm.unsqueeze(0)
if influence_groups is None:
total_influence -= 0.5 * tgt_norm.unsqueeze(0)
else:
for key in total_influence.keys():
total_influence[key] -= 0.5 * tgt_norm[key].unsqueeze(0)

assert total_influence.shape[0] == len(src_ids)
assert total_influence.shape[1] == len(tgt_ids)
# Move influence scores to CPU to save memory
if influence_groups is None:
assert total_influence.shape[0] == len(src_ids)
assert total_influence.shape[1] == len(tgt_ids)
total_influence = total_influence.cpu()
else:
for key, value in total_influence.items():
assert value.shape[0] == len(src_ids)
assert value.shape[1] == len(tgt_ids)
total_influence[key] = value.cpu()

result["src_ids"] = src_ids
result["tgt_ids"] = tgt_ids
result["influence"] = total_influence.cpu()
result["influence"] = total_influence

return result

Expand All @@ -142,6 +180,8 @@ def compute_self_influence(
self,
src_log: Tuple[str, Dict[str, Dict[str, torch.Tensor]]],
precondition: Optional[bool] = True,
hessian: Optional[str] = "auto",
influence_groups: Optional[List[str]] = None,
damping: Optional[float] = None,
):
"""
Expand All @@ -160,18 +200,44 @@ def compute_self_influence(
src = unflatten_log(
log=src, path=self._state.get_state("model_module")["path"]
)
tgt = self.precondition(src_log, damping)[1] if precondition else src

# Compute self-influence scores
tgt = src
if precondition:
tgt = self.precondition(src_log, hessian=hessian, damping=damping)[1]

# Initialize influence
total_influence = 0
if influence_groups is not None:
total_influence = {"total": 0}
for influence_group in influence_groups:
total_influence[influence_group] = 0

# Compute self-influence scores
for module_name in src.keys():
src_module = src[module_name]["grad"]
tgt_module = tgt[module_name]["grad"] if tgt is not None else src_module
module_influence = reduce(src_module * tgt_module, "n a b -> n", "sum")
total_influence += module_influence.reshape(-1)
module_influence = reduce(
src_module * tgt_module, "n a b -> n", "sum"
).reshape(-1)
if influence_groups is None:
total_influence += module_influence
else:
total_influence["total"] += module_influence
in_groups = [ig for ig in influence_groups if ig in module_name]
for group in in_groups:
total_influence[group] += module_influence

# Move influence scores to CPU to save memory
if influence_groups is not None:
assert len(total_influence) == len(src_ids)
total_influence = total_influence.cpu()
else:
for key, value in total_influence.items():
assert len(value) == len(src_ids)
total_influence[key] = value.cpu()

result["src_ids"] = src_ids
result["influence"] = total_influence.cpu()
result["influence"] = total_influence

return result

Expand All @@ -181,6 +247,8 @@ def compute_influence_all(
loader: torch.utils.data.DataLoader,
mode: Optional[str] = "dot",
precondition: Optional[bool] = True,
hessian: Optional[str] = "auto",
influence_groups: Optional[List[str]] = None,
damping: Optional[float] = None,
):
"""
Expand All @@ -195,21 +263,34 @@ def compute_influence_all(
damping (Optional[float], optional): Damping parameter for preconditioning. Defaults to None.
"""
if precondition:
src_log = self.precondition(src_log, damping)
src_log = self.precondition(src_log, hessian=hessian, damping=damping)

result_all = None
for tgt_log in tqdm(loader, desc="Compute IF"):
result = self.compute_influence(
src_log, tgt_log, mode=mode, precondition=False, damping=damping
src_log,
tgt_log,
mode=mode,
precondition=False,
hessian=hessian,
influence_groups=influence_groups,
damping=damping,
)

# Merge results
if result_all is None:
result_all = result
else:
result_all["tgt_ids"].extend(result["tgt_ids"])
continue
result_all["tgt_ids"].extend(result["tgt_ids"])
if influence_groups is None:
result_all["influence"] = torch.cat(
[result_all["influence"], result["influence"]], dim=1
)
else:
for key in result_all["influence"].keys():
result_all["influence"][key] = torch.cat(
[result_all["influence"][key], result["influence"][key]],
dim=1,
)

return result_all

0 comments on commit 4410152

Please sign in to comment.