From dbca512154341bb13e1b15d207176f2d403aff30 Mon Sep 17 00:00:00 2001 From: siutin Date: Fri, 3 Feb 2023 03:13:03 +0800 Subject: add an internal API for obtaining current task id --- modules/progress.py | 8 ++++++++ 1 file changed, 8 insertions(+) (limited to 'modules/progress.py') diff --git a/modules/progress.py b/modules/progress.py index c69ecf3d..05032ac5 100644 --- a/modules/progress.py +++ b/modules/progress.py @@ -4,6 +4,7 @@ import time import gradio as gr from pydantic import BaseModel, Field +from typing import List from modules.shared import opts @@ -37,6 +38,9 @@ def add_task_to_queue(id_job): pending_tasks[id_job] = time.time() +class CurrentTaskResponse(BaseModel): + current_task: str = Field(default=None, title="Task ID", description="id of the current progress task") + class ProgressRequest(BaseModel): id_task: str = Field(default=None, title="Task ID", description="id of the task to get progress for") id_live_preview: int = Field(default=-1, title="Live preview image ID", description="id of last received last preview image") @@ -56,6 +60,8 @@ class ProgressResponse(BaseModel): def setup_progress_api(app): return app.add_api_route("/internal/progress", progressapi, methods=["POST"], response_model=ProgressResponse) +def setup_current_task_api(app): + return app.add_api_route("/internal/current_task", current_task_api, methods=["GET"], response_model=CurrentTaskResponse) def progressapi(req: ProgressRequest): active = req.id_task == current_task @@ -97,3 +103,5 @@ def progressapi(req: ProgressRequest): return ProgressResponse(active=active, queued=queued, completed=completed, progress=progress, eta=eta, live_preview=live_preview, id_live_preview=id_live_preview, textinfo=shared.state.textinfo) +def current_task_api(): + return CurrentTaskResponse(current_task=current_task) \ No newline at end of file -- cgit v1.2.1 From 9407f1731aa8c112ffc0efaa611a76f7fead3d0c Mon Sep 17 00:00:00 2001 From: siutin Date: Mon, 6 Feb 2023 03:53:05 +0800 Subject: store the last generated result --- modules/progress.py | 10 ++++++++++ 1 file changed, 10 insertions(+) (limited to 'modules/progress.py') diff --git a/modules/progress.py b/modules/progress.py index 05032ac5..27a336ad 100644 --- a/modules/progress.py +++ b/modules/progress.py @@ -37,6 +37,16 @@ def finish_task(id_task): def add_task_to_queue(id_job): pending_tasks[id_job] = time.time() +last_task_id = None +last_task_result = None + +def set_last_task_result(id_job, result): + global last_task_id + global last_task_result + + last_task_id = id_job + last_task_result = result + class CurrentTaskResponse(BaseModel): current_task: str = Field(default=None, title="Task ID", description="id of the current progress task") -- cgit v1.2.1 From 4242e194e417ec5008d09ec6d756594ac65f77bd Mon Sep 17 00:00:00 2001 From: siutin Date: Mon, 6 Feb 2023 03:55:31 +0800 Subject: add a button to restore the current progress --- modules/progress.py | 14 ++++++++++++++ 1 file changed, 14 insertions(+) (limited to 'modules/progress.py') diff --git a/modules/progress.py b/modules/progress.py index 27a336ad..36963c92 100644 --- a/modules/progress.py +++ b/modules/progress.py @@ -48,6 +48,20 @@ def set_last_task_result(id_job, result): last_task_result = result +def restore_progress_call(task_tag): + if current_task is None or not current_task[5:-1].startswith(task_tag): + + # image, generation_info, html_info, html_log + return tuple(list([None, None, None, None])) + + else: + + t_task = current_task + while t_task != last_task_id: + time.sleep(2.5) + return last_task_result + + class CurrentTaskResponse(BaseModel): current_task: str = Field(default=None, title="Task ID", description="id of the current progress task") -- cgit v1.2.1 From e0b58527ff040f9c547ea45b5fcf1bfb7ab23cdd Mon Sep 17 00:00:00 2001 From: siutin Date: Mon, 6 Feb 2023 15:57:26 +0800 Subject: use condition to wait for result --- modules/progress.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) (limited to 'modules/progress.py') diff --git a/modules/progress.py b/modules/progress.py index 36963c92..1947c0fd 100644 --- a/modules/progress.py +++ b/modules/progress.py @@ -6,6 +6,7 @@ import gradio as gr from pydantic import BaseModel, Field from typing import List +from modules import call_queue from modules.shared import opts import modules.shared as shared @@ -57,8 +58,9 @@ def restore_progress_call(task_tag): else: t_task = current_task - while t_task != last_task_id: - time.sleep(2.5) + with call_queue.queue_lock_condition: + call_queue.queue_lock_condition.wait_for(lambda: t_task == last_task_id) + return last_task_result -- cgit v1.2.1 From 70ab21e67d128b953fbf4a360e02ac783f40dd55 Mon Sep 17 00:00:00 2001 From: siutin Date: Wed, 29 Mar 2023 00:17:19 +0800 Subject: keep randomId simpler --- modules/progress.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) (limited to 'modules/progress.py') diff --git a/modules/progress.py b/modules/progress.py index 1947c0fd..e99267f5 100644 --- a/modules/progress.py +++ b/modules/progress.py @@ -49,8 +49,8 @@ def set_last_task_result(id_job, result): last_task_result = result -def restore_progress_call(task_tag): - if current_task is None or not current_task[5:-1].startswith(task_tag): +def restore_progress_call(): + if current_task is None: # image, generation_info, html_info, html_log return tuple(list([None, None, None, None])) -- cgit v1.2.1 From 984970068c2bdc14cff266129ca25a26fbccbf2e Mon Sep 17 00:00:00 2001 From: siutin Date: Mon, 17 Apr 2023 01:06:28 +0800 Subject: multi users support --- modules/progress.py | 60 ++++++++++++++++++++++++++++++++++++++--------------- 1 file changed, 43 insertions(+), 17 deletions(-) (limited to 'modules/progress.py') diff --git a/modules/progress.py b/modules/progress.py index e99267f5..13568701 100644 --- a/modules/progress.py +++ b/modules/progress.py @@ -4,7 +4,9 @@ import time import gradio as gr from pydantic import BaseModel, Field -from typing import List +from typing import Optional +from fastapi import Depends, Security +from fastapi.security import APIKeyCookie from modules import call_queue from modules.shared import opts @@ -12,57 +14,71 @@ from modules.shared import opts import modules.shared as shared +current_task_user = None current_task = None pending_tasks = {} finished_tasks = [] -def start_task(id_task): +def start_task(user, id_task): global current_task + global current_task_user + current_task_user = user current_task = id_task - pending_tasks.pop(id_task, None) + pending_tasks.pop((user, id_task), None) -def finish_task(id_task): +def finish_task(user, id_task): global current_task + global current_task_user if current_task == id_task: current_task = None - finished_tasks.append(id_task) + if current_task_user == user: + current_task_user = None + + finished_tasks.append((user, id_task)) if len(finished_tasks) > 16: finished_tasks.pop(0) -def add_task_to_queue(id_job): - pending_tasks[id_job] = time.time() +def add_task_to_queue(user, id_job): + pending_tasks[(user, id_job)] = time.time() last_task_id = None last_task_result = None +last_task_user = None + +def set_last_task_result(user, id_job, result): -def set_last_task_result(id_job, result): global last_task_id global last_task_result + global last_task_user last_task_id = id_job last_task_result = result + last_task_user = user -def restore_progress_call(): +def restore_progress_call(request: gr.Request): if current_task is None: # image, generation_info, html_info, html_log return tuple(list([None, None, None, None])) else: + user = request.username - t_task = current_task - with call_queue.queue_lock_condition: - call_queue.queue_lock_condition.wait_for(lambda: t_task == last_task_id) + if current_task_user == user: + t_task = current_task + with call_queue.queue_lock_condition: + call_queue.queue_lock_condition.wait_for(lambda: t_task == last_task_id) - return last_task_result + return last_task_result + return tuple(list([None, None, None, None])) class CurrentTaskResponse(BaseModel): current_task: str = Field(default=None, title="Task ID", description="id of the current progress task") @@ -87,6 +103,19 @@ def setup_progress_api(app): return app.add_api_route("/internal/progress", progressapi, methods=["POST"], response_model=ProgressResponse) def setup_current_task_api(app): + + def get_current_user(token: Optional[str] = Security(APIKeyCookie(name="access-token", auto_error=False))): + return None if token is None else app.tokens.get(token) + + def current_task_api(current_user: str = Depends(get_current_user)): + + if app.auth is None or current_task_user == current_user: + current_user_task = current_task + else: + current_user_task = None + + return CurrentTaskResponse(current_task=current_user_task) + return app.add_api_route("/internal/current_task", current_task_api, methods=["GET"], response_model=CurrentTaskResponse) def progressapi(req: ProgressRequest): @@ -127,7 +156,4 @@ def progressapi(req: ProgressRequest): else: live_preview = None - return ProgressResponse(active=active, queued=queued, completed=completed, progress=progress, eta=eta, live_preview=live_preview, id_live_preview=id_live_preview, textinfo=shared.state.textinfo) - -def current_task_api(): - return CurrentTaskResponse(current_task=current_task) \ No newline at end of file + return ProgressResponse(active=active, queued=queued, completed=completed, progress=progress, eta=eta, live_preview=live_preview, id_live_preview=id_live_preview, textinfo=shared.state.textinfo) \ No newline at end of file -- cgit v1.2.1