aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--modules/interrogate.py1
-rw-r--r--modules/textual_inversion/preprocess.py75
-rw-r--r--modules/textual_inversion/textual_inversion.py1
-rw-r--r--modules/textual_inversion/ui.py14
-rw-r--r--modules/ui.py36
5 files changed, 124 insertions, 3 deletions
diff --git a/modules/interrogate.py b/modules/interrogate.py
index f62a4745..eed87144 100644
--- a/modules/interrogate.py
+++ b/modules/interrogate.py
@@ -21,6 +21,7 @@ Category = namedtuple("Category", ["name", "topn", "items"])
re_topn = re.compile(r"\.top(\d+)\.")
+
class InterrogateModels:
blip_model = None
clip_model = None
diff --git a/modules/textual_inversion/preprocess.py b/modules/textual_inversion/preprocess.py
new file mode 100644
index 00000000..209e928f
--- /dev/null
+++ b/modules/textual_inversion/preprocess.py
@@ -0,0 +1,75 @@
+import os
+from PIL import Image, ImageOps
+import tqdm
+
+from modules import shared, images
+
+
+def preprocess(process_src, process_dst, process_flip, process_split, process_caption):
+ size = 512
+ src = os.path.abspath(process_src)
+ dst = os.path.abspath(process_dst)
+
+ assert src != dst, 'same directory specified as source and desitnation'
+
+ os.makedirs(dst, exist_ok=True)
+
+ files = os.listdir(src)
+
+ shared.state.textinfo = "Preprocessing..."
+ shared.state.job_count = len(files)
+
+ if process_caption:
+ shared.interrogator.load()
+
+ def save_pic_with_caption(image, index):
+ if process_caption:
+ caption = "-" + shared.interrogator.generate_caption(image)
+ else:
+ caption = ""
+
+ image.save(os.path.join(dst, f"{index:05}-{subindex[0]}{caption}.png"))
+ subindex[0] += 1
+
+ def save_pic(image, index):
+ save_pic_with_caption(image, index)
+
+ if process_flip:
+ save_pic_with_caption(ImageOps.mirror(image), index)
+
+ for index, imagefile in enumerate(tqdm.tqdm(files)):
+ subindex = [0]
+ filename = os.path.join(src, imagefile)
+ img = Image.open(filename).convert("RGB")
+
+ if shared.state.interrupted:
+ break
+
+ ratio = img.height / img.width
+ is_tall = ratio > 1.35
+ is_wide = ratio < 1 / 1.35
+
+ if process_split and is_tall:
+ img = img.resize((size, size * img.height // img.width))
+
+ top = img.crop((0, 0, size, size))
+ save_pic(top, index)
+
+ bot = img.crop((0, img.height - size, size, img.height))
+ save_pic(bot, index)
+ elif process_split and is_wide:
+ img = img.resize((size * img.width // img.height, size))
+
+ left = img.crop((0, 0, size, size))
+ save_pic(left, index)
+
+ right = img.crop((img.width - size, 0, img.width, size))
+ save_pic(right, index)
+ else:
+ img = images.resize_image(1, img, size, size)
+ save_pic(img, index)
+
+ shared.state.nextjob()
+
+ if process_caption:
+ shared.interrogator.send_blip_to_ram()
diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py
index 1183aab7..d4e250d8 100644
--- a/modules/textual_inversion/textual_inversion.py
+++ b/modules/textual_inversion/textual_inversion.py
@@ -7,6 +7,7 @@ import tqdm
import html
import datetime
+
from modules import shared, devices, sd_hijack, processing, sd_models
import modules.textual_inversion.dataset
diff --git a/modules/textual_inversion/ui.py b/modules/textual_inversion/ui.py
index 633037d8..f19ac5e0 100644
--- a/modules/textual_inversion/ui.py
+++ b/modules/textual_inversion/ui.py
@@ -2,24 +2,31 @@ import html
import gradio as gr
-import modules.textual_inversion.textual_inversion as ti
+import modules.textual_inversion.textual_inversion
+import modules.textual_inversion.preprocess
from modules import sd_hijack, shared
def create_embedding(name, initialization_text, nvpt):
- filename = ti.create_embedding(name, nvpt, init_text=initialization_text)
+ filename = modules.textual_inversion.textual_inversion.create_embedding(name, nvpt, init_text=initialization_text)
sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings()
return gr.Dropdown.update(choices=sorted(sd_hijack.model_hijack.embedding_db.word_embeddings.keys())), f"Created: {filename}", ""
+def preprocess(*args):
+ modules.textual_inversion.preprocess.preprocess(*args)
+
+ return "Preprocessing finished.", ""
+
+
def train_embedding(*args):
try:
sd_hijack.undo_optimizations()
- embedding, filename = ti.train_embedding(*args)
+ embedding, filename = modules.textual_inversion.textual_inversion.train_embedding(*args)
res = f"""
Training {'interrupted' if shared.state.interrupted else 'finished'} at {embedding.step} steps.
@@ -30,3 +37,4 @@ Embedding saved to {html.escape(filename)}
raise
finally:
sd_hijack.apply_optimizations()
+
diff --git a/modules/ui.py b/modules/ui.py
index 8912deff..e7bde53b 100644
--- a/modules/ui.py
+++ b/modules/ui.py
@@ -961,6 +961,8 @@ def create_ui(wrap_gradio_gpu_call):
with gr.Row().style(equal_height=False):
with gr.Column():
with gr.Group():
+ gr.HTML(value="<p style='margin-bottom: 0.7em'>See <b><a href=\"https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Textual-Inversion\">wiki</a></b> for detailed explanation.</p>")
+
gr.HTML(value="<p style='margin-bottom: 0.7em'>Create a new embedding</p>")
new_embedding_name = gr.Textbox(label="Name")
@@ -975,6 +977,24 @@ def create_ui(wrap_gradio_gpu_call):
create_embedding = gr.Button(value="Create", variant='primary')
with gr.Group():
+ gr.HTML(value="<p style='margin-bottom: 0.7em'>Preprocess images</p>")
+
+ process_src = gr.Textbox(label='Source directory')
+ process_dst = gr.Textbox(label='Destination directory')
+
+ with gr.Row():
+ process_flip = gr.Checkbox(label='Flip')
+ process_split = gr.Checkbox(label='Split into two')
+ process_caption = gr.Checkbox(label='Add caption')
+
+ with gr.Row():
+ with gr.Column(scale=3):
+ gr.HTML(value="")
+
+ with gr.Column():
+ run_preprocess = gr.Button(value="Preprocess", variant='primary')
+
+ with gr.Group():
gr.HTML(value="<p style='margin-bottom: 0.7em'>Train an embedding; must specify a directory with a set of 512x512 images</p>")
train_embedding_name = gr.Dropdown(label='Embedding', choices=sorted(sd_hijack.model_hijack.embedding_db.word_embeddings.keys()))
learn_rate = gr.Number(label='Learning rate', value=5.0e-03)
@@ -1018,6 +1038,22 @@ def create_ui(wrap_gradio_gpu_call):
]
)
+ run_preprocess.click(
+ fn=wrap_gradio_gpu_call(modules.textual_inversion.ui.preprocess, extra_outputs=[gr.update()]),
+ _js="start_training_textual_inversion",
+ inputs=[
+ process_src,
+ process_dst,
+ process_flip,
+ process_split,
+ process_caption,
+ ],
+ outputs=[
+ ti_output,
+ ti_outcome,
+ ],
+ )
+
train_embedding.click(
fn=wrap_gradio_gpu_call(modules.textual_inversion.ui.train_embedding, extra_outputs=[gr.update()]),
_js="start_training_textual_inversion",