第一个极小的机器学习的应用
? 現在給出一個Web統計信息,他們存儲著每小時的訪問次數。每一行包含連續的小時和信息,以及該小時Web的訪問次數。現在要解決的問題是,估計在何時訪問量達到基礎設施的極限。極限數據是每小時100000次訪問。
1.讀取數據:
# 獲取數據 filepath = r'C:\Users\TD\Desktop\data\Machine Learning\1400OS_01_Codes\data\web_traffic.tsv' data = sp.genfromtxt(filepath,delimiter = '\t') x = data[:,0] y = data[:,1]其中,x表示小時,y表示訪問量。
2.預處理和清洗數據:
print sp.sum(sp.isnan(y))結果顯示含有8個控值,為了方便,在此處理缺失值辦法是直接剔除。
x = x[~sp.isnan(y)] y = y[~sp.isnan(y)]接下來,畫出散點圖,觀察數據的規律:
# 可視化,觀察數據規律 plt.scatter(x,y) plt.title('Web traffic over the last month') plt.xlabel('Time') plt.ylabel('Hits/hours') plt.xticks([w*24*7 for w in range(5)],['week {}'.format(i) for i in range(5)]) plt.autoscale(tight = True) plt.grid() plt.show()3 選擇正確的模型和學習算法:
回答原始問題需要明確以下幾點:
1)找到噪聲數據后真正的模型
2)使用這個模型預測未來,一遍解決我們的問題
1.首先需要明白模型與實際數據區別。模型可以理解為對復雜現實世界簡化的理論近似。它總會包含一些劣質的類容,這個就叫做近似誤差。我們用真實數據與模型預測的數據之間的平方距離來計算這個誤差,對于一個訓練好的模型f,按照下面函數來計算誤差:
def error(f,x,y):return sp.sum((f(x)-y)**2)?2.簡單的線性模型
現在用一個線性模型來擬合上面數據,看看可以得到什么。
plt.scatter(x,y) plt.title('Web traffic over the last month') plt.xlabel('Time') plt.ylabel('Hits/hours') plt.xticks([w*24*7 for w in range(5)],['week {}'.format(i) for i in range(5)]) plt.autoscale(tight = True) plt.grid()# 開始構建模型,使用1階多項式擬合 p1 = sp.polyfit(x,y,1) f1 = sp.poly1d(p1) # 將擬合系數傳入ployld函數創建模型函數f1 fx = sp.linspace(0,x[-1],1000) plt.plot(fx,f1(fx),linewidth = 3) plt.legend(["d = {}".format(f1.order)], loc = "upper left") plt.show()
上圖顯示了第一個訓練的模型,發現前四個星期好像沒有偏差很多,可以清楚的看到直線模型的假設是有問題的。
3 接下來用3階,10階,50階多項式來擬合:
colors = ['g', 'k', 'b', 'm', 'r']def error(f,x,y):return sp.sum((f(x)-y)**2)def plot_models(x, y, models, fname, mx = None):plt.clf()plt.scatter(x, y, s=10)plt.title("Web traffic over the last month")plt.xlabel("Time")plt.ylabel("Hits/hour")plt.xticks([w * 7 * 24 for w in range(10)], ['week %i' % w for w in range(10)])if models:if mx is None:mx = sp.linspace(0, x[-1], 1000)for model, color in zip(models, colors):# print "Model:",model# print "Coeffs:",model.coeffsplt.plot(mx, model(mx), c = color,linewidth = 1.5)plt.legend(["d = {}".format(m.order) for m in models], loc="upper left")plt.autoscale(tight=True)plt.grid(True, linestyle='-', color='0.75')plt.savefig(fname)# create and plot models os.chdir(r'C:\Users\TD\Desktop\data\Machine Learning\1400OS_01_Codes\data') f1 = sp.poly1d(sp.polyfit(x,y,1)) f3 = sp.poly1d(sp.polyfit(x, y, 3)) f10 = sp.poly1d(sp.polyfit(x, y, 10)) f50 = sp.poly1d(sp.polyfit(x, y, 50)) plot_models(x, y, [f1,f3, f10, f50],"2.png")# error indices = [1,3,10,50] for index,model in zip(indices,[f1,f3,f10,f50]):print 'Error d= {} : {}'.format(index,error(model,x,y))可以看出多項式越復雜,數據逼近越好。他們誤差如下:
Error d= 1 : 317389767.34
Error d= 3 : 139350144.032
Error d= 10 : 121942326.364
Error d= 50 : 109504607.366
看看10階和100階的多項式,我們發現了巨大的震蕩。似乎這樣擬合的太過了,他不斷捕捉到背后數據的生成,還把噪聲數據也考慮進去了。這樣叫做過擬合。然而1階的顯然太簡單了,不能反映數據的規律,這種叫做欠擬合。不管是欠擬合還是過擬合,都不適合進行預測。
4 已退為進,另眼看數據
觀察數據,似乎第三周和第四周之間有一個拐點。這可以讓我們將3.5周作為分界點,把數據分為兩份,并訓練出兩條直線。
plt.scatter(x, y, s=10) plt.title("Web traffic over the last month") plt.xlabel("Time") plt.ylabel("Hits/hour") plt.xticks([w * 7 * 24 for w in range(10)], ['week %i' % w for w in range(10)]) inflection = int(3.5*7*24) xa = x[:inflection] ya = y[:inflection] xb = x[inflection:] yb = y[inflection:] fa = sp.poly1d(sp.polyfit(xa,ya,1)) fb = sp.poly1d(sp.polyfit(xb,yb,1)) plt.scatter(x,y) fax = sp.linspace(0,x[-1],1000) fbx = sp.linspace(x[inflection]/1.1,x[-1]*1.1,1000) plt.plot(fax,fa(fax),c = 'g',linewidth = 2.5) plt.plot(fbx,fb(fbx),c = 'r',linewidth = 2.5) plt.show()
很明顯,這兩條直線組合起來似乎比之前任何模型都可以更好的擬合數據,雖然組合后的誤差高于高階多項式的誤差。為什么僅僅在最后一周上更相信線性模型呢?這是因為我們認為他更好的符合未來數據。10階和100階多項式在此沒有光明的未來,他們只是非常努力的對給定的數據進行擬合,但是他們卻無法推廣到將來的數據上,這就是過擬合,另外低階模型也不能恰好的模擬數據,叫做欠擬合。
5 訓練與測試
如果有些外來數據用于模型評估,那么僅從近似誤差結果就可以判斷出我們的選擇的模型是好還是壞。盡管我們找不到未來數據,但是可以從現有的數據中拿出一部分,來判斷我們的結果是好還是壞了。利用拐點后的數據進行訓練,得到的二階模型的誤差最小,這個模型很適中,既不欠擬合也不過擬合。
6 回答最初的問題
得到了訓練的模型。只需要帶入數值就可以計算得到我們所求結果。
實驗代碼:
#!/usr/bin/env python # -*- coding: utf-8 -*- # __author__ : '小糖果'import scipy as sp import matplotlib.pyplot as plt from scipy.optimize import fsolve import os# 獲取數據 filepath = r'C:\Users\TD\Desktop\data\Machine Learning\1400OS_01_Codes\data' data = sp.genfromtxt(os.path.join(filepath,'web_traffic.tsv'),delimiter = '\t') x = data[:,0] y = data[:,1]# 缺失數據處理,用相鄰數據平均數代替 print sp.sum(sp.isnan(y)) x = x[~sp.isnan(y)] y = y[~sp.isnan(y)]colors = ['g', 'k', 'b', 'm', 'r']def error(f,x,y):return sp.sum((f(x)-y)**2)def plot_model(x, y,models = None,fname = None,mx = None):plt.clf()plt.scatter(x,y)plt.title("Web traffic over the last month")plt.xlabel("Time")plt.ylabel("Hits/hour")plt.xticks([w*24*7 for w in range(10)],['week {}'.format(i) for i in range(10)])if models:if mx is None:mx = sp.linspace(0,x[-1],1000)for (model,color) in zip(models,colors):plt.plot(mx,model(mx),c = color,linewidth = 2)plt.legend(['d = {}'.format(m.order) for m in models],loc = 'upper left')plt.autoscale(tight = True)plt.grid(True)if fname:plt.savefig(fname)else:plt.show()# 查看初始數據 plot_model(x,y,fname = os.path.join(filepath,'1.jpg'))#分別用1,2,10,50階多項式擬合 f1 = sp.poly1d(sp.polyfit(x,y,1)) f2 = sp.poly1d(sp.polyfit(x,y,2)) f10 = sp.poly1d(sp.polyfit(x,y,10)) f50 = sp.poly1d(sp.polyfit(x,y,50)) plot_model(x,y,models = [f1,f2,f10,f50],fname = os.path.join(filepath,'2.jpg'))# 線性分段擬合 plt.clf() inflection = int(3.5*7*24) xa = x[:inflection] ya = y[:inflection] xb = x[inflection:] yb = y[inflection:] fa = sp.poly1d(sp.polyfit(xa,ya,1)) fb = sp.poly1d(sp.polyfit(xb,yb,1)) plt.scatter(x,y) fax = sp.linspace(0,x[-1],1000) fbx = sp.linspace(x[inflection]/1.1,x[-1]*1.1,1000) plt.plot(fax,fa(fax),c = 'g',linewidth = 2.5) plt.plot(fbx,fb(fbx),c = 'r',linewidth = 2.5) plt.grid(True) plt.savefig(os.path.join(filepath,'3.jpg'))# 只使用后面部分數據 f1 = sp.poly1d(sp.polyfit(xb,yb,1)) f2 = sp.poly1d(sp.polyfit(xb,yb,2)) f10 = sp.poly1d(sp.polyfit(xb,yb,10)) f50 = sp.poly1d(sp.polyfit(xb,yb,50)) plot_model(xb,yb,models = [f1,f2,f10,f50],mx = sp.linspace(xb[0],xb[-1],100),fname = os.path.join(filepath,'4,jpg'))# 求問題的解,使用二次多項式模型 ans = fsolve(f2 - 100000,800)/7/24 print ans
轉載于:https://www.cnblogs.com/td15980891505/p/5996062.html
總結
以上是生活随笔為你收集整理的第一个极小的机器学习的应用的全部內容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: hdu 4496 并查集 逆向 并查集删
- 下一篇: 6.4(2)