aboutsummaryrefslogtreecommitdiff
path: root/modules
diff options
context:
space:
mode:
authorAUTOMATIC1111 <16777216c@gmail.com>2023-12-31 22:33:32 +0300
committerGitHub <noreply@github.com>2023-12-31 22:33:32 +0300
commitbe5f1acc8f6e5bfc7f8234fd570d663b2fde9c27 (patch)
treea07510354ce2dddf0afaf37ba256eba6d820900d /modules
parentf3af8c8d04d6be58ddb3b55f77d8006241dca8f6 (diff)
parent5768afc776a66bb94e77a9c1daebeea58fa731d5 (diff)
Merge pull request #14478 from akx/dtype-inspect
Add utility to inspect a model's dtype/device
Diffstat (limited to 'modules')
-rw-r--r--modules/devices.py3
-rw-r--r--modules/interrogate.py3
-rw-r--r--modules/sd_models_xl.py3
-rw-r--r--modules/torch_utils.py17
-rw-r--r--modules/upscaler_utils.py5
-rw-r--r--modules/xlmr.py5
-rw-r--r--modules/xlmr_m18.py5
7 files changed, 34 insertions, 7 deletions
diff --git a/modules/devices.py b/modules/devices.py
index c956207f..bd6bd579 100644
--- a/modules/devices.py
+++ b/modules/devices.py
@@ -4,6 +4,7 @@ from functools import lru_cache
import torch
from modules import errors, shared
+from modules.torch_utils import get_param
if sys.platform == "darwin":
from modules import mac_specific
@@ -131,7 +132,7 @@ patch_module_list = [
def manual_cast_forward(self, *args, **kwargs):
- org_dtype = next(self.parameters()).dtype
+ org_dtype = get_param(self).dtype
self.to(dtype)
args = [arg.to(dtype) if isinstance(arg, torch.Tensor) else arg for arg in args]
kwargs = {k: v.to(dtype) if isinstance(v, torch.Tensor) else v for k, v in kwargs.items()}
diff --git a/modules/interrogate.py b/modules/interrogate.py
index 3045560d..5be5a10f 100644
--- a/modules/interrogate.py
+++ b/modules/interrogate.py
@@ -11,6 +11,7 @@ from torchvision import transforms
from torchvision.transforms.functional import InterpolationMode
from modules import devices, paths, shared, lowvram, modelloader, errors
+from modules.torch_utils import get_param
blip_image_eval_size = 384
clip_model_name = 'ViT-L/14'
@@ -131,7 +132,7 @@ class InterrogateModels:
self.clip_model = self.clip_model.to(devices.device_interrogate)
- self.dtype = next(self.clip_model.parameters()).dtype
+ self.dtype = get_param(self.clip_model).dtype
def send_clip_to_ram(self):
if not shared.opts.interrogate_keep_models_in_memory:
diff --git a/modules/sd_models_xl.py b/modules/sd_models_xl.py
index 1de31b0d..c3602a7e 100644
--- a/modules/sd_models_xl.py
+++ b/modules/sd_models_xl.py
@@ -6,6 +6,7 @@ import sgm.models.diffusion
import sgm.modules.diffusionmodules.denoiser_scaling
import sgm.modules.diffusionmodules.discretizer
from modules import devices, shared, prompt_parser
+from modules.torch_utils import get_param
def get_learned_conditioning(self: sgm.models.diffusion.DiffusionEngine, batch: prompt_parser.SdConditioning | list[str]):
@@ -90,7 +91,7 @@ sgm.modules.GeneralConditioner.get_target_prompt_token_count = get_target_prompt
def extend_sdxl(model):
"""this adds a bunch of parameters to make SDXL model look a bit more like SD1.5 to the rest of the codebase."""
- dtype = next(model.model.diffusion_model.parameters()).dtype
+ dtype = get_param(model.model.diffusion_model).dtype
model.model.diffusion_model.dtype = dtype
model.model.conditioning_key = 'crossattn'
model.cond_stage_key = 'txt'
diff --git a/modules/torch_utils.py b/modules/torch_utils.py
new file mode 100644
index 00000000..e5b52393
--- /dev/null
+++ b/modules/torch_utils.py
@@ -0,0 +1,17 @@
+from __future__ import annotations
+
+import torch.nn
+
+
+def get_param(model) -> torch.nn.Parameter:
+ """
+ Find the first parameter in a model or module.
+ """
+ if hasattr(model, "model") and hasattr(model.model, "parameters"):
+ # Unpeel a model descriptor to get at the actual Torch module.
+ model = model.model
+
+ for param in model.parameters():
+ return param
+
+ raise ValueError(f"No parameters found in model {model!r}")
diff --git a/modules/upscaler_utils.py b/modules/upscaler_utils.py
index 8e413854..c60e3beb 100644
--- a/modules/upscaler_utils.py
+++ b/modules/upscaler_utils.py
@@ -7,6 +7,7 @@ import tqdm
from PIL import Image
from modules import images, shared
+from modules.torch_utils import get_param
logger = logging.getLogger(__name__)
@@ -17,8 +18,8 @@ def upscale_without_tiling(model, img: Image.Image):
img = np.ascontiguousarray(np.transpose(img, (2, 0, 1))) / 255
img = torch.from_numpy(img).float()
- model_weight = next(iter(model.model.parameters()))
- img = img.unsqueeze(0).to(device=model_weight.device, dtype=model_weight.dtype)
+ param = get_param(model)
+ img = img.unsqueeze(0).to(device=param.device, dtype=param.dtype)
with torch.no_grad():
output = model(img)
diff --git a/modules/xlmr.py b/modules/xlmr.py
index a407a3ca..6e000a56 100644
--- a/modules/xlmr.py
+++ b/modules/xlmr.py
@@ -5,6 +5,9 @@ from transformers.models.xlm_roberta.configuration_xlm_roberta import XLMRoberta
from transformers import XLMRobertaModel,XLMRobertaTokenizer
from typing import Optional
+from modules.torch_utils import get_param
+
+
class BertSeriesConfig(BertConfig):
def __init__(self, vocab_size=30522, hidden_size=768, num_hidden_layers=12, num_attention_heads=12, intermediate_size=3072, hidden_act="gelu", hidden_dropout_prob=0.1, attention_probs_dropout_prob=0.1, max_position_embeddings=512, type_vocab_size=2, initializer_range=0.02, layer_norm_eps=1e-12, pad_token_id=0, position_embedding_type="absolute", use_cache=True, classifier_dropout=None,project_dim=512, pooler_fn="average",learn_encoder=False,model_type='bert',**kwargs):
@@ -62,7 +65,7 @@ class BertSeriesModelWithTransformation(BertPreTrainedModel):
self.post_init()
def encode(self,c):
- device = next(self.parameters()).device
+ device = get_param(self).device
text = self.tokenizer(c,
truncation=True,
max_length=77,
diff --git a/modules/xlmr_m18.py b/modules/xlmr_m18.py
index a727e865..e3e81961 100644
--- a/modules/xlmr_m18.py
+++ b/modules/xlmr_m18.py
@@ -5,6 +5,9 @@ from transformers.models.xlm_roberta.configuration_xlm_roberta import XLMRoberta
from transformers import XLMRobertaModel,XLMRobertaTokenizer
from typing import Optional
+from modules.torch_utils import get_param
+
+
class BertSeriesConfig(BertConfig):
def __init__(self, vocab_size=30522, hidden_size=768, num_hidden_layers=12, num_attention_heads=12, intermediate_size=3072, hidden_act="gelu", hidden_dropout_prob=0.1, attention_probs_dropout_prob=0.1, max_position_embeddings=512, type_vocab_size=2, initializer_range=0.02, layer_norm_eps=1e-12, pad_token_id=0, position_embedding_type="absolute", use_cache=True, classifier_dropout=None,project_dim=512, pooler_fn="average",learn_encoder=False,model_type='bert',**kwargs):
@@ -68,7 +71,7 @@ class BertSeriesModelWithTransformation(BertPreTrainedModel):
self.post_init()
def encode(self,c):
- device = next(self.parameters()).device
+ device = get_param(self).device
text = self.tokenizer(c,
truncation=True,
max_length=77,