实战Kaggle比赛(1):树叶分类
生活随笔
收集整理的這篇文章主要介紹了
实战Kaggle比赛(1):树叶分类
小編覺得挺不錯的,現(xiàn)在分享給大家,幫大家做個參考.
比賽地址:https://www.kaggle.com/c/leaf-classification/rules
完整代碼:https://github.com/SPECTRELWF/kaggle_competition
個人主頁:liuweifeng.top:8090
比賽題目:對樹葉的類別進行分類,樹葉總共99個類別,樹葉的圖片如下:
我也不知道怎么分類,反正總共有99中類別的樹葉。下載到的數(shù)據(jù)集解壓后如下:
image里面存了所有的樹葉圖像,train.csv是訓練文件的標號以及類別,后面有一堆的特征,我沒用到,因為比賽已經結束了,我只是純純的拿了練習下CNN。test.csv文件是測試數(shù)據(jù)的標號,sample_submission.csv文件是提交樣例,長這樣:
第一列是id,后面的99列是對應的每個類別的概率,分類結果加上softmax就行。
思路:
直接使用的基于ImageNet預訓練的resnet101,微調一下。
預處理
將訓練集的id和label寫到一個txt文件中,測試集的id寫入另一個txt文件:
# !/usr/bin/python3 # -*- coding:utf-8 -*- # Author:WeiFeng Liu # @Time: 2021/12/8 上午10:27import os import pandas as pd classes = ['Acer_Capillipes', 'Acer_Circinatum', 'Acer_Mono', 'Acer_Opalus', 'Acer_Palmatum', 'Acer_Pictum', 'Acer_Platanoids', 'Acer_Rubrum', 'Acer_Rufinerve', 'Acer_Saccharinum', 'Alnus_Cordata', 'Alnus_Maximowiczii', 'Alnus_Rubra', 'Alnus_Sieboldiana', 'Alnus_Viridis', 'Arundinaria_Simonii', 'Betula_Austrosinensis', 'Betula_Pendula', 'Callicarpa_Bodinieri', 'Castanea_Sativa', 'Celtis_Koraiensis', 'Cercis_Siliquastrum', 'Cornus_Chinensis', 'Cornus_Controversa', 'Cornus_Macrophylla', 'Cotinus_Coggygria', 'Crataegus_Monogyna', 'Cytisus_Battandieri', 'Eucalyptus_Glaucescens', 'Eucalyptus_Neglecta', 'Eucalyptus_Urnigera', 'Fagus_Sylvatica', 'Ginkgo_Biloba', 'Ilex_Aquifolium', 'Ilex_Cornuta', 'Liquidambar_Styraciflua', 'Liriodendron_Tulipifera', 'Lithocarpus_Cleistocarpus', 'Lithocarpus_Edulis', 'Magnolia_Heptapeta', 'Magnolia_Salicifolia', 'Morus_Nigra', 'Olea_Europaea', 'Phildelphus', 'Populus_Adenopoda', 'Populus_Grandidentata', 'Populus_Nigra', 'Prunus_Avium', 'Prunus_X_Shmittii', 'Pterocarya_Stenoptera', 'Quercus_Afares', 'Quercus_Agrifolia', 'Quercus_Alnifolia', 'Quercus_Brantii', 'Quercus_Canariensis', 'Quercus_Castaneifolia', 'Quercus_Cerris', 'Quercus_Chrysolepis', 'Quercus_Coccifera', 'Quercus_Coccinea', 'Quercus_Crassifolia', 'Quercus_Crassipes', 'Quercus_Dolicholepis', 'Quercus_Ellipsoidalis', 'Quercus_Greggii', 'Quercus_Hartwissiana', 'Quercus_Ilex', 'Quercus_Imbricaria', 'Quercus_Infectoria_sub', 'Quercus_Kewensis', 'Quercus_Nigra', 'Quercus_Palustris', 'Quercus_Phellos', 'Quercus_Phillyraeoides', 'Quercus_Pontica', 'Quercus_Pubescens', 'Quercus_Pyrenaica', 'Quercus_Rhysophylla', 'Quercus_Rubra', 'Quercus_Semecarpifolia', 'Quercus_Shumardii', 'Quercus_Suber', 'Quercus_Texana', 'Quercus_Trojana', 'Quercus_Variabilis', 'Quercus_Vulcanica', 'Quercus_x_Hispanica', 'Quercus_x_Turneri', 'Rhododendron_x_Russellianum', 'Salix_Fragilis', 'Salix_Intergra', 'Sorbus_Aria', 'Tilia_Oliveri', 'Tilia_Platyphyllos', 'Tilia_Tomentosa', 'Ulmus_Bergmanniana', 'Viburnum_Tinus', 'Viburnum_x_Rhytidophylloides', 'Zelkova_Serrata']train_txt = open('train.txt','w') train_csv = pd.read_csv(r'leaf-classification/train.csv') ids = train_csv['id'] species = train_csv['species']for i in range(len(ids)):train_txt.write(str(ids[i]))train_txt.write(' ')train_txt.write(str(classes.index(str(species[i]))))train_txt.write('\n') train_txt.close()test_txt = open('test.txt','w') test_csv = pd.read_csv(r'leaf-classification/test.csv') ids = test_csv['id'] for i in range(len(ids)):test_txt.write(str(ids[i]))test_txt.write('\n') test_txt.close()模型resnet101
# !/usr/bin/python3 # -*- coding:utf-8 -*- # Author:WeiFeng Liu # @Time: 2021/12/8 上午10:24 import torch import torchvision.models import torchvision.transforms as transforms import torch.nn as nn import torchvision.models as modelsclass resnet101(nn.Module):def __init__(self, num_classes=1000):super(resnet101, self).__init__()self.num_classes = num_classesself.feature_extract = torchvision.models.resnet101(pretrained=True)self.net = nn.Sequential(nn.Linear(1000, 512),nn.ReLU(inplace=True),nn.Dropout(0.5),nn.Linear(512, 256),nn.ReLU(inplace=True),nn.Dropout(0.5),nn.Linear(256, num_classes),)def forward(self, x):x = self.feature_extract(x)x = self.net(x)return x# x = torch.randn((1,3,224,224)) # net = resnet101(num_classes=99) # print(net) # print(net(x).shape)dataloader
# !/usr/bin/python3 # -*- coding:utf-8 -*- # Author:WeiFeng Liu # @Time: 2021/12/8 上午10:24 import numpy as np import torch.utils.data as data import torch import torchvision.transforms as transforms from PIL import Image data_root = r'leaf-classification/images/'class leaf_Dataset(data.Dataset):def __init__(self,is_train=True,transform=None):self.is_train = is_trainself.transform = transformself.images = []self.labels = []if is_train:file = open('train.txt','r')lines = file.readlines()for line in lines:res = line[:-1]image = res.split(' ')[0]label = int(res.split(' ')[1])self.images.append(image)self.labels.append(label)print(self.images)print(self.labels)def __len__(self):return len(self.images)def __getitem__(self, index):image_name = self.images[index] + '.jpg'image_path = data_root + image_nameimg = Image.open(image_path).convert('RGB')# print(img)img = self.transform(img)label = self.labels[index]label = torch.from_numpy(np.array(label))return img, labeltransforms = transforms.Compose([transforms.Resize((224,224)),transforms.ToTensor() ]) # !/usr/bin/python3 # -*- coding:utf-8 -*- # Author:WeiFeng Liu # @Time: 2021/12/8 上午10:25""" 使用imagenet預訓練的rennet101來在樹葉數(shù)據(jù)集上面進行微調 """ import torch import torchvision.transforms as transforms from dataset import leaf_Dataset import torch.utils.data as data import torch.optim as optim import torch.nn as nn from resnet import resnet101 #使用Adam優(yōu)化器來訓練網絡,不凍結參數(shù)# 設置hyperparameterepoch = 200 lr = 1e-3 b1 = 0.9 b2 = 0.999 device = torch.device('cuda:0') train_loss = [] # 初始化網絡模型 net = resnet101(num_classes=99) net.to(device)# load data transforms = transforms.Compose([transforms.Resize((224,224)),transforms.RandomHorizontalFlip(),transforms.RandomVerticalFlip(),transforms.ToTensor(), ]) data = leaf_Dataset(is_train=True,transform=transforms) dataloader = torch.utils.data.DataLoader(data,batch_size=64,shuffle=True)loss_func = nn.CrossEntropyLoss() opt = torch.optim.Adam(net.parameters(),lr=lr,betas=(b1,b2))for epoch in range(1,epoch + 1):for i, (x,y) in enumerate(dataloader):x = x.to(device)y = y.to(device)pred = net(x)loss = loss_func(pred,y)opt.zero_grad()loss.backward()opt.step()train_loss.append(loss.item())print("epoch: %d batch_idx:%d loss:%.3f" %(epoch,i,loss.item()))torch.save(net.state_dict(),'model/epoch:%d'%epoch + '.pth') from utils import plot_curve plot_curve(train_loss)loss
將預測結果寫入要提交的文件
# !/usr/bin/python3 # -*- coding:utf-8 -*- # Author:WeiFeng Liu # @Time: 2021/12/8 下午5:42 import torch import torchvision.transforms as transforms import numpy as np import os from PIL import Image from resnet import resnet101 import torch.nn.functional as Fimage_path = r'leaf-classification/images' f = open('test.txt','r') tmp = f.readlines() test_file = [] for i in tmp:i = i[:-1]test_file.append(i+'.jpg') print(test_file)device = torch.device('cuda:0') net = resnet101(num_classes=99) print('load weight........') net.load_state_dict(torch.load('model/epoch:200.pth')) net.to(device) net.eval() transformss = transforms.Compose([transforms.Resize((224,224)),transforms.ToTensor() ]) res = [] with torch.no_grad():for image in test_file:img = Image.open(os.path.join(image_path,image)).convert('RGB')img = transformss(img)img = torch.unsqueeze(img,dim=0)img = img.to(device)# print(img.shape)pred = net(img)pred = F.softmax(pred).flatten()pred = pred.cpu().numpy()print(pred)res.append(pred)np.savetxt("result.csv",res,delimiter = ',')總結
以上是生活随笔為你收集整理的实战Kaggle比赛(1):树叶分类的全部內容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: opencv-python将视频帧还原成
- 下一篇: opencv-python将视频切分成帧