Skip to content

Commit

Permalink
add log=grad as default option to eval
Browse files Browse the repository at this point in the history
  • Loading branch information
hage1005 committed Jan 5, 2024
1 parent a2c882e commit d5f4779
Show file tree
Hide file tree
Showing 3 changed files with 7 additions and 4 deletions.
2 changes: 0 additions & 2 deletions analog/analog.py
Original file line number Diff line number Diff line change
Expand Up @@ -324,8 +324,6 @@ def initialize_from_log(self) -> None:
lora_state = torch.load(os.path.join(lora_dir, "lora_state_dict.pt"))
if not is_lora(self.model):
self.add_lora(lora_state=lora_state)
if not any("analog_lora_A" in name for name in self.model.state_dict()):
self.add_lora(lora_state=lora_state)
for name in lora_state:
assert name in self.model.state_dict(), f"{name} not in model!"
self.model.load_state_dict(lora_state, strict=False)
Expand Down
7 changes: 6 additions & 1 deletion analog/logging/option.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,11 +97,16 @@ def _sanity_check(self):
)
self._log["grad"] = True

def eval(self):
def eval(self, log="grad"):
"""
Enable the evaluation mode. This will turn of saving and updating
statistic.
"""
if isinstance(log, str):
self._log[log] = True
else:
raise ValueError(f"Unsupported log type for eval: {type(log)}")

self.clear(log=False, save=True, statistic=True)

def clear(self, log=True, save=True, statistic=True):
Expand Down
2 changes: 1 addition & 1 deletion examples/mnist_influence/compute_influences_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,6 @@

# Save
if_scores = if_scores.numpy().tolist()[0]
torch.save(if_scores, "if_analog_scheduler_init_from_log_0.8.pt")
torch.save(if_scores, "examples/mnist_influence/if_analog_scheduler.pt")
print("Computation time:", time.time() - start)
print("Top influential data indices:", top_influential_data.numpy().tolist())

0 comments on commit d5f4779

Please sign in to comment.