Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
a-r-r-o-w committed Dec 24, 2024
1 parent 87bb2fe commit ba1269d
Showing 1 changed file with 8 additions and 4 deletions.
12 changes: 8 additions & 4 deletions tests/quantization/torchao/test_torchao.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,9 @@ def tearDown(self):
gc.collect()
torch.cuda.empty_cache()

def get_dummy_components(self, quantization_config: TorchAoConfig, model_id: str = "hf-internal-testing/tiny-flux-pipe"):
def get_dummy_components(
self, quantization_config: TorchAoConfig, model_id: str = "hf-internal-testing/tiny-flux-pipe"
):
transformer = FluxTransformer2DModel.from_pretrained(
model_id,
subfolder="transformer",
Expand Down Expand Up @@ -436,7 +438,9 @@ def test_memory_footprint(self):
"""
for model_id in ["hf-internal-testing/tiny-flux-pipe", "hf-internal-testing/tiny-flux-sharded"]:
transformer_int4wo = self.get_dummy_components(TorchAoConfig("int4wo"), model_id=model_id)["transformer"]
transformer_int4wo_gs32 = self.get_dummy_components(TorchAoConfig("int4wo", group_size=32), model_id=model_id)["transformer"]
transformer_int4wo_gs32 = self.get_dummy_components(
TorchAoConfig("int4wo", group_size=32), model_id=model_id
)["transformer"]
transformer_int8wo = self.get_dummy_components(TorchAoConfig("int8wo"), model_id=model_id)["transformer"]
transformer_bf16 = self.get_dummy_components(None, model_id=model_id)["transformer"]

Expand Down Expand Up @@ -654,7 +658,7 @@ def test_quantization(self):
gc.collect()
torch.cuda.empty_cache()
torch.cuda.synchronize()

def test_serialization(self):
quantization_config = TorchAoConfig("int8wo")
components = self.get_dummy_components(quantization_config)
Expand All @@ -673,6 +677,6 @@ def test_serialization(self):

weight = loaded_pipe.transformer.x_embedder.weight
self.assertTrue(isinstance(weight, AffineQuantizedTensor))

loaded_output = loaded_pipe(**inputs)[0].flatten()
self.assertTrue(np.allclose(output, loaded_output, atol=1e-3, rtol=1e-3))

0 comments on commit ba1269d

Please sign in to comment.