Instructions to use MilaDeepGraph/ProtST-ESM1b-LocalizationPrediction with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- Transformers
How to use MilaDeepGraph/ProtST-ESM1b-LocalizationPrediction with Transformers:
# Use a pipeline as a high-level helper from transformers import pipeline pipe = pipeline("feature-extraction", model="MilaDeepGraph/ProtST-ESM1b-LocalizationPrediction", trust_remote_code=True)# Load model directly from transformers import AutoModel model = AutoModel.from_pretrained("MilaDeepGraph/ProtST-ESM1b-LocalizationPrediction", trust_remote_code=True, dtype="auto") - Notebooks
- Google Colab
- Kaggle
| import math | |
| import torch | |
| import torch.nn as nn | |
| from typing import Optional, Tuple, Union | |
| from dataclasses import dataclass | |
| from transformers import PreTrainedModel | |
| from transformers.modeling_outputs import ModelOutput | |
| from transformers.models.esm import EsmPreTrainedModel, EsmModel | |
| from transformers.models.bert import BertPreTrainedModel, BertModel | |
| from .configuration_protst import ProtSTConfig | |
| class EsmProteinRepresentationOutput(ModelOutput): | |
| protein_feature: torch.FloatTensor = None | |
| residue_feature: torch.FloatTensor = None | |
| class BertTextRepresentationOutput(ModelOutput): | |
| text_feature: torch.FloatTensor = None | |
| word_feature: torch.FloatTensor = None | |
| class ProtSTClassificationOutput(ModelOutput): | |
| loss: Optional[torch.FloatTensor] = None | |
| logits: torch.FloatTensor = None | |
| class ProtSTHead(nn.Module): | |
| def __init__(self, config, out_dim=512): | |
| super().__init__() | |
| self.dense = nn.Linear(config.hidden_size, config.hidden_size) | |
| self.out_proj = nn.Linear(config.hidden_size, out_dim) | |
| def forward(self, x): | |
| x = self.dense(x) | |
| x = nn.functional.relu(x) | |
| x = self.out_proj(x) | |
| return x | |
| class BertForPubMed(BertPreTrainedModel): | |
| def __init__(self, config): | |
| super().__init__(config) | |
| self.pad_token_id = config.pad_token_id | |
| self.cls_token_id = config.cls_token_id | |
| self.sep_token_id = config.sep_token_id | |
| self.bert = BertModel(config, add_pooling_layer=False) | |
| self.text_mlp = ProtSTHead(config) | |
| self.word_mlp = ProtSTHead(config) | |
| self.post_init() # NOTE | |
| def forward( | |
| self, | |
| input_ids: Optional[torch.Tensor] = None, | |
| attention_mask: Optional[torch.Tensor] = None, | |
| token_type_ids: Optional[torch.Tensor] = None, | |
| position_ids: Optional[torch.Tensor] = None, | |
| head_mask: Optional[torch.Tensor] = None, | |
| inputs_embeds: Optional[torch.Tensor] = None, | |
| encoder_hidden_states: Optional[torch.Tensor] = None, | |
| encoder_attention_mask: Optional[torch.Tensor] = None, | |
| output_attentions: Optional[bool] = None, | |
| output_hidden_states: Optional[bool] = None, | |
| return_dict: Optional[bool] = None, | |
| ) -> Union[Tuple[torch.Tensor], ModelOutput]: | |
| return_dict = return_dict if return_dict is not None else self.config.use_return_dict | |
| outputs = self.bert( | |
| input_ids, | |
| attention_mask=attention_mask, | |
| token_type_ids=token_type_ids, | |
| position_ids=position_ids, | |
| head_mask=head_mask, | |
| inputs_embeds=inputs_embeds, | |
| encoder_hidden_states=encoder_hidden_states, | |
| encoder_attention_mask=encoder_attention_mask, | |
| output_attentions=output_attentions, | |
| output_hidden_states=output_hidden_states, | |
| return_dict=return_dict, | |
| ) | |
| word_feature = outputs.last_hidden_state | |
| is_special = (input_ids == self.cls_token_id) | (input_ids == self.sep_token_id) | (input_ids == self.pad_token_id) | |
| special_mask = (~is_special).to(torch.int64).unsqueeze(-1) | |
| pooled_feature = ((word_feature * special_mask).sum(1) / (special_mask.sum(1) + 1.0e-6)).to(word_feature.dtype) | |
| pooled_feature = self.text_mlp(pooled_feature) | |
| word_feature = self.word_mlp(word_feature) | |
| if not return_dict: | |
| return (pooled_feature, word_feature) | |
| return BertTextRepresentationOutput(text_feature=pooled_feature, word_feature=word_feature) | |
| class EsmForProteinRepresentation(EsmPreTrainedModel): | |
| def __init__(self, config): | |
| super().__init__(config) | |
| self.cls_token_id = config.cls_token_id | |
| self.pad_token_id = config.pad_token_id | |
| self.eos_token_id = config.eos_token_id | |
| self.esm = EsmModel(config, add_pooling_layer=False) | |
| self.post_init() # NOTE | |
| def forward( | |
| self, | |
| input_ids: Optional[torch.LongTensor] = None, | |
| attention_mask: Optional[torch.Tensor] = None, | |
| position_ids: Optional[torch.LongTensor] = None, | |
| head_mask: Optional[torch.Tensor] = None, | |
| inputs_embeds: Optional[torch.FloatTensor] = None, | |
| output_attentions: Optional[bool] = None, | |
| output_hidden_states: Optional[bool] = None, | |
| return_dict: Optional[bool] = None, | |
| ) -> Union[Tuple, EsmProteinRepresentationOutput]: | |
| return_dict = return_dict if return_dict is not None else self.config.use_return_dict | |
| outputs = self.esm( | |
| input_ids, | |
| attention_mask=attention_mask, | |
| position_ids=position_ids, | |
| head_mask=head_mask, | |
| inputs_embeds=inputs_embeds, | |
| output_attentions=output_attentions, | |
| output_hidden_states=output_hidden_states, | |
| return_dict=return_dict, | |
| ) | |
| residue_feature = outputs.last_hidden_state # [batch_size, seq_len, hidden_dim] | |
| # mean readout | |
| is_special = ( | |
| (input_ids == self.cls_token_id) | (input_ids == self.eos_token_id) | (input_ids == self.pad_token_id) | |
| ) | |
| special_mask = (~is_special).to(torch.int64).unsqueeze(-1) | |
| protein_feature = ((residue_feature * special_mask).sum(1) / (special_mask.sum(1) + 1.0e-6)).to(residue_feature.dtype) | |
| return EsmProteinRepresentationOutput( | |
| protein_feature=protein_feature, residue_feature=residue_feature | |
| ) | |
| class ProtSTPreTrainedModel(PreTrainedModel): | |
| config_class = ProtSTConfig | |
| class ProtSTForProteinPropertyPrediction(ProtSTPreTrainedModel): | |
| def __init__(self, config): | |
| super().__init__(config) | |
| self.config = config | |
| self.protein_model = EsmForProteinRepresentation(config.protein_config) | |
| self.classifier = ProtSTHead(config.protein_config, out_dim=config.num_labels) | |
| self.post_init() # NOTE | |
| def forward( | |
| self, | |
| input_ids: Optional[torch.LongTensor] = None, | |
| attention_mask: Optional[torch.Tensor] = None, | |
| position_ids: Optional[torch.LongTensor] = None, | |
| head_mask: Optional[torch.Tensor] = None, | |
| inputs_embeds: Optional[torch.FloatTensor] = None, | |
| labels: Optional[torch.LongTensor] = None, | |
| output_attentions: Optional[bool] = None, | |
| output_hidden_states: Optional[bool] = None, | |
| return_dict: Optional[bool] = None, | |
| ) -> Union[Tuple, ProtSTClassificationOutput]: | |
| r""" | |
| labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): | |
| Labels for computing the protein classification loss. Indices should be in `[0, ..., config.num_labels - 1]`. | |
| Returns: | |
| Examples: | |
| """ | |
| return_dict = return_dict if return_dict is not None else self.config.use_return_dict | |
| outputs = self.protein_model( | |
| input_ids, | |
| attention_mask=attention_mask, | |
| position_ids=position_ids, | |
| head_mask=head_mask, | |
| inputs_embeds=inputs_embeds, | |
| output_attentions=output_attentions, | |
| output_hidden_states=output_hidden_states, | |
| return_dict=return_dict, | |
| ) | |
| logits = self.classifier(outputs.protein_feature) # [bsz, xxx] -> [bsz, num_labels] | |
| loss = None | |
| if labels is not None: | |
| loss_fct = nn.CrossEntropyLoss() | |
| labels = labels.to(logits.device) | |
| loss = loss_fct(logits.view(-1, logits.shape[-1]), labels.view(-1)) | |
| if not return_dict: | |
| output = (logits,) | |
| return ((loss,) + output) if loss is not None else output | |
| return ProtSTClassificationOutput(loss=loss, logits=logits) |