通过Prot_T5模型对蛋白质进行Embedding

注意

ESM 进行嵌入之后有开始词和结束词,需要截断。
Prot_T5行嵌入之后有开始词或结束词,我还没搞清楚,总之是和ESM不一样。

在HuggingFace上面下载模型

https://huggingface.co/Rostlab/prot_t5_xl_half_uniref50-enc

一些教程

关于如何使用该模型完成常见任务的详尽交互式示例,请参阅Google Colab:https://colab.research.google.com/drive/1TUj-ayG3WO52n5N50S7KH9vtt6zRkdmj?usp=sharing#scrollTo=ET2v51slC5ui

开箱即用代码

pip install torch transformers sentencepiece h5py

from transformers import T5EncoderModel, T5Tokenizer
import torch
import h5py
import time,re
import numpy as np

device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
print("Using {}".format(device))


data=[] #[("P27204","AKKRSRSRKRSASRKRSRSRKRSASKKSSKKHVRKALAAGMKNHLLAHPKGSNNFILAKKKAPRRRRRVAKKVKKAPPKARRRVRVAKSRRSRTARSRRRR")]       #[(id,seq)....]
save_path="datasets/PDB14189/prot_feature"
#------------
# txt=open("datasets/PDB14189/PDB14189_N.txt","r").read()
# for line in txt.split("\n"):
#     if line.startswith(">"):
#         id=line[1:]
#     else:
#         seq=line
#         data.append((id,seq))
        
# txt=open("datasets/PDB14189/PDB14189_P.txt","r").read()
# for line in txt.split("\n"):
#     if line.startswith(">"):
#         id=line[1:]
#     else:
#         seq=line
#         data.append((id,seq))
#------------

def get_T5_model():
    model = T5EncoderModel.from_pretrained("Rostlab/prot_t5_xl_half_uniref50-enc")
    model = model.to(device) # move model to GPU
    model = model.eval() 
    tokenizer = T5Tokenizer.from_pretrained('Rostlab/prot_t5_xl_half_uniref50-enc', do_lower_case=False)
    return model, tokenizer

model, tokenizer=get_T5_model()


for item in data:
    sequence_examples = [item[1]]
    sequence_examples = [" ".join(list(re.sub(r"[UZOB]", "X", sequence))) for sequence in sequence_examples]

    ids = tokenizer.batch_encode_plus(sequence_examples, add_special_tokens=True, padding="longest")
    input_ids = torch.tensor(ids['input_ids']).to(device)
    attention_mask = torch.tensor(ids['attention_mask']).to(device)
    # generate embeddings
    with torch.no_grad():
        embedding_repr = model(input_ids=input_ids,attention_mask=attention_mask)
    # extract embeddings for the first ([0,:]) sequence in the batch while removing padded & special tokens ([0,:7]) 
    emb_0 = embedding_repr.last_hidden_state[0,:] # shape (7 x 1024)
    print(item[0],emb_0.shape)
    np.save(f"{save_path}/{item[0]}.npy",emb_0.cpu())

posted @ 2025-08-25 01:14  ylifs  阅读(66)  评论(0)    收藏  举报