aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorsuperhero-7 <537093830@qq.com>2023-10-01 12:25:19 +0800
committersuperhero-7 <537093830@qq.com>2023-10-01 12:25:19 +0800
commit2d947175b902d6838c803036d9757e7d3226b41d (patch)
treed62d9f599c94734d5e40b5ed711be13e0d4e081d
parentf8f4ff2bb8f56877dede466934dd8ddf25c21063 (diff)
fix linter issues
-rw-r--r--modules/sd_hijack.py4
-rw-r--r--modules/sd_models_config.py3
-rw-r--r--modules/xlmr_m18.py12
3 files changed, 9 insertions, 10 deletions
diff --git a/modules/sd_hijack.py b/modules/sd_hijack.py
index 4b36c0e9..0689699c 100644
--- a/modules/sd_hijack.py
+++ b/modules/sd_hijack.py
@@ -212,7 +212,7 @@ class StableDiffusionModelHijack:
model_embeddings = m.cond_stage_model.roberta.embeddings
model_embeddings.token_embedding = EmbeddingsWithFixes(model_embeddings.word_embeddings, self)
m.cond_stage_model = sd_hijack_xlmr.FrozenXLMREmbedderWithCustomWords(m.cond_stage_model, self)
-
+
elif type(m.cond_stage_model) == ldm.modules.encoders.modules.FrozenCLIPEmbedder:
model_embeddings = m.cond_stage_model.transformer.text_model.embeddings
model_embeddings.token_embedding = EmbeddingsWithFixes(model_embeddings.token_embedding, self)
@@ -258,7 +258,7 @@ class StableDiffusionModelHijack:
if hasattr(m, 'cond_stage_model'):
delattr(m, 'cond_stage_model')
-
+
elif type(m.cond_stage_model) == sd_hijack_xlmr.FrozenXLMREmbedderWithCustomWords:
m.cond_stage_model = m.cond_stage_model.wrapped
diff --git a/modules/sd_models_config.py b/modules/sd_models_config.py
index 9ba89dfc..deab2f6e 100644
--- a/modules/sd_models_config.py
+++ b/modules/sd_models_config.py
@@ -95,8 +95,7 @@ def guess_model_config_from_state_dict(sd, filename):
if diffusion_model_input.shape[1] == 8:
return config_instruct_pix2pix
-
- # import pdb; pdb.set_trace()
+
if sd.get('cond_stage_model.roberta.embeddings.word_embeddings.weight', None) is not None:
if sd.get('cond_stage_model.transformation.weight').size()[0] == 1024:
return config_alt_diffusion_m18
diff --git a/modules/xlmr_m18.py b/modules/xlmr_m18.py
index 18785692..a727e865 100644
--- a/modules/xlmr_m18.py
+++ b/modules/xlmr_m18.py
@@ -1,4 +1,4 @@
-from transformers import BertPreTrainedModel,BertModel,BertConfig
+from transformers import BertPreTrainedModel,BertConfig
import torch.nn as nn
import torch
from transformers.models.xlm_roberta.configuration_xlm_roberta import XLMRobertaConfig
@@ -28,7 +28,7 @@ class BertSeriesModelWithTransformation(BertPreTrainedModel):
config_class = BertSeriesConfig
def __init__(self, config=None, **kargs):
- # modify initialization for autoloading
+ # modify initialization for autoloading
if config is None:
config = XLMRobertaConfig()
config.attention_probs_dropout_prob= 0.1
@@ -80,7 +80,7 @@ class BertSeriesModelWithTransformation(BertPreTrainedModel):
text["attention_mask"] = torch.tensor(
text['attention_mask']).to(device)
features = self(**text)
- return features['projection_state']
+ return features['projection_state']
def forward(
self,
@@ -147,8 +147,8 @@ class BertSeriesModelWithTransformation(BertPreTrainedModel):
"hidden_states": outputs.hidden_states,
"attentions": outputs.attentions,
}
-
-
+
+
# return {
# 'pooler_output':pooler_output,
# 'last_hidden_state':outputs.last_hidden_state,
@@ -161,4 +161,4 @@ class BertSeriesModelWithTransformation(BertPreTrainedModel):
class RobertaSeriesModelWithTransformation(BertSeriesModelWithTransformation):
base_model_prefix = 'roberta'
- config_class= RobertaSeriesConfig \ No newline at end of file
+ config_class= RobertaSeriesConfig