import numpy as np
classCRF:def__init__(self,y=None,x=None,y_num=None,x_num=None,N=None):self.y = yself.x = xself.y_num = y_numself.x_num = x_numself.N = Nself.get_feature()self.build_Marix(self.x[0])defget_feature(self):self.ti =[lambda y_1, y, x, i:1if i ==2and y_1 ==1and y ==2else0,lambda y_1, y, x, i:1if i ==3and y_1 ==1and y ==2else0,lambda y_1, y, x, i:1if i ==2and y_1 ==1and y ==1else0,lambda y_1, y, x, i:1if i ==3and y_1 ==2and y ==1else0,lambda y_1, y, x, i:1if i ==2and y_1 ==2and y ==1else0,lambda y_1, y, x, i:1if i ==3and y_1 ==2and y ==2else0,]self.w_ti =[1,1,0.6,1,1,0.2]self.si =[lambda y_1, y, x, i:1if i ==1and y ==1else0,lambda y_1, y, x, i:1if i ==1and y ==2else0,lambda y_1, y, x, i:1if i ==2and y ==2else0,lambda y_1, y, x, i:1if i ==2and y ==1else0,lambda y_1, y, x, i:1if i ==3and y ==1else0,lambda y_1, y, x, i:1if i ==3and y ==2else0,]self.w_si =[1,0.5,0.5,0.8,0.8,0.5]self.fk = self.ti+self.siself.wk = self.w_ti+self.w_sidefbuild_Marix(self,x):self.Marix = np.zeros((self.N+1,self.y_num,self.y_num))for i inrange(self.N+1):for n inrange(self.y_num):if i == self.N:self.Marix[i][:,0]=1breakfor m inrange(self.y_num):for k inrange(len(self.fk)):if i ==0:if n==0:self.Marix[i][0][m]+= self.wk[k]*(self.fk[k](0,m+1,x[i],i+1))else:self.Marix[i][n][m]+= self.wk[k]*(self.fk[k](n+1,m+1,x[i],i+1))self.Marix = np.exp(self.Marix)defget_aiT(self):self.aiT = np.zeros((self.N+1,self.y_num))self.aiT[0,:]=1for i inrange(1,self.N+1):self.aiT[i]= self.aiT[i-1].dot(self.Marix[i])defget_biT(self):self.biT = np.zeros((self.N +1, self.y_num))self.biT[self.N,:]=1for i inrange(self.N-1,-1,-1):self.biT[i]= self.Marix[i+1].dot(self.biT[i+1])#--------概率計算問題---------------defpredict_Py_i(self,i,yi):self.get_aiT()self.get_biT()return self.aiT[i][yi-1]*self.biT[i][yi-1]/np.sum(self.aiT[self.N])defcompute_p(self,y):result =1for i inrange(self.N):if i ==0:Z = self.Marix[i]result *= self.Marix[i][0][y[i]-1]else:Z = Z.dot(self.Marix[i])result *= self.Marix[i][y[i-1]-1][y[i]-1]return result/np.sum(Z)#----------維特比預測算法-------------------------defViterbi(self,x):p_marix = np.zeros((len(x),self.y_num))rout =[[0]for i inrange(self.y_num)]for j inrange(self.y_num):p_marix[0][j]= np.sum(np.array([f(0,j+1,x[0],1)for f in self.fk])*self.wk)for i inrange(1,len(x)):for l inrange(self.y_num):max=0rout_tem =0for j inrange(self.y_num):wf = p_marix[i -1][j]+ np.sum(np.array([f(j+1, l+1, x[i], i +1)for f in self.fk])* self.wk)if wf >max:max= wfrout_tem = jp_marix[i][l]=maxrout[l].append(rout_tem)max_result = np.max(p_marix[len(x)-1])max_index =list(p_marix[len(x)-1]).index(np.max(p_marix[len(x)-1]))rout[max_index].append(max_index)return max_result,[i+1for i in rout[max_index][1:]]# -----------學習算法--------------------------------defcompute_pxy_f(self):#pfklist_pxyti = np.zeros(len(self.ti))list_pxysi = np.zeros(len(self.si))#p_fklist_countti = np.zeros(len(self.ti))list_countsi = np.zeros(len(self.si))for x,y inzip(self.x,self.y):self.build_Marix(x)self.get_aiT()self.get_biT()for i inrange(len(x)):list_temp = np.zeros(len(self.ti))for k inrange(len(self.ti)):if i ==0:right =1*self.Marix[0][0][y[i]-1]*self.biT[0][y[i]-1]left = self.ti[k](i,y[i],x[i],i+1)list_temp[k]+= left*right/np.sum(self.aiT[self.N])else:right = self.aiT[i][y[i-1]-1]* self.Marix[i][y[i-1]-1][y[i]-1]*self.biT[i][y[i]-1]left = self.ti[k](y[i-1],y[i],x[i],i+1)list_temp[k]+= left*right/np.sum(self.aiT[self.N])if left ==1:list_countti[k]+=1list_pxyti += list_tempfor i inrange(len(x)):list_temp = np.zeros(len(self.si))for k inrange(len(self.si)):right = self.aiT[i][y[i]-1]*self.biT[i][y[i]-1]left = self.si[k](y[i-1], y[i], x[i], i+1)list_temp[k]+= left * right / np.sum(self.aiT[self.N])if left ==1:list_countsi[k]+=1list_pxysi += list_tempreturn list_countti/len(self.x),list_countsi/len(self.x),list_pxyti/len(self.x),list_pxysi/len(self.x)defcompute_fw(self):left =0right =0for x in self.x:self.build_Marix(x)self.get_aiT()left += np.log(np.sum(self.aiT[self.N]))for x, y inzip(self.x, self.y):for i inrange(len(x)):for k inrange(len(self.fk)):if i ==0:right += self.wk[k]* self.fk[k](0,y[i],x[i],i+1)else:right += self.wk[k]* self.fk[k](y[i-1], y[i], x[i], i+1)return(left - right)/len(self.x)deffit(self,max_iter=3,how='IIS',lr=0.001):self.w_ti =[0]*len(self.ti)self.w_si =[0]*len(self.si)self.wk = self.w_ti + self.w_siself.x = np.array(self.x)self.y = np.array(self.y)if how =='IIS':S =20for i inrange(max_iter):ep_tk,ep_sk,eptk,epsk = self.compute_pxy_f()if np.linalg.norm(1/S*np.log(ep_tk/eptk),ord=2)+np.linalg.norm(1/S*np.log(ep_sk/epsk),ord=2)<0.1:print('when iter is '+str(i)+' shoulian')breakself.w_ti +=1/S*np.log(ep_tk/eptk)self.w_si +=1/S*np.log(ep_sk/epsk)self.wk =list(self.w_ti)+list(self.w_si)#按著最大熵模型的公式寫的,用EP - E_P為倒數,不知道對不對elif how =='GD':for i inrange(max_iter):ep_tk,ep_sk,eptk,epsk = self.compute_pxy_f()gt = eptk - ep_tkgs = epsk - ep_skfanshut = np.linalg.norm(gt,ord=2)fanshus = np.linalg.norm(gs,ord=2)if fanshut + fanshus <0.75:print('when iter is '+str(i)+' shoulian')breaktemp_w_ti = self.w_titemp_w_si = self.w_sifw_list =[]#線性搜索這里有問題,fw會一直縮小,因該是fw的計算出現了錯誤for k inrange(20):self.w_ti = temp_w_ti - lr * gt * kself.w_si = temp_w_si - lr * gs * kself.wk =list(self.w_ti)+list(self.w_si)fw = self.compute_fw()fw_list.append(fw)min_index = fw_list.index(min(fw_list))deta_wti = gt * lr * min_indexdeta_wsi = gs * lr * min_indexself.w_ti = temp_w_ti - deta_wtiself.w_si = temp_w_si - deta_wsiself.wk =list(self.w_ti)+list(self.w_si)defmain():y =[[1,2,2],[2,1,1],[1,1,1],[1,2,2],[2,2,2],[2,1,2],[1,1,2],[1,2,1],[1,2,2],[1,2,2],[1,2,2],[1,2,2]]x =[[1,1,1],[1,1,1],[1,1,1],[1,1,1],[1,1,1],[1,1,1],[1,1,1],[1,1,1],[1,1,1],[1,1,1],[1,1,1],[1,1,1]]CRF_test = CRF(y=y,y_num=2,N=3,x=x)print(CRF_test.Marix)print(CRF_test.compute_p(y[0]))print(CRF_test.predict_Py_i(2,1))print(CRF_test.Viterbi(x[0]))CRF_test.fit(50, how='IIS')print(CRF_test.wk)CRF_test.fit(500,how='GD',lr=0.001)print(CRF_test.wk)if __name__ =='__main__':main()#---------result--------------------
usr/bin/python3 /Users/zhengyanzhao/PycharmProjects/tongjixuexi/shixian2/CRF.py
[[[2.718281831.64872127][1.1.]][[4.055199974.48168907][6.049647461.64872127]][[2.225540934.48168907][6.049647462.01375271]][[2.718281831.][2.718281831.]]]0.064867835429079150.5082915274316868(4.3,[1,2,1])
when iteris11 shoulian
[0.18764402126707205,0.23235841506946808,0.21344554371583913,0.20112964244479178,0.24547679607078648,0.20600431503056774,0.3669918383125717,0.39593684451308797,0.385216372799946,0.377295537944856,0.3971223364931556,0.3658935638684794]
when iteris64 shoulian
[0.23036225378918423,0.0846981674992883,0.0935981916969745,0.03945814811020889,0.0812481327750673,0.16101306581281896,0.43884913282239046,0.15771695572586986,0.37418378149743614,0.2182414425846152,0.16699382351353245,0.4110185294594026]Process finished with exit code 0