机器学习入门01-K临近(KNN)的java实现
K臨近(KNN)算法是一種原理比較簡單的機器學習算法,其原理是將待分類數據與所有樣本數據計算距離,根據距離由近到遠選取K個臨近點,根據臨近點占比和距離權重對待分類點進行分類。
由于需要做距離計算,樣本數據每個特征必須為數值型數據。加入我們需要對不同鳥進行分類,從翼展、身高、體重三個方面對老鷹、鴿子、麻雀三種鳥進行分類計算。下面給出一組假設的樣本數據:
| 分類 | 翼展 | 體重 | 身高 |
| 老鷹 | 2米 | 5.0kg | 1.0米 |
| 鴿子 | 0.5米 | 0.5kg | 0.3米 |
| 麻雀 | 0.2米 | 0.05kg | 0.1米 |
從數據中可以看出,由于不同特征的值具跨度范圍不一致,如果直接進行計算,容易造成權重失衡,為了消除權重失衡需要對每個特征內部進行歸一化,即特征內每個值除以其中的最大值。那么歸一化后老鷹(1.0,1.0,1.0),鴿子(0.25,0.1,0.3),麻雀(0.1,0.01,0.1)。我們可以將這三個特征數據想象為一個個三維空間中的點,那么待分類對象就是計算一個三維坐標距離樣本點的距離。假設一個待分類數據(x,y,z),采用KNN算法進行分類,通過歐式距離可以計算出它離某個樣本點(x1,y1,z1)的距離。
計算公式:距離=sqrt((x - x1)^2 + (y - y1)^2 + (z - z1)^2)。
實際實現為了降低計算消耗可以忽略開方運算,只做平方計算,消除值為負數的差值即可。
實現代碼:
distance = Math.pow(Double.parseDouble(testData[j]) - Double.parseDouble(sample[j + 1]), 2);從原理和實現上不難看出,KNN算法沒有訓練過程,拿到樣本數據后就可以直接使用,雖然計算簡單,由于需要對每個樣本進行距離計算,當樣本數量過大后,將會消耗極大的計算時間和內存空間。針對這種問題,可以采用先取出距離較近的一些點,再進行距離計算。即根據待分類數據(x,y,z),我們增加一個參數,查找半徑,當樣本數據中超過K個數據處于半徑范圍內,則停止查找。
實現代碼:
private List<String[]> findNearestNeibor(List<String[]> modelList, String[] testData, double radius, int k) {List<String[]> result = new ArrayList<String[]>();double step = radius;while(true) {for(int i = 0; i < modelList.size(); i++) {String[] modelSample = modelList.get(i);List<Boolean> tempResult = new ArrayList<Boolean>();for(int j = 0; j < testData.length; j++) {double sampleMin = Double.parseDouble(testData[j]) - step;double sampleMax = Double.parseDouble(testData[j]) + step;double modelSampleIndex = Double.parseDouble(modelSample[j + 1]);if (modelSampleIndex >= sampleMin && modelSampleIndex <= sampleMax) {tempResult.add(true);}else {tempResult.add(false);}}if (!tempResult.contains(false)) {result.add(modelSample);}}if (result.size() >= k) {return result;}else {step += radius;}}當查找到大于K個值后,再進行距離計算,找出最近的K個值并給出結果。假設K=1時,即取離待分類點最近的樣本點作為分類結果。
實現代碼:
private String getResultTag(List<String[]> nearestList, String[] testData) {String result = new String();double min = testData.length;for(int i = 0; i < nearestList.size(); i++) {String[] nearSample = nearestList.get(i);double distance = 0.0;for(int j = 1; j < testData.length; j++) {distance += Math.pow(Double.parseDouble(testData[j]) - Double.parseDouble(nearSample[j]), 2);}if (distance < min) {result = nearSample[0];min = distance;}}return result;}接下來,進行算法測試,隨機生成一個包含10000個樣本三種分類的文本文件,分類A的特征一在0.9左右,特征二0.5左右,特征三0.3左右;分類B的特征一在0.3左右,特征二0.6左右,特征三0.9左右;分類C的特征一在0.6左右,特征二0.9左右,特征三0.3左右;
如圖:
同樣,為了提高計算速度,默認K為1情況下,采用一邊讀取一邊計算距離,當完成整個樣本文件讀取后,即完成計算。
實現代碼:
public String predict(File model, String[] testData) {String result = new String();double min = testData.length;try {BufferedReader reader = new BufferedReader(new FileReader(model));String line;while ((line = reader.readLine()) != null) {String[] sample = line.split(",");double distance = 0.0;for(int j = 0; j < testData.length; j++) {distance += Math.pow(Double.parseDouble(testData[j]) - Double.parseDouble(sample[j + 1]), 2);}if (distance < min) {result = sample[0];min = distance;}}reader.close();}catch (Exception e) {e.printStackTrace();}return result;}測試代碼及測試結果:
public static void main(String[] args) throws Exception{KNN knn = new KNN();String[] testData = new String[] {"0.32","0.65","0.83"};long time1 = System.currentTimeMillis();String result = knn.predict(new File("C:/Users/admin/Desktop/test/sample.csv"), testData);long time2 = System.currentTimeMillis();System.out.println("計算用時:" + (time2 - time1) + "毫秒");System.out.println(result);}總結
以上是生活随笔為你收集整理的机器学习入门01-K临近(KNN)的java实现的全部內容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: 应用机器学习进行无人机航拍影像质量评估
- 下一篇: 图像处理之添加文字水印