对比学习

# coding=utf-8
"""PyTorch RoBERTa model. """

import math
import warnings
import fitlog

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn import CrossEntropyLoss, MSELoss, MarginRankingLoss

from transformers.activations import ACT2FN, gelu
from transformers.configuration_roberta import RobertaConfig

from transformers.modeling_roberta import (
    RobertaPreTrainedModel,
    RobertaModel
)

from .CVAEModel import CVAEModel
from .Attention import AttentionInArgs
from .SelfAttention import SelfAttention
from .GATModel import GAT

import logging
logger = logging.getLogger(__name__)

import numpy as np

class RobertaLMHead(nn.Module):
    """Roberta Head for masked language modeling."""

    def __init__(self, config):
        super().__init__()
        self.dense = nn.Linear(config.hidden_size, config.hidden_size)
        self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)

        self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
        self.bias = nn.Parameter(torch.zeros(config.vocab_size))

        # Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings`
        self.decoder.bias = self.bias

    def forward(self, features, **kwargs):
        x = self.dense(features)
        x = gelu(x)
        x = self.layer_norm(x)

        # project back to size of vocabulary with bias
        x = self.decoder(x)

        return x

class RobertaClassificationHead(nn.Module):
    """Head for sentence-level classification tasks."""

    def __init__(self, config):
        super().__init__()
        self.dense = nn.Linear(config.hidden_size, config.hidden_size)
        self.dropout = nn.Dropout(config.hidden_dropout_prob)
        self.out_proj = nn.Linear(config.hidden_size, config.num_labels)
        # self.maxpool = nn.MaxPool2d((256, 1))

    def forward(self, features, **kwargs):
        x = features[:, 0, :]  # take <s> token (equiv. to [CLS])
        # x = self.maxpool(features).squeeze(1)
        x = self.dropout(x)
        x = self.dense(x)
        x = torch.tanh(x)
        x = self.dropout(x)
        x = self.out_proj(x)
        return x

class RobertaPDTBModel(RobertaPreTrainedModel):
    authorized_missing_keys = [r"position_ids"]

    def __init__(self, config):
        super().__init__(config)
        self.num_labels = config.num_labels

        self.roberta = RobertaModel(config, add_pooling_layer=False)
        # self.roberta_for_mlm = RobertaModel(config, add_pooling_layer=False)

        self.lm_head = RobertaLMHead(config)

        self.classifier = RobertaClassificationHead(config)

        self.laynorm = nn.LayerNorm(config.hidden_size)
        
        # 加attention
        self.self_attention = SelfAttention(input_size=768,
                                         embedding_dim=256,
                                         output_size=768
                                         )
        # self.self_atten = nn.MultiheadAttention(768, 2)
        self.attention = AttentionInArgs(input_size=768,            # 768
                                         embedding_dim=256,         # 256
                                         output_size=768            # 256, 768
                                         )
        self.self_attention2 = nn.MultiheadAttention(768, 2, dropout=0.2)
        self.attention2 = nn.MultiheadAttention(768, 2)

        self.attn_fc = nn.Linear(768, 128)
        self.atten_fc = nn.Linear(128, 256)

        # torch.Size([8, 256, 768]) torch.Size([8, 256, 256])
        # self.gat = GAT(in_features=768, n_dim=256, n_class=768)

        self.projector = nn.Sequential(
            nn.Linear(768, 256),
            nn.LayerNorm(256),
            nn.ReLU(),
            nn.Linear(256, 768)            
        )

        self.cvae = CVAEModel(config.hidden_size, config.hidden_size)

        self.init_weights()

    def _select_attention(self, sequence_output, attention_mask, arg1_first=False):
        # 加attention
        arg_len = sequence_output.shape[1] // 2
        # X: [8, 128, 768], [8, 128, 768] --> [8, 256, 768]
        # X: [8, 82, 768], [8, 82, 768] --> [8, 164, 768]
        arg1, arg2 = sequence_output[:, :arg_len, :], sequence_output[:, arg_len:, :]
        # arg1_mask, arg2_mask = attention_mask[:, :arg_len], attention_mask[:, :arg_len]
        arg1 = self.self_attention(arg1, arg1)
        arg2 = self.self_attention(arg2, arg2)
        # arg1, _ = self.self_atten(arg1, arg1, arg1)
        # arg2, _ = self.self_atten(arg2, arg2, arg2)
        
        # print(arg1.shape)

        if arg1_first:
            # seq_out, _ = self.inter_atten(arg1, arg2, arg2)  # [8, 128, 768]
            seq_out = self.attention(arg1, arg2)
        else:
            # seq_out, _ = self.inter_atten(arg2, arg1, arg1)  # [8, 128, 768]
            seq_out = self.attention(arg2, arg1)

        # print(seq_out.shape)
        return self.attn_fc(seq_out)
        # return seq_out

    def _random_mask(self, sequence_output, arg1_first=False):
        # 加attention
        arg_len = sequence_output.shape[1] // 2
        # X: [8, 128, 768], [8, 128, 768] --> [8, 256, 768]
        # X: [8, 82, 768], [8, 82, 768] --> [8, 164, 768]
        arg1, arg2 = sequence_output[:, :arg_len, :], sequence_output[:, arg_len:, :]

        if arg1_first:
            return self.attn_fc(arg1)
        else:
            return self.attn_fc(arg2)


    def _add_attention(self, sequence_output, attention_mask):
        # 加attention
        arg_len = sequence_output.shape[1] // 2
        # X: [8, 128, 768], [8, 128, 768] --> [8, 256, 768]
        # X: [8, 82, 768], [8, 82, 768] --> [8, 164, 768]
        arg1, arg2 = sequence_output[:, :arg_len, :], sequence_output[:, arg_len:, :]
        # arg1_mask, arg2_mask = attention_mask[:, :arg_len], attention_mask[:, :arg_len]
        # 目标: [8, 256, 256]
        arg1 = self.self_attention(arg1, arg1)
        arg2 = self.self_attention(arg2, arg2)
        sequence_output = self.attention(arg1, arg2, attention_mask)  # [8, 256]
        
        # logging.info('adj: ' + str(adj[0]) + ' ' + str(adj.shape))
        # sequence_output = self.gat(sequence_output, adj)
        return sequence_output

    def _add_attention_main(self, sequence_output, attention_mask=None):
        # 加attention
        arg_len = sequence_output.shape[1] // 2
        # X: [8, 128, 768], [8, 128, 768] --> [8, 256, 768]
        # X: [8, 82, 768], [8, 82, 768] --> [8, 164, 768]
        arg1, arg2 = sequence_output[:, :arg_len, :], sequence_output[:, arg_len:, :]
        # arg1_mask, arg2_mask = attention_mask[:, :arg_len], attention_mask[:, :arg_len]
        # 目标: [8, 256, 256]
        arg1, _ = self.self_attention2(arg1.transpose(0, 1), arg1.transpose(0, 1), arg1.transpose(0, 1))
        # print(arg1.shape)
        arg2, _ = self.self_attention2(arg2.transpose(0, 1), arg2.transpose(0, 1), arg2.transpose(0, 1))
        arg1_attn, _ = self.attention2(arg1, arg1, arg2) 
        arg2_attn, _ = self.attention2(arg2, arg2, arg1) 
        
        arg1_attn = arg1_attn.transpose(0, 1)
        arg2_attn = arg2_attn.transpose(0, 1)
        sequence_output = torch.cat([arg1_attn, arg2_attn], dim=1)
        
        return sequence_output

    def _mlm_attention(self, sequence_output, attention_mask, input_ids, args=None, tokenizer=None):
        torch.set_printoptions(profile="full")

        masked_input_ids = input_ids.clone().detach()
        
        sequence_for_mask = self._select_attention(sequence_output, attention_mask, arg1_first=True)
        # sequence_for_mask =  self._random_mask(sequence_output, arg1_first=True)
        sequence_for_mask = torch.sum(sequence_for_mask, dim=2)
        # sequence_for_mask = torch.argmax(sequence_for_mask, dim=2)  # [8, 128]
        # print('seq mask: ', sequence_for_mask)
        
        mask_idx = torch.argsort(sequence_for_mask, dim=1, descending=True)[:, :args.mask_num]
        # print(mask_idx.shape, mask_idx)

        for elem, idx in zip(masked_input_ids, mask_idx):
            elem[idx] = 50264    # '<mask>': 50264
        # for elem in masked_input_ids:
        #     idx = list(range(0, 128))
        #     random.shuffle(idx)
        #     idx = idx[:args.mask_num]        
        #     idx = sequence_for_mask[:, idx]            
        #     elem[128 + idx] = 50264

        # print('masked input_ids_arg2: ', masked_input_ids, '\n')
        # print('input_ids: ', input_ids)
        # print(tokenizer.convert_ids_to_tokens(masked_input_ids[0]))
        # print(tokenizer.convert_ids_to_tokens(input_ids[0]))

        sequence_for_mask = self._select_attention(sequence_output, attention_mask)
        # sequence_for_mask =  self._random_mask(sequence_output)
        sequence_for_mask = torch.sum(sequence_for_mask, dim=2)
        # sequence_for_mask = torch.argmax(sequence_for_mask, dim=2)
        mask_idx = torch.argsort(sequence_for_mask, dim=1, descending=True)[:, :args.mask_num]

        for elem, idx in zip(masked_input_ids, mask_idx):
            elem[idx] = 50264    # '<mask>': 50264

        # for elem in masked_input_ids:
        #     idx = list(range(0, 128))
        #     random.shuffle(idx)
        #     idx = idx[:args.mask_num]
        #     idx = sequence_for_mask[:, idx]
        #     elem[idx] = 50264

        # print(mask_idx.shape, mask_idx)
        # print('masked input_ids_arg1: ', masked_input_ids, '\n')
        # print('input_ids: ', input_ids)
        # print(tokenizer.convert_ids_to_tokens(masked_input_ids[0]))
        # print(tokenizer.convert_ids_to_tokens(input_ids[0]))

        return masked_input_ids, input_ids
    
    def get_contrastive_loss(self, self_sample, positive_sample, reverse_sample):
        self.cos = nn.CosineSimilarity(dim=-1)
        self_and_pos = self.cos(self_sample, positive_sample)
        self_and_neg = self.cos(self_sample, reverse_sample)

        temp1 = torch.div(self_and_pos, 0.5)
        temp2 = torch.div(self_and_neg, 0.5)

        loss = -nn.LogSoftmax(0)(torch.div(temp1, temp1 + temp2)).diag().sum()
   
        return loss

    def loss_hardest_from_batchneg_and_nonclick(self, gap_value, self_sample, positive_sample, reverse_sample=None, labels=None, device=None):
        batch_size = self_sample.size(0)
        self_sample= torch.mean(self_sample, 1)
        positive_sample = torch.mean(positive_sample, 1)
        reverse_sample = torch.mean(reverse_sample,1)

        """select hardest from batchneg and nonclick, the pos must be better than hardest case"""
        # query_embeddings = fluid.layers.reshape(query_embeddings, shape=[batch_size, 768])
        # pos_embeddings = fluid.layers.reshape(pos_embeddings, shape=[batch_size, 768])
        # neg_embedding = fluid.layers.reshape(neg_embedding, shape=[batch_size, 768])

        self_sample = self_sample.view(batch_size, 768)
        positive_sample = positive_sample.view(batch_size, 768)
        reverse_sample = reverse_sample.view(batch_size, 768)
        
        # query_embeddings_norm = fluid.layers.l2_normalize(x=query_embeddings, axis=-1) #[b,768] # paddle
        # pos_embeddings_norm = fluid.layers.l2_normalize(x=pos_embeddings, axis=-1)
        # neg_embeddings_norm = fluid.layers.l2_normalize(x=neg_embedding, axis=-1)

        # 和下面的sum,就是余弦相似度
        self_sample_norm = F.normalize(self_sample, dim=-1, p=2)
        positive_sample_norm = F.normalize(positive_sample, dim=-1, p=2)
        reverse_sample_norm = F.normalize(reverse_sample, dim=-1, p=2)
       
        # make eye
        self_mask = torch.matmul(self_sample_norm,self_sample_norm.transpose(0,1))
        ones = torch.ones_like(self_mask, dtype = torch.float32).to(device)
        self_mask = ones + torch.sign(self_mask + 1e-4 - ones)

        pos_score = torch.sum(self_sample_norm*positive_sample_norm, dim=1, keepdim=True)
        
        # 自身和正例相似度
        cosdist_pos_self = torch.matmul(positive_sample_norm, self_sample_norm.transpose(0,1))
        # 可以不加这个,一个trick
        cosdist_self_pos = torch.matmul(self_sample_norm, positive_sample_norm.transpose(0,1))

        # 自身和cvae的相似度
        neg_score_sup = torch.sum(self_sample_norm*reverse_sample_norm, dim=1, keepdim=True)

        # cvae负例和batch内其他负例,拼接,一起用
        cosdist_self_all = torch.cat([cosdist_self_pos - 10 * self_mask, neg_score_sup], dim=1)

        neg_score_hardest, _ = torch.max(cosdist_self_all, dim=1, keepdim=True)
        neg_score_for_pos, _ = torch.max(cosdist_pos_self - 10 * self_mask, dim=1, keepdim=True)
        # labels = torch.ones((batch_size, 1), dtype=torch.float32).to(device)

        # max(0, -y*(x1-x2) + margin)
        margin_loss_func = MarginRankingLoss(margin=gap_value)
        rank_loss_from_self = margin_loss_func(pos_score, neg_score_hardest, labels)
        rank_loss_from_pos = margin_loss_func(pos_score, neg_score_for_pos, labels)
        # rank_loss_from_self = F.margin_ranking_loss(pos_score, neg_score_hardest, labels, margin=gap_value)
        # rank_loss_from_pos = F.margin_ranking_loss(pos_score, neg_score_for_pos, labels, margin=gap_value)

        loss = rank_loss_from_self + rank_loss_from_pos
        loss = torch.mean(loss)
        return loss

    def _inter_attention(self, sequence_output, attention_mask):
  
        # print(arg1.shape)
        seq_out, _ = self.attention2(sequence_output.transpose(0, 1), 
                                     sequence_output.transpose(0, 1), 
                                     sequence_output.transpose(0, 1))  # [8, 128, 768]
        return seq_out.transpose(0, 1)

    def forward(
        self,
        input_ids=None,
        attention_mask=None,
        token_type_ids=None,
        position_ids=None,
        head_mask=None,
        inputs_embeds=None,
        labels=None,
        output_attentions=None,
        output_hidden_states=None,
        return_dict=None,
        Training=False,
        tokenizer=None,
        args=None,
        global_step=0
    ):
        r"""
        labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`):
            Labels for computing the sequence classification/regression loss.
            Indices should be in :obj:`[0, ..., config.num_labels - 1]`.
            If :obj:`config.num_labels == 1` a regression loss is computed (Mean-Square loss),
            If :obj:`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
        """
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict

        outputs = self.roberta(
            input_ids,
            attention_mask=attention_mask,
            token_type_ids=token_type_ids,
            position_ids=position_ids,
            head_mask=head_mask,
            inputs_embeds=inputs_embeds,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
        )
        sequence_output = outputs[0]
        
        # 在这修改代码
        sequence_output = self.laynorm(sequence_output)
        
        ###########################################################################################
        # MLM任务
        input_ids_label = None
        if Training and args.do_mlm > 0:
            masked_ids, input_ids_label = self._mlm_attention(sequence_output, attention_mask, input_ids, args, tokenizer=tokenizer)
            mlm_tasks = self.roberta(masked_ids, 
                                    attention_mask=attention_mask,
                                    token_type_ids=token_type_ids,
                                    position_ids=position_ids,
                                    head_mask=head_mask,
                                    inputs_embeds=inputs_embeds,
                                    output_attentions=output_attentions,
                                    output_hidden_states=output_hidden_states,
                                    return_dict=return_dict,
                                    )
            mlm_sequence_output = mlm_tasks[0]
   
            prediction_scores = self.lm_head(mlm_sequence_output)

            # 测试
            if global_step % 100 == 0:
                pre = torch.argmax(prediction_scores, dim=2)
                print('pre: ', tokenizer.convert_ids_to_tokens(pre[0]))
                print('label: ', tokenizer.convert_ids_to_tokens(input_ids[0]))

        masked_mlm_loss = None
        if input_ids_label is not None:
            loss_fct_mlm = CrossEntropyLoss()
            masked_mlm_loss = loss_fct_mlm(prediction_scores.view(-1, self.config.vocab_size), input_ids_label.view(-1))
        ###############################################################################################

        # 主任务
        # sequence_output = self._add_attention(sequence_output, attention_mask)
        # sequence_output = self.atten_fc(sequence_output.transpose(1, 2)).transpose(1, 2)
        sequence_output = self._inter_attention(sequence_output, attention_mask)

        out, mu, logvar = self.cvae(x=sequence_output, y=labels, Training=True, device=args.device)
        cvae_loss = CVAEModel.loss_function(recon_x=out, x=sequence_output, mu=mu, logvar=logvar) 

        # 正例构建
        positive_sample = nn.Dropout(0.1)(sequence_output)
        # positive_sample = self.cvae(x=sequence_output, y=labels, device=args.device)
        # 负例构建
        # reverse_sample = sequence_output[torch.randperm(sequence_output.shape[0])]
        reverse_sample,_,_ = self.cvae(x=sequence_output, y=labels, Training=True, Use=True, device=args.device)
        # reverse_sample = self.cvae(x=sequence_output, y=labels, device=args.device)
        # contrastive_loss = self.get_contrastive_loss(sequence_output, positive_sample, reverse_sample)
        contrastive_loss = self.loss_hardest_from_batchneg_and_nonclick(0.2, sequence_output, positive_sample, 
                                                                             reverse_sample, labels, args.device)

        sequence_output = self.laynorm(sequence_output)

        logits = self.classifier(sequence_output)

        # 计算loss
        loss = None
        if labels is not None:
            if self.num_labels == 1:
                #  We are doing regression
                loss_fct = MSELoss()
                loss = loss_fct(logits.view(-1), labels.view(-1))
            else:
                loss_fct = CrossEntropyLoss()
                # [8, 2], [8]
                loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
                
                # loss_主 + loss_mlm
                if Training and args.do_mlm > 0 and masked_mlm_loss is not None: 
                    loss = loss + args.mlm_theta * masked_mlm_loss
                
                # logger.info(str(contrastive_loss))
                if Training:
                    if global_step % 100 == 0:
                        fitlog.add_loss(contrastive_loss, name = 'contrastive loss', step=global_step)
                        fitlog.add_loss(cvae_loss, name = 'cvae_loss', step=global_step)
                    # loss = loss + 0.5 * contrastive_loss + 0.01 * cvae_loss     # 0.01
                    # loss = loss + 0.5 * contrastive_loss + 0.001 * cvae_loss
                    # loss = loss + 0.05 * contrastive_loss + 0.001 * cvae_loss
                    loss = loss + args.con_theta * contrastive_loss + args.cvae_theta * cvae_loss  # 0.005, 0.001

        if not return_dict:
            output = (logits,) + outputs[2:]
            return ((loss,) + output) if loss is not None else output

        return SequenceClassifierOutput(
            loss=loss,
            logits=logits,
            hidden_states=outputs.hidden_states,
            attentions=outputs.attentions,
        )
posted @ 2021-09-22 14:28  douzujun  阅读(105)  评论(0编辑  收藏  举报