aboutsummaryrefslogtreecommitdiff
path: root/modules/sd_disable_initialization.py
diff options
context:
space:
mode:
authorInvincibleDude <81354513+InvincibleDude@users.noreply.github.com>2023-02-05 18:02:44 +0300
committerGitHub <noreply@github.com>2023-02-05 18:02:44 +0300
commitf4b78e73a424299a496801930e6d8868d8d03e61 (patch)
tree48884e8a2ba070d8640f79c1676ffff3e35f37e7 /modules/sd_disable_initialization.py
parent3ec2eb8bf12ae629c292ed0e96f199669040c5de (diff)
parentea9bd9fc7409109adcd61b897abc2c8881161256 (diff)
Merge branch 'AUTOMATIC1111:master' into improved-hr-conflict-test
Diffstat (limited to 'modules/sd_disable_initialization.py')
-rw-r--r--modules/sd_disable_initialization.py17
1 files changed, 10 insertions, 7 deletions
diff --git a/modules/sd_disable_initialization.py b/modules/sd_disable_initialization.py
index e90aa9fe..c4a09d15 100644
--- a/modules/sd_disable_initialization.py
+++ b/modules/sd_disable_initialization.py
@@ -20,8 +20,9 @@ class DisableInitialization:
```
"""
- def __init__(self):
+ def __init__(self, disable_clip=True):
self.replaced = []
+ self.disable_clip = disable_clip
def replace(self, obj, field, func):
original = getattr(obj, field, None)
@@ -75,12 +76,14 @@ class DisableInitialization:
self.replace(torch.nn.init, 'kaiming_uniform_', do_nothing)
self.replace(torch.nn.init, '_no_grad_normal_', do_nothing)
self.replace(torch.nn.init, '_no_grad_uniform_', do_nothing)
- self.create_model_and_transforms = self.replace(open_clip, 'create_model_and_transforms', create_model_and_transforms_without_pretrained)
- self.CLIPTextModel_from_pretrained = self.replace(ldm.modules.encoders.modules.CLIPTextModel, 'from_pretrained', CLIPTextModel_from_pretrained)
- self.transformers_modeling_utils_load_pretrained_model = self.replace(transformers.modeling_utils.PreTrainedModel, '_load_pretrained_model', transformers_modeling_utils_load_pretrained_model)
- self.transformers_tokenization_utils_base_cached_file = self.replace(transformers.tokenization_utils_base, 'cached_file', transformers_tokenization_utils_base_cached_file)
- self.transformers_configuration_utils_cached_file = self.replace(transformers.configuration_utils, 'cached_file', transformers_configuration_utils_cached_file)
- self.transformers_utils_hub_get_from_cache = self.replace(transformers.utils.hub, 'get_from_cache', transformers_utils_hub_get_from_cache)
+
+ if self.disable_clip:
+ self.create_model_and_transforms = self.replace(open_clip, 'create_model_and_transforms', create_model_and_transforms_without_pretrained)
+ self.CLIPTextModel_from_pretrained = self.replace(ldm.modules.encoders.modules.CLIPTextModel, 'from_pretrained', CLIPTextModel_from_pretrained)
+ self.transformers_modeling_utils_load_pretrained_model = self.replace(transformers.modeling_utils.PreTrainedModel, '_load_pretrained_model', transformers_modeling_utils_load_pretrained_model)
+ self.transformers_tokenization_utils_base_cached_file = self.replace(transformers.tokenization_utils_base, 'cached_file', transformers_tokenization_utils_base_cached_file)
+ self.transformers_configuration_utils_cached_file = self.replace(transformers.configuration_utils, 'cached_file', transformers_configuration_utils_cached_file)
+ self.transformers_utils_hub_get_from_cache = self.replace(transformers.utils.hub, 'get_from_cache', transformers_utils_hub_get_from_cache)
def __exit__(self, exc_type, exc_val, exc_tb):
for obj, field, original in self.replaced: