Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support module-wise influence computations #99

Merged
merged 7 commits into from
Apr 22, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
119 changes: 100 additions & 19 deletions logix/analysis/influence_function.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,8 @@ def compute_influence(
tgt_log: Tuple[str, Dict[str, Dict[str, torch.Tensor]]],
mode: Optional[str] = "dot",
precondition: Optional[bool] = True,
precondition_hessian: Optional[str] = "auto",
hessian: Optional[str] = "auto",
influence_groups: Optional[List[str]] = None,
damping: Optional[float] = None,
):
"""
Expand All @@ -92,12 +93,18 @@ def compute_influence(
result = {}
if precondition:
src_log = self.precondition(
src_log=src_log, damping=damping, hessian=precondition_hessian
src_log=src_log, damping=damping, hessian=hessian
)

src_ids, src = src_log
tgt_ids, tgt = tgt_log

# Initialize influence
total_influence = 0
if influence_groups is not None:
total_influence = {"total": 0}
for influence_group in influence_groups:
total_influence[influence_group] = 0

# Compute influence scores. By default, we should compute the basic influence
# scores, which is essentially the inner product between the source and target
Expand All @@ -113,27 +120,58 @@ def compute_influence(
else:
synchronize_device(src, tgt)
for module_name in src.keys():
total_influence += cross_dot_product(
module_influence = cross_dot_product(
src[module_name]["grad"], tgt[module_name]["grad"]
)
if influence_groups is None:
total_influence += module_influence
else:
total_influence["total"] += module_influence
in_groups = [ig for ig in influence_groups if ig in module_name]
for group in in_groups:
total_influence[group] += module_influence

if mode == "cosine":
tgt_norm = self.compute_self_influence(
tgt_log, precondition=True, damping=damping
tgt_log,
precondition=True,
hessian=hessian,
influence_groups=influence_groups,
damping=damping,
)
total_influence /= torch.sqrt(tgt_norm.unsqueeze(0))
if influence_groups is None:
total_influence /= torch.sqrt(tgt_norm.unsqueeze(0))
else:
for key in total_influence.keys():
total_influence[key] /= tgt_norm[key]
elif mode == "l2":
tgt_norm = self.compute_self_influence(
tgt_log, precondition=True, damping=damping
tgt_log,
precondition=True,
hessian=hessian,
influence_groups=influence_groups,
damping=damping,
)
total_influence = 2 * total_influence - tgt_norm.unsqueeze(0)
if influence_groups is None:
total_influence -= 0.5 * tgt_norm.unsqueeze(0)
else:
for key in total_influence.keys():
total_influence[key] -= 0.5 * tgt_norm[key].unsqueeze(0)

assert total_influence.shape[0] == len(src_ids)
assert total_influence.shape[1] == len(tgt_ids)
# Move influence scores to CPU to save memory
if influence_groups is None:
assert total_influence.shape[0] == len(src_ids)
assert total_influence.shape[1] == len(tgt_ids)
total_influence = total_influence.cpu()
else:
for key, value in total_influence.items():
assert value.shape[0] == len(src_ids)
assert value.shape[1] == len(tgt_ids)
total_influence[key] = value.cpu()

result["src_ids"] = src_ids
result["tgt_ids"] = tgt_ids
result["influence"] = total_influence.cpu()
result["influence"] = total_influence

return result

Expand All @@ -142,6 +180,8 @@ def compute_self_influence(
self,
src_log: Tuple[str, Dict[str, Dict[str, torch.Tensor]]],
precondition: Optional[bool] = True,
hessian: Optional[str] = "auto",
influence_groups: Optional[List[str]] = None,
damping: Optional[float] = None,
):
"""
Expand All @@ -160,18 +200,44 @@ def compute_self_influence(
src = unflatten_log(
log=src, path=self._state.get_state("model_module")["path"]
)
tgt = self.precondition(src_log, damping)[1] if precondition else src

# Compute self-influence scores
tgt = src
if precondition:
tgt = self.precondition(src_log, hessian=hessian, damping=damping)[1]

# Initialize influence
total_influence = 0
if influence_groups is not None:
total_influence = {"total": 0}
for influence_group in influence_groups:
total_influence[influence_group] = 0

# Compute self-influence scores
for module_name in src.keys():
src_module = src[module_name]["grad"]
tgt_module = tgt[module_name]["grad"] if tgt is not None else src_module
module_influence = reduce(src_module * tgt_module, "n a b -> n", "sum")
total_influence += module_influence.reshape(-1)
module_influence = reduce(
src_module * tgt_module, "n a b -> n", "sum"
).reshape(-1)
if influence_groups is None:
total_influence += module_influence
else:
total_influence["total"] += module_influence
in_groups = [ig for ig in influence_groups if ig in module_name]
for group in in_groups:
total_influence[group] += module_influence

# Move influence scores to CPU to save memory
if influence_groups is not None:
assert len(total_influence) == len(src_ids)
total_influence = total_influence.cpu()
else:
for key, value in total_influence.items():
assert len(value) == len(src_ids)
total_influence[key] = value.cpu()

result["src_ids"] = src_ids
result["influence"] = total_influence.cpu()
result["influence"] = total_influence

return result

Expand All @@ -181,6 +247,8 @@ def compute_influence_all(
loader: torch.utils.data.DataLoader,
mode: Optional[str] = "dot",
precondition: Optional[bool] = True,
hessian: Optional[str] = "auto",
influence_groups: Optional[List[str]] = None,
damping: Optional[float] = None,
):
"""
Expand All @@ -195,21 +263,34 @@ def compute_influence_all(
damping (Optional[float], optional): Damping parameter for preconditioning. Defaults to None.
"""
if precondition:
src_log = self.precondition(src_log, damping)
src_log = self.precondition(src_log, hessian=hessian, damping=damping)

result_all = None
for tgt_log in tqdm(loader, desc="Compute IF"):
result = self.compute_influence(
src_log, tgt_log, mode=mode, precondition=False, damping=damping
src_log,
tgt_log,
mode=mode,
precondition=False,
hessian=hessian,
influence_groups=influence_groups,
damping=damping,
)

# Merge results
if result_all is None:
result_all = result
else:
result_all["tgt_ids"].extend(result["tgt_ids"])
continue
result_all["tgt_ids"].extend(result["tgt_ids"])
if influence_groups is None:
result_all["influence"] = torch.cat(
[result_all["influence"], result["influence"]], dim=1
)
else:
for key in result_all["influence"].keys():
result_all["influence"][key] = torch.cat(
[result_all["influence"][key], result["influence"][key]],
dim=1,
)

return result_all
Loading