"""Lightweight in-process API wrapper around WanGP generation.""" from __future__ import annotations import contextlib import copy import importlib import io import json import numpy as np import os import queue import re import sys import threading import time from dataclasses import dataclass, field from pathlib import Path from typing import Any, Iterator, Sequence from PIL import Image from shared.utils.process_locks import set_main_generation_running from shared.utils.virtual_media import get_virtual_media_vsource, parse_virtual_media_path, replace_virtual_media_source _RUNTIME_LOCK = threading.RLock() _GENERATION_LOCK = threading.RLock() _RUNTIME: "_WanGPRuntime | None" = None _BANNER_PRINTED = False _STATUS_STEP_PREFIX_RE = re.compile(r"^(?:prompt|sample|sliding window|window|chunk|task|step|phase|pass)\s+\d+\s*/\s*\d+\s*(?:,\s*)?", re.IGNORECASE) _STATUS_INDEX_RE = re.compile(r"^\[\s*\d+\s*/\s*\d+\s*\]\s*") _STATUS_TIME_ONLY_RE = re.compile(r"^[\d:.]+\s*[smh]?$", re.IGNORECASE) def extract_status_phase_label(text: str | None) -> str: raw_text = str(text or "").strip() if len(raw_text) == 0: return "" parts = [part.strip() for part in raw_text.split("|") if len(part.strip()) > 0] or [raw_text] stripped_wrapper = False for part in parts: phase_text = part while True: cleaned = _STATUS_INDEX_RE.sub("", phase_text) cleaned = _STATUS_STEP_PREFIX_RE.sub("", cleaned) cleaned = cleaned.lstrip(" -:,") if cleaned == phase_text: break stripped_wrapper = True phase_text = cleaned.strip() if len(phase_text) > 0 and not _STATUS_TIME_ONLY_RE.fullmatch(phase_text): return phase_text return "" if stripped_wrapper else raw_text @dataclass(frozen=True) class StreamMessage: stream: str text: str @dataclass(frozen=True) class ProgressUpdate: phase: str status: str progress: int current_step: int | None total_steps: int | None raw_phase: str | None = None unit: str | None = None @dataclass(frozen=True) class PreviewUpdate: image: Image.Image | None phase: str status: str progress: int current_step: int | None total_steps: int | None @dataclass(frozen=True) class SessionEvent: kind: str data: Any = None timestamp: float = field(default_factory=time.time) @dataclass(frozen=True) class GeneratedArtifact: path: str | None media_type: str client_id: str = "" video_tensor_uint8: Any = None video_tensor_hdr: Any = None hdr: bool = False audio_tensor: Any = None audio_sampling_rate: int | None = None fps: float | None = None flashvsr_continue_cache: Any = None @classmethod def from_payload(cls, payload: dict[str, Any], *, default_client_id: str = "") -> "GeneratedArtifact | None": if not isinstance(payload, dict): return None return cls( path=str(payload.get("path") or "") or None, media_type=str(payload.get("media_type") or "video"), client_id=str(payload.get("client_id") or default_client_id or "").strip(), video_tensor_uint8=payload.get("video_tensor_uint8"), video_tensor_hdr=payload.get("video_tensor_hdr"), hdr=bool(payload.get("hdr", False)), audio_tensor=payload.get("audio_tensor"), audio_sampling_rate=payload.get("audio_sampling_rate"), fps=payload.get("fps"), flashvsr_continue_cache=payload.get("flashvsr_continue_cache"), ) @dataclass(frozen=True) class GenerationResult: success: bool generated_files: list[str] errors: list["GenerationError"] total_tasks: int successful_tasks: int failed_tasks: int artifacts: tuple[GeneratedArtifact, ...] = () @property def cancelled(self) -> bool: return len(self.errors) > 0 and all(error.cancelled for error in self.errors) @dataclass(frozen=True) class GenerationError: message: str task_index: int | None = None task_id: Any = None stage: str | None = None def __str__(self) -> str: return self.message @property def cancelled(self) -> bool: stage = str(self.stage or "").strip().lower() if stage == "cancelled": return True return str(self.message or "").strip().lower() == "generation was cancelled" def get_api_output_options(plugin_data: Any) -> tuple[bool, bool]: api_options = {} if not isinstance(plugin_data, dict) else plugin_data.get("api", {}) if not isinstance(api_options, dict): return False, False return bool(api_options.get("return_video_uint8") or api_options.get("return_media")), bool(api_options.get("return_audio") or api_options.get("return_media")) def _coerce_api_video_tensor_uint8(output_video_frames: Any) -> Any: try: import torch except Exception: torch = None if torch is not None and torch.is_tensor(output_video_frames): if output_video_frames.dtype == torch.uint8: return output_video_frames return output_video_frames.detach().cpu().float().clamp(-1, 1).add(1.0).mul(127.5).round().to(torch.uint8) if isinstance(output_video_frames, list) and len(output_video_frames) == 1 and torch is not None and torch.is_tensor(output_video_frames[0]): return _coerce_api_video_tensor_uint8(output_video_frames[0]) if isinstance(output_video_frames, list) and torch is not None: tensors = [item for item in output_video_frames if torch.is_tensor(item)] if len(tensors) == len(output_video_frames) and tensors and all(item.dtype == torch.uint8 and item.ndim == 4 for item in tensors): return torch.cat(tensors, dim=1) if len(tensors) == len(output_video_frames) and tensors and all(item.dtype != torch.uint8 and item.ndim == 4 for item in tensors): return torch.cat([_coerce_api_video_tensor_uint8(item) for item in tensors], dim=1) return None def _coerce_api_video_tensor_hdr(output_video_frames: Any) -> Any: try: import torch except Exception: torch = None if torch is not None and torch.is_tensor(output_video_frames): return output_video_frames if output_video_frames.dtype != torch.uint8 else None if isinstance(output_video_frames, list) and len(output_video_frames) == 1 and torch is not None and torch.is_tensor(output_video_frames[0]): return output_video_frames[0] if output_video_frames[0].dtype != torch.uint8 else None if isinstance(output_video_frames, list) and torch is not None: tensors = [item for item in output_video_frames if torch.is_tensor(item)] if len(tensors) == len(output_video_frames) and tensors and all(item.dtype != torch.uint8 and item.ndim == 4 for item in tensors): return torch.cat(tensors, dim=1) return None def _coerce_api_audio_tensor(output_audio_data: Any) -> Any: return None if output_audio_data is None else np.asarray(output_audio_data, dtype=np.float32) def build_api_output_artifact_payload(client_id: str, video_path: Any, media_type: str, output_video_frames: Any, output_audio_data: Any, output_audio_sampling_rate: Any, output_fps: Any, *, hdr: bool = False, flashvsr_continue_cache: Any = None) -> dict[str, Any] | None: client_id = str(client_id or "").strip() if len(client_id) == 0: return None output_path = str(video_path[0]) if isinstance(video_path, list) and len(video_path) > 0 else str(video_path or "") return { "client_id": client_id, "path": output_path, "media_type": str(media_type or "video"), "video_tensor_uint8": None if hdr else _coerce_api_video_tensor_uint8(output_video_frames), "video_tensor_hdr": _coerce_api_video_tensor_hdr(output_video_frames) if hdr else None, "hdr": bool(hdr), "audio_tensor": _coerce_api_audio_tensor(output_audio_data), "audio_sampling_rate": int(output_audio_sampling_rate) if output_audio_sampling_rate else None, "fps": float(output_fps) if output_fps else None, "flashvsr_continue_cache": flashvsr_continue_cache, } def store_api_output_artifact(gen: dict[str, Any], client_id: str, video_path: Any, media_type: str, output_video_frames: Any, output_audio_data: Any, output_audio_sampling_rate: Any, output_fps: Any, *, hdr: bool = False, flashvsr_continue_cache: Any = None) -> bool: payload = build_api_output_artifact_payload(client_id, video_path, media_type, output_video_frames, output_audio_data, output_audio_sampling_rate, output_fps, hdr=hdr, flashvsr_continue_cache=flashvsr_continue_cache) if payload is None: return False gen.setdefault("api_output_artifacts", {})[payload["client_id"]] = payload return True class SessionStream: def __init__(self) -> None: self._queue: queue.Queue[SessionEvent | object] = queue.Queue() self._closed = threading.Event() self._sentinel = object() def put(self, kind: str, data: Any = None) -> None: if self._closed.is_set(): return self._queue.put(SessionEvent(kind=kind, data=data)) def close(self) -> None: if self._closed.is_set(): return self._closed.set() self._queue.put(self._sentinel) def get(self, timeout: float | None = None) -> SessionEvent | None: try: item = self._queue.get(timeout=timeout) except queue.Empty: return None if item is self._sentinel: return None return item def iter(self, timeout: float | None = None) -> Iterator[SessionEvent]: while True: event = self.get(timeout=timeout) if event is None: if self._closed.is_set(): break continue yield event @property def closed(self) -> bool: return self._closed.is_set() class _OutputCapture(io.TextIOBase): def __init__( self, stream_name: str, emit_line, console: io.TextIOBase | None = None, *, console_isatty: bool = True, ) -> None: self._stream_name = stream_name self._emit_line = emit_line self._console = console self._console_isatty = bool(console_isatty) self._buffer = "" def writable(self) -> bool: return True @property def encoding(self) -> str: return str(getattr(self._console, "encoding", "utf-8")) def isatty(self) -> bool: return self._console_isatty def write(self, text: str) -> int: if not text: return 0 if self._console is not None: self._console.write(text) self._buffer += text self._drain(False) return len(text) def flush(self) -> None: if self._console is not None: self._console.flush() self._drain(True) def _drain(self, flush_all: bool) -> None: while True: split_at = -1 for delimiter in ("\r", "\n"): index = self._buffer.find(delimiter) if index >= 0 and (split_at < 0 or index < split_at): split_at = index if split_at < 0: break line = self._buffer[:split_at] self._buffer = self._buffer[split_at + 1 :] if line: self._emit_line(self._stream_name, line) if flush_all and self._buffer: self._emit_line(self._stream_name, self._buffer) self._buffer = "" @dataclass(frozen=True) class _WanGPRuntime: module: Any root: Path config_path: Path cli_args: tuple[str, ...] class SessionJob: def __init__(self, session: "WanGPSession") -> None: self._session = session self._callbacks: object | None = None self.events = SessionStream() self._done = threading.Event() self._cancel_requested = threading.Event() self._webui_submission_ready = threading.Event() self._thread: threading.Thread | None = None self._result: GenerationResult | None = None self._webui_manifest: list[dict[str, Any]] = [] self._webui_client_ids: tuple[str, ...] = () self._webui_load_queue_token = "" self._webui_owner_call_id = "" def _bind_thread(self, thread: threading.Thread) -> None: self._thread = thread def _bind_callbacks(self, callbacks: object | None) -> None: self._callbacks = callbacks def _set_result(self, result: GenerationResult) -> None: self._result = result self._done.set() def _set_webui_bridge(self, *, manifest: Sequence[dict[str, Any]], client_ids: Sequence[str], load_queue_token: str) -> None: self._webui_manifest = copy.deepcopy(list(manifest)) self._webui_client_ids = tuple(str(client_id or "").strip() for client_id in client_ids if str(client_id or "").strip()) self._webui_load_queue_token = str(load_queue_token or "").strip() def release_input_payload(self) -> None: self._webui_manifest = [] def _mark_webui_submission_ready(self) -> None: self._webui_submission_ready.set() def _bind_webui_owner_call(self, call_id: str) -> None: self._webui_owner_call_id = str(call_id or "").strip() def cancel(self) -> None: self._cancel_requested.set() owner = getattr(self._session, "_gradio_session_proxy", None) capture = getattr(owner, "_capture_cancelled_job", None) if callable(capture): capture(self) def result(self, timeout: float | None = None) -> GenerationResult: if not self._done.wait(timeout=timeout): raise TimeoutError("WanGP session job timed out") return self._result or GenerationResult( success=False, generated_files=[], errors=[], total_tasks=0, successful_tasks=0, failed_tasks=0, artifacts=(), ) def join(self, timeout: float | None = None) -> GenerationResult: return self.result(timeout=timeout) @property def done(self) -> bool: return self._done.is_set() @property def cancel_requested(self) -> bool: return self._cancel_requested.is_set() @property def webui_manifest(self) -> list[dict[str, Any]]: return copy.deepcopy(self._webui_manifest) @property def webui_client_ids(self) -> tuple[str, ...]: return self._webui_client_ids @property def primary_client_id(self) -> str: return "" if not self._webui_client_ids else self._webui_client_ids[0] @property def webui_load_queue_token(self) -> str: return self._webui_load_queue_token @property def webui_submission_ready(self) -> bool: return self._webui_submission_ready.is_set() @property def webui_owner_call_id(self) -> str: return self._webui_owner_call_id class WanGPSession: def __init__( self, *, root: str | os.PathLike[str] | None = None, config_path: str | os.PathLike[str] | None = None, output_dir: str | os.PathLike[str] | None = None, callbacks: object | None = None, cli_args: Sequence[str] = (), console_output: bool = True, console_isatty: bool = True, webui_state: dict[str, Any] | None = None, ) -> None: self._root = Path(root or Path(__file__).resolve().parents[1]).resolve() self._config_path = Path(config_path).resolve() if config_path is not None else (self._root / "wgp_config.json").resolve() self._output_dir = Path(output_dir).resolve() if output_dir is not None else None self._callbacks = callbacks self._cli_args = tuple(str(arg) for arg in cli_args) self._console_output = bool(console_output) self._console_isatty = bool(console_isatty) self._use_webui_queue = isinstance(webui_state, dict) self._state = webui_state if isinstance(webui_state, dict) else self._create_headless_state() self._active_job: SessionJob | None = None self._job_lock = threading.Lock() self._attachment_keys: tuple[str, ...] | None = None def ensure_ready(self) -> "WanGPSession": self._ensure_runtime() return self def submit(self, source: str | os.PathLike[str] | dict[str, Any] | list[dict[str, Any]], callbacks: object | None = None) -> SessionJob: tasks = self._normalize_source(source, caller_base_path=self._get_caller_base_path()) return self._submit_tasks(tasks, callbacks=callbacks) def submit_task(self, settings: dict[str, Any], callbacks: object | None = None) -> SessionJob: caller_base_path = self._get_caller_base_path() task = self._normalize_task(settings, task_index=1) return self._submit_tasks([self._absolutize_task_paths(task, caller_base_path)], callbacks=callbacks) def submit_manifest(self, settings_list: list[dict[str, Any]], callbacks: object | None = None) -> SessionJob: caller_base_path = self._get_caller_base_path() tasks = [ self._absolutize_task_paths(self._normalize_task(settings, task_index=index + 1), caller_base_path) for index, settings in enumerate(settings_list) ] return self._submit_tasks(tasks, callbacks=callbacks) def run(self, source: str | os.PathLike[str] | dict[str, Any] | list[dict[str, Any]], callbacks: object | None = None) -> GenerationResult: return self.submit(source, callbacks=callbacks).result() def run_task(self, settings: dict[str, Any], callbacks: object | None = None) -> GenerationResult: return self.submit_task(settings, callbacks=callbacks).result() def run_manifest(self, settings_list: list[dict[str, Any]], callbacks: object | None = None) -> GenerationResult: return self.submit_manifest(settings_list, callbacks=callbacks).result() def close(self) -> None: if self._use_webui_queue: return runtime = self._ensure_runtime() with _GENERATION_LOCK, _pushd(runtime.root): runtime.module.release_model() def cancel(self) -> None: with self._job_lock: job = self._active_job if job is not None: job.cancel() @staticmethod def _create_headless_state() -> dict[str, Any]: return { "gen": { "queue": [], "in_progress": False, "file_list": [], "file_settings_list": [], "audio_file_list": [], "audio_file_settings_list": [], "selected": 0, "audio_selected": 0, "prompt_no": 0, "prompts_max": 0, "repeat_no": 0, "total_generation": 1, "window_no": 0, "total_windows": 0, "progress_status": "", "process_status": "process:main", "api_output_artifacts": {}, }, "loras": [], } def _submit_tasks(self, tasks: list[dict[str, Any]], callbacks: object | None = None) -> SessionJob: with self._job_lock: if self._active_job is not None and not self._active_job.done: raise RuntimeError("WanGP session already has a generation in progress") job = SessionJob(self) self._bind_callbacks_to_job(job, callbacks) prepared_tasks = copy.deepcopy(tasks) client_ids = self._ensure_task_client_ids(prepared_tasks, priority=self._use_webui_queue) if self._use_webui_queue: prepared_tasks, manifest, load_queue_token = self._prepare_webui_bridge(prepared_tasks) job._set_webui_bridge(manifest=manifest, client_ids=client_ids, load_queue_token=load_queue_token) thread = threading.Thread( target=self._run_job, args=(job, prepared_tasks), daemon=True, name="wangp-session-job", ) job._bind_thread(thread) self._active_job = job thread.start() return job def _bind_callbacks_to_job(self, job: SessionJob, callbacks: object | None = None) -> None: callback = self._callbacks if callbacks is None else callbacks job._bind_callbacks(callback) if callback is None: return binder = getattr(callback, "bind_job", None) if not callable(binder): return try: binder(session=self, job=job) except TypeError: binder(job) @staticmethod def _ensure_task_client_ids(tasks: list[dict[str, Any]], *, priority: bool = False) -> tuple[str, ...]: client_seed = time.time_ns() client_ids: list[str] = [] for index, task in enumerate(tasks, start=1): params = copy.deepcopy(WanGPSession._get_task_settings(task)) client_id = str(params.get("client_id", "") or "").strip() if len(client_id) == 0: client_id = f"api_{client_seed}_{index}" params["client_id"] = client_id if priority: params["priority"] = True elif "priority" in params and not params["priority"]: params.pop("priority", None) task["params"] = params client_ids.append(client_id) return tuple(client_ids) def _prepare_webui_bridge(self, tasks: list[dict[str, Any]]) -> tuple[list[dict[str, Any]], list[dict[str, Any]], str]: manifest = [] for index, task in enumerate(tasks, start=1): params = copy.deepcopy(self._get_task_settings(task)) params["priority"] = True task["params"] = params manifest.append({ "id": task.get("id", index), "params": copy.deepcopy(params), "plugin_data": copy.deepcopy(task.get("plugin_data", {})), }) return tasks, manifest, str(time.time_ns()) def _run_job(self, job: SessionJob, tasks: list[dict[str, Any]]) -> None: if self._use_webui_queue: self._run_webui_job(job, tasks) return from shared.api_cli import run_cli_job run_cli_job(self, job, tasks) def _run_webui_job(self, job: SessionJob, tasks: list[dict[str, Any]]) -> None: from shared.api_webui import run_webui_job run_webui_job(self, job, tasks) def _build_progress_update(self, data: Any, *, include_state_fallback: bool = True) -> ProgressUpdate: current_step: int | None = None total_steps: int | None = None status = "" unit: str | None = None if isinstance(data, list) and data: head = data[0] if isinstance(head, tuple) and len(head) == 2: current_step = int(head[0]) total_steps = int(head[1]) status = str(data[1] if len(data) > 1 else "") if len(data) > 3: unit = str(data[3]) else: status = str(data[1] if len(data) > 1 else head) else: status = str(data or "") raw_phase = None if include_state_fallback: progress_phase = self._state["gen"].get("progress_phase") if isinstance(progress_phase, tuple) and progress_phase: raw_phase = extract_status_phase_label(progress_phase[0]) if current_step is None and len(progress_phase) > 1 and "denoising" in raw_phase.lower(): try: progress_step = int(progress_phase[1]) except (TypeError, ValueError): progress_step = -1 try: inference_steps = int(self._state["gen"].get("num_inference_steps") or 0) except (TypeError, ValueError): inference_steps = 0 if progress_step >= 0 and inference_steps > 0: current_step = progress_step total_steps = inference_steps if len(status) == 0: status = str(self._state["gen"].get("progress_status", "") or raw_phase or "") status_phase_label = extract_status_phase_label(status) if len(status_phase_label) > 0 and len(str(raw_phase or "").strip()) > 0 and current_step is None: normalized_status_phase = self._normalize_phase(status_phase_label) normalized_raw_phase = self._normalize_phase(raw_phase) if normalized_status_phase != normalized_raw_phase: raw_phase = None display_phase = raw_phase or status_phase_label phase = self._normalize_phase(display_phase or status) if not self._phase_supports_progress(phase): current_step = None total_steps = None progress = self._estimate_progress(phase, current_step, total_steps) return ProgressUpdate( phase=phase, status=status, progress=progress, current_step=current_step, total_steps=total_steps, raw_phase=display_phase or None, unit=unit, ) def _build_preview_update(self, wgp, tasks: list[dict[str, Any]], payload: Any) -> PreviewUpdate | None: progress = self._build_progress_update([0, self._state["gen"].get("progress_status", "")]) model_type = "" queue_tasks = self._state["gen"].get("queue") or tasks if queue_tasks: model_type = str(self._get_task_settings(queue_tasks[0]).get("model_type", "")) image = wgp.generate_preview(model_type, payload) if model_type else None return PreviewUpdate( image=image, phase=progress.phase, status=progress.status, progress=progress.progress, current_step=progress.current_step, total_steps=progress.total_steps, ) def _emit_stream(self, job: SessionJob, stream_name: str, line: str) -> None: message = StreamMessage(stream=stream_name, text=line) job.events.put("stream", message) self._emit_callback("on_stream", message, job=job) def _emit_callback(self, method_name: str, payload: Any, *, job: SessionJob | None = None) -> None: callback = self._callbacks if job is None or job._callbacks is None else job._callbacks if callback is None: return method = getattr(callback, method_name, None) if callable(method): method(payload) on_event = getattr(callback, "on_event", None) if callable(on_event): on_event(SessionEvent(kind=method_name.removeprefix("on_"), data=payload)) def _configure_runtime(self, runtime: _WanGPRuntime) -> None: runtime.module.server_config["notification_sound_enabled"] = 0 if self._output_dir is not None: self._output_dir.mkdir(parents=True, exist_ok=True) runtime.module.server_config["save_path"] = str(self._output_dir) runtime.module.server_config["image_save_path"] = str(self._output_dir) runtime.module.server_config["audio_save_path"] = str(self._output_dir) runtime.module.save_path = str(self._output_dir) runtime.module.image_save_path = str(self._output_dir) runtime.module.audio_save_path = str(self._output_dir) for output_path in ( runtime.module.save_path, runtime.module.image_save_path, runtime.module.audio_save_path, ): Path(output_path).mkdir(parents=True, exist_ok=True) def _prepare_state_for_run(self, tasks: list[dict[str, Any]]) -> None: gen = self._state["gen"] gen["queue"] = tasks set_main_generation_running(self._state, True) gen["process_status"] = "process:main" gen["progress_status"] = "" gen["progress_phase"] = ("", -1) gen["abort"] = False gen["early_stop"] = False gen["early_stop_forwarded"] = False gen["preview"] = None gen["status"] = "Generating..." gen["in_progress"] = True gen.setdefault("api_output_artifacts", {}) self._ensure_runtime().module.gen_in_progress = True def _reset_state_after_run(self) -> None: gen = self._state["gen"] gen["queue"] = [] set_main_generation_running(self._state, False) gen["process_status"] = "process:main" gen["progress_status"] = "" gen["progress_phase"] = ("", -1) gen["abort"] = False gen["early_stop"] = False gen["early_stop_forwarded"] = False gen.pop("in_progress", None) self._ensure_runtime().module.gen_in_progress = False def _collect_outputs(self, base_file_count: int, base_audio_count: int) -> list[str]: gen = self._state["gen"] files = gen["file_list"][base_file_count:] audio_files = gen["audio_file_list"][base_audio_count:] return [str(Path(path).resolve()) for path in [*files, *audio_files]] def _consume_output_artifact(self, client_id: str) -> GeneratedArtifact | None: gen = self._state["gen"] artifacts = gen.get("api_output_artifacts") if not isinstance(artifacts, dict): return None payload = artifacts.pop(str(client_id or "").strip(), None) return GeneratedArtifact.from_payload(payload, default_client_id=str(client_id or "").strip()) def _peek_output_artifact(self, client_id: str) -> GeneratedArtifact | None: gen = self._state["gen"] artifacts = gen.get("api_output_artifacts") if not isinstance(artifacts, dict): return None payload = artifacts.get(str(client_id or "").strip(), None) return GeneratedArtifact.from_payload(payload, default_client_id=str(client_id or "").strip()) def _consume_output_artifacts(self, tasks: Sequence[dict[str, Any]]) -> tuple[GeneratedArtifact, ...]: artifacts: list[GeneratedArtifact] = [] for task in tasks: client_id = str(self._get_task_settings(task).get("client_id", "") or "").strip() if len(client_id) == 0: continue artifact = self._consume_output_artifact(client_id) if artifact is not None: artifacts.append(artifact) return tuple(artifacts) def _request_cancel_unlocked(self, wgp) -> None: gen = self._state["gen"] gen["resume"] = True gen["abort"] = True if wgp.wan_model is not None: wgp.wan_model._interrupt = True def _normalize_source( self, source: str | os.PathLike[str] | dict[str, Any] | list[dict[str, Any]], *, caller_base_path: Path, ) -> list[dict[str, Any]]: if isinstance(source, (str, os.PathLike)): return self._load_tasks_from_path(self._resolve_source_path(Path(source), caller_base_path), caller_base_path) if isinstance(source, list): return [ self._absolutize_task_paths(self._normalize_task(task, task_index=index + 1), caller_base_path) for index, task in enumerate(source) ] if isinstance(source, dict): if isinstance(source.get("tasks"), list): tasks = source["tasks"] return [ self._absolutize_task_paths(self._normalize_task(task, task_index=index + 1), caller_base_path) for index, task in enumerate(tasks) ] return [self._absolutize_task_paths(self._normalize_task(source, task_index=1), caller_base_path)] raise TypeError("WanGP session source must be a path, a settings dict, or a manifest list") def _normalize_task(self, task: dict[str, Any], *, task_index: int) -> dict[str, Any]: if not isinstance(task, dict): raise TypeError(f"Task {task_index} must be a dictionary") normalized = copy.deepcopy(task) if "settings" in normalized and "params" not in normalized: normalized["params"] = normalized.pop("settings") if "params" not in normalized: normalized = {"id": task_index, "params": normalized, "plugin_data": {}} normalized.setdefault("id", task_index) normalized.setdefault("plugin_data", {}) normalized.setdefault("params", {}) if not isinstance(normalized["plugin_data"], dict): normalized["plugin_data"] = {} settings = normalized["params"] if isinstance(settings, dict): api_options = settings.pop("_api", None) if isinstance(api_options, dict): normalized["plugin_data"]["api"] = copy.deepcopy(api_options) runtime_settings_version = getattr(self._ensure_runtime().module, "settings_version", None) if runtime_settings_version is not None: settings.setdefault("settings_version", runtime_settings_version) self._normalize_settings_values(settings) normalized.setdefault("prompt", settings.get("prompt", "")) normalized.setdefault("length", settings.get("video_length")) normalized.setdefault("steps", settings.get("num_inference_steps")) normalized.setdefault("repeats", settings.get("repeat_generation", 1)) return normalized @staticmethod def _normalize_settings_values(settings: dict[str, Any]) -> None: force_fps = settings.get("force_fps") if isinstance(force_fps, (int, float)) and not isinstance(force_fps, bool): if isinstance(force_fps, float) and not force_fps.is_integer(): settings["force_fps"] = str(force_fps) else: settings["force_fps"] = str(int(force_fps)) @staticmethod def _get_task_settings(task: dict[str, Any]) -> dict[str, Any]: settings = task.get("params") if isinstance(settings, dict): return settings settings = task.get("settings") if isinstance(settings, dict): return settings return {} def _load_tasks_from_path(self, path: Path, caller_base_path: Path) -> list[dict[str, Any]]: runtime = self._ensure_runtime() if not path.exists(): raise FileNotFoundError(path) if path.suffix.lower() == ".json": return self._load_settings_json(path, caller_base_path) with _pushd(runtime.root): tasks, error = runtime.module._parse_queue_zip(str(path), self._state) if error: raise RuntimeError(error) return [self._normalize_task(task, task_index=index + 1) for index, task in enumerate(tasks)] def _load_settings_json(self, path: Path, caller_base_path: Path) -> list[dict[str, Any]]: with path.open("r", encoding="utf-8") as handle: payload = json.load(handle) if isinstance(payload, list): raw_tasks = payload elif isinstance(payload, dict) and isinstance(payload.get("tasks"), list): raw_tasks = payload["tasks"] elif isinstance(payload, dict): raw_tasks = [payload] else: raise RuntimeError("Settings file must contain a JSON object or a list of tasks") tasks = [self._normalize_task(task, task_index=index + 1) for index, task in enumerate(raw_tasks)] return [self._absolutize_task_paths(task, caller_base_path) for task in tasks] @staticmethod def _get_caller_base_path() -> Path: return Path.cwd().resolve() @staticmethod def _resolve_source_path(path: Path, caller_base_path: Path) -> Path: if path.is_absolute(): return path.resolve() return (caller_base_path / path).resolve() def _absolutize_task_paths(self, task: dict[str, Any], caller_base_path: Path) -> dict[str, Any]: normalized = copy.deepcopy(task) settings = normalized.get("params") if not isinstance(settings, dict): return normalized for key in self._get_attachment_keys(): if key not in settings: continue settings[key] = self._absolutize_setting_path(settings[key], caller_base_path) return normalized def _get_attachment_keys(self) -> tuple[str, ...]: if self._attachment_keys is None: runtime = self._ensure_runtime() keys = getattr(runtime.module, "ATTACHMENT_KEYS", ()) self._attachment_keys = tuple(str(key) for key in keys) return self._attachment_keys def _absolutize_setting_path(self, value: Any, caller_base_path: Path) -> Any: if isinstance(value, list): return [self._absolutize_setting_path(item, caller_base_path) for item in value] if isinstance(value, os.PathLike): value = os.fspath(value) if not isinstance(value, str) or not value.strip(): return value spec = parse_virtual_media_path(value) if spec is not None and get_virtual_media_vsource(spec) is not None: return value path = Path(spec.source_path if spec is not None else value) if path.is_absolute(): resolved = str(path.resolve()) else: resolved = str((caller_base_path / path).resolve()) return replace_virtual_media_source(value, resolved) if spec is not None else resolved @staticmethod def _make_generation_error( error: Any, *, task_index: int | None = None, task_id: Any = None, stage: str | None = None, ) -> GenerationError: if isinstance(error, GenerationError): return error if isinstance(error, BaseException): message = str(error) or error.__class__.__name__ else: message = str(error) return GenerationError(message=message, task_index=task_index, task_id=task_id, stage=stage) def _ensure_runtime(self) -> _WanGPRuntime: global _RUNTIME with _RUNTIME_LOCK: if _RUNTIME is not None: if _RUNTIME.root != self._root or _RUNTIME.config_path != self._config_path or _RUNTIME.cli_args != self._cli_args: raise RuntimeError("WanGP runtime already loaded with different root/config/cli args") return _RUNTIME argv = ["wgp.py", *self._cli_args] default_config_path = (self._root / "wgp_config.json").resolve() if self._config_path.name != "wgp_config.json": raise ValueError("config_path must point to a file named 'wgp_config.json'") if self._config_path != default_config_path: self._config_path.parent.mkdir(parents=True, exist_ok=True) if "--config" not in argv: argv.extend(["--config", str(self._config_path.parent)]) if str(self._root) not in sys.path: sys.path.insert(0, str(self._root)) with _pushd(self._root), _temporary_argv(argv): module = importlib.import_module("wgp") module_root = Path(module.__file__).resolve().parent if module_root != self._root: raise RuntimeError(f"WanGP module already loaded from {module_root}, expected {self._root}") if not hasattr(module, "app"): module.app = module.WAN2GPApplication() module.download_ffmpeg() _RUNTIME = _WanGPRuntime( module=module, root=self._root, config_path=self._config_path, cli_args=self._cli_args, ) _print_banner_once(module, enabled=not self._use_webui_queue) return _RUNTIME @staticmethod def _normalize_phase(text: str | None) -> str: lowered = extract_status_phase_label(text).lower() if "denoising first pass" in lowered or "denoising 1st pass" in lowered: return "inference_stage_1" if "denoising second pass" in lowered or "denoising 2nd pass" in lowered: return "inference_stage_2" if "denoising third pass" in lowered or "denoising 3rd pass" in lowered: return "inference_stage_3" if "loading model" in lowered or lowered.startswith("loading"): return "loading_model" if "enhancing prompt" in lowered or "encoding prompt" in lowered or "encoding" in lowered: return "encoding_text" if "vae decoding" in lowered or "decoding" in lowered: return "decoding" if "saved" in lowered or "completed" in lowered or "output" in lowered: return "downloading_output" if "cancel" in lowered or "abort" in lowered: return "cancelled" return "inference" @staticmethod def _phase_supports_progress(phase: str | None) -> bool: return str(phase or "") in {"inference", "inference_stage_1", "inference_stage_2", "inference_stage_3"} @staticmethod def _estimate_progress(phase: str, current_step: int | None, total_steps: int | None) -> int: if total_steps is None or total_steps <= 0 or current_step is None: if phase == "loading_model": return 10 if phase == "encoding_text": return 18 if phase == "inference_stage_1": return 25 if phase == "inference_stage_2": return 70 if phase == "inference_stage_3": return 80 if phase == "decoding": return 90 if phase == "downloading_output": return 95 if phase == "cancelled": return 0 return 15 ratio = max(0.0, min(1.0, current_step / total_steps)) if phase == "loading_model": return min(15, 5 + int(ratio * 10)) if phase == "encoding_text": return min(22, 12 + int(ratio * 10)) if phase == "inference_stage_1": return min(68, 20 + int(ratio * 48)) if phase == "inference_stage_2": return min(88, 68 + int(ratio * 20)) if phase == "inference_stage_3": return min(89, 80 + int(ratio * 9)) if phase == "decoding": return min(95, 85 + int(ratio * 10)) if phase == "downloading_output": return min(98, 92 + int(ratio * 6)) if phase == "cancelled": return 0 return min(90, 20 + int(ratio * 65)) def init( *, root: str | os.PathLike[str] | None = None, config_path: str | os.PathLike[str] | None = None, output_dir: str | os.PathLike[str] | None = None, callbacks: object | None = None, cli_args: Sequence[str] = (), console_output: bool = True, console_isatty: bool = True, webui_state: dict[str, Any] | None = None, ) -> WanGPSession: """Create and eagerly initialize a reusable WanGP session.""" return WanGPSession( root=root, config_path=config_path, output_dir=output_dir, callbacks=callbacks, cli_args=cli_args, console_output=console_output, console_isatty=console_isatty, webui_state=webui_state, ).ensure_ready() def create_gradio_webui_session(plugin) -> Any: from shared.api_webui import create_gradio_webui_session as _create_gradio_webui_session return _create_gradio_webui_session(plugin, init_fn=init) def create_gradio_progress_callbacks(progress) -> Any: from shared.api_webui import create_gradio_progress_callbacks as _create_gradio_progress_callbacks return _create_gradio_progress_callbacks(progress) @contextlib.contextmanager def _pushd(path: Path) -> Iterator[None]: previous = Path.cwd() os.chdir(path) try: yield finally: os.chdir(previous) @contextlib.contextmanager def _temporary_argv(argv: Sequence[str]) -> Iterator[None]: previous = list(sys.argv) sys.argv = list(argv) try: yield finally: sys.argv = previous def _print_banner_once(module, *, enabled: bool = True) -> None: global _BANNER_PRINTED if not enabled: return if _BANNER_PRINTED: return _BANNER_PRINTED = True banner = f"Powered by WanGP v{module.WanGP_version} - a DeepBeepMeep Production\n" console = sys.__stdout__ if sys.__stdout__ is not None else sys.stdout if console is not None: console.write(banner) console.flush()