联邦学习实战-1:用python从零开始实现横向联邦学习
什么是聯邦學習?
簡單來說就是在一個多方的環境中,數據集是零散的(在各個不同的客戶端中),那么怎樣實現機器學習算法呢?
首先想到的就是將多個數據集合并合并起來,然后統一的使用傳統的機器學習或者深度學習算法進行計算,但是如果有一方因為數據隱私問題不愿意提交自己的數據呢?
那么就出現了聯邦學習,核心就是“數據不動模型動,數據可用不可見”
多個客戶端不提交數據而是提交訓練時的參數/梯度給中心服務器,中心服務器進行計算后再將參數/梯度返回多個客戶端再學習的過程
整個過程數據的所有權依然在用戶手中,這就是聯邦學習
當然數據隱私方面,聯邦學習還將結合同態加密、安全多方計算、查分隱私等隱私計算技術實現更安全的保障
(ps:這里只是簡單的介紹,詳細的內容請多查閱其他資料)
基本概念入門學習見:《Federated_Machine_Learning:Concept_and_Applications》精讀
一、環境準備
實驗基于機器學習庫PyTorch, 所以需要一些基礎的PyTorch使用
(ps:不會也沒事,下面代碼有詳細的注釋,因為我也剛剛入門 😃 )
- anaconda、python3.7、PyTorch
pip install torch - GPU安裝CUDA、cuDNN
二、橫向聯邦圖像分類
基本信息
數據集:CIFAR10
模型:ResNet-18
環境角色:
- 中心服務器
- 多個客戶端
為了簡化,這里服務器客戶端都是在單機上模擬,后面使用FATE會在真實多臺機器上實現
基本的流程:
2.1 配置文件
配置文件包含了整個項目的模型、數據集、epoch等核心訓練參數
需要注意的是,一般來說配置文件需要在所有的客戶端與服務端之間同步一致
創建一個配置文件:
項目文件夾下./utils/conf.json創建配置文件:
{"model_name" : "resnet18","no_models" : 10,"type" : "cifar","global_epochs" : 20,"local_epochs" : 3,"k" : 6,"batch_size" : 32,"lr" : 0.001,"momentum" : 0.0001,"lambda" : 0.1 }- model_name:模型名稱
- no_models:客戶端總數量
- type:數據集信息
- global_epochs:全局迭代次數,即服務端與客戶端的通信迭代次數
- local_epochs:本地模型訓練迭代次數
- k:每一輪迭代時,服務端會從所有客戶端中挑選k個客戶端參與訓練。
- batch_size:本地訓練每一輪的樣本數
- lr,momentum,lambda:本地訓練的超參數設置
2.1 構建訓練數據集
構建數據集代碼如下:
datasets.py
import torchvision as tv# 獲取數據集 def get_dataset(dir, name):if name == 'mnist':# root: 數據路徑# train參數表示是否是訓練集或者測試集# download=true表示從互聯網上下載數據集并把數據集放在root路徑中# transform:圖像類型的轉換train_dataset = tv.datasets.MNIST(dir, train=True, download=True, transform=tv.transforms.ToTensor())eval_dataset = tv.datasets.MNIST(dir, train=False, transform=tv.transforms.ToTensor())elif name == 'cifar':# 設置兩個轉換格式# transforms.Compose 是將多個transform組合起來使用(由transform構成的列表)transform_train = tv.transforms.Compose([# transforms.RandomCrop: 切割中心點的位置隨機選取tv.transforms.RandomCrop(32, padding=4), tv.transforms.RandomHorizontalFlip(),tv.transforms.ToTensor(),# transforms.Normalize: 給定均值:(R,G,B) 方差:(R,G,B),將會把Tensor正則化tv.transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),])transform_test = tv.transforms.Compose([tv.transforms.ToTensor(),tv.transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),])train_dataset = tv.datasets.CIFAR10(dir, train=True, download=True, transform=transform_train)eval_dataset = tv.datasets.CIFAR10(dir, train=False, transform=transform_test)return train_dataset, eval_dataset2.2 服務端
服務端的主要功能是模型的聚合、評估,最終的模型也是在服務器上生成
首先創建一個服務類
所有的程序放在server.py
構造函數
定義其構造函數:
# 定義構造函數 def __init__(self, conf, eval_dataset):# 導入配置文件self.conf = conf# 根據配置獲取模型文件self.global_model = models.get_model(self.conf["model_name"])# 生成一個測試集合加載器self.eval_loader = torch.utils.data.DataLoader(eval_dataset,# 設置單個批次大小32batch_size=self.conf["batch_size"],# 打亂數據集shuffle=True)聚合函數
定義全局聯邦平均FedAvg聚合函數:
FedAvg算法的公式如下:
Gt+1=Gt+λ∑i=1m(Lit+1?Git)G^{t+1} = G^{t} + \lambda \sum^m_{i=1}(L_i^{t+1}-G_i^t)Gt+1=Gt+λ∑i=1m?(Lit+1??Git?)?
GtG^tGt表示第t輪更新的全局模型參數,Lit+1L_i^{t+1}Lit+1??表示第i個客戶端在第t+1輪本地更新后的模型
在模型聚合時,weight_accumulator就是(Lit+1?Git)i=1,2,...m(L_i^{t+1}-G_i^t) \ i = 1,2,...m(Lit+1??Git?)?i=1,2,...m?部分,具體weight_accumulator的計算會在后面詳細介紹其實現
# 全局聚合模型 # weight_accumulator 存儲了每一個客戶端的上傳參數變化值/差值 def model_aggregate(self, weight_accumulator):# 遍歷服務器的全局模型for name, data in self.global_model.state_dict().items():# 更新每一層乘上學習率update_per_layer = weight_accumulator[name] * self.conf["lambda"]# 累加和if data.type() != update_per_layer.type():# 因為update_per_layer的type是floatTensor,所以將起轉換為模型的LongTensor(有一定的精度損失)data.add_(update_per_layer.to(torch.int64))else:data.add_(update_per_layer)評估函數
定義模型評估函數
評估函數主要是不斷的評估當前模型的性能,判斷是否可以提前終止迭代或者是出現了發散退化等現象
# 評估函數def model_eval(self):self.global_model.eval() # 開啟模型評估模式(不修改參數)total_loss = 0.0correct = 0dataset_size = 0# 遍歷評估數據集合for batch_id, batch in enumerate(self.eval_loader):data, target = batch# 獲取所有的樣本總量大小dataset_size += data.size()[0]# 存儲到gpuif torch.cuda.is_available():data = data.cuda()target = target.cuda()# 加載到模型中訓練output = self.global_model(data)# 聚合所有的損失 cross_entropy交叉熵函數計算損失total_loss += torch.nn.functional.cross_entropy(output,target,reduction='sum').item()# 獲取最大的對數概率的索引值, 即在所有預測結果中選擇可能性最大的作為最終的分類結果pred = output.data.max(1)[1]# 統計預測結果與真實標簽target的匹配總個數correct += pred.eq(target.data.view_as(pred)).cpu().sum().item()acc = 100.0 * (float(correct) / float(dataset_size)) # 計算準確率total_1 = total_loss / dataset_size # 計算損失值return acc, total_12.3 客戶端
客戶端的主要功能是:
- 接受服務器下發的指令和全局模型
- 利用本地數據進行局部模型訓練
此部分所有程序都在client.py中
構造函數
定義client類
# 構造函數def __init__(self, conf, model, train_dataset, id = 1):# 配置文件self.conf = conf# 客戶端本地模型(一般由服務器傳輸)self.local_model = model# 客戶端IDself.client_id = id# 客戶端本地數據集self.train_dataset = train_dataset# 按ID對訓練集合的拆分all_range = list(range(len(self.train_dataset)))data_len = int(len(self.train_dataset) / self.conf['no_models'])indices = all_range[id * data_len: (id + 1) * data_len]# 生成一個數據加載器self.train_loader = torch.utils.data.DataLoader(# 制定父集合self.train_dataset,# batch_size每個batch加載多少個樣本(默認: 1)batch_size=conf["batch_size"],# 指定子集合# sampler定義從數據集中提取樣本的策略sampler=torch.utils.data.sampler.SubsetRandomSampler(indices))本案例中根據ID將數據集進行橫向切分,每個客戶端之間沒有交集
本地訓練
本地模型訓練函數:采用交叉熵作為本地訓練的損失函數,并使用梯度下降來求解參數
# 模型本地訓練函數def local_train(self, model):# 整體的過程:拉取服務器的模型,通過部分本地數據集訓練得到for name, param in model.state_dict().items():# 客戶端首先用服務器端下發的全局模型覆蓋本地模型self.local_model.state_dict()[name].copy_(param.clone())# 定義最優化函數器用于本地模型訓練optimizer = torch.optim.SGD(self.local_model.parameters(), lr=self.conf['lr'], momentum=self.conf['momentum'])# 本地訓練模型self.local_model.train() # 設置開啟模型訓練(可以更改參數)# 開始訓練模型for e in range(self.conf["local_epochs"]):for batch_id, batch in enumerate(self.train_loader):data, target = batch# 加載到gpuif torch.cuda.is_available():data = data.cuda()target = target.cuda()# 梯度optimizer.zero_grad()# 訓練預測output = self.local_model(data)# 計算損失函數 cross_entropy交叉熵誤差loss = torch.nn.functional.cross_entropy(output, target)# 反向傳播loss.backward()# 更新參數optimizer.step()print("Epoch %d done" % e)# 創建差值字典(結構與模型參數同規格),用于記錄差值diff = dict()for name, data in self.local_model.state_dict().items():# 計算訓練后與訓練前的差值diff[name] = (data - model.state_dict()[name])print("Client %d local train done" % self.client_id)# 客戶端返回差值return diff2.4 整合
所有程序代碼在main.py中
import argparse import json import randomimport datasets from client import * from server import *if __name__ == '__main__':# 設置命令行程序parser = argparse.ArgumentParser(description='Federated Learning')parser.add_argument('-c', '--conf', dest='conf')# 獲取所有的參數args = parser.parse_args()# 讀取配置文件with open(args.conf, 'r') as f:conf = json.load(f)# 獲取數據集, 加載描述信息train_datasets, eval_datasets = datasets.get_dataset("./data/", conf["type"])# 開啟服務器server = Server(conf, eval_datasets)# 客戶端列表clients = []# 添加10個客戶端到列表for c in range(conf["no_models"]):clients.append(Client(conf, server.global_model, train_datasets, c))print("\n\n")# 全局模型訓練for e in range(conf["global_epochs"]):print("Global Epoch %d" % e)# 每次訓練都是從clients列表中隨機采樣k個進行本輪訓練candidates = random.sample(clients, conf["k"])print("select clients is: ")for c in candidates:print(c.client_id)# 權重累計weight_accumulator = {}# 初始化空模型參數weight_accumulatorfor name, params in server.global_model.state_dict().items():# 生成一個和參數矩陣大小相同的0矩陣weight_accumulator[name] = torch.zeros_like(params)# 遍歷客戶端,每個客戶端本地訓練模型for c in candidates:diff = c.local_train(server.global_model)# 根據客戶端的參數差值字典更新總體權重for name, params in server.global_model.state_dict().items():weight_accumulator[name].add_(diff[name])# 模型參數聚合server.model_aggregate(weight_accumulator)# 模型評估acc, loss = server.model_eval()print("Epoch %d, acc: %f, loss: %f\n" % (e, acc, loss))2.5 測試
按照以上配置,(本人)運行后的準確度以及損失為:
官方的對比:
聯邦學習與中心化訓練的效果對比
- 聯邦訓練配置:一共10臺客戶端設備(no_models=10),每一輪任意挑選其中的5臺參與訓練(k=5), 每一次本地訓練迭代次數為3次(local_epochs=3),全局迭代次數為20次(global_epochs=20)。
- 集中式訓練配置:我們不需要單獨編寫集中式訓練代碼,只需要修改聯邦學習配置既可使其等價于集中式訓練。具體來說,我們將客戶端設備no_models和每一輪挑選的參與訓練設備數k都設為1即可。這樣只有1臺設備參與的聯邦訓練等價于集中式訓練。其余參數配置信息與聯邦學習訓練一致。圖中我們將局部迭代次數分別設置了1,2,3來進行比較。
聯邦學習在模型推斷上的效果對比
圖中的單點訓練只的是在某一個客戶端下,利用本地的數據進行模型訓練的結果。
- 我們看到單點訓練的模型效果(藍色條)明顯要低于聯邦訓練 的效果(綠色條和紅色條),這也說明了僅僅通過單個客戶端的數據,不能夠 很好的學習到數據的全局分布特性,模型的泛化能力較差。
- 此外,對于每一輪 參與聯邦訓練的客戶端數目(k 值)不同,其性能也會有一定的差別,k 值越大,每一輪參與訓練的客戶端數目越多,其性能也會越好,但每一輪的完成時間也會相對較長。
學習資料來自于:
楊強:《聯邦學習實戰》
https://github.com/FederatedAI/Practicing-Federated-Learning/tree/main/chapter03_Python_image_classification
總結
以上是生活随笔為你收集整理的联邦学习实战-1:用python从零开始实现横向联邦学习的全部內容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: [失败] 网易云音乐爬虫分析
- 下一篇: ubuntu下网易云音乐适配高分辨率屏幕