From a6adc22f0711c8ab78c6ef8fc78715f815cc750f Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Thu, 1 Sep 2022 21:20:25 +0300 Subject: added interrupt button added save button --always-batch-cond-uncond as a workaround for performance regression option for low memory users specify gradio version as 3.1.5 because of what looks like a bug --- requirements.txt | 2 +- script.js | 7 +- webui.py | 244 ++++++++++++++++++++++++++++++++++++++++--------------- 3 files changed, 183 insertions(+), 70 deletions(-) diff --git a/requirements.txt b/requirements.txt index 91b21222..8538310b 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,6 +1,6 @@ basicsr gfpgan -gradio +gradio==3.1.5 numpy Pillow realesrgan diff --git a/script.js b/script.js index 9d409f86..1d7de7da 100644 --- a/script.js +++ b/script.js @@ -1,5 +1,3 @@ -console.log("running") - titles = { "Sampling steps": "How many times to imptove the generated image itratively; higher values take longer; very low values can produce bad results", "Sampling method": "Which algorithm to use to produce the image", @@ -29,6 +27,9 @@ titles = { "Inpaint at full resolution": "Upscale masked region to target resolution, do inpainting, downscale back and paste into original image", "Denoising Strength": "Determines how little respect the algorithm should have for image's content. At 0, nothing will change, and at 1 you'll get an unrelated image.", + + "Interrupt": "Stop processing images and return any results accumulated so far.", + "Save": "Write image to a directory (default - log/images) and generation parameters into csv file.", } function gradioApp(){ @@ -36,7 +37,7 @@ function gradioApp(){ } function addTitles(root){ - root.querySelectorAll('span').forEach(function(span){ + root.querySelectorAll('span, button').forEach(function(span){ tooltip = titles[span.textContent]; if(tooltip){ span.title = tooltip; diff --git a/webui.py b/webui.py index e5a21b2f..e56900a1 100644 --- a/webui.py +++ b/webui.py @@ -68,6 +68,7 @@ parser.add_argument("--embeddings-dir", type=str, default='embeddings', help="em parser.add_argument("--allow-code", action='store_true', help="allow custom script execution from webui") parser.add_argument("--medvram", action='store_true', help="enable stable diffusion model optimizations for sacrficing a little speed for low VRM usage") parser.add_argument("--lowvram", action='store_true', help="enable stable diffusion model optimizations for sacrficing a lot of speed for very low VRM usage") +parser.add_argument("--always-batch-cond-uncond", action='store_true', help="a workaround test; may help with speed in you use --lowvram") parser.add_argument("--precision", type=str, help="evaluate at this precision", choices=["full", "autocast"], default="autocast") parser.add_argument("--share", action='store_true', help="use share=True for gradio and make the UI accessible through their site (doesn't work for me but you might have better luck)") cmd_opts = parser.parse_args() @@ -75,9 +76,20 @@ cmd_opts = parser.parse_args() cpu = torch.device("cpu") gpu = torch.device("cuda") device = gpu if torch.cuda.is_available() else cpu -batch_cond_uncond = not (cmd_opts.lowvram or cmd_opts.medvram) +batch_cond_uncond = cmd_opts.always_batch_cond_uncond or not (cmd_opts.lowvram or cmd_opts.medvram) queue_lock = threading.Lock() + +class State: + interrupted = False + job = "" + + def interrupt(self): + self.interrupted = True + + +state = State() + if not cmd_opts.share: # fix gradio phoning home gradio.utils.version_check = lambda: None @@ -198,6 +210,7 @@ class Options: "outdir_img2img_grids": OptionInfo("outputs/img2img-grids", 'Output dictectory for img2img grids'), "save_to_dirs": OptionInfo(False, "When writing images/grids, create a directory with name derived from the prompt"), "save_to_dirs_prompt_len": OptionInfo(10, "When using above, how many words from prompt to put into directory name", gr.Slider, {"minimum": 1, "maximum": 32, "step": 1}), + "outdir_save": OptionInfo("log/images", "Directory for saving images using the Save button"), "samples_save": OptionInfo(True, "Save indiviual samples"), "samples_format": OptionInfo('png', 'File format for indiviual samples'), "grid_save": OptionInfo(True, "Save image grids"), @@ -400,8 +413,6 @@ def save_image(image, path, basename, seed=None, prompt=None, extension='png', i image.save(f"{fullfn_without_extension}.jpg", quality=opts.jpeg_quality, pnginfo=pnginfo) - - def sanitize_filename_part(text): return text.replace(' ', '_').translate({ord(x): '' for x in invalid_filename_chars})[:128] @@ -410,6 +421,7 @@ def plaintext_to_html(text): text = "".join([f"

{html.escape(x)}

\n" for x in text.split('\n')]) return text + def image_grid(imgs, batch_size=1, rows=None): if rows is None: if opts.n_rows > 0: @@ -652,18 +664,29 @@ def wrap_gradio_gpu_call(func): return res - return f + return wrap_gradio_call(f) def wrap_gradio_call(func): def f(*args, **kwargs): t = time.perf_counter() - res = list(func(*args, **kwargs)) + + try: + res = list(func(*args, **kwargs)) + except Exception as e: + print("Error completing request", file=sys.stderr) + print("Arguments:", args, kwargs, file=sys.stderr) + print(traceback.format_exc(), file=sys.stderr) + + res = [None, f"
{plaintext_to_html(type(e).__name__+': '+str(e))}
"] + elapsed = time.perf_counter() - t # last item is always HTML res[-1] = res[-1] + f"

Time taken: {elapsed:.2f}s

" + state.interrupted = False + return tuple(res) return f @@ -883,7 +906,6 @@ class StableDiffusionProcessing: self.extra_generation_params: dict = extra_generation_params self.overlay_images = overlay_images self.paste_to = None - self.progress_info = "" def init(self): pass @@ -959,6 +981,15 @@ class CFGDenoiser(nn.Module): return denoised + +def extended_trange(*args, **kwargs): + for x in tqdm.trange(*args, desc=state.job, **kwargs): + if state.interrupted: + break + + yield x + + class KDiffusionSampler: def __init__(self, funcname): self.model_wrap = k_diffusion.external.CompVisDenoiser(sd_model) @@ -980,7 +1011,7 @@ class KDiffusionSampler: self.model_wrap_cfg.init_latent = p.init_latent if hasattr(k_diffusion.sampling, 'trange'): - k_diffusion.sampling.trange = lambda *args, **kwargs: tqdm.tqdm(range(*args), desc=p.progress_info, **kwargs) + k_diffusion.sampling.trange = lambda *args, **kwargs: extended_trange(*args, **kwargs) return self.func(self.model_wrap_cfg, xi, sigma_sched, extra_args={'cond': conditioning, 'uncond': unconditional_conditioning, 'cond_scale': p.cfg_scale}, disable=False) @@ -989,13 +1020,36 @@ class KDiffusionSampler: x = x * sigmas[0] if hasattr(k_diffusion.sampling, 'trange'): - k_diffusion.sampling.trange = lambda *args, **kwargs: tqdm.tqdm(range(*args), desc=p.progress_info, **kwargs) + k_diffusion.sampling.trange = lambda *args, **kwargs: extended_trange(*args, **kwargs) samples_ddim = self.func(self.model_wrap_cfg, x, sigmas, extra_args={'cond': conditioning, 'uncond': unconditional_conditioning, 'cond_scale': p.cfg_scale}, disable=False) return samples_ddim -Processed = namedtuple('Processed', ['images', 'seed', 'info']) +class Processed: + def __init__(self, p: StableDiffusionProcessing, images, seed, info): + self.images = images + self.prompt = p.prompt + self.seed = seed + self.info = info + self.width = p.width + self.height = p.height + self.sampler = samplers[p.sampler_index].name + self.cfg_scale = p.cfg_scale + self.steps = p.steps + + def js(self): + obj = { + "prompt": self.prompt, + "seed": int(self.seed), + "width": self.width, + "height": self.height, + "sampler": self.sampler, + "cfg_scale": self.cfg_scale, + "steps": self.steps, + } + + return json.dumps(obj) def process_images(p: StableDiffusionProcessing) -> Processed: @@ -1063,6 +1117,9 @@ def process_images(p: StableDiffusionProcessing) -> Processed: p.init() for n in range(p.n_iter): + if state.interrupted: + break + prompts = all_prompts[n * p.batch_size:(n + 1) * p.batch_size] seeds = all_seeds[n * p.batch_size:(n + 1) * p.batch_size] @@ -1075,7 +1132,9 @@ def process_images(p: StableDiffusionProcessing) -> Processed: # we manually generate all input noises because each one should have a specific seed x = create_random_tensors([opt_C, p.height // opt_f, p.width // opt_f], seeds=seeds) - p.progress_info = f"Batch {n+1} out of {p.n_iter}" + if p.n_iter > 0: + state.job = f"Batch {n+1} out of {p.n_iter}" + samples_ddim = p.sample(x=x, conditioning=c, unconditional_conditioning=uc) x_samples_ddim = model.decode_first_stage(samples_ddim) @@ -1137,7 +1196,7 @@ def process_images(p: StableDiffusionProcessing) -> Processed: save_image(grid, p.outpath_grids, "grid", seed, prompt, opts.grid_format, info=infotext(), short_filename=not opts.grid_extended_filename) torch_gc() - return Processed(output_images, seed, infotext()) + return Processed(p, output_images, seed, infotext()) class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing): @@ -1188,52 +1247,47 @@ def txt2img(prompt: str, negative_prompt: str, steps: int, sampler_index: int, u module.display = display exec(compiled, module.__dict__) - processed = Processed(*display_result_data) + processed = Processed(p, *display_result_data) else: processed = process_images(p) - return processed.images, processed.seed, plaintext_to_html(processed.info) + return processed.images, processed.js(), plaintext_to_html(processed.info) +def save_files(js_data, images): + import csv -class Flagging(gr.FlaggingCallback): + os.makedirs(opts.outdir_save, exist_ok=True) - def setup(self, components, flagging_dir: str): - pass - - def flag(self, flag_data, flag_option=None, flag_index=None, username=None): - import csv + filenames = [] - os.makedirs("log/images", exist_ok=True) + data = json.loads(js_data) - # those must match the "txt2img" function - prompt, steps, sampler_index, use_gfpgan, prompt_matrix, n_iter, batch_size, cfg_scale, seed, height, width, code, images, seed, comment = flag_data + with open("log/log.csv", "a", encoding="utf8", newline='') as file: + import time + import base64 - filenames = [] + at_start = file.tell() == 0 + writer = csv.writer(file) + if at_start: + writer.writerow(["prompt", "seed", "width", "height", "sampler", "cfgs", "steps", "filename"]) - with open("log/log.csv", "a", encoding="utf8", newline='') as file: - import time - import base64 + filename_base = str(int(time.time() * 1000)) + for i, filedata in enumerate(images): + filename = filename_base + ("" if len(images) == 1 else "-" + str(i + 1)) + ".png" + filepath = os.path.join(opts.outdir_save, filename) - at_start = file.tell() == 0 - writer = csv.writer(file) - if at_start: - writer.writerow(["prompt", "seed", "width", "height", "cfgs", "steps", "filename"]) + if filedata.startswith("data:image/png;base64,"): + filedata = filedata[len("data:image/png;base64,"):] - filename_base = str(int(time.time() * 1000)) - for i, filedata in enumerate(images): - filename = "log/images/"+filename_base + ("" if len(images) == 1 else "-"+str(i+1)) + ".png" + with open(filepath, "wb") as imgfile: + imgfile.write(base64.decodebytes(filedata.encode('utf-8'))) - if filedata.startswith("data:image/png;base64,"): - filedata = filedata[len("data:image/png;base64,"):] + filenames.append(filename) - with open(filename, "wb") as imgfile: - imgfile.write(base64.decodebytes(filedata.encode('utf-8'))) + writer.writerow([data["prompt"], data["seed"], data["width"], data["height"], data["sampler"], data["cfg_scale"], data["steps"], filenames[0]]) - filenames.append(filename) + return '', '', plaintext_to_html(f"Saved: {filenames[0]}") - writer.writerow([prompt, seed, width, height, cfg_scale, steps, filenames[0]]) - - print("Logged:", filenames[0]) with gr.Blocks(analytics_enabled=False) as txt2img_interface: with gr.Row(): @@ -1267,8 +1321,15 @@ with gr.Blocks(analytics_enabled=False) as txt2img_interface: with gr.Column(variant='panel'): with gr.Group(): gallery = gr.Gallery(label='Output') - output_seed = gr.Number(label='Seed', visible=False) + + with gr.Group(): + with gr.Row(): + interrupt = gr.Button('Interrupt') + save = gr.Button('Save') + + with gr.Group(): html_info = gr.HTML() + generation_info = gr.Textbox(visible=False) txt2img_args = dict( fn=wrap_gradio_gpu_call(txt2img), @@ -1289,7 +1350,7 @@ with gr.Blocks(analytics_enabled=False) as txt2img_interface: ], outputs=[ gallery, - output_seed, + generation_info, html_info ] ) @@ -1297,6 +1358,25 @@ with gr.Blocks(analytics_enabled=False) as txt2img_interface: prompt.submit(**txt2img_args) submit.click(**txt2img_args) + interrupt.click( + fn=lambda: state.interrupt(), + inputs=[], + outputs=[], + ) + + save.click( + fn=wrap_gradio_call(save_files), + inputs=[ + generation_info, + gallery, + ], + outputs=[ + html_info, + html_info, + html_info, + ] + ) + def get_crop_region(mask, pad=0): h, w = mask.shape @@ -1508,6 +1588,7 @@ def img2img(prompt: str, init_img, init_img_with_mask, steps: int, sampler_index p.batch_size = 1 p.do_not_save_grid = True + state.job = f"Batch {i + 1} out of {n_iter}" processed = process_images(p) if initial_seed is None: @@ -1523,13 +1604,13 @@ def img2img(prompt: str, init_img, init_img_with_mask, steps: int, sampler_index save_image(grid, p.outpath_grids, "grid", initial_seed, prompt, opts.grid_format, info=info, short_filename=not opts.grid_extended_filename) - processed = Processed(history, initial_seed, initial_info) + processed = Processed(p, history, initial_seed, initial_info) elif is_upscale: initial_seed = None initial_info = None - upscaler = sd_upscalers[upscaler_name] + upscaler = sd_upscalers.get(upscaler_name, next(iter(sd_upscalers.values()))) img = upscaler(init_img) torch_gc() @@ -1553,6 +1634,7 @@ def img2img(prompt: str, init_img, init_img_with_mask, steps: int, sampler_index for i in range(batch_count): p.init_images = work[i*p.batch_size:(i+1)*p.batch_size] + state.job = f"Batch {i + 1} out of {batch_count}" processed = process_images(p) if initial_seed is None: @@ -1565,19 +1647,19 @@ def img2img(prompt: str, init_img, init_img_with_mask, steps: int, sampler_index image_index = 0 for y, h, row in grid.tiles: for tiledata in row: - tiledata[2] = work_results[image_index] + tiledata[2] = work_results[image_index] if image_index

{message}

" - return [info] + return '', '', info pnginfo_interface = gr.Interface( @@ -1789,6 +1897,8 @@ pnginfo_interface = gr.Interface( ], outputs=[ gr.HTML(), + gr.HTML(), + gr.HTML(), ], allow_flagging="never", analytics_enabled=False, @@ -1809,7 +1919,7 @@ def run_settings(*args): opts.save(config_filename) - return 'Settings saved.', '' + return 'Settings saved.', '', '' def create_setting_component(key): @@ -1839,6 +1949,7 @@ settings_interface = gr.Interface( outputs=[ gr.Textbox(label='Result'), gr.HTML(), + gr.HTML(), ], title=None, description=None, @@ -1863,17 +1974,18 @@ try: except Exception: pass -sd_config = OmegaConf.load(cmd_opts.config) -sd_model = load_model_from_config(sd_config, cmd_opts.ckpt) -sd_model = (sd_model if cmd_opts.no_half else sd_model.half()) +if False: + sd_config = OmegaConf.load(cmd_opts.config) + sd_model = load_model_from_config(sd_config, cmd_opts.ckpt) + sd_model = (sd_model if cmd_opts.no_half else sd_model.half()) -if cmd_opts.lowvram or cmd_opts.medvram: - setup_for_low_vram(sd_model) -else: - sd_model = sd_model.to(device) + if cmd_opts.lowvram or cmd_opts.medvram: + setup_for_low_vram(sd_model) + else: + sd_model = sd_model.to(device) -model_hijack = StableDiffusionModelHijack() -model_hijack.hijack(sd_model) + model_hijack = StableDiffusionModelHijack() + model_hijack.hijack(sd_model) with open(os.path.join(script_path, "style.css"), "r", encoding="utf8") as file: css = file.read() -- cgit v1.2.1