aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--modules/devices.py4
-rw-r--r--modules/interrogate.py5
-rw-r--r--modules/sd_models_xl.py4
-rw-r--r--modules/upscaler_utils.py5
-rw-r--r--modules/xlmr.py4
-rw-r--r--modules/xlmr_m18.py5
-rw-r--r--test/test_torch_utils.py4
7 files changed, 14 insertions, 17 deletions
diff --git a/modules/devices.py b/modules/devices.py
index bd6bd579..ff279ac5 100644
--- a/modules/devices.py
+++ b/modules/devices.py
@@ -4,7 +4,7 @@ from functools import lru_cache
import torch
from modules import errors, shared
-from modules.torch_utils import get_param
+from modules import torch_utils
if sys.platform == "darwin":
from modules import mac_specific
@@ -132,7 +132,7 @@ patch_module_list = [
def manual_cast_forward(self, *args, **kwargs):
- org_dtype = get_param(self).dtype
+ org_dtype = torch_utils.get_param(self).dtype
self.to(dtype)
args = [arg.to(dtype) if isinstance(arg, torch.Tensor) else arg for arg in args]
kwargs = {k: v.to(dtype) if isinstance(v, torch.Tensor) else v for k, v in kwargs.items()}
diff --git a/modules/interrogate.py b/modules/interrogate.py
index 5be5a10f..35a627ca 100644
--- a/modules/interrogate.py
+++ b/modules/interrogate.py
@@ -10,8 +10,7 @@ import torch.hub
from torchvision import transforms
from torchvision.transforms.functional import InterpolationMode
-from modules import devices, paths, shared, lowvram, modelloader, errors
-from modules.torch_utils import get_param
+from modules import devices, paths, shared, lowvram, modelloader, errors, torch_utils
blip_image_eval_size = 384
clip_model_name = 'ViT-L/14'
@@ -132,7 +131,7 @@ class InterrogateModels:
self.clip_model = self.clip_model.to(devices.device_interrogate)
- self.dtype = get_param(self.clip_model).dtype
+ self.dtype = torch_utils.get_param(self.clip_model).dtype
def send_clip_to_ram(self):
if not shared.opts.interrogate_keep_models_in_memory:
diff --git a/modules/sd_models_xl.py b/modules/sd_models_xl.py
index c3602a7e..0de17af3 100644
--- a/modules/sd_models_xl.py
+++ b/modules/sd_models_xl.py
@@ -6,7 +6,7 @@ import sgm.models.diffusion
import sgm.modules.diffusionmodules.denoiser_scaling
import sgm.modules.diffusionmodules.discretizer
from modules import devices, shared, prompt_parser
-from modules.torch_utils import get_param
+from modules import torch_utils
def get_learned_conditioning(self: sgm.models.diffusion.DiffusionEngine, batch: prompt_parser.SdConditioning | list[str]):
@@ -91,7 +91,7 @@ sgm.modules.GeneralConditioner.get_target_prompt_token_count = get_target_prompt
def extend_sdxl(model):
"""this adds a bunch of parameters to make SDXL model look a bit more like SD1.5 to the rest of the codebase."""
- dtype = get_param(model.model.diffusion_model).dtype
+ dtype = torch_utils.get_param(model.model.diffusion_model).dtype
model.model.diffusion_model.dtype = dtype
model.model.conditioning_key = 'crossattn'
model.cond_stage_key = 'txt'
diff --git a/modules/upscaler_utils.py b/modules/upscaler_utils.py
index c60e3beb..f5cb92d5 100644
--- a/modules/upscaler_utils.py
+++ b/modules/upscaler_utils.py
@@ -6,8 +6,7 @@ import torch
import tqdm
from PIL import Image
-from modules import images, shared
-from modules.torch_utils import get_param
+from modules import images, shared, torch_utils
logger = logging.getLogger(__name__)
@@ -18,7 +17,7 @@ def upscale_without_tiling(model, img: Image.Image):
img = np.ascontiguousarray(np.transpose(img, (2, 0, 1))) / 255
img = torch.from_numpy(img).float()
- param = get_param(model)
+ param = torch_utils.get_param(model)
img = img.unsqueeze(0).to(device=param.device, dtype=param.dtype)
with torch.no_grad():
diff --git a/modules/xlmr.py b/modules/xlmr.py
index 6e000a56..319771b7 100644
--- a/modules/xlmr.py
+++ b/modules/xlmr.py
@@ -5,7 +5,7 @@ from transformers.models.xlm_roberta.configuration_xlm_roberta import XLMRoberta
from transformers import XLMRobertaModel,XLMRobertaTokenizer
from typing import Optional
-from modules.torch_utils import get_param
+from modules import torch_utils
class BertSeriesConfig(BertConfig):
@@ -65,7 +65,7 @@ class BertSeriesModelWithTransformation(BertPreTrainedModel):
self.post_init()
def encode(self,c):
- device = get_param(self).device
+ device = torch_utils.get_param(self).device
text = self.tokenizer(c,
truncation=True,
max_length=77,
diff --git a/modules/xlmr_m18.py b/modules/xlmr_m18.py
index e3e81961..f6055504 100644
--- a/modules/xlmr_m18.py
+++ b/modules/xlmr_m18.py
@@ -4,8 +4,7 @@ import torch
from transformers.models.xlm_roberta.configuration_xlm_roberta import XLMRobertaConfig
from transformers import XLMRobertaModel,XLMRobertaTokenizer
from typing import Optional
-
-from modules.torch_utils import get_param
+from modules import torch_utils
class BertSeriesConfig(BertConfig):
@@ -71,7 +70,7 @@ class BertSeriesModelWithTransformation(BertPreTrainedModel):
self.post_init()
def encode(self,c):
- device = get_param(self).device
+ device = torch_utils.get_param(self).device
text = self.tokenizer(c,
truncation=True,
max_length=77,
diff --git a/test/test_torch_utils.py b/test/test_torch_utils.py
index f1aec832..23ccb93a 100644
--- a/test/test_torch_utils.py
+++ b/test/test_torch_utils.py
@@ -3,7 +3,7 @@ import types
import pytest
import torch
-from modules.torch_utils import get_param
+from modules import torch_utils
@pytest.mark.parametrize("wrapped", [True, False])
@@ -14,6 +14,6 @@ def test_get_param(wrapped):
if wrapped:
# more or less how spandrel wraps a thing
mod = types.SimpleNamespace(model=mod)
- p = get_param(mod)
+ p = torch_utils.get_param(mod)
assert p.dtype == torch.float16
assert p.device == cpu