Skip to content

Commit

Permalink
Apply PERF401 autofixes from ruff (#140980)
Browse files Browse the repository at this point in the history
Summary:
* Automatically applies ruff rule 401. Turns loops into equivalent list comprehensions which are faster and do not leak the scope of the loop variables.
* list comprehensions not only often have better typing, but are 50+% faster than for loops on overhead. They also preserve length information etc and are better for the interpreter to optimize.
* Manually went back and made mypy happy after the change.
* Also fixed style lints in files covered by flake8 but not by pyfmt

X-link: pytorch/pytorch#140980
Approved by: https://github.com/justinchuby, https://github.com/malfet

Reviewed By: izaitsevfb

Differential Revision: D66262948

fbshipit-source-id: 4d871761b25633da20bca8d1f37b9842144f2218
  • Loading branch information
Skylion007 authored and facebook-github-bot committed Nov 21, 2024
1 parent 598f5f9 commit 1de25b5
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 6 deletions.
4 changes: 1 addition & 3 deletions userbenchmark/dynamo/dynamobench/_dynamo/testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,9 +95,7 @@ def collect_results(
results.append(buffers)
for example in example_inputs:
if isinstance(example, (tuple, list)):
for inp in example:
if isinstance(inp, torch.Tensor):
results.append(inp.grad)
results.extend(inp.grad for inp in example if isinstance(inp, torch.Tensor))
else:
if isinstance(example, torch.Tensor):
results.append(example.grad)
Expand Down
7 changes: 4 additions & 3 deletions userbenchmark/dynamo/dynamobench/_dynamo/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1535,9 +1535,10 @@ def checkpoint_params(gm):
rng_state = torch.clone(torch.random.get_rng_state())
if torch.cuda.is_available():
cuda_rng_state = torch.clone(torch.cuda.get_rng_state())
saved_state = []
for param in itertools.chain(gm.parameters(), gm.buffers()):
saved_state.append((param, param._version, torch.clone(param)))
saved_state = [
(param, param._version, torch.clone(param))
for param in itertools.chain(gm.parameters(), gm.buffers())
]

def restore():
with torch.no_grad():
Expand Down

0 comments on commit 1de25b5

Please sign in to comment.