Instructions to use nikraf/directionality_probe with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- Transformers
How to use nikraf/directionality_probe with Transformers:
# Use a pipeline as a high-level helper from transformers import pipeline pipe = pipeline("feature-extraction", model="nikraf/directionality_probe", trust_remote_code=True)# Load model directly from transformers import AutoModel model = AutoModel.from_pretrained("nikraf/directionality_probe", trust_remote_code=True, dtype="auto") - Notebooks
- Google Colab
- Kaggle
| import entrypoint_setup | |
| import os | |
| import torch | |
| import warnings | |
| import sqlite3 | |
| import gzip | |
| from torch.utils.data import DataLoader | |
| from tqdm.auto import tqdm | |
| from dataclasses import dataclass | |
| from typing import Optional, Callable, List | |
| from huggingface_hub import hf_hub_download | |
| try: | |
| from seed_utils import seed_worker, dataloader_generator, get_global_seed | |
| from data.dataset_classes import SimpleProteinDataset | |
| from base_models.get_base_models import get_base_model | |
| from pooler import Pooler | |
| from utils import torch_load, print_message, maybe_compile | |
| except ImportError: | |
| from .seed_utils import seed_worker, dataloader_generator, get_global_seed | |
| from .data.dataset_classes import SimpleProteinDataset | |
| from .base_models.get_base_models import get_base_model | |
| from .pooler import Pooler | |
| from .utils import torch_load, print_message, maybe_compile | |
| def build_collator(tokenizer) -> Callable[[List[str]], tuple[torch.Tensor, torch.Tensor]]: | |
| def _collate_fn(sequences: List[str]) -> tuple[torch.Tensor, torch.Tensor]: | |
| """Collate function for batching sequences.""" | |
| return tokenizer(sequences, return_tensors="pt", padding='longest', pad_to_multiple_of=8) | |
| return _collate_fn | |
| def get_embedding_filename(model_name: str, matrix_embed: bool, pooling_types: List[str], extension: str = 'pth') -> str: | |
| """ | |
| Generate embedding filename with pooling types for vector embeddings. | |
| Args: | |
| model_name: Name of the model | |
| matrix_embed: Whether embeddings are matrices (True) or vectors (False) | |
| pooling_types: List of pooling types used (only relevant for vector embeddings) | |
| extension: File extension ('pth' or 'db') | |
| Returns: | |
| Filename string in format: {model_name}_{matrix_embed}[_{pooling_types}].{extension} | |
| """ | |
| base_name = f'{model_name}_{matrix_embed}' | |
| if not matrix_embed and pooling_types: | |
| # For vector embeddings, include pooling types in filename | |
| pooling_str = '_'.join(sorted(pooling_types)) # Sort for consistency | |
| base_name = f'{base_name}_{pooling_str}' | |
| return f'{base_name}.{extension}' | |
| class EmbeddingArguments: | |
| def __init__( | |
| self, | |
| embedding_batch_size: int = 4, | |
| embedding_num_workers: int = 0, | |
| download_embeddings: bool = False, | |
| download_dir: str = 'Synthyra/vector_embeddings', | |
| matrix_embed: bool = False, | |
| embedding_pooling_types: List[str] = ['mean'], | |
| save_embeddings: bool = False, | |
| embed_dtype: torch.dtype = torch.float32, | |
| model_dtype: torch.dtype = None, | |
| sql: bool = False, | |
| embedding_save_dir: str = 'embeddings', | |
| **kwargs | |
| ): | |
| self.batch_size = embedding_batch_size | |
| self.num_workers = embedding_num_workers | |
| self.download_embeddings = download_embeddings | |
| self.download_dir = download_dir | |
| self.matrix_embed = matrix_embed | |
| self.pooling_types = embedding_pooling_types | |
| self.save_embeddings = save_embeddings | |
| self.embed_dtype = embed_dtype | |
| self.model_dtype = model_dtype | |
| self.sql = sql | |
| self.embedding_save_dir = embedding_save_dir | |
| class Embedder: | |
| def __init__(self, args: EmbeddingArguments, all_seqs: List[str]): | |
| self.args = args | |
| self.all_seqs = all_seqs | |
| self.batch_size = args.batch_size | |
| self.num_workers = args.num_workers | |
| self.matrix_embed = args.matrix_embed | |
| self.pooling_types = args.pooling_types | |
| self.download_embeddings = args.download_embeddings | |
| self.download_dir = args.download_dir | |
| self.save_embeddings = args.save_embeddings | |
| self.embed_dtype = args.embed_dtype | |
| self.model_dtype = args.model_dtype | |
| self.sql = args.sql | |
| self.embedding_save_dir = args.embedding_save_dir | |
| self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
| print_message(f'Device {self.device} found') | |
| def _download_embeddings(self, model_name: str): | |
| # download from download_dir | |
| # unzip | |
| # move to embedding_save_dir | |
| filename = get_embedding_filename(model_name, self.matrix_embed, self.pooling_types, 'pth') | |
| try: | |
| local_path = hf_hub_download( | |
| repo_id=self.download_dir, | |
| filename=f'embeddings/{filename}.gz', | |
| repo_type='dataset' | |
| ) | |
| except: | |
| print(f'No embeddings found for {model_name} in {self.download_dir}') | |
| return | |
| # unzip | |
| print_message(f'Unzipping {local_path}') | |
| with gzip.open(local_path, 'rb') as f_in: | |
| with open(local_path.replace('.gz', ''), 'wb') as f_out: | |
| f_out.write(f_in.read()) | |
| # move to embedding_save_dir | |
| unzipped_path = local_path.replace('.gz', '') | |
| final_path = os.path.join(self.embedding_save_dir, filename) | |
| if os.path.exists(final_path): | |
| print_message(f'Found existing embeddings in {final_path}') | |
| # Load downloaded embeddings | |
| downloaded_embeddings = torch_load(unzipped_path) | |
| existing_embeddings = torch_load(final_path) | |
| download_dtype = torch.float16 | |
| if self.embed_dtype != download_dtype: | |
| print_message(f"Warning:\nDownloaded embeddings are {download_dtype} but the current setting is {self.embed_dtype}\nWhen combining with existing embeddings, this could result in unintended biases or reductions in performance") | |
| # Combine with existing embeddings | |
| print_message('Combining and casting') | |
| downloaded_embeddings.update(existing_embeddings) | |
| # Cast all embeddings to the correct dtype | |
| for seq in downloaded_embeddings: | |
| downloaded_embeddings[seq] = downloaded_embeddings[seq].to(self.embed_dtype) | |
| # Save the combined embeddings | |
| print_message(f'Saving combined embeddings to {final_path}') | |
| torch.save(downloaded_embeddings, final_path) | |
| else: | |
| print_message(f'Downloading embeddings from {self.download_dir}, no previous embeddings found') | |
| downloaded_embeddings = torch.load(unzipped_path) | |
| torch.save(downloaded_embeddings, final_path) | |
| return final_path | |
| def _read_sequences_from_db(self, db_path: str) -> set[str]: | |
| """Read sequences from SQLite database.""" | |
| import sqlite3 | |
| sequences = [] | |
| with sqlite3.connect(db_path) as conn: | |
| c = conn.cursor() | |
| c.execute("SELECT sequence FROM embeddings") | |
| while True: | |
| row = c.fetchone() | |
| if row is None: | |
| break | |
| sequences.append(row[0]) | |
| return set(sequences) | |
| def _read_embeddings_from_disk(self, model_name: str): | |
| if self.sql: | |
| filename = get_embedding_filename(model_name, self.matrix_embed, self.pooling_types, 'db') | |
| save_path = os.path.join(self.embedding_save_dir, filename) | |
| if os.path.exists(save_path): | |
| conn = sqlite3.connect(save_path) | |
| c = conn.cursor() | |
| c.execute('CREATE TABLE IF NOT EXISTS embeddings (sequence text PRIMARY KEY, embedding blob)') | |
| already_embedded = self._read_sequences_from_db(save_path) | |
| to_embed = [seq for seq in self.all_seqs if seq not in already_embedded] | |
| print_message(f"Loaded {len(already_embedded)} already embedded sequences from {save_path}\nEmbedding {len(to_embed)} new sequences") | |
| return to_embed, save_path, {} | |
| else: | |
| print_message(f"No embeddings found in {save_path}") | |
| return self.all_seqs, save_path, {} | |
| else: | |
| embeddings_dict = {} | |
| filename = get_embedding_filename(model_name, self.matrix_embed, self.pooling_types, 'pth') | |
| save_path = os.path.join(self.embedding_save_dir, filename) | |
| if os.path.exists(save_path): | |
| print_message(f"Loading embeddings from {save_path}") | |
| embeddings_dict = torch_load(save_path) | |
| print_message(f"Loaded {len(embeddings_dict)} embeddings from {save_path}") | |
| # Cast existing embeddings to the specified dtype | |
| #for seq in embeddings_dict: | |
| # embeddings_dict[seq] = embeddings_dict[seq].to(self.embed_dtype) | |
| to_embed = [seq for seq in self.all_seqs if seq not in embeddings_dict] | |
| return to_embed, save_path, embeddings_dict | |
| else: | |
| print_message(f"No embeddings found in {save_path}") | |
| return self.all_seqs, save_path, {} | |
| def _embed_sequences( | |
| self, | |
| to_embed: List[str], | |
| save_path: str, | |
| embedding_model: any, | |
| tokenizer: any, | |
| embeddings_dict: dict[str, torch.Tensor]) -> Optional[dict[str, torch.Tensor]]: | |
| os.makedirs(self.embedding_save_dir, exist_ok=True) | |
| model = embedding_model.to(self.device).eval() | |
| model = maybe_compile(model) | |
| device = self.device | |
| collate_fn = build_collator(tokenizer) | |
| print_message(f'Pooling types: {self.pooling_types}') | |
| if self.matrix_embed: | |
| pooler = None | |
| else: | |
| pooler = Pooler(self.pooling_types) | |
| def _get_embeddings( | |
| residue_embeddings: torch.Tensor, | |
| attention_mask: Optional[torch.Tensor] = None, | |
| attentions: Optional[torch.Tensor] = None | |
| ) -> torch.Tensor: | |
| if residue_embeddings.ndim == 2 or self.matrix_embed: # sometimes already vector emb | |
| return residue_embeddings | |
| else: | |
| return pooler(emb=residue_embeddings, attention_mask=attention_mask, attentions=attentions) | |
| dataset = SimpleProteinDataset(to_embed) | |
| dataloader = DataLoader( | |
| dataset, | |
| batch_size=self.batch_size, | |
| num_workers=self.num_workers, | |
| prefetch_factor=2 if self.num_workers > 0 else None, | |
| collate_fn=collate_fn, | |
| shuffle=False, | |
| pin_memory=True, | |
| worker_init_fn=seed_worker, | |
| generator=dataloader_generator(get_global_seed()) | |
| ) | |
| if self.sql: | |
| conn = sqlite3.connect(save_path) | |
| c = conn.cursor() | |
| c.execute('CREATE TABLE IF NOT EXISTS embeddings (sequence text PRIMARY KEY, embedding blob)') | |
| for i, batch in tqdm(enumerate(dataloader), total=len(dataloader), desc='Embedding batches'): | |
| seqs = to_embed[i * self.batch_size:(i + 1) * self.batch_size] | |
| batch = {k: v.to(device) for k, v in batch.items() if isinstance(v, torch.Tensor)} | |
| if 'attention_mask' in batch: | |
| attention_mask = batch['attention_mask'] | |
| elif 'sequence_ids' in batch: | |
| attention_mask = (batch['sequence_ids'] != -1).long().to(device) | |
| else: | |
| attention_mask = torch.ones_like(batch['input_ids'], device=device) | |
| if 'parti' in self.pooling_types: | |
| try: | |
| residue_embeddings, attentions = model(**batch, output_attentions=True) | |
| embeddings = _get_embeddings(residue_embeddings, attention_mask=attention_mask, attentions=attentions).cpu() | |
| except Exception as e: | |
| print_message(f"Error in parti pooling: {e}\nDefaulting to mean pooling") | |
| self.pooling_types = ['mean'] | |
| pooler = Pooler(self.pooling_types) | |
| residue_embeddings = model(**batch) | |
| embeddings = _get_embeddings(residue_embeddings, attention_mask=attention_mask).cpu() | |
| else: | |
| residue_embeddings = model(**batch) | |
| embeddings = _get_embeddings(residue_embeddings, attention_mask=attention_mask).cpu() | |
| for seq, emb, mask in zip(seqs, embeddings, attention_mask.cpu()): | |
| if self.matrix_embed: | |
| emb = emb[mask.bool()] | |
| if self.sql: | |
| c.execute("INSERT OR REPLACE INTO embeddings VALUES (?, ?)", | |
| (seq, emb.numpy().tobytes())) # only supports float32 | |
| else: | |
| embeddings_dict[seq] = emb.to(self.embed_dtype) | |
| if (i + 1) % 100 == 0 and self.sql: | |
| conn.commit() | |
| if self.sql: | |
| conn.commit() | |
| conn.close() | |
| return embeddings_dict | |
| if self.save_embeddings: | |
| print_message(f"Saving embeddings to {save_path}") | |
| torch.save(embeddings_dict, save_path) | |
| return embeddings_dict | |
| def __call__(self, model_name: str, model_type: str = None, model_path: str = None): | |
| if self.download_embeddings: | |
| self._download_embeddings(model_name) | |
| if self.device == 'cpu': | |
| warnings.warn("Downloading embeddings is recommended for CPU usage - Embedding on CPU will be extremely slow!") | |
| to_embed, save_path, embeddings_dict = self._read_embeddings_from_disk(model_name) | |
| if len(to_embed) > 0: | |
| print_message(f"Embedding {len(to_embed)} sequences with {model_name}") | |
| dispatch_name = model_type or model_name | |
| model, tokenizer = get_base_model(dispatch_name, dtype=self.model_dtype, model_path=model_path) | |
| return self._embed_sequences(to_embed, save_path, model, tokenizer, embeddings_dict) | |
| else: | |
| print_message(f"No sequences to embed with {model_name}") | |
| return embeddings_dict | |
| if __name__ == '__main__': | |
| ### Embed all supported datasets with all supported models | |
| # py -m embedder | |
| import argparse | |
| from huggingface_hub import upload_file, login | |
| from data.supported_datasets import vector_benchmark | |
| from data.data_mixin import DataArguments, DataMixin | |
| from base_models.get_base_models import BaseModelArguments, get_base_model | |
| from seed_utils import set_global_seed | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument('--token', default=None, help='Huggingface token') | |
| parser.add_argument('--batch_size', type=int, default=16) | |
| parser.add_argument('--num_workers', type=int, default=4) | |
| parser.add_argument('--embed_dtype', type=str, default='float16') | |
| parser.add_argument('--model_names', nargs='+', default=['standard']) | |
| parser.add_argument('--models_to_skip', nargs='+', default=[], help='When checking for existing embeddings, skip these models.') | |
| parser.add_argument('--embedding_save_dir', type=str, default='embeddings') | |
| parser.add_argument('--download_dir', type=str, default='Synthyra/vector_embeddings') | |
| parser.add_argument('--embedding_pooling_types', nargs='+', default=['mean', 'var'], help='Pooling types for embeddings.') | |
| args = parser.parse_args() | |
| chosen_seed = set_global_seed() | |
| if args.token is not None: | |
| login(args.token) | |
| if args.embed_dtype == 'float16': | |
| dtype = torch.float16 | |
| elif args.embed_dtype == 'bfloat16': | |
| dtype = torch.bfloat16 | |
| elif args.embed_dtype == 'float32': | |
| dtype = torch.float32 | |
| else: | |
| raise ValueError(f"Invalid embedding dtype: {args.embed_dtype}") | |
| # Get data | |
| data_args = DataArguments( | |
| data_names=vector_benchmark, | |
| max_length=1024, | |
| trim=False | |
| ) | |
| all_seqs = DataMixin(data_args).get_data()[1] | |
| # Embed for each model | |
| model_args = BaseModelArguments(model_names=args.model_names) | |
| for model_name in model_args.model_names: | |
| embedder_args = EmbeddingArguments( | |
| batch_size=args.batch_size, | |
| num_workers=args.num_workers, | |
| download_embeddings=model_name not in args.models_to_skip, | |
| matrix_embed=False, | |
| embedding_pooling_types=args.embedding_pooling_types, | |
| save_embeddings=True, | |
| embed_dtype=dtype, | |
| sql=False, | |
| embedding_save_dir='embeddings' | |
| ) | |
| embedder = Embedder(embedder_args, all_seqs) | |
| _ = embedder(model_name) | |
| filename = get_embedding_filename(model_name, False, embedder_args.pooling_types, 'pth') | |
| save_path = os.path.join(args.embedding_save_dir, filename) | |
| compressed_path = f"{save_path}.gz" | |
| print(f"Compressing {save_path} to {compressed_path}") | |
| with open(save_path, 'rb') as f_in: | |
| with gzip.open(compressed_path, 'wb') as f_out: | |
| f_out.write(f_in.read()) | |
| upload_path = compressed_path | |
| path_in_repo = f'embeddings/{filename}.gz' | |
| upload_file( | |
| path_or_fileobj=upload_path, | |
| path_in_repo=path_in_repo, | |
| repo_id=args.download_dir, | |
| repo_type='dataset' | |
| ) | |
| print('Done') |