| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| import inspect |
| import tempfile |
| import traceback |
| import unittest |
| import unittest.mock as mock |
| from typing import Dict, List, Tuple |
|
|
| import numpy as np |
| import requests_mock |
| import torch |
| from requests.exceptions import HTTPError |
|
|
| from diffusers.models import UNet2DConditionModel |
| from diffusers.models.attention_processor import AttnProcessor, AttnProcessor2_0, XFormersAttnProcessor |
| from diffusers.training_utils import EMAModel |
| from diffusers.utils import logging, torch_device |
| from diffusers.utils.testing_utils import CaptureLogger, require_torch_2, require_torch_gpu, run_test_in_subprocess |
|
|
|
|
| |
| def _test_from_save_pretrained_dynamo(in_queue, out_queue, timeout): |
| error = None |
| try: |
| init_dict, model_class = in_queue.get(timeout=timeout) |
|
|
| model = model_class(**init_dict) |
| model.to(torch_device) |
| model = torch.compile(model) |
|
|
| with tempfile.TemporaryDirectory() as tmpdirname: |
| model.save_pretrained(tmpdirname) |
| new_model = model_class.from_pretrained(tmpdirname) |
| new_model.to(torch_device) |
|
|
| assert new_model.__class__ == model_class |
| except Exception: |
| error = f"{traceback.format_exc()}" |
|
|
| results = {"error": error} |
| out_queue.put(results, timeout=timeout) |
| out_queue.join() |
|
|
|
|
| class ModelUtilsTest(unittest.TestCase): |
| def tearDown(self): |
| super().tearDown() |
|
|
| import diffusers |
|
|
| diffusers.utils.import_utils._safetensors_available = True |
|
|
| def test_accelerate_loading_error_message(self): |
| with self.assertRaises(ValueError) as error_context: |
| UNet2DConditionModel.from_pretrained("hf-internal-testing/stable-diffusion-broken", subfolder="unet") |
|
|
| |
| assert "conv_out.bias" in str(error_context.exception) |
|
|
| def test_cached_files_are_used_when_no_internet(self): |
| |
| response_mock = mock.Mock() |
| response_mock.status_code = 500 |
| response_mock.headers = {} |
| response_mock.raise_for_status.side_effect = HTTPError |
| response_mock.json.return_value = {} |
|
|
| |
| orig_model = UNet2DConditionModel.from_pretrained( |
| "hf-internal-testing/tiny-stable-diffusion-torch", subfolder="unet" |
| ) |
|
|
| |
| with mock.patch("requests.request", return_value=response_mock): |
| |
| model = UNet2DConditionModel.from_pretrained( |
| "hf-internal-testing/tiny-stable-diffusion-torch", subfolder="unet", local_files_only=True |
| ) |
|
|
| for p1, p2 in zip(orig_model.parameters(), model.parameters()): |
| if p1.data.ne(p2.data).sum() > 0: |
| assert False, "Parameters not the same!" |
|
|
| def test_one_request_upon_cached(self): |
| |
| if torch_device == "mps": |
| return |
|
|
| import diffusers |
|
|
| diffusers.utils.import_utils._safetensors_available = False |
|
|
| with tempfile.TemporaryDirectory() as tmpdirname: |
| with requests_mock.mock(real_http=True) as m: |
| UNet2DConditionModel.from_pretrained( |
| "hf-internal-testing/tiny-stable-diffusion-torch", subfolder="unet", cache_dir=tmpdirname |
| ) |
|
|
| download_requests = [r.method for r in m.request_history] |
| assert download_requests.count("HEAD") == 2, "2 HEAD requests one for config, one for model" |
| assert download_requests.count("GET") == 2, "2 GET requests one for config, one for model" |
|
|
| with requests_mock.mock(real_http=True) as m: |
| UNet2DConditionModel.from_pretrained( |
| "hf-internal-testing/tiny-stable-diffusion-torch", subfolder="unet", cache_dir=tmpdirname |
| ) |
|
|
| cache_requests = [r.method for r in m.request_history] |
| assert ( |
| "HEAD" == cache_requests[0] and len(cache_requests) == 1 |
| ), "We should call only `model_info` to check for _commit hash and `send_telemetry`" |
|
|
| diffusers.utils.import_utils._safetensors_available = True |
|
|
| def test_weight_overwrite(self): |
| with tempfile.TemporaryDirectory() as tmpdirname, self.assertRaises(ValueError) as error_context: |
| UNet2DConditionModel.from_pretrained( |
| "hf-internal-testing/tiny-stable-diffusion-torch", |
| subfolder="unet", |
| cache_dir=tmpdirname, |
| in_channels=9, |
| ) |
|
|
| |
| assert "Cannot load" in str(error_context.exception) |
|
|
| with tempfile.TemporaryDirectory() as tmpdirname: |
| model = UNet2DConditionModel.from_pretrained( |
| "hf-internal-testing/tiny-stable-diffusion-torch", |
| subfolder="unet", |
| cache_dir=tmpdirname, |
| in_channels=9, |
| low_cpu_mem_usage=False, |
| ignore_mismatched_sizes=True, |
| ) |
|
|
| assert model.config.in_channels == 9 |
|
|
|
|
| class UNetTesterMixin: |
| def test_forward_signature(self): |
| init_dict, _ = self.prepare_init_args_and_inputs_for_common() |
|
|
| model = self.model_class(**init_dict) |
| signature = inspect.signature(model.forward) |
| |
| arg_names = [*signature.parameters.keys()] |
|
|
| expected_arg_names = ["sample", "timestep"] |
| self.assertListEqual(arg_names[:2], expected_arg_names) |
|
|
| def test_forward_with_norm_groups(self): |
| init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() |
|
|
| init_dict["norm_num_groups"] = 16 |
| init_dict["block_out_channels"] = (16, 32) |
|
|
| model = self.model_class(**init_dict) |
| model.to(torch_device) |
| model.eval() |
|
|
| with torch.no_grad(): |
| output = model(**inputs_dict) |
|
|
| if isinstance(output, dict): |
| output = output.to_tuple()[0] |
|
|
| self.assertIsNotNone(output) |
| expected_shape = inputs_dict["sample"].shape |
| self.assertEqual(output.shape, expected_shape, "Input and output shapes do not match") |
|
|
|
|
| class ModelTesterMixin: |
| main_input_name = None |
| base_precision = 1e-3 |
|
|
| def test_from_save_pretrained(self): |
| init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() |
|
|
| model = self.model_class(**init_dict) |
| if hasattr(model, "set_default_attn_processor"): |
| model.set_default_attn_processor() |
| model.to(torch_device) |
| model.eval() |
|
|
| with tempfile.TemporaryDirectory() as tmpdirname: |
| model.save_pretrained(tmpdirname) |
| new_model = self.model_class.from_pretrained(tmpdirname) |
| if hasattr(new_model, "set_default_attn_processor"): |
| new_model.set_default_attn_processor() |
| new_model.to(torch_device) |
|
|
| with torch.no_grad(): |
| image = model(**inputs_dict) |
| if isinstance(image, dict): |
| image = image.to_tuple()[0] |
|
|
| new_image = new_model(**inputs_dict) |
|
|
| if isinstance(new_image, dict): |
| new_image = new_image.to_tuple()[0] |
|
|
| max_diff = (image - new_image).abs().sum().item() |
| self.assertLessEqual(max_diff, 5e-5, "Models give different forward passes") |
|
|
| def test_getattr_is_correct(self): |
| init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() |
| model = self.model_class(**init_dict) |
|
|
| |
| model.dummy_attribute = 5 |
| model.register_to_config(test_attribute=5) |
|
|
| logger = logging.get_logger("diffusers.models.modeling_utils") |
| |
| logger.setLevel(30) |
| with CaptureLogger(logger) as cap_logger: |
| assert hasattr(model, "dummy_attribute") |
| assert getattr(model, "dummy_attribute") == 5 |
| assert model.dummy_attribute == 5 |
|
|
| |
| assert cap_logger.out == "" |
|
|
| logger = logging.get_logger("diffusers.models.modeling_utils") |
| |
| logger.setLevel(30) |
| with CaptureLogger(logger) as cap_logger: |
| assert hasattr(model, "save_pretrained") |
| fn = model.save_pretrained |
| fn_1 = getattr(model, "save_pretrained") |
|
|
| assert fn == fn_1 |
| |
| assert cap_logger.out == "" |
|
|
| |
| with self.assertWarns(FutureWarning): |
| assert model.test_attribute == 5 |
|
|
| with self.assertWarns(FutureWarning): |
| assert getattr(model, "test_attribute") == 5 |
|
|
| with self.assertRaises(AttributeError) as error: |
| model.does_not_exist |
|
|
| assert str(error.exception) == f"'{type(model).__name__}' object has no attribute 'does_not_exist'" |
|
|
| @require_torch_gpu |
| def test_set_attn_processor_for_determinism(self): |
| torch.use_deterministic_algorithms(False) |
| init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() |
| model = self.model_class(**init_dict) |
| model.to(torch_device) |
|
|
| if not hasattr(model, "set_attn_processor"): |
| |
| return |
|
|
| assert all(type(proc) == AttnProcessor2_0 for proc in model.attn_processors.values()) |
| with torch.no_grad(): |
| output_1 = model(**inputs_dict)[0] |
|
|
| model.set_default_attn_processor() |
| assert all(type(proc) == AttnProcessor for proc in model.attn_processors.values()) |
| with torch.no_grad(): |
| output_2 = model(**inputs_dict)[0] |
|
|
| model.enable_xformers_memory_efficient_attention() |
| assert all(type(proc) == XFormersAttnProcessor for proc in model.attn_processors.values()) |
| with torch.no_grad(): |
| output_3 = model(**inputs_dict)[0] |
|
|
| model.set_attn_processor(AttnProcessor2_0()) |
| assert all(type(proc) == AttnProcessor2_0 for proc in model.attn_processors.values()) |
| with torch.no_grad(): |
| output_4 = model(**inputs_dict)[0] |
|
|
| model.set_attn_processor(AttnProcessor()) |
| assert all(type(proc) == AttnProcessor for proc in model.attn_processors.values()) |
| with torch.no_grad(): |
| output_5 = model(**inputs_dict)[0] |
|
|
| model.set_attn_processor(XFormersAttnProcessor()) |
| assert all(type(proc) == XFormersAttnProcessor for proc in model.attn_processors.values()) |
| with torch.no_grad(): |
| output_6 = model(**inputs_dict)[0] |
|
|
| torch.use_deterministic_algorithms(True) |
|
|
| |
| assert torch.allclose(output_2, output_1, atol=self.base_precision) |
| assert torch.allclose(output_2, output_3, atol=self.base_precision) |
| assert torch.allclose(output_2, output_4, atol=self.base_precision) |
| assert torch.allclose(output_2, output_5, atol=self.base_precision) |
| assert torch.allclose(output_2, output_6, atol=self.base_precision) |
|
|
| def test_from_save_pretrained_variant(self): |
| init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() |
|
|
| model = self.model_class(**init_dict) |
| if hasattr(model, "set_default_attn_processor"): |
| model.set_default_attn_processor() |
|
|
| model.to(torch_device) |
| model.eval() |
|
|
| with tempfile.TemporaryDirectory() as tmpdirname: |
| model.save_pretrained(tmpdirname, variant="fp16") |
| new_model = self.model_class.from_pretrained(tmpdirname, variant="fp16") |
| if hasattr(new_model, "set_default_attn_processor"): |
| new_model.set_default_attn_processor() |
|
|
| |
| with self.assertRaises(OSError) as error_context: |
| self.model_class.from_pretrained(tmpdirname) |
|
|
| |
| assert "Error no file named diffusion_pytorch_model.bin found in directory" in str(error_context.exception) |
|
|
| new_model.to(torch_device) |
|
|
| with torch.no_grad(): |
| image = model(**inputs_dict) |
| if isinstance(image, dict): |
| image = image.to_tuple()[0] |
|
|
| new_image = new_model(**inputs_dict) |
|
|
| if isinstance(new_image, dict): |
| new_image = new_image.to_tuple()[0] |
|
|
| max_diff = (image - new_image).abs().sum().item() |
| self.assertLessEqual(max_diff, 5e-5, "Models give different forward passes") |
|
|
| @require_torch_2 |
| def test_from_save_pretrained_dynamo(self): |
| init_dict, _ = self.prepare_init_args_and_inputs_for_common() |
| inputs = [init_dict, self.model_class] |
| run_test_in_subprocess(test_case=self, target_func=_test_from_save_pretrained_dynamo, inputs=inputs) |
|
|
| def test_from_save_pretrained_dtype(self): |
| init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() |
|
|
| model = self.model_class(**init_dict) |
| model.to(torch_device) |
| model.eval() |
|
|
| for dtype in [torch.float32, torch.float16, torch.bfloat16]: |
| if torch_device == "mps" and dtype == torch.bfloat16: |
| continue |
| with tempfile.TemporaryDirectory() as tmpdirname: |
| model.to(dtype) |
| model.save_pretrained(tmpdirname) |
| new_model = self.model_class.from_pretrained(tmpdirname, low_cpu_mem_usage=True, torch_dtype=dtype) |
| assert new_model.dtype == dtype |
| new_model = self.model_class.from_pretrained(tmpdirname, low_cpu_mem_usage=False, torch_dtype=dtype) |
| assert new_model.dtype == dtype |
|
|
| def test_determinism(self, expected_max_diff=1e-5): |
| init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() |
| model = self.model_class(**init_dict) |
| model.to(torch_device) |
| model.eval() |
|
|
| with torch.no_grad(): |
| first = model(**inputs_dict) |
| if isinstance(first, dict): |
| first = first.to_tuple()[0] |
|
|
| second = model(**inputs_dict) |
| if isinstance(second, dict): |
| second = second.to_tuple()[0] |
|
|
| out_1 = first.cpu().numpy() |
| out_2 = second.cpu().numpy() |
| out_1 = out_1[~np.isnan(out_1)] |
| out_2 = out_2[~np.isnan(out_2)] |
| max_diff = np.amax(np.abs(out_1 - out_2)) |
| self.assertLessEqual(max_diff, expected_max_diff) |
|
|
| def test_output(self): |
| init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() |
| model = self.model_class(**init_dict) |
| model.to(torch_device) |
| model.eval() |
|
|
| with torch.no_grad(): |
| output = model(**inputs_dict) |
|
|
| if isinstance(output, dict): |
| output = output.to_tuple()[0] |
|
|
| self.assertIsNotNone(output) |
|
|
| |
| input_tensor = inputs_dict[self.main_input_name] |
| expected_shape = input_tensor.shape |
| self.assertEqual(output.shape, expected_shape, "Input and output shapes do not match") |
|
|
| def test_model_from_pretrained(self): |
| init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() |
|
|
| model = self.model_class(**init_dict) |
| model.to(torch_device) |
| model.eval() |
|
|
| |
| |
| with tempfile.TemporaryDirectory() as tmpdirname: |
| model.save_pretrained(tmpdirname) |
| new_model = self.model_class.from_pretrained(tmpdirname) |
| new_model.to(torch_device) |
| new_model.eval() |
|
|
| |
| for param_name in model.state_dict().keys(): |
| param_1 = model.state_dict()[param_name] |
| param_2 = new_model.state_dict()[param_name] |
| self.assertEqual(param_1.shape, param_2.shape) |
|
|
| with torch.no_grad(): |
| output_1 = model(**inputs_dict) |
|
|
| if isinstance(output_1, dict): |
| output_1 = output_1.to_tuple()[0] |
|
|
| output_2 = new_model(**inputs_dict) |
|
|
| if isinstance(output_2, dict): |
| output_2 = output_2.to_tuple()[0] |
|
|
| self.assertEqual(output_1.shape, output_2.shape) |
|
|
| @unittest.skipIf(torch_device == "mps", "Training is not supported in mps") |
| def test_training(self): |
| init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() |
|
|
| model = self.model_class(**init_dict) |
| model.to(torch_device) |
| model.train() |
| output = model(**inputs_dict) |
|
|
| if isinstance(output, dict): |
| output = output.to_tuple()[0] |
|
|
| input_tensor = inputs_dict[self.main_input_name] |
| noise = torch.randn((input_tensor.shape[0],) + self.output_shape).to(torch_device) |
| loss = torch.nn.functional.mse_loss(output, noise) |
| loss.backward() |
|
|
| @unittest.skipIf(torch_device == "mps", "Training is not supported in mps") |
| def test_ema_training(self): |
| init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() |
|
|
| model = self.model_class(**init_dict) |
| model.to(torch_device) |
| model.train() |
| ema_model = EMAModel(model.parameters()) |
|
|
| output = model(**inputs_dict) |
|
|
| if isinstance(output, dict): |
| output = output.to_tuple()[0] |
|
|
| input_tensor = inputs_dict[self.main_input_name] |
| noise = torch.randn((input_tensor.shape[0],) + self.output_shape).to(torch_device) |
| loss = torch.nn.functional.mse_loss(output, noise) |
| loss.backward() |
| ema_model.step(model.parameters()) |
|
|
| def test_outputs_equivalence(self): |
| def set_nan_tensor_to_zero(t): |
| |
| |
| device = t.device |
| if device.type == "mps": |
| t = t.to("cpu") |
| t[t != t] = 0 |
| return t.to(device) |
|
|
| def recursive_check(tuple_object, dict_object): |
| if isinstance(tuple_object, (List, Tuple)): |
| for tuple_iterable_value, dict_iterable_value in zip(tuple_object, dict_object.values()): |
| recursive_check(tuple_iterable_value, dict_iterable_value) |
| elif isinstance(tuple_object, Dict): |
| for tuple_iterable_value, dict_iterable_value in zip(tuple_object.values(), dict_object.values()): |
| recursive_check(tuple_iterable_value, dict_iterable_value) |
| elif tuple_object is None: |
| return |
| else: |
| self.assertTrue( |
| torch.allclose( |
| set_nan_tensor_to_zero(tuple_object), set_nan_tensor_to_zero(dict_object), atol=1e-5 |
| ), |
| msg=( |
| "Tuple and dict output are not equal. Difference:" |
| f" {torch.max(torch.abs(tuple_object - dict_object))}. Tuple has `nan`:" |
| f" {torch.isnan(tuple_object).any()} and `inf`: {torch.isinf(tuple_object)}. Dict has" |
| f" `nan`: {torch.isnan(dict_object).any()} and `inf`: {torch.isinf(dict_object)}." |
| ), |
| ) |
|
|
| init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() |
|
|
| model = self.model_class(**init_dict) |
| model.to(torch_device) |
| model.eval() |
|
|
| with torch.no_grad(): |
| outputs_dict = model(**inputs_dict) |
| outputs_tuple = model(**inputs_dict, return_dict=False) |
|
|
| recursive_check(outputs_tuple, outputs_dict) |
|
|
| @unittest.skipIf(torch_device == "mps", "Gradient checkpointing skipped on MPS") |
| def test_enable_disable_gradient_checkpointing(self): |
| if not self.model_class._supports_gradient_checkpointing: |
| return |
|
|
| init_dict, _ = self.prepare_init_args_and_inputs_for_common() |
|
|
| |
| model = self.model_class(**init_dict) |
| self.assertFalse(model.is_gradient_checkpointing) |
|
|
| |
| model.enable_gradient_checkpointing() |
| self.assertTrue(model.is_gradient_checkpointing) |
|
|
| |
| model.disable_gradient_checkpointing() |
| self.assertFalse(model.is_gradient_checkpointing) |
|
|
| def test_deprecated_kwargs(self): |
| has_kwarg_in_model_class = "kwargs" in inspect.signature(self.model_class.__init__).parameters |
| has_deprecated_kwarg = len(self.model_class._deprecated_kwargs) > 0 |
|
|
| if has_kwarg_in_model_class and not has_deprecated_kwarg: |
| raise ValueError( |
| f"{self.model_class} has `**kwargs` in its __init__ method but has not defined any deprecated kwargs" |
| " under the `_deprecated_kwargs` class attribute. Make sure to either remove `**kwargs` if there are" |
| " no deprecated arguments or add the deprecated argument with `_deprecated_kwargs =" |
| " [<deprecated_argument>]`" |
| ) |
|
|
| if not has_kwarg_in_model_class and has_deprecated_kwarg: |
| raise ValueError( |
| f"{self.model_class} doesn't have `**kwargs` in its __init__ method but has defined deprecated kwargs" |
| " under the `_deprecated_kwargs` class attribute. Make sure to either add the `**kwargs` argument to" |
| f" {self.model_class}.__init__ if there are deprecated arguments or remove the deprecated argument" |
| " from `_deprecated_kwargs = [<deprecated_argument>]`" |
| ) |
|
|