Bag of Tricks for Efficient Text Classification(Fasttext)
Fasttext歷史意義:
1、提出一種新的文本分類方法-Fasttext,能夠快速進行文本分類,效果較好
2、提出一種新的使用子詞的詞向量訓練方法,能夠在一定程度上解決oov問題
3、將Fasttext開源使得工業界和學術界能夠快速的使用Fasttext
?
深度學習文本分類模型:
優點:效果好,能達到非常好的效果,不用做特征工程,模型簡潔
缺點:速度比較慢,無法在大規模的文本分類任務上應用
?
機器學習文本分類模型:
優點:速度一般都很快,模型都是線性分類器,比較簡單;效果還可以,在某些任務上可以取得比較好的結果
缺點:需要做特征工程,分類效果依賴于有效特征的提取
?
本文主要結構:
一、Abstract
? ? ? ?提出一種簡單的高效文本分類模型,效果和其它深度學習模型相當,但是速度快很多倍
二、Inrtroduction
? ? ? ?文本分類是自然語言處理的重要任務,可以用于信息檢索,網頁搜索、文檔分類等;基于深度學習可以達到非常的好的效果,但是速度慢限制文本分類的應用;基于機器學習的線性分類器效果也很好,有用于大規模分類任務的潛力;從現在詞向量中得到靈感,提出一種使用新的文本分類方法Fasttext,這種方法能夠快速的訓練和測試并且達到和最優結果相似的結果。
三、Model architerture
? ? ? ?詳細介紹Fasttext的模型結構以及兩個技巧,分別是層次softmax和n-gram特征? ? ?
模型結構如上圖,與CBOW的模型結果相同,與CBOW模型的區別和聯系如下所示:
聯系:
? ? ?1)都是log-line模型,模型簡單
? ? ?2)都是對輸入的詞向量做平均,然后進行預測
? ? ?3)模型結構完全一樣
區別:
? ? ?1)? Fasttext提取的是句子特征,CBOW提取的是上下文特征?
? ? ? 2)Fasttext需要標注語料,是監督學習,CBOW不需要人工標注語料,是無監督學習
Fasttext存在的問題:
? ? ? 1)? 當類別非常多的時候,最后的softmax速度比較慢(因為要構造詞表大小的數據)
? ? ? 2)? 使用的是詞袋模型,沒有詞序信息
解決辦法:
? ? ? ?1) 層次softmax
? ? ? ? ? ? ?和word2vec中的層次softmax一樣,可以減少參數由原來的H*V -> H*log2V (V表示詞表大小)
? ? ? ? 2) 添加使用n-gram特征
? ? ? ? ? ? ?輸入模型的數據中添加了n-gram特征,并且用到了hash
? ? ? ? ? ? ?如果每一個詞對應一個向量,那么詞表太大;如果多個詞對應一個向量,不夠準確,所以構建hash方法
假如詞表大小限制為10w,1-gram單詞個數為3w,2-gram詞組個數為10w,3-gram詞組個數為40w,1-gram不用做hash,所以說詞表10w中前3w是留給1-gram的,剩余7w個位置還有50w個詞組沒有位置安放,所以50w/7w約等于7,也就是說這50w個詞組中大約有7個詞對應同一個詞向量。
? ? ? ? ? ? ? Fasttext另一篇文章中提到subword,主要是根據n-gram把詞拆開進行預測
?
四、Experiments
? ? ? ?在文本分類任務上和tag預測任務上都取得了非常好的結果,效果和其它深度模型相差不多,但是速度上會快很多
五、Discussion and conclusion
? ? ? ?對論文進行一些總結
? ? ? ?關鍵點:
? ? ? ? ? ? ?基于深度學習的文本分類方法效果好,但速度比較慢;
? ? ? ? ? ? ?基于線性分類器的機器學習方法速度比較快,但是需要做更多的特征工程;
? ? ? ? ? ? ?提出Fasttext模型
? ? ? ? 創新點:
? ? ? ? ? ? ?提出一種新的文本分類模型Fasttext;
? ? ? ? ? ? ?提出一些加快和使得文本分類效果更高的技巧-層次softmax和n-gram特征;
? ? ? ? ? ? ?在文本分類任務上和tag預測兩個任務上都取得了又快又好的結果。
? ? ? ? 啟發點:
? ? ? ? ? ? 雖然深度學習能夠取得非常好的結果,但是在訓練和測試的時候,非常慢限制了他們在大數據集上的應用(模型不一定在效果上大幅度提升,效果差不多,速度大幅度提升也是一種創新);
? ? ? ? ? ? 然而線性分類器不同特征和類別之間不共享參數,可能限制了一些只有少量樣本類別的泛化能力(共享詞向量);
? ? ? ? ? ? 大部分詞向量方法對每個詞分配一個獨立的詞向量,沒有共享參數,特別是這些方法忽略之間的聯系,而對于形態學豐富的語言更加重要。
?
六、代碼實現
# ****** 數據預處理 *****# 主要包括幾個部分-數據集加載、讀取標簽和數據、創建word2id、將數據轉化為id, 本次實驗還是使用AG數據集合,數據集下載位置 AG News: https://s3.amazonaws.com/fast-ai-nlp/ag_news_csv.tgz# encoding = 'utf-8'from torch.utils import data import os import csv import nltk import numpy as np# 數據集加載f = open("./data/AG/train.csv") rows = csv.reader(f,delimiter=',',quotechar='"') rows = list(rows) rows[1:5][['3','Carlyle Looks Toward Commercial Aerospace (Reuters)','Reuters - Private investment firm Carlyle Group,\\which has a reputation for making well-timed and occasionally\\controversial plays in the defense industry, has quietly placed\\its bets on another part of the market.'],['3',"Oil and Economy Cloud Stocks' Outlook (Reuters)",'Reuters - Soaring crude prices plus worries\\about the economy and the outlook for earnings are expected to\\hang over the stock market next week during the depth of the\\summer doldrums.'],['3','Iraq Halts Oil Exports from Main Southern Pipeline (Reuters)','Reuters - Authorities have halted oil export\\flows from the main pipeline in southern Iraq after\\intelligence showed a rebel militia could strike\\infrastructure, an oil official said on Saturday.'],['3','Oil prices soar to all-time record, posing new menace to US economy (AFP)','AFP - Tearaway world oil prices, toppling records and straining wallets, present a new economic menace barely three months before the US presidential elections.']]# 讀取標簽和數據n_gram,lowercase,label,datas = 2,True,[],[]for row in rows:label.append(int(row[0])-1)txt = " ".join(row[1:])if lowercase:txt = txt.lower()txt = nltk.word_tokenize(txt) #將句子轉化為詞new_txt = []for i in range(len(txt)):for j in range(n_gram):if j<=i:new_txt.append(" ".join(txt[i-j:i+1]))datas.append(new_txt)# word2idmin_count,word_freq = 3,{} for data in datas:for word in data:if word not in word_freq:word_freq[word] = 1else:word_freq[word] += 1# 首先構建uni-gram,不需要hashword2id = {"<pad>":0,"<unk>":1} for word in word_freq:if word_freq[word] < min_count or " " in word:continueword2id[word] = len(word2id)uniwords_num = len(word2id)# 構建2-gram以上的詞,需要hashfor word in word_freq:if word_freq[word] < min_count or " " not in word:continueword2id[word] = len(word2id)# 將文本中的詞都轉化為id,設置句子長度為100,詞表最大限制為1w max_length = 100for i,data in enumerate(datas):for j,word in enumerate(data):if " " not in word:datas[i][j] = word2id.get(word,1)else:datas[i][j] = word2id.get(word,1)%10000 + uniwords_numdatas[i] = datas[i][0:max_length] + [0]*(max_length - len(datas[i]))模型細節:
# """ 模型代碼 """# encoding='utf-8'import torch import torch.nn as nn import numpy as npclass Fasttext(nn.Module):def __init__(self,vocab_size,embedding_size,max_length,label_num):super(Fasttext,self).__init__()self.embedding = nn.Embedding(vocab_size,embedding_size)self.avg_pool = nn.AvgPool1d(kernel_size=max_length,stride=1)self.fc = nn.Linear(embedding_size,label_num)def forward(self,x):x = x.long()out = self.embedding(x) # batch_size * length * embedding_sizeout = out.transpose(1,2).contiguous() # batch_size * embedding_size * lengthout = self.avg_pool(out).squeeze() # batch_size * embedding_sizeout = self.fc(out) # batch_size * label_numreturn outfasttext = Fasttext(vocab_size=1000,embedding_size=10,max_length=100,label_num=4) test = torch.zeros([64,100]).long() out = fasttext(test)""" 查看網絡參數 """ from torchsummary import summarysummary(fasttext,input_size=(100,)) """ 模型訓練 """# encoding = 'utf-8'import torch import torch.autograd as autograd import torch.nn as nn import torch.optim as optim from model import Fasttext from data import AG_Data import numpy as np from tqdm import tqdmimport config as argumentparser config = argumentparser.ArgumentParser()""" 加載數據集 """training_set = AG_Data("/AG/train.csv",min_count = config.min_count,max_length=config.max_length,n_gram=config.n_gram)train_iter = torch.utils.data.DataLoader(dataset=training_set,batch_size=config.batch_size,shuffle=True,num_workers=0)test_set = AG_Data(data_path="/AG/test.csv",min_count=config.min_count,max_length = config.max_length,n_gram = config.n_gram,word2id = training_set.word2id,uniwords_num=training_set.uniwords_num)test_iter = torch.utils.data.DataLoader(dataset=test_set,batch_size=config.batch_size,shuffle = True,num_workers=0)""" 構建模型 """ model = Fasttext(vocab_size=training_set.uniwords_num+100000,embedding_size=config.embed_size,max_length=config.max_length,label_num=config.label_num)if config.cuda and torch.cuda.is_available():model.cuda()criterion = nn.CrossEntropyLoss() optimizer = optim.Adam(model.parameters(),lr=config.learning_rate) loss = -1def get_test_result(data_iter,data_set):mode.eval()true_sample_num = 0for data,label in data_iter:if config.cuda and torch.cuda.is_available():data = data.cuda()label = label.cuda()else:data = torch.autograd.Variable(data).long()out = model(data)true_sample_num += np.sum((torch.argmax(out,1)==label.long()).cpu().numpy())acc = true_sample_num/data_set.__len__()return accfor epoch in range(1):model.train()process_bar = tqdm(train_iter)for data,label in process_bar:if config.cuda and torch.cuda.is_available():data = data.cuda()label = label.cuda()else:data = torch.autograd.Variable(data).long()label = torch.autograd.Variable(label).squeeze()out = model(data)loss_now = criterion(out,autograd.Variable(label.long()))if loss == -1:loss = loss_now.data.item()else:loss = 0.95*loss + 0.05*loss_now.data.item()process_bar.set_postfix(loss=loss_now.data.item())process_bar.update()optimizer.zero_grad()loss_now.backward()optimizer.step()test_acc = get_test_result(test_iter,test_set)print("The test acc is: %.5f" % test_acc)完整代碼,詳見git:
總結
以上是生活随笔為你收集整理的Bag of Tricks for Efficient Text Classification(Fasttext)的全部內容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: Character-level Conv
- 下一篇: Hierarchical Attenti