Comp-Agg (A Compare-Aggregate Model for Matching Text Sequences)
CompareAggregate研究意義:
1、采用“比較-聚合”框架,并對(duì)此進(jìn)行改進(jìn)
2、采用多種數(shù)據(jù)集驗(yàn)證模型的泛化性
?
本文主要結(jié)構(gòu)如下所示:
一、Abstract
? ? ? 摘要部分主要介紹本文利用詞嵌入作為輸入,CNN網(wǎng)絡(luò)作為聚合函數(shù),提出比較聚合框架;關(guān)注于不同的比較函數(shù)來(lái)對(duì)文本向量進(jìn)行匹配;并且使用不同的幾份數(shù)據(jù)評(píng)估模型;
基于element-wise的比較函數(shù)可能會(huì)比復(fù)雜神經(jīng)網(wǎng)絡(luò)效果更好。
二、Introudction
? ? ? 首先提及了很多自然語(yǔ)言處理任務(wù)都需要對(duì)兩個(gè)或多個(gè)句子進(jìn)行匹配,然后作出決定。
? ? ? ?
三、Method
? ? ? ? ? ? ?主要介紹模型的結(jié)構(gòu)以及六個(gè)不同的比較函數(shù)
四、Experiment
? ? ? ? ? ? 實(shí)驗(yàn)部分主要介紹不同比較函數(shù)以及組合函數(shù)在四個(gè)不同任務(wù)數(shù)據(jù)集合的效果,證明組合比較函數(shù)模型的有效性
五、Related Work
? ? ? ? ? ?相關(guān)工作部分簡(jiǎn)單的描述了孿生網(wǎng)絡(luò)、注意力機(jī)制以及比較-聚合網(wǎng)絡(luò)相關(guān)的應(yīng)用
六、Conclusions
? ? ? ? ? ? 最后一部分總結(jié)了本文系統(tǒng)分析“比較-聚合”模型在四個(gè)不同任務(wù)數(shù)據(jù)集上的有效性,此外還提出了詞級(jí)別的比較函數(shù)element-wise 比較函數(shù)表現(xiàn)好于其它函數(shù),并且根據(jù)實(shí)驗(yàn)結(jié)果很多不同任務(wù)可以共享“比較-聚合”結(jié)構(gòu),在未來(lái)的任務(wù)中,可以把它使用在多任務(wù)學(xué)習(xí)中。
? ? ? ? ?關(guān)鍵點(diǎn): 采用“比較-聚合”框架;利用多種數(shù)據(jù)集證明模型的有效性;提出多種比較函數(shù)并探究了交互的最佳方式
? ? ? ? ?創(chuàng)新點(diǎn): 利用門控單元提取語(yǔ)義特征,利用注意力機(jī)制完成句子權(quán)重匹配,利用向量的差和積進(jìn)行特征提取
七、Code
# -*- coding: utf-8 -*-# @Time : 2021/2/14 下午2:07 # @Author : TaoWang # @Description : "比較-聚合" 模型結(jié)構(gòu)import torch import torch.nn as nn import numpy as np from torch.utils.data import DataLoader, Dataset from torch.autograd import Variable# 預(yù)處理層 class Preprocess(nn.Module):def __init__(self, in_features, out_features):""":param in_features: :param out_features: """super().__init__()self.Wi = nn.Parameter(torch.randn(in_features, out_features))self.bi = nn.Parameter(torch.randn(out_features))self.wu = nn.Parameter(torch.randn(in_features, out_features))self.bu = nn.Parameter(torch.randn(out_features))def forward(self, x):""":param x: :return: """gate = torch.matmul(x, self.Wi)gate = torch.sigmoid(gate + self.bi.expand_as(date))out = torch.matmul(x, self.Wu)out = torch.tanh(out + self.bu.expand_as(out))return gate * out# 注意力層 class Attention(nn.Module):def __init__(self):super().__init__()self.wg = nn.Parameter(torch.randn(hidden_size, hidden_size))self.bg = nn.Parameter(torch.randn(hidden_size))def forward(self, q, a):""":param q: :param a: :return: """G = torch.matmul(q, self.wg)G = G + self.bg.expand_as(G)G = torch.matmul(G, a.permute(0, 2, 1))G = torch.softmax(G, dim=1)H = torch.matmul(G.permute(0, 2, 1), q)return H# 模型比較層 class Compare(nn.Module):def __init__(self):super().__init__()self.W = nn.Parameter(torch.randn(2*hidden_size, hidden_size))self.b = nn.Parameter(torch.randn(hidden_size))def forward(self, h, a):""":param h: :param a: :return: """sub = (h - a) * (h - a)mult = h * aT = torch.matmul(torch.cat([sub, mult], dim=2), self.W)T = torch.relu(T + self.b.expand_as(T))return T# 模型比較聚合層匯總class CompAgg(torch.nn.Module):def __init__(self):super(CompAgg, self).__init__()self.embedding = nn.Embedding(vocab_size, embedding_size)self.embedding.weight.data.copy_(torch.from_numpy(embed))self.preprocess = Preprocess(embedding_size, hidden_size)self.attention = Attention()self.compare = Compare()self.aggregate = nn.Conv1d(in_channels=max_len, out_channels=window, kernel_size=3, stride=1, padding=1)self.predict = nn.Linear(window * hidden_size, classes)def forward(self, q, a):""":param q: 設(shè) q長(zhǎng)度 20:param a: 設(shè) a長(zhǎng)度 40:return: """# emb_q: batch * 20 * 200, emb_a: batch * 40 * 200emb_q, emb_a = self.embedding(q), self.embedding(a)# q_bar: batch * 20 * 100, a_bar: batch * 40 * 100q_bar, a_bar = self.preprocess(emb_q), self.preprocess(emb_a)# H: batch * 40 * 100H = self.attention(q_bar, a_bar)# T: batch * 40 * 100T = self.compare(H, a_bar)# r: batch * 3 * 100r = self.aggregate(T)# r: batch * 300r = r.view(-1, window * hidden_size)# out: batch * 3out = self.predict(r)return out?
總結(jié)
以上是生活随笔為你收集整理的Comp-Agg (A Compare-Aggregate Model for Matching Text Sequences)的全部?jī)?nèi)容,希望文章能夠幫你解決所遇到的問(wèn)題。
- 上一篇: SiameseNet(Learning
- 下一篇: ESIM (Enhanced LSTM