| from __future__ import annotations |
|
|
| import os |
| import sys |
|
|
| from collections.abc import MutableSequence |
|
|
|
|
| def _arg_name_to_option(arg_name: str) -> str: |
| arg_name = str(arg_name or "").strip() |
| if not arg_name: |
| return "" |
| return arg_name if arg_name.startswith("--") else f"--{arg_name}" |
|
|
|
|
| def _cuda_visible_device(device: str) -> str: |
| device = str(device or "").strip().lower() |
| if device.startswith("cuda:"): |
| device = device.split(":", 1)[1] |
| return device if device.isdigit() else "" |
|
|
|
|
| def _rewrite_arg_value(argv: MutableSequence[str], option: str, value: str) -> None: |
| for index, arg in enumerate(argv): |
| if arg == option and index + 1 < len(argv): |
| argv[index + 1] = value |
| return |
| if str(arg).startswith(f"{option}="): |
| argv[index] = f"{option}={value}" |
| return |
|
|
|
|
| def set_default_cuda_device_from_arg(arg_name: str, default_device: str = "cuda:0") -> bool: |
| option = _arg_name_to_option(arg_name) |
| if not option: |
| return False |
| argv = sys.argv |
| for index, arg in enumerate(argv[1:], start=1): |
| if arg == option and index + 1 < len(argv): |
| visible_device = _cuda_visible_device(argv[index + 1]) |
| break |
| if str(arg).startswith(f"{option}="): |
| visible_device = _cuda_visible_device(str(arg).split("=", 1)[1]) |
| break |
| else: |
| return False |
|
|
| if not visible_device: |
| return False |
| os.environ["CUDA_VISIBLE_DEVICES"] = visible_device |
| _rewrite_arg_value(argv, option, default_device) |
| return True |
|
|