| """Recommendation engine that loads the trained MLPMetric checkpoint plus the |
| pre-built model pool, and exposes ``Recommender.recommend`` for the Gradio app. |
| """ |
| from __future__ import annotations |
|
|
| import json |
| import os |
| import re |
| import threading |
| from dataclasses import dataclass |
| from types import SimpleNamespace |
| from typing import List, Optional |
|
|
| import numpy as np |
| import torch |
|
|
| from inference_lib import MLPMetric, MLPMetricFull |
|
|
|
|
| EMBEDDING_MODEL = "text-embedding-3-small" |
| EMBEDDING_DIM = 1536 |
|
|
|
|
| |
| |
| OFFICIAL_ORGS: set[str] = { |
| |
| "meta-llama", "facebook", "facebookresearch", "huggyllama", "codellama", |
| "mistralai", "google", "google-bert", "google-t5", "google-research", |
| "google-deepmind", "microsoft", "openai", "openai-community", "nvidia", |
| "anthropic", "allenai", "salesforce", "salesforceresearch", "apple", |
| "xai-org", "x-ai", "huggingfaceh4", "huggingfacem4", "huggingface", |
| "stabilityai", "stability-ai", "ibm-granite", "ibm", "cohere", |
| "cohereforai", "snowflake", "databricks", "mixedbread-ai", "jinaai", |
| "jina-ai", "nomic-ai", "intfloat", "baai", "intel", "amd", |
| "bigscience", "bigcode", "eleutherai", "togethercomputer", "mosaicml", |
| "nousresearch", "tiiuae", "answerdotai", "arcee-ai", "black-forest-labs", |
| "runwayml", "compvis", "segmind", "laion", "timm", "llava-hf", "llava-vl", |
| "sentence-transformers", |
| |
| "deepseek-ai", "qwen", "qwenlm", "alibaba-nlp", "alibaba-pai", "damo-nlp", |
| "01-ai", "stepfun-ai", "thudm", "zhipuai", "baichuan-inc", "internlm", |
| "moonshotai", "minimaxai", "shanghai-ai-laboratory", "bytedance", |
| "bytedance-research", "bytedance-seed", "tencent", "tencent-arc", |
| "tencent-hunyuan", "baidu", "iflytek", "sensetime", "sensenova", |
| "opengvlab", "m-a-p", "iic", "hkunlp", "llm360", "silma-ai", |
| |
| "elyza", "rinna", "vinai", "aubmindlab", "upstage", "42dot", |
| "kakaobrain", "kakao", "beomi", "heegyu", "sakanaai", "sakana-ai", |
| "pfnet", "pfnet-research", "naver", "naver-clova-ix", "clovaai", |
| "kaist-ai", "kaist", "snunlp", "llm-jp", "cyberagent", "rakuten", |
| "stockmark", "sionic-ai", "openthaigpt", |
| |
| "helsinki-nlp", "speechbrain", "pyannote", "suno", "suno-ai", "cardiffnlp", |
| "phind", "lmsys", "lmstudio-community", "lighteternal", "vilm", |
| |
| "wizardlmteam", "wizardlm", "openchat", "migtissera", "teknium", |
| "flax-community", "distilbert", "xlnet", "tinyllama", |
| "princeton-nlp", "stanford", "stanford-nlp", "tatsu-lab", "open-orca", |
| "amazon", "amazon-science", |
| } |
|
|
| |
| |
| |
| |
| _OFFICIAL_BARE_PREFIX_RE = re.compile( |
| r"^(?:" |
| r"bert(?:-(?:base|large|tiny|small|mini))?" |
| r"|roberta(?:-(?:base|large))?" |
| r"|distilbert" |
| r"|albert(?:-(?:base|large|xlarge|xxlarge))?" |
| r"|electra" |
| r"|xlm(?:-roberta)?(?:-(?:base|large))?" |
| r"|gpt-?2|gpt-?j|gpt-?neo(?:x)?|gpt-?oss" |
| r"|t5(?:-(?:base|large|small|3b|11b))?" |
| r"|bart(?:-(?:base|large))?" |
| r"|bloom(?:-(?:560m|1b1|1b7|3b|7b1|14b))?" |
| r"|opt(?:-(?:125m|350m|1\.3b|2\.7b|6\.7b|13b|30b|66b))?" |
| r"|llama-?[1-4](?:\.[0-9])?" |
| r"|qwen[1-3]?(?:\.[0-9])?" |
| r"|mistral(?:-)?" |
| r"|mixtral" |
| r"|gemma[1-3]?" |
| r"|phi-?[1-4](?:\.[0-9])?" |
| r"|deepseek(?:-(?:coder|math|llm|moe|v[1-3]|r1|chat|prover|vl))?" |
| r"|falcon" |
| r"|yi-?(?:[0-9])?" |
| r"|olmo(?:e)?" |
| r"|granite" |
| r"|starcoder[12]?" |
| r"|vit(?:-(?:tiny|small|base|large|huge))?" |
| r"|deit(?:-(?:tiny|small|base))?" |
| r"|swin(?:-(?:tiny|small|base|large))?" |
| r"|convnext(?:-(?:t|s|b|l|xl|tiny|small|base|large))?" |
| r"|beit" |
| r"|clip(?:-vit)?" |
| r"|resnet[-_]?(?:18|34|50|101|152)" |
| r"|whisper(?:-(?:tiny|base|small|medium|large))?" |
| r"|wav2vec2" |
| r"|hubert" |
| r"|all-minilm|all-mpnet|all-distilroberta" |
| r"|sentence-t5|sentence-bert" |
| r"|e5-(?:small|base|large|mistral)" |
| r"|bge-(?:small|base|large|m3|reranker)" |
| r"|gte-(?:small|base|large)" |
| r"|nomic-embed" |
| r")(?:[-._\s]|$)", |
| re.IGNORECASE, |
| ) |
|
|
| |
| |
| _BARE_NAME_BLOCKLIST = ( |
| "finetune", "distill", "pruneofa", "qat", |
| "gguf", "awq", "gptq", "ggml", "exl2", "quant", |
| "lora", "qlora", "dpo", "sft", "rlhf", "rpmax", |
| "merge", "stock", "frankenstein", "slerp", |
| "int8", "int4", "fp8", "fp4", "8bit", "4bit", |
| "(", "[", |
| ) |
|
|
|
|
| def _is_official_name(name: str) -> bool: |
| n = name.strip() |
| if not n: |
| return False |
| if "/" in n: |
| return n.split("/", 1)[0].lower() in OFFICIAL_ORGS |
| |
| nl = n.lower() |
| if len(nl) > 70: |
| return False |
| if any(tok in nl for tok in _BARE_NAME_BLOCKLIST): |
| return False |
| return _OFFICIAL_BARE_PREFIX_RE.match(nl) is not None |
|
|
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| |
| _ALL_CAPS = ( |
| "text", "text_embedding", "vision", "vision_generate", |
| "vision_language", "audio", "video", "document", |
| ) |
| _CAP_BIT = {c: 1 << i for i, c in enumerate(_ALL_CAPS)} |
|
|
|
|
| def _bits(*caps: str) -> int: |
| b = 0 |
| for c in caps: |
| b |= _CAP_BIT[c] |
| return b |
|
|
|
|
| |
| |
| |
| _FAMILY_CAPS: dict[str, int] = {} |
|
|
|
|
| def _assign(families, caps_bits): |
| for f in families: |
| _FAMILY_CAPS[f.lower()] = caps_bits |
|
|
|
|
| _CAP_TEXT = _bits("text") |
| _CAP_VIS = _bits("vision") |
| _CAP_VL = _bits("vision_language", "vision") |
| _CAP_VGEN = _bits("vision_generate") |
| _CAP_AUDIO = _bits("audio") |
| _CAP_VIDEO = _bits("video") |
| _CAP_EMB = _bits("text_embedding", "text") |
| _CAP_DOC = _bits("document", "vision") |
| _CAP_AUDIO_GEN = _bits("audio", "vision_generate") |
|
|
| |
| _assign([ |
| "vit", "deit", "swin", "beit", "convnext", "resnet", "resnext", "cnn", |
| "conv", "densenet", "efficientnet", "efficientvit", "fastervit", |
| "mobilenet", "mnasnet", "mixnet", "muxnet", "mvit", "nat", "nfnet", "vgg", |
| "regnety", "regvit", "rexnet", "sequencer2d", "segformer", "segnext", |
| "deeplabv3", "mask", "point", "sam", "dinat", "dat", "caformer", "cmx", |
| "cswin", "cvt", "cloformer", "derivative", "gdi", "graphormer", "hrnet", |
| "hrformer", "internimage", "inception", "googlenet", "layernas", "laneaf", |
| "motip", "moat", "osvos", "panoptic", "pit", "proto", "pvtv2", "q2l", |
| "r50", "r[2+1]d", "rdnet", "rednet", "revbifpn", "scrfd", "scalenet", |
| "shift", "svtr", "tinyvit", "transnext", "twist", "uniformer", "uninet", |
| "unireplknet", "van", "xcit", "aerialformer", "asymmnet", "alphanet", |
| "aot", "autoformer", "bit", "cait", "crisscross", "deaot", "debiformer", |
| "handreader", "hitnet", "hybrid", "ipt", "kronos", "layoutmask", "lit", |
| "litv2", "mavil", "mim", "moganet", "odise", "redimnet", "resmlp", |
| "resnest", "rstt", "san", "sombrero", "unet", "xmem", "segnext", |
| "fmixia", "gladiator", "mfann3bv0", "mfann3bv1", |
| ], _CAP_VIS) |
|
|
| |
| _assign([ |
| "diffusion", "gan", "imagen", "dall", "pixart", "dit", "edm2", "nuwa", |
| "pdo", "sid", "sida", |
| ], _CAP_VGEN) |
|
|
| |
| _assign([ |
| "whisper", "wav2vec", "wav2vec2", "hubert", "ast", "conformer", "wavlm", |
| "beats", "speechstew", "titannet", "seamlessm4t", "m2d", "eat", "erann", |
| "mossformer2", |
| ], _CAP_AUDIO) |
|
|
| |
| _assign(["musicgen", "audioldm"], _CAP_AUDIO_GEN) |
|
|
| |
| _assign(["video", "slowfast", "tarsier"], _CAP_VIDEO | _CAP_VIS) |
|
|
| |
| _assign([ |
| "clip", "llava", "kosmos", "blip", "idefics", "flava", "flamingo", |
| "openflamingo", "instructblip", "lxmert", "vilt", "vinvl", "uniter", |
| "vlm", "mm1", "mplug", "ofa", "pali", "imp", "vila", "aurora", "sapiens", |
| "pllava", "ppllava", "plm", "mgm", "janus", "aimv2", "altclip", "aya", |
| "albef", "apollo", "captain", |
| ], _CAP_VL) |
|
|
| |
| _assign([ |
| "bge", "e5", "gte", "mpnet", "minilm", "retrieval", |
| ], _CAP_EMB) |
|
|
| |
| _assign([ |
| "layoutlmv3", "trocr", "git", "donut", |
| ], _CAP_DOC) |
|
|
| |
| |
| _assign([ |
| |
| "llama", "qwen", "mistral", "ministral", "mixtral", "gemma", |
| "recurrentgemma", "deepseek", "phi", "falcon", "yi", "granite", "t5", |
| "bart", "bloom", "opt", "palm", "smollm", "smoltulu", "wizard", "magnum", |
| "ece", "transformer", "moe", "gemini", "claude", "command", "cohere", |
| "dolly", "mpt", "pythia", "dolphin", "hermes", "openbuddy", "openchat", |
| "baichuan", "internlm", "olmo", "olmner", "dbrx", "exaone", "nemotron", |
| "nova", "orca", "solar", "stablelm", "tinymistral", "lucie", "lyra", |
| "megatron", "miniuslight", "miscii", "neural", "nanolm", "pllum", |
| "prymmal", "qandoraexp", "redpajama", "reflexis", "rewiz", "rusted", |
| "rwkv", "saba1", "saka", "sauerkrautlm", "summer", "triangulum", "ul2", |
| "ultiima", "una", "unifiedqa", "vicious", "winter", "zephyr", "zeus", |
| "lamarck", "llemma", "llm", "light", "magnolia", "mita", "openbuddy", |
| "pegasus", "h2o", "flan", "ernie", "gopher", "chinchilla", "grok", |
| "blossom", "branch", "bellatrix", "chatwaifu", "chocolatine", "cleverboi", |
| "cursa", "cybernet", "dans", "deepmind", "gal", "glam", "incoder", |
| "inexpertus", "josie", "kstc", "lexora", "linkbricks", "longformer", |
| "mamba", "mavil", "mindact", "minerva", "mossformer2", "mugglemath", |
| "neo", "openflamingo", "pantheon", "qanet", "reasoningcore", "rft", |
| "rnn", "sjt", "stm", "thea", "tora", "tsunami", "twist", "ultiima", |
| "vlama", "wide", "xlm", "xlm-roberta", "xlnet", |
| |
| "bert", "distilbert", "roberta", "albert", "electra", |
| |
| "code", "codellama", "starcoder", "llemma", "metamath", "damomath", |
| "acemath", "aceinstruct", |
| |
| "marian", "seamlessm4t", |
| |
| "molmo", |
| ], _CAP_TEXT) |
|
|
| |
| _FAMILY_CAPS["seamlessm4t"] = _CAP_AUDIO | _CAP_TEXT |
|
|
|
|
| def _family_caps_bits(family: str) -> int: |
| return _FAMILY_CAPS.get((family or "").strip().lower(), _CAP_TEXT) |
|
|
|
|
| |
| |
| |
| |
| |
| _REQ_TEXT = _CAP_BIT["text"] |
| _REQ_EMB = _CAP_BIT["text_embedding"] |
| _REQ_VIS = _CAP_BIT["vision"] |
| _REQ_VGEN = _CAP_BIT["vision_generate"] |
| _REQ_VL = _CAP_BIT["vision_language"] |
| _REQ_AUDIO = _CAP_BIT["audio"] |
| _REQ_VIDEO = _CAP_BIT["video"] |
| _REQ_DOC = _CAP_BIT["document"] |
|
|
|
|
| |
| |
| _TASK_RULES: list[tuple[re.Pattern, int]] = [ |
| |
| (re.compile(r"\b(text[- ]to[- ]image|image[- ]to[- ]image|image generation|" |
| r"conditional image generation|personalized image generation|" |
| r"image synthesis|image editing|inpaint|outpaint|" |
| r"super[- ]?resolution|denoising|colorization|style transfer|" |
| r"text[- ]to[- ]video|video generation|video synthesis|" |
| r"image restoration|image-to-image)\b", re.I), |
| _REQ_VGEN), |
|
|
| |
| (re.compile(r"\b(text[- ]to[- ]speech|speech synthesis|\btts\b|" |
| r"voice conversion|voice cloning|audio generation|" |
| r"music generation|music synthesis|sound generation)\b", re.I), |
| _REQ_AUDIO), |
|
|
| |
| (re.compile(r"\b(speech recognition|automatic speech recognition|\basr\b|" |
| r"audio classification|audio captioning|audio retrieval|" |
| r"audio.*detection|audio source separation|audio tagging|" |
| r"audio super-resolution|spoken|phoneme|acoustic|" |
| r"voice activity|voice.*recognition|emotion recognition.*audio|" |
| r"music.*(?:auto-tagging|source|transcription|recognition|" |
| r"classification|tagging)|sound classification|" |
| r"bandwidth extension|speech.*(?:enhancement|separation|" |
| r"translation))\b", re.I), |
| _REQ_AUDIO), |
|
|
| |
| (re.compile(r"\b(video classification|video captioning|" |
| r"video question answering|video retrieval|video understanding|" |
| r"action recognition.*video|video.*recognition|" |
| r"temporal action|video segmentation)\b", re.I), |
| _REQ_VIDEO), |
|
|
| |
| |
| (re.compile(r"\b(visual question answering|\bvqa\b|image captioning|" |
| r"visual reasoning|image[- ]text|visual grounding|" |
| r"referring expression|chart.*(?:understanding|qa|vqa)|" |
| r"document.*understanding|docvqa|infovqa|chartqa|" |
| r"visual entailment|visual commonsense|image retrieval|" |
| r"cross-modal retrieval|scene text|multimodal)\b", re.I), |
| _REQ_VL), |
|
|
| |
| (re.compile(r"\b(\bocr\b|optical character|handwriting recognition|" |
| r"document.*(?:classification|layout|parsing)|" |
| r"table.*(?:detection|recognition|extraction)|" |
| r"form understanding|receipt|invoice)\b", re.I), |
| _REQ_DOC | _REQ_VL), |
|
|
| |
| (re.compile(r"\b(image classification|object detection|" |
| r"(semantic|instance|panoptic) segmentation|pose estimation|" |
| r"depth estimation|optical flow|scene.*(?:classification|" |
| r"recognition|parsing)|action recognition|action detection|" |
| r"action segmentation|object tracking|tracking|" |
| r"object discovery|object localization|keypoint|landmark|" |
| r"face.*(?:recognition|detection|verification|generation|" |
| r"reconstruction)|person re-identification|reid|" |
| r"3d.*(?:detect|classif|segment|reconstruct|pose|tracking|" |
| r"completion|generation)|point cloud|stereo|monocular|" |
| r"saliency|edge detection|anomaly detection|image matting|" |
| r"crowd counting|gaze estimation|hand.*estimation|" |
| r"animal pose|aerial|remote sensing|medical image|" |
| r"histology|radiology|tumor|lesion|breast|cancer image|" |
| r"covid.*(?:image|x-ray|ct)|x-ray|fundus|retinal|" |
| r"satellite|change detection|crop|bird|flower|food|" |
| r"plant|skin|dermoscopy|driving|lane|nuscenes|kitti|" |
| r"object.*(?:counting|grasping|6d))\b", re.I), |
| _REQ_VIS | _REQ_VL), |
|
|
| |
| (re.compile(r"\b(text retrieval|document retrieval|passage retrieval|" |
| r"semantic search|sentence similarity|sentence[- ]?embedding|" |
| r"paraphrase identification|semantic textual similarity|" |
| r"\bsts\b|reranking|clustering|bitext mining)\b", re.I), |
| _REQ_EMB | _REQ_TEXT), |
| ] |
|
|
| |
| |
| |
| |
| _DEFAULT_TASK_CAPS = _REQ_TEXT | _REQ_VL |
|
|
|
|
| def _task_required_caps_bits(task: str) -> int: |
| if not task: |
| return _DEFAULT_TASK_CAPS |
| t = str(task).strip() |
| for pat, bits in _TASK_RULES: |
| if pat.search(t): |
| return bits |
| return _DEFAULT_TASK_CAPS |
|
|
|
|
| def caps_bits_to_labels(bits: int) -> list[str]: |
| """Human-readable capability labels (e.g., 'vision-generation') from a bitmask.""" |
| pretty = { |
| "text": "text", |
| "text_embedding": "text-embedding", |
| "vision": "vision", |
| "vision_generate": "vision-generation", |
| "vision_language": "vision-language", |
| "audio": "audio", |
| "video": "video", |
| "document": "document", |
| } |
| return [pretty[c] for c in _ALL_CAPS if bits & _CAP_BIT[c]] |
|
|
|
|
| def _slug(s: str) -> str: |
| return re.sub(r"[^a-z0-9]+", "", str(s).strip().lower()) |
|
|
|
|
| def _build_alias_map(name2id: dict[str, int]) -> dict[str, int]: |
| """Loose lookup: lowercased, also a slugged form, also strip composite markers.""" |
| out: dict[str, int] = {} |
| for k, v in name2id.items(): |
| for alias in {k, k.strip().lower(), _slug(k)}: |
| if alias and alias not in out: |
| out[alias] = v |
| |
| if "::" in k: |
| tail = k.split("::", 1)[1] |
| for alias in {tail, tail.strip().lower(), _slug(tail)}: |
| if alias and alias not in out: |
| out[alias] = v |
| return out |
|
|
|
|
| @dataclass |
| class Recommendation: |
| rank: int |
| model_name: str |
| score: float |
| size_bucket: int |
| size_b: float |
| family_id: int |
| family: str |
| popularity: int |
| hf_url: str |
|
|
|
|
| class Recommender: |
| """Loads the checkpoint, model pool, and ID maps; exposes ``recommend``.""" |
|
|
| def __init__( |
| self, |
| checkpoint_path: str, |
| args_path: str, |
| data_dir: str, |
| pool_path: str, |
| device: str = "cpu", |
| ): |
| self.device = torch.device(device) |
|
|
| with open(args_path) as f: |
| self._train_args = json.load(f) |
| with open(os.path.join(data_dir, "task2id.json")) as f: |
| self.task2id: dict[str, int] = json.load(f) |
| with open(os.path.join(data_dir, "metric2id.json")) as f: |
| metric2id_raw: dict[str, int] = json.load(f) |
| |
| |
| family2id_path = os.path.join(data_dir, "family2id.json") |
| if os.path.isfile(family2id_path): |
| with open(family2id_path) as f: |
| family2id: dict[str, int] = json.load(f) |
| self._id2family: dict[int, str] = {v: k for k, v in family2id.items()} |
| else: |
| self._id2family = {} |
| |
| |
| self.metric2id = metric2id_raw |
| self.task_alias = _build_alias_map(self.task2id) |
| self.metric_alias = _build_alias_map(self.metric2id) |
|
|
| pool = np.load(pool_path, allow_pickle=True) |
| self.model_names: list[str] = list(pool["names"].tolist()) |
| self.size_ids = torch.tensor(pool["size_ids"], dtype=torch.long) |
| |
| if "sizes_b" in pool.files: |
| self.sizes_b: np.ndarray = pool["sizes_b"].astype(np.float32) |
| else: |
| self.sizes_b = np.full(len(self.model_names), np.nan, dtype=np.float32) |
| self.family_ids = torch.tensor(pool["family_ids"], dtype=torch.long) |
| self.popularities: np.ndarray = pool["popularities"] |
| self.urls: list[str] = list(pool["urls"].tolist()) |
|
|
| |
| self.is_official: np.ndarray = np.array( |
| [_is_official_name(n) for n in self.model_names], dtype=bool |
| ) |
|
|
| |
| |
| |
| family_ids_np = self.family_ids.cpu().numpy() |
| self.model_caps_bits: np.ndarray = np.array( |
| [_family_caps_bits(self._id2family.get(int(fid), "")) |
| for fid in family_ids_np], |
| dtype=np.int64, |
| ) |
|
|
| |
| cfg = self._train_args |
| model_name = str(cfg.get("model_name", "MLPMetric")) |
| model_args = SimpleNamespace( |
| num_models=cfg.get("num_models", len(self.model_names)), |
| num_tasks=cfg.get("num_tasks"), |
| num_metrics=cfg.get("num_metrics"), |
| num_size_buckets=cfg.get("num_size_buckets"), |
| num_families=cfg.get("num_families"), |
| num_datasets=cfg.get("num_datasets", 100000), |
| token_dim=cfg["token_dim"], |
| model_dim=cfg["model_dim"], |
| task_dim=cfg["task_dim"], |
| metric_dim=cfg.get("metric_dim", cfg["task_dim"]), |
| size_dim=cfg["size_dim"], |
| family_dim=cfg.get("family_dim", cfg["size_dim"]), |
| dataset_desp_dim=cfg["dataset_desp_dim"], |
| dataset_id_emb_dim=cfg.get("dataset_id_emb_dim", 256), |
| dataset_desp_emb_dim=cfg.get("dataset_desp_emb_dim", 1536), |
| model_desp_emb_dim=cfg.get("model_desp_emb_dim", 1536), |
| hidden_dim=cfg["hidden_dim"], |
| dropout_rate=cfg.get("dropout_rate", 0.0), |
| use_id_emb=bool(cfg.get("use_id_emb", False)), |
| use_size_prior=bool(cfg.get("use_size_prior", True)), |
| use_family_prior=bool(cfg.get("use_family_prior", False)), |
| use_size_feature=bool(cfg.get("use_size_feature", True)), |
| use_metric_feature=bool(cfg.get("use_metric_feature", True)), |
| use_model_id_emb=bool(cfg.get("use_model_id_emb", True)), |
| use_model_name_emb=bool(cfg.get("use_model_name_emb", True)), |
| use_model_desc_emb=bool(cfg.get("use_model_desc_emb", True)), |
| use_dataset_id_emb=bool(cfg.get("use_dataset_id_emb", True)), |
| use_dataset_desc_emb=bool(cfg.get("use_dataset_desc_emb", True)), |
| unknown_metric_id=int(cfg.get("unknown_metric_id", 0)), |
| ) |
| if model_name == "MLPMetricFull": |
| self.model = MLPMetricFull(model_args).to(self.device).eval() |
| else: |
| self.model = MLPMetric(model_args).to(self.device).eval() |
| self._model_name = model_name |
|
|
| raw = torch.load(checkpoint_path, map_location="cpu") |
| state = raw.get("model", raw) if isinstance(raw, dict) else raw |
| missing, unexpected = self.model.load_state_dict(state, strict=False) |
| if missing or unexpected: |
| print(f"[Recommender] loaded with missing={len(missing)} unexpected={len(unexpected)}") |
| if missing: |
| print(" e.g. missing:", missing[:3]) |
| if unexpected: |
| print(" e.g. unexpected:", unexpected[:3]) |
|
|
| |
| |
| self._cache_lock = threading.Lock() |
| with torch.no_grad(): |
| self.model_cache = self.model.build_model_cache( |
| self.model_names, |
| self.size_ids, |
| all_model_family_ids=self.family_ids if self.model.use_family_prior else None, |
| device=self.device, |
| ) |
|
|
| |
| self._oai_client = None |
|
|
| |
|
|
| def _make_openai_client(self, api_key: Optional[str] = None): |
| from openai import OpenAI |
| |
| |
| |
| if api_key: |
| return OpenAI(api_key=api_key) |
| |
| if self._oai_client is None: |
| self._oai_client = OpenAI() |
| return self._oai_client |
|
|
| def embed_description(self, text: str, api_key: Optional[str] = None) -> np.ndarray: |
| text = (text or "").strip() |
| if not text: |
| raise ValueError("Dataset description must be non-empty.") |
| try: |
| client = self._make_openai_client(api_key) |
| except Exception as e: |
| raise ValueError( |
| "OpenAI client could not be created. Paste an API key into " |
| "the 'OpenAI API key' field above. Original error: " + str(e) |
| ) |
| try: |
| resp = client.embeddings.create(model=EMBEDDING_MODEL, input=text) |
| except Exception as e: |
| |
| |
| raise ValueError(f"OpenAI embedding call failed: {e}") |
| vec = np.asarray(resp.data[0].embedding, dtype=np.float32) |
| if vec.shape[-1] != EMBEDDING_DIM: |
| raise RuntimeError( |
| f"Expected {EMBEDDING_DIM}-dim embedding, got {vec.shape[-1]}. " |
| f"Make sure the API key has access to {EMBEDDING_MODEL}." |
| ) |
| return vec |
|
|
| |
|
|
| def resolve_task(self, task: str) -> int: |
| if task is None: |
| raise ValueError("Task must be provided.") |
| for cand in (task, task.strip().lower(), _slug(task)): |
| if cand in self.task_alias: |
| return self.task_alias[cand] |
| raise ValueError( |
| f"Unknown task '{task}'. Pick one from the dropdown — the model has only seen {len(self.task2id)} task labels." |
| ) |
|
|
| def resolve_metric(self, metric: str) -> int: |
| if metric is None or not str(metric).strip(): |
| return int(self.model.unknown_metric_id) |
| for cand in (metric, metric.strip().lower(), _slug(metric)): |
| if cand in self.metric_alias: |
| return self.metric_alias[cand] |
| |
| return int(self.model.unknown_metric_id) |
|
|
| |
|
|
| def recommend( |
| self, |
| dataset_description: str, |
| task: str, |
| metric: Optional[str] = None, |
| top_k: int = 20, |
| popularity_weight: float = 0.0, |
| hf_only: bool = True, |
| min_size_b: Optional[float] = None, |
| max_size_b: Optional[float] = None, |
| official_only: bool = False, |
| api_key: Optional[str] = None, |
| ) -> List[Recommendation]: |
| """Score all candidate models and return the top-k. |
| |
| ``popularity_weight`` (0..1) blends a log(downloads) signal into the |
| ranking, useful when several models have near-tied scores. Default 0 |
| means "pure model output". |
| |
| ``hf_only`` (default True) drops candidates whose model name is not a |
| HuggingFace repo id (those are paper baselines like ``inceptionv4`` |
| that the user cannot download with ``hf hub``). |
| |
| ``min_size_b`` / ``max_size_b`` (optional, in B params) restrict |
| results to candidates whose raw parameter count falls in the range. |
| ``None`` (or 0 from the UI) means "no limit". Models with unknown |
| size are excluded once any size bound is set. |
| |
| ``official_only`` (default False) restricts to a curated whitelist of |
| foundation-lab orgs (DeepSeek, Qwen, Llama, gpt-oss, Mistral, ...). |
| |
| ``api_key`` (optional) — OpenAI API key supplied by the caller (e.g. |
| from a Gradio textbox). When given, used for this single request only; |
| otherwise the recommender falls back to ``OPENAI_API_KEY`` in env. |
| """ |
| task_id = self.resolve_task(task) |
| metric_id = self.resolve_metric(metric) |
| task_caps_bits = _task_required_caps_bits(task) |
| emb = self.embed_description(dataset_description, api_key=api_key) |
| return self._score( |
| emb, task_id, metric_id, top_k, popularity_weight, hf_only, |
| min_size_b=min_size_b, max_size_b=max_size_b, |
| official_only=official_only, |
| task_caps_bits=task_caps_bits, |
| ) |
|
|
| @torch.no_grad() |
| def _score( |
| self, |
| desp_emb: np.ndarray, |
| task_id: int, |
| metric_id: int, |
| top_k: int, |
| popularity_weight: float, |
| hf_only: bool = True, |
| min_size_b: Optional[float] = None, |
| max_size_b: Optional[float] = None, |
| official_only: bool = False, |
| task_caps_bits: Optional[int] = None, |
| ) -> List[Recommendation]: |
| device = self.device |
| task_t = torch.tensor([task_id], dtype=torch.long, device=device) |
| metric_t = torch.tensor([metric_id], dtype=torch.long, device=device) |
| desp_t = torch.tensor(desp_emb, dtype=torch.float32, device=device).unsqueeze(0) |
|
|
| with self._cache_lock: |
| scores = self.model.score_matrix( |
| task_t, desp_t, self.model_cache, metric_ids=metric_t |
| ).squeeze(0) |
| scores_np = scores.detach().cpu().numpy().astype(np.float32) |
|
|
| if popularity_weight > 0.0: |
| pop = np.log1p(self.popularities.astype(np.float32)) |
| if pop.max() > 0: |
| pop = pop / pop.max() |
| |
| s_norm = scores_np - scores_np.mean() |
| if s_norm.std() > 1e-6: |
| s_norm = s_norm / s_norm.std() |
| ranking_scores = s_norm + popularity_weight * pop |
| else: |
| ranking_scores = scores_np |
|
|
| |
| if hf_only: |
| has_url = np.array([bool(u) for u in self.urls]) |
| ranking_scores = np.where(has_url, ranking_scores, -np.inf) |
|
|
| |
| |
| |
| size_filter_active = (min_size_b not in (None, 0)) or (max_size_b not in (None, 0)) |
| if size_filter_active: |
| sizes = self.sizes_b |
| in_range = ~np.isnan(sizes) |
| if min_size_b not in (None, 0): |
| in_range &= sizes >= float(min_size_b) |
| if max_size_b not in (None, 0): |
| in_range &= sizes <= float(max_size_b) |
| ranking_scores = np.where(in_range, ranking_scores, -np.inf) |
|
|
| |
| if official_only: |
| ranking_scores = np.where(self.is_official, ranking_scores, -np.inf) |
|
|
| |
| |
| |
| |
| if task_caps_bits: |
| compat = (self.model_caps_bits & task_caps_bits) != 0 |
| ranking_scores = np.where(compat, ranking_scores, -np.inf) |
|
|
| top_k = max(1, min(int(top_k), len(self.model_names))) |
| top_idx = np.argpartition(-ranking_scores, top_k - 1)[:top_k] |
| top_idx = top_idx[np.argsort(-ranking_scores[top_idx])] |
|
|
| out: list[Recommendation] = [] |
| for rank, i in enumerate(top_idx, start=1): |
| out.append( |
| Recommendation( |
| rank=rank, |
| model_name=self.model_names[i], |
| score=float(scores_np[i]), |
| size_bucket=int(self.size_ids[i]), |
| size_b=float(self.sizes_b[i]), |
| family_id=int(self.family_ids[i]), |
| family=self._id2family.get(int(self.family_ids[i]), ""), |
| popularity=int(self.popularities[i]), |
| hf_url=self.urls[i], |
| ) |
| ) |
| return out |
|
|
|
|
| def default_recommender() -> Recommender: |
| """Convenience constructor. |
| |
| Resolves paths in this order: |
| 1. Environment variables (``MODEL_CKPT``, ``MODEL_ARGS``, ``DATA_DIR``, ``POOL_PATH``). |
| 2. Self-contained Spaces layout: ``web/checkpoint/`` and ``web/data/``. |
| 3. Original project tree (development mode). |
| """ |
| here = os.path.dirname(os.path.abspath(__file__)) |
| root = os.path.dirname(here) |
|
|
| |
| spaces_args = os.path.join(here, "checkpoint/args.json") |
| spaces_data = os.path.join(here, "data") |
| spaces_ckpt_full = os.path.join(here, "checkpoint/MLPMetricFull.pt") |
| spaces_ckpt_metric = os.path.join(here, "checkpoint/MLPMetric.pt") |
| spaces_ckpt = spaces_ckpt_full if os.path.exists(spaces_ckpt_full) else spaces_ckpt_metric |
|
|
| dev_ckpt = os.path.join(root, "checkpoint/mlp/unified_augmented_v2/FinalModel_v2_full_data_deployment/MLPMetricFull.pt") |
| dev_args = os.path.join(root, "checkpoint/mlp/unified_augmented_v2/FinalModel_v2_full_data_deployment/args.json") |
| dev_data = os.path.join(root, "data/unified_augmented_v2") |
|
|
| def _pick(env_key: str, primary: str, fallback: str) -> str: |
| v = os.environ.get(env_key) |
| if v: |
| return v |
| return primary if os.path.exists(primary) else fallback |
|
|
| return Recommender( |
| checkpoint_path=_pick("MODEL_CKPT", spaces_ckpt, dev_ckpt), |
| args_path=_pick("MODEL_ARGS", spaces_args, dev_args), |
| data_dir=_pick("DATA_DIR", spaces_data, dev_data), |
| pool_path=os.environ.get("POOL_PATH", os.path.join(here, "assets/model_pool.npz")), |
| device=os.environ.get("DEVICE", "cpu"), |
| ) |
|
|
|
|
| if __name__ == "__main__": |
| rec = default_recommender() |
| print(f"Loaded {len(rec.model_names)} candidate models, " |
| f"{len(rec.task2id)} tasks, {len(rec.metric2id)} metrics.") |
| sample_task = next(iter(rec.task2id)) |
| print(f"\nSmoke test: ranking for task={sample_task!r}") |
| fake_emb = np.random.randn(EMBEDDING_DIM).astype(np.float32) |
| out = rec._score(fake_emb, rec.task2id[sample_task], rec.model.unknown_metric_id, 5, 0.0) |
| for r in out: |
| print(f" #{r.rank} {r.model_name:<60} score={r.score:+.4f} pop={r.popularity}") |
|
|