Skip to content

Commit

Permalink
Bump minimum TorchAO version to 0.7.0 (#10293)
Browse files Browse the repository at this point in the history
* bump min torchao version to 0.7.0

* update
  • Loading branch information
a-r-r-o-w authored Dec 23, 2024
1 parent 3c2e2aa commit ffc0eaa
Show file tree
Hide file tree
Showing 3 changed files with 52 additions and 51 deletions.
5 changes: 5 additions & 0 deletions src/diffusers/quantizers/torchao/torchao_quantizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,11 @@ def validate_environment(self, *args, **kwargs):
raise ImportError(
"Loading a TorchAO quantized model requires the torchao library. Please install with `pip install torchao`"
)
torchao_version = version.parse(importlib.metadata.version("torch"))
if torchao_version < version.parse("0.7.0"):
raise RuntimeError(
f"The minimum required version of `torchao` is 0.7.0, but the current version is {torchao_version}. Please upgrade with `pip install -U torchao`."
)

self.offload = False

Expand Down
4 changes: 2 additions & 2 deletions src/diffusers/utils/testing_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -490,11 +490,11 @@ def decorator(test_case):
return decorator


def require_torchao_version_greater(torchao_version):
def require_torchao_version_greater_or_equal(torchao_version):
def decorator(test_case):
correct_torchao_version = is_torchao_available() and version.parse(
version.parse(importlib.metadata.version("torchao")).base_version
) > version.parse(torchao_version)
) >= version.parse(torchao_version)
return unittest.skipUnless(
correct_torchao_version, f"Test requires torchao with version greater than {torchao_version}."
)(test_case)
Expand Down
94 changes: 45 additions & 49 deletions tests/quantization/torchao/test_torchao.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@
nightly,
require_torch,
require_torch_gpu,
require_torchao_version_greater,
require_torchao_version_greater_or_equal,
slow,
torch_device,
)
Expand Down Expand Up @@ -74,13 +74,13 @@ def forward(self, input, *args, **kwargs):

if is_torchao_available():
from torchao.dtypes import AffineQuantizedTensor
from torchao.dtypes.affine_quantized_tensor import TensorCoreTiledLayoutType
from torchao.quantization.linear_activation_quantized_tensor import LinearActivationQuantizedTensor
from torchao.utils import get_model_size_in_bytes


@require_torch
@require_torch_gpu
@require_torchao_version_greater("0.6.0")
@require_torchao_version_greater_or_equal("0.7.0")
class TorchAoConfigTest(unittest.TestCase):
def test_to_dict(self):
"""
Expand Down Expand Up @@ -125,7 +125,7 @@ def test_repr(self):
# Slices for these tests have been obtained on our aws-g6e-xlarge-plus runners
@require_torch
@require_torch_gpu
@require_torchao_version_greater("0.6.0")
@require_torchao_version_greater_or_equal("0.7.0")
class TorchAoTest(unittest.TestCase):
def tearDown(self):
gc.collect()
Expand All @@ -139,11 +139,13 @@ def get_dummy_components(self, quantization_config: TorchAoConfig):
quantization_config=quantization_config,
torch_dtype=torch.bfloat16,
)
text_encoder = CLIPTextModel.from_pretrained(model_id, subfolder="text_encoder")
text_encoder_2 = T5EncoderModel.from_pretrained(model_id, subfolder="text_encoder_2")
text_encoder = CLIPTextModel.from_pretrained(model_id, subfolder="text_encoder", torch_dtype=torch.bfloat16)
text_encoder_2 = T5EncoderModel.from_pretrained(
model_id, subfolder="text_encoder_2", torch_dtype=torch.bfloat16
)
tokenizer = CLIPTokenizer.from_pretrained(model_id, subfolder="tokenizer")
tokenizer_2 = AutoTokenizer.from_pretrained(model_id, subfolder="tokenizer_2")
vae = AutoencoderKL.from_pretrained(model_id, subfolder="vae")
vae = AutoencoderKL.from_pretrained(model_id, subfolder="vae", torch_dtype=torch.bfloat16)
scheduler = FlowMatchEulerDiscreteScheduler()

return {
Expand Down Expand Up @@ -212,7 +214,7 @@ def get_dummy_tensor_inputs(self, device=None, seed: int = 0):
def _test_quant_type(self, quantization_config: TorchAoConfig, expected_slice: List[float]):
components = self.get_dummy_components(quantization_config)
pipe = FluxPipeline(**components)
pipe.to(device=torch_device, dtype=torch.bfloat16)
pipe.to(device=torch_device)

inputs = self.get_dummy_inputs(torch_device)
output = pipe(**inputs)[0]
Expand Down Expand Up @@ -276,7 +278,6 @@ def test_int4wo_quant_bfloat16_conversion(self):
self.assertTrue(isinstance(weight, AffineQuantizedTensor))
self.assertEqual(weight.quant_min, 0)
self.assertEqual(weight.quant_max, 15)
self.assertTrue(isinstance(weight.layout_type, TensorCoreTiledLayoutType))

def test_device_map(self):
"""
Expand Down Expand Up @@ -341,21 +342,33 @@ def test_device_map(self):

def test_modules_to_not_convert(self):
quantization_config = TorchAoConfig("int8_weight_only", modules_to_not_convert=["transformer_blocks.0"])
quantized_model = FluxTransformer2DModel.from_pretrained(
quantized_model_with_not_convert = FluxTransformer2DModel.from_pretrained(
"hf-internal-testing/tiny-flux-pipe",
subfolder="transformer",
quantization_config=quantization_config,
torch_dtype=torch.bfloat16,
)

unquantized_layer = quantized_model.transformer_blocks[0].ff.net[2]
unquantized_layer = quantized_model_with_not_convert.transformer_blocks[0].ff.net[2]
self.assertTrue(isinstance(unquantized_layer, torch.nn.Linear))
self.assertFalse(isinstance(unquantized_layer.weight, AffineQuantizedTensor))
self.assertEqual(unquantized_layer.weight.dtype, torch.bfloat16)

quantized_layer = quantized_model.proj_out
quantized_layer = quantized_model_with_not_convert.proj_out
self.assertTrue(isinstance(quantized_layer.weight, AffineQuantizedTensor))
self.assertEqual(quantized_layer.weight.layout_tensor.data.dtype, torch.int8)

quantization_config = TorchAoConfig("int8_weight_only")
quantized_model = FluxTransformer2DModel.from_pretrained(
"hf-internal-testing/tiny-flux-pipe",
subfolder="transformer",
quantization_config=quantization_config,
torch_dtype=torch.bfloat16,
)

size_quantized_with_not_convert = get_model_size_in_bytes(quantized_model_with_not_convert)
size_quantized = get_model_size_in_bytes(quantized_model)

self.assertTrue(size_quantized < size_quantized_with_not_convert)

def test_training(self):
quantization_config = TorchAoConfig("int8_weight_only")
Expand Down Expand Up @@ -406,23 +419,6 @@ def test_torch_compile(self):
# Note: Seems to require higher tolerance
self.assertTrue(np.allclose(normal_output, compile_output, atol=1e-2, rtol=1e-3))

@staticmethod
def _get_memory_footprint(module):
quantized_param_memory = 0.0
unquantized_param_memory = 0.0

for param in module.parameters():
if param.__class__.__name__ == "AffineQuantizedTensor":
data, scale, zero_point = param.layout_tensor.get_plain()
quantized_param_memory += data.numel() + data.element_size()
quantized_param_memory += scale.numel() + scale.element_size()
quantized_param_memory += zero_point.numel() + zero_point.element_size()
else:
unquantized_param_memory += param.data.numel() * param.data.element_size()

total_memory = quantized_param_memory + unquantized_param_memory
return total_memory, quantized_param_memory, unquantized_param_memory

def test_memory_footprint(self):
r"""
A simple test to check if the model conversion has been done correctly by checking on the
Expand All @@ -433,20 +429,18 @@ def test_memory_footprint(self):
transformer_int8wo = self.get_dummy_components(TorchAoConfig("int8wo"))["transformer"]
transformer_bf16 = self.get_dummy_components(None)["transformer"]

total_int4wo, quantized_int4wo, unquantized_int4wo = self._get_memory_footprint(transformer_int4wo)
total_int4wo_gs32, quantized_int4wo_gs32, unquantized_int4wo_gs32 = self._get_memory_footprint(
transformer_int4wo_gs32
)
total_int8wo, quantized_int8wo, unquantized_int8wo = self._get_memory_footprint(transformer_int8wo)
total_bf16, quantized_bf16, unquantized_bf16 = self._get_memory_footprint(transformer_bf16)

self.assertTrue(quantized_bf16 == 0 and total_bf16 == unquantized_bf16)
# int4wo_gs32 has smaller group size, so more groups -> more scales and zero points
self.assertTrue(total_int8wo < total_bf16 < total_int4wo_gs32)
# int4 with default group size quantized very few linear layers compared to a smaller group size of 32
self.assertTrue(quantized_int4wo < quantized_int4wo_gs32 and unquantized_int4wo > unquantized_int4wo_gs32)
total_int4wo = get_model_size_in_bytes(transformer_int4wo)
total_int4wo_gs32 = get_model_size_in_bytes(transformer_int4wo_gs32)
total_int8wo = get_model_size_in_bytes(transformer_int8wo)
total_bf16 = get_model_size_in_bytes(transformer_bf16)

# Latter has smaller group size, so more groups -> more scales and zero points
self.assertTrue(total_int4wo < total_int4wo_gs32)
# int8 quantizes more layers compare to int4 with default group size
self.assertTrue(quantized_int8wo < quantized_int4wo)
self.assertTrue(total_int8wo < total_int4wo)
# int4wo does not quantize too many layers because of default group size, but for the layers it does
# there is additional overhead of scales and zero points
self.assertTrue(total_bf16 < total_int4wo)

def test_wrong_config(self):
with self.assertRaises(ValueError):
Expand All @@ -456,7 +450,7 @@ def test_wrong_config(self):
# This class is not to be run as a test by itself. See the tests that follow this class
@require_torch
@require_torch_gpu
@require_torchao_version_greater("0.6.0")
@require_torchao_version_greater_or_equal("0.7.0")
class TorchAoSerializationTest(unittest.TestCase):
model_name = "hf-internal-testing/tiny-flux-pipe"
quant_method, quant_method_kwargs = None, None
Expand Down Expand Up @@ -565,7 +559,7 @@ class TorchAoSerializationINTA16W8CPUTest(TorchAoSerializationTest):
# Slices for these tests have been obtained on our aws-g6e-xlarge-plus runners
@require_torch
@require_torch_gpu
@require_torchao_version_greater("0.6.0")
@require_torchao_version_greater_or_equal("0.7.0")
@slow
@nightly
class SlowTorchAoTests(unittest.TestCase):
Expand All @@ -581,11 +575,13 @@ def get_dummy_components(self, quantization_config: TorchAoConfig):
quantization_config=quantization_config,
torch_dtype=torch.bfloat16,
)
text_encoder = CLIPTextModel.from_pretrained(model_id, subfolder="text_encoder")
text_encoder_2 = T5EncoderModel.from_pretrained(model_id, subfolder="text_encoder_2")
text_encoder = CLIPTextModel.from_pretrained(model_id, subfolder="text_encoder", torch_dtype=torch.bfloat16)
text_encoder_2 = T5EncoderModel.from_pretrained(
model_id, subfolder="text_encoder_2", torch_dtype=torch.bfloat16
)
tokenizer = CLIPTokenizer.from_pretrained(model_id, subfolder="tokenizer")
tokenizer_2 = AutoTokenizer.from_pretrained(model_id, subfolder="tokenizer_2")
vae = AutoencoderKL.from_pretrained(model_id, subfolder="vae")
vae = AutoencoderKL.from_pretrained(model_id, subfolder="vae", torch_dtype=torch.bfloat16)
scheduler = FlowMatchEulerDiscreteScheduler()

return {
Expand Down Expand Up @@ -617,7 +613,7 @@ def get_dummy_inputs(self, device: torch.device, seed: int = 0):

def _test_quant_type(self, quantization_config, expected_slice):
components = self.get_dummy_components(quantization_config)
pipe = FluxPipeline(**components).to(dtype=torch.bfloat16)
pipe = FluxPipeline(**components)
pipe.enable_model_cpu_offload()

inputs = self.get_dummy_inputs(torch_device)
Expand Down

0 comments on commit ffc0eaa

Please sign in to comment.