aboutsummaryrefslogtreecommitdiff
path: root/modules
diff options
context:
space:
mode:
Diffstat (limited to 'modules')
-rw-r--r--modules/api/api.py44
-rw-r--r--modules/call_queue.py2
-rw-r--r--modules/extras.py3
-rw-r--r--modules/interrogate.py3
-rw-r--r--modules/postprocessing.py3
-rw-r--r--modules/shared.py13
6 files changed, 36 insertions, 32 deletions
diff --git a/modules/api/api.py b/modules/api/api.py
index 41adaef7..9d33b9a9 100644
--- a/modules/api/api.py
+++ b/modules/api/api.py
@@ -330,7 +330,7 @@ class Api:
p.outpath_grids = opts.outdir_txt2img_grids
p.outpath_samples = opts.outdir_txt2img_samples
- shared.state.begin()
+ shared.state.begin(job="scripts_txt2img")
if selectable_scripts is not None:
p.script_args = script_args
processed = scripts.scripts_txt2img.run(p, *p.script_args) # Need to pass args as list here
@@ -387,7 +387,7 @@ class Api:
p.outpath_grids = opts.outdir_img2img_grids
p.outpath_samples = opts.outdir_img2img_samples
- shared.state.begin()
+ shared.state.begin(job="scripts_img2img")
if selectable_scripts is not None:
p.script_args = script_args
processed = scripts.scripts_img2img.run(p, *p.script_args) # Need to pass args as list here
@@ -396,7 +396,6 @@ class Api:
processed = process_images(p)
shared.state.end()
-
b64images = list(map(encode_pil_to_base64, processed.images)) if send_images else []
if not img2imgreq.include_init_images:
@@ -603,44 +602,42 @@ class Api:
def create_embedding(self, args: dict):
try:
- shared.state.begin()
+ shared.state.begin(job="create_embedding")
filename = create_embedding(**args) # create empty embedding
sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings() # reload embeddings so new one can be immediately used
- shared.state.end()
return models.CreateResponse(info=f"create embedding filename: {filename}")
except AssertionError as e:
- shared.state.end()
return models.TrainResponse(info=f"create embedding error: {e}")
+ finally:
+ shared.state.end()
+
def create_hypernetwork(self, args: dict):
try:
- shared.state.begin()
+ shared.state.begin(job="create_hypernetwork")
filename = create_hypernetwork(**args) # create empty embedding
- shared.state.end()
return models.CreateResponse(info=f"create hypernetwork filename: {filename}")
except AssertionError as e:
- shared.state.end()
return models.TrainResponse(info=f"create hypernetwork error: {e}")
+ finally:
+ shared.state.end()
def preprocess(self, args: dict):
try:
- shared.state.begin()
+ shared.state.begin(job="preprocess")
preprocess(**args) # quick operation unless blip/booru interrogation is enabled
shared.state.end()
- return models.PreprocessResponse(info = 'preprocess complete')
+ return models.PreprocessResponse(info='preprocess complete')
except KeyError as e:
- shared.state.end()
return models.PreprocessResponse(info=f"preprocess error: invalid token: {e}")
- except AssertionError as e:
- shared.state.end()
+ except Exception as e:
return models.PreprocessResponse(info=f"preprocess error: {e}")
- except FileNotFoundError as e:
+ finally:
shared.state.end()
- return models.PreprocessResponse(info=f'preprocess error: {e}')
def train_embedding(self, args: dict):
try:
- shared.state.begin()
+ shared.state.begin(job="train_embedding")
apply_optimizations = shared.opts.training_xattention_optimizations
error = None
filename = ''
@@ -653,15 +650,15 @@ class Api:
finally:
if not apply_optimizations:
sd_hijack.apply_optimizations()
- shared.state.end()
return models.TrainResponse(info=f"train embedding complete: filename: {filename} error: {error}")
- except AssertionError as msg:
- shared.state.end()
+ except Exception as msg:
return models.TrainResponse(info=f"train embedding error: {msg}")
+ finally:
+ shared.state.end()
def train_hypernetwork(self, args: dict):
try:
- shared.state.begin()
+ shared.state.begin(job="train_hypernetwork")
shared.loaded_hypernetworks = []
apply_optimizations = shared.opts.training_xattention_optimizations
error = None
@@ -679,9 +676,10 @@ class Api:
sd_hijack.apply_optimizations()
shared.state.end()
return models.TrainResponse(info=f"train embedding complete: filename: {filename} error: {error}")
- except AssertionError:
+ except Exception as exc:
+ return models.TrainResponse(info=f"train embedding error: {exc}")
+ finally:
shared.state.end()
- return models.TrainResponse(info=f"train embedding error: {error}")
def get_memory(self):
try:
diff --git a/modules/call_queue.py b/modules/call_queue.py
index 69bf63d2..3b94f8a4 100644
--- a/modules/call_queue.py
+++ b/modules/call_queue.py
@@ -30,7 +30,7 @@ def wrap_gradio_gpu_call(func, extra_outputs=None):
id_task = None
with queue_lock:
- shared.state.begin()
+ shared.state.begin(job=id_task)
progress.start_task(id_task)
try:
diff --git a/modules/extras.py b/modules/extras.py
index 830b53aa..e9c0263e 100644
--- a/modules/extras.py
+++ b/modules/extras.py
@@ -73,8 +73,7 @@ def to_half(tensor, enable):
def run_modelmerger(id_task, primary_model_name, secondary_model_name, tertiary_model_name, interp_method, multiplier, save_as_half, custom_name, checkpoint_format, config_source, bake_in_vae, discard_weights, save_metadata):
- shared.state.begin()
- shared.state.job = 'model-merge'
+ shared.state.begin(job="model-merge")
def fail(message):
shared.state.textinfo = message
diff --git a/modules/interrogate.py b/modules/interrogate.py
index 9b2c5b60..a3ae1dd5 100644
--- a/modules/interrogate.py
+++ b/modules/interrogate.py
@@ -184,8 +184,7 @@ class InterrogateModels:
def interrogate(self, pil_image):
res = ""
- shared.state.begin()
- shared.state.job = 'interrogate'
+ shared.state.begin(job="interrogate")
try:
if shared.cmd_opts.lowvram or shared.cmd_opts.medvram:
lowvram.send_everything_to_cpu()
diff --git a/modules/postprocessing.py b/modules/postprocessing.py
index 38544c38..136e9c88 100644
--- a/modules/postprocessing.py
+++ b/modules/postprocessing.py
@@ -9,8 +9,7 @@ from modules.shared import opts
def run_postprocessing(extras_mode, image, image_folder, input_dir, output_dir, show_extras_results, *args, save_output: bool = True):
devices.torch_gc()
- shared.state.begin()
- shared.state.job = 'extras'
+ shared.state.begin(job="extras")
image_data = []
image_names = []
diff --git a/modules/shared.py b/modules/shared.py
index 203ee1b9..9ab9d98b 100644
--- a/modules/shared.py
+++ b/modules/shared.py
@@ -4,6 +4,7 @@ import os
import sys
import threading
import time
+import logging
import gradio as gr
import torch
@@ -18,6 +19,8 @@ from modules.paths_internal import models_path, script_path, data_path, sd_confi
from ldm.models.diffusion.ddpm import LatentDiffusion
from typing import Optional
+log = logging.getLogger(__name__)
+
demo = None
parser = cmd_args.parser
@@ -144,12 +147,15 @@ class State:
def request_restart(self) -> None:
self.interrupt()
self.server_command = "restart"
+ log.info("Received restart request")
def skip(self):
self.skipped = True
+ log.info("Received skip request")
def interrupt(self):
self.interrupted = True
+ log.info("Received interrupt request")
def nextjob(self):
if opts.live_previews_enable and opts.show_progress_every_n_steps == -1:
@@ -173,7 +179,7 @@ class State:
return obj
- def begin(self):
+ def begin(self, job: str = "(unknown)"):
self.sampling_step = 0
self.job_count = -1
self.processing_has_refined_job_count = False
@@ -187,10 +193,13 @@ class State:
self.interrupted = False
self.textinfo = None
self.time_start = time.time()
-
+ self.job = job
devices.torch_gc()
+ log.info("Starting job %s", job)
def end(self):
+ duration = time.time() - self.time_start
+ log.info("Ending job %s (%.2f seconds)", self.job, duration)
self.job = ""
self.job_count = 0