Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update logging setup interface #106

Merged
merged 1 commit into from
Jun 3, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion examples/bert/compute_influence.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ def main():
log_loader = logix.build_log_dataloader()

# influence analysis
logix.setup({"log": "grad"})
logix.setup({"grad": ["log"]})
logix.eval()
for batch in test_loader:
data_id = tokenizer.batch_decode(batch["input_ids"])
Expand Down
2 changes: 1 addition & 1 deletion examples/cifar/compute_influences.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@
log_loader = logix.build_log_dataloader()

logix.eval()
logix.setup({"log": "grad"})
logix.setup({"grad": ["log"]})
for test_input, test_target in test_loader:
with logix(data_id=id_gen(test_input)):
test_input, test_target = test_input.to(DEVICE), test_target.to(DEVICE)
Expand Down
2 changes: 1 addition & 1 deletion examples/language_modeling/compute_influence.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ def main():
log_loader = logix.build_log_dataloader(batch_size=64)

# Influence analysis
logix.setup({"log": "grad"})
logix.setup({"grad": ["log"]})
logix.eval()
merged_test_logs = []
for idx, batch in enumerate(tqdm(data_loader)):
Expand Down
2 changes: 1 addition & 1 deletion examples/mnist/compute_influences.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@
)

# logix.add_analysis({"influence": InfluenceFunction})
logix.setup({"log": "grad"})
logix.setup({"grad": ["log"]})
logix.eval()
for test_input, test_target in test_loader:
with logix(data_id=id_gen(test_input)):
Expand Down
2 changes: 1 addition & 1 deletion examples/mnist/compute_influences_manual.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@
)

# logix.add_analysis({"influence": InfluenceFunction})
logix.setup({"log": "grad"})
logix.setup({"grad": ["log"]})
logix.eval()
for test_input, test_target in test_loader:
### Start
Expand Down
2 changes: 1 addition & 1 deletion logix/huggingface/callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ def on_train_begin(self, args, state, control, **kwargs):
self.logix.initialize_from_log()

if self.args.mode in ["influence", "self_influence"]:
self.logix.setup({"log": "grad"})
self.logix.setup({"grad": ["log"]})
self.logix.eval()

state.epoch = 0
Expand Down
72 changes: 28 additions & 44 deletions logix/logging/logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from logix.batch_info import BatchInfo
from logix.config import LoggingConfig
from logix.state import LogIXState
from logix.statistic import Log
from logix.logging.option import LogOption
from logix.logging.log_saver import LogSaver
from logix.logging.utils import compute_per_sample_gradient
Expand Down Expand Up @@ -61,13 +62,12 @@ def log(self, data_id: Any, mask: Optional[torch.Tensor] = None):

def save_log(self):
# save log to disk
if any(self.opt.save.values()):
self.log_saver.buffer_write(binfo=self.binfo)
self.log_saver.flush()
self.log_saver.buffer_write(binfo=self.binfo)
self.log_saver.flush()

def update(self):
# Update statistics
for stat in self.opt.statistic["grad"]:
def update(self, save: bool = False):
# gradient plugin has to be excecuted after accumulating all gradients
for stat in self.opt.grad[1:]:
for module_name, _ in self.binfo.log.items():
stat.update(
state=self.state,
Expand All @@ -84,7 +84,8 @@ def update(self):
torch.cuda.current_stream().synchronize()

# Write and flush the buffer if necessary
self.save_log()
if save:
self.save_log()

def _forward_hook_fn(
self, module: nn.Module, inputs: Tuple[torch.Tensor], module_name: str
Expand All @@ -100,7 +101,6 @@ def _forward_hook_fn(
assert len(inputs) == 1

activations = inputs[0]
log = self.binfo.log[module_name]

# If `mask` is not None, apply the mask to activations. This is
# useful for example when you work with sequence models that use
Expand All @@ -118,14 +118,8 @@ def _forward_hook_fn(
if self.dtype is not None:
activations = activations.to(dtype=self.dtype)

if self.opt.log["forward"]:
if "forward" not in log:
log["forward"] = activations
else:
log["forward"] += activations

for stat in self.opt.statistic["forward"]:
stat.update(
for plugin in self.opt.forward:
plugin.update(
state=self.state,
binfo=self.binfo,
module=module,
Expand Down Expand Up @@ -154,19 +148,12 @@ def _backward_hook_fn(
assert len(grad_outputs) == 1

error = grad_outputs[0]
log = self.binfo.log[module_name]

if self.dtype is not None:
error = error.to(dtype=self.dtype)

if self.opt.log["backward"]:
if "backward" not in log:
log["backward"] = error
else:
log["backward"] += error

for stat in self.opt.statistic["backward"]:
stat.update(
for plugin in self.opt.backward:
plugin.update(
state=self.state,
binfo=self.binfo,
module=module,
Expand Down Expand Up @@ -194,24 +181,29 @@ def _grad_hook_fn(
"""
assert len(inputs) == 1

log = self.binfo.log[module_name]

# In case, the same module is used multiple times in the forward pass,
# we need to accumulate the gradients. We achieve this by using the
# additional tensor hook on the output of the module.
def _grad_backward_hook_fn(grad: torch.Tensor):
if self.opt.log["grad"]:
if len(self.opt.grad) > 0:
assert self.opt.grad[0] == Log
per_sample_gradient = compute_per_sample_gradient(
inputs[0], grad, module
)

if self.dtype is not None:
per_sample_gradient = per_sample_gradient.to(dtype=self.dtype)

if "grad" not in log:
log["grad"] = per_sample_gradient
else:
log["grad"] += per_sample_gradient
for plugin in self.opt.grad[:1]:
plugin.update(
state=self.state,
binfo=self.binfo,
module=module,
module_name=module_name,
log_type="grad",
data=per_sample_gradient,
cpu_offload=self.cpu_offload,
)

tensor_hook = outputs.register_hook(_grad_backward_hook_fn)
self.tensor_hooks.append(tensor_hook)
Expand All @@ -227,15 +219,11 @@ def _tensor_forward_hook_fn(self, tensor: torch.Tensor, tensor_name: str) -> Non
tensor: The tensor triggering the hook.
tensor_name (str): A string identifier for the tensor, useful for logging.
"""
log = self.binfo.log[tensor_name]

if self.dtype is not None:
tensor = tensor.to(dtype=self.dtype)

log["forward"] = tensor

for stat in self.opt.statistic["forward"]:
stat.update(
for plugin in self.opt.forward:
plugin.update(
state=self.state,
binfo=self.binfo,
module=None,
Expand All @@ -256,15 +244,11 @@ def _tensor_backward_hook_fn(self, grad: torch.Tensor, tensor_name: str) -> None
grad: The gradient tensor triggering the hook.
tensor_name (str): A string identifier for the tensor whose gradient is being tracked.
"""
log = self.binfo.log[tensor_name]

if self.dtype is not None:
grad = grad.to(dtype=self.dtype)

log["backward"] = grad

for stat in self.opt.statistic["backward"]:
stat.update(
for plugin in self.opt.backward:
plugin.update(
state=self.state,
binfo=self.binfo,
module=None,
Expand Down
Loading
Loading