SimCSE论文及源码解读
對比學習的思想是拉近同類樣本的距離,增大不同類樣本的距離,目標是要從樣本中學習到一個好的語義表示空間。SimCSE是一種簡單的無監督對比學習框架,它通過對同一句子兩次Dropout得到一對正樣例,將該句子與同一個batch內的其它句子作為一對負樣例。模型結構如下所示:
損失函數為:
?i=?log?esim?(hizi,hizi′)/τ∑j=1Nesim?(hizi,hjzj′)/τ\ell_{i}=-\log \frac{e^{\operatorname{sim}\left(\mathbf{h}_{i}^{z_{i}}, \mathbf{h}_{i}^{z_{i}^{\prime}}\right) / \tau}}{\sum_{j=1}^{N} e^{\operatorname{sim}\left(\mathbf{h}_{i}^{z_{i}}, \mathbf{h}_{j}^{z_{j}^{\prime}}\right) / \tau}} ?i?=?log∑j=1N?esim(hizi??,hjzj′??)/τesim(hizi??,hizi′??)/τ?
代碼實現
在作者的代碼中,并不是將一個句子輸入到模型中兩次,而是復制一份放到同一個batch里。模型的核心是 cl_forward 函數:
def cl_forward(cls,encoder,input_ids=None,attention_mask=None,token_type_ids=None,position_ids=None,head_mask=None,inputs_embeds=None,labels=None,output_attentions=None,output_hidden_states=None,return_dict=None,mlm_input_ids=None,mlm_labels=None, ):return_dict = return_dict if return_dict is not None else cls.config.use_return_dictori_input_ids = input_ids # 形狀為[bs, num_sent, sent_len], bs=32batch_size = input_ids.size(0)# Number of sentences in one instance# 2: pair instance,[自己,自己]; 3: pair instance with a hard negative,[自己,自己,難例]num_sent = input_ids.size(1)mlm_outputs = None# Flatten input for encodinginput_ids = input_ids.view((-1, input_ids.size(-1))) # [bs * num_sent, sent_len]attention_mask = attention_mask.view((-1, attention_mask.size(-1))) # [bs * num_sent, sent_len]if token_type_ids is not None:token_type_ids = token_type_ids.view((-1, token_type_ids.size(-1))) # [bs * num_sent, sent_len]# Get raw embeddings, [bs, num_sent, sent_len, hidden_size]outputs = encoder(input_ids,attention_mask=attention_mask,token_type_ids=token_type_ids,position_ids=position_ids,head_mask=head_mask,inputs_embeds=inputs_embeds,output_attentions=output_attentions,output_hidden_states=True if cls.model_args.pooler_type in ['avg_top2', 'avg_first_last'] else False,return_dict=True,)# MLM auxiliary objectiveif mlm_input_ids is not None:mlm_input_ids = mlm_input_ids.view((-1, mlm_input_ids.size(-1)))mlm_outputs = encoder(mlm_input_ids,attention_mask=attention_mask,token_type_ids=token_type_ids,position_ids=position_ids,head_mask=head_mask,inputs_embeds=inputs_embeds,output_attentions=output_attentions,output_hidden_states=True if cls.model_args.pooler_type in ['avg_top2', 'avg_first_last'] else False,return_dict=True,)# Poolingpooler_output = cls.pooler(attention_mask, outputs)pooler_output = pooler_output.view((batch_size, num_sent, pooler_output.size(-1))) # (bs, num_sent, hidden_size)# If using "cls", we add an extra MLP layer# (same as BERT's original implementation) over the representation.if cls.pooler_type == "cls":pooler_output = cls.mlp(pooler_output)# Separate representation, [bs, hidden_size], 同一樣本經過“兩次Dropout”得到的兩個句向量z1, z2 = pooler_output[:,0], pooler_output[:,1]# Hard negativeif num_sent == 3:z3 = pooler_output[:, 2]# Gather all embeddings if using distributed trainingif dist.is_initialized() and cls.training:# Gather hard negativeif num_sent >= 3:z3_list = [torch.zeros_like(z3) for _ in range(dist.get_world_size())]dist.all_gather(tensor_list=z3_list, tensor=z3.contiguous())z3_list[dist.get_rank()] = z3z3 = torch.cat(z3_list, 0)# Dummy vectors for allgatherz1_list = [torch.zeros_like(z1) for _ in range(dist.get_world_size())]z2_list = [torch.zeros_like(z2) for _ in range(dist.get_world_size())]# Allgatherdist.all_gather(tensor_list=z1_list, tensor=z1.contiguous())dist.all_gather(tensor_list=z2_list, tensor=z2.contiguous())# Since allgather results do not have gradients, we replace the# current process's corresponding embeddings with original tensorsz1_list[dist.get_rank()] = z1z2_list[dist.get_rank()] = z2# Get full batch embeddings: (bs x N, hidden)z1 = torch.cat(z1_list, 0)z2 = torch.cat(z2_list, 0)# [bs, bs],計算該樣本與其它樣本的相似度cos_sim = cls.sim(z1.unsqueeze(1), z2.unsqueeze(0))# Hard negativeif num_sent >= 3:z1_z3_cos = cls.sim(z1.unsqueeze(1), z3.unsqueeze(0))cos_sim = torch.cat([cos_sim, z1_z3_cos], 1)# [bs, ], 內容為[0,1,...,bs-1],表示每個樣本最相似的樣本下標labels = torch.arange(cos_sim.size(0)).long().to(cls.device)# 此處顯示出對比學習loss和常規交叉熵loss的區別,# 對比學習的label數是[bs,bs],而交叉熵的label數是[bs, label_nums]loss_fct = nn.CrossEntropyLoss()# Calculate loss with hard negativesif num_sent == 3:# Note that weights are actually logits of weightsz3_weight = cls.model_args.hard_negative_weightweights = torch.tensor([[0.0] * (cos_sim.size(-1) - z1_z3_cos.size(-1)) + [0.0] * i + [z3_weight] + [0.0] * (z1_z3_cos.size(-1) - i - 1) for i in range(z1_z3_cos.size(-1))]).to(cls.device)cos_sim = cos_sim + weightsloss = loss_fct(cos_sim, labels)# Calculate loss for MLMif mlm_outputs is not None and mlm_labels is not None:mlm_labels = mlm_labels.view(-1, mlm_labels.size(-1))prediction_scores = cls.lm_head(mlm_outputs.last_hidden_state)masked_lm_loss = loss_fct(prediction_scores.view(-1, cls.config.vocab_size), mlm_labels.view(-1))loss = loss + cls.model_args.mlm_weight * masked_lm_lossif not return_dict:output = (cos_sim,) + outputs[2:]return ((loss,) + output) if loss is not None else outputreturn SequenceClassifierOutput(loss=loss,logits=cos_sim,hidden_states=outputs.hidden_states,attentions=outputs.attentions,)上述代碼考慮諸多場景,比如分布式訓練、難例三元組、mlm mask,寫的較為復雜。
以下是簡化版,更加符合論文的表述:
loss_func = nn.CrossEntropyLoss() def simcse_loss(batch_emb):"""用于無監督SimCSE訓練的loss"""batch_size = batch_emb.size(0) # [bs, hidden_size]# 構造標簽, [bs, 2], bs=64y_true = torch.cat([torch.arange(1, batch_size, step=2, dtype=torch.long).unsqueeze(1),torch.arange(0, batch_size, step=2, dtype=torch.long).unsqueeze(1)],dim=1).reshape([batch_size,])# 計算score和lossnorm_emb = F.normalize(batch_emb, dim=1, p=2)# [bs, bs],計算該樣本與其它樣本的相似度sim_score = torch.matmul(norm_emb, norm_emb.transpose(0,1))# 對角線的位置,也就是自身的余弦相似度,肯定為1,不產生loss,需要mask掉sim_score = sim_score - torch.eye(batch_size) * 1e12sim_score = sim_score * 20 # 溫度系數loss = loss_func(sim_score, y_true)return lossFAQ
- 如果同一個batch里有其它語義相似的正樣本,但在這里被當作了負樣例處理,不是也拉遠了同類樣本的距離嗎?
參考
- princeton-nlp/SimCSE
- “被玩壞了”的Dropout
- 細節滿滿!理解對比學習和SimCSE,就看這6個知識點
- SIMCSE算法源碼分析
總結
以上是生活随笔為你收集整理的SimCSE论文及源码解读的全部內容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: php 模拟微信登录,PHP快速实现微信
- 下一篇: Android插件化换肤