aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--launch.py15
1 files changed, 14 insertions, 1 deletions
diff --git a/launch.py b/launch.py
index bcbb792c..668548f1 100644
--- a/launch.py
+++ b/launch.py
@@ -7,6 +7,7 @@ import shlex
import platform
import argparse
import json
+import detection
dir_repos = "repositories"
dir_extensions = "extensions"
@@ -15,6 +16,12 @@ git = os.environ.get('GIT', "git")
index_url = os.environ.get('INDEX_URL', "")
stored_commit_hash = None
+# Get the GPU vendor and the operating system
+gpu = detection.check_gpu()
+if os.name == "posix":
+ os_name = platform.uname().system
+else:
+ os_name = os.name
def commit_hash():
global stored_commit_hash
@@ -173,7 +180,11 @@ def run_extensions_installers(settings_file):
def prepare_environment():
- torch_command = os.environ.get('TORCH_COMMAND', "pip install torch==1.12.1+cu113 torchvision==0.13.1+cu113 --extra-index-url https://download.pytorch.org/whl/cu113")
+ if gpu == "AMD" and os_name !="nt":
+ torch_command = os.environ.get('TORCH_COMMAND', "pip install torch torchvision --extra-index-url https://download.pytorch.org/whl/rocm5.2")
+ else:
+ torch_command = os.environ.get('TORCH_COMMAND', "pip install torch==1.12.1+cu113 torchvision==0.13.1+cu113 --extra-index-url https://download.pytorch.org/whl/cu113")
+
requirements_file = os.environ.get('REQS_FILE', "requirements_versions.txt")
commandline_args = os.environ.get('COMMANDLINE_ARGS', "")
@@ -295,6 +306,8 @@ def tests(test_dir):
def start():
+ print(f"Operating System: {os_name}")
+ print(f"GPU: {gpu}")
print(f"Launching {'API server' if '--nowebui' in sys.argv else 'Web UI'} with arguments: {' '.join(sys.argv[1:])}")
import webui
if '--nowebui' in sys.argv: