联邦学习(Federated Learning)详解以及示例代码
聯邦學習也稱為協同學習,它可以在產生數據的設備上進行大規模的訓練,并且這些敏感數據保留在數據的所有者那里,本地收集、本地訓練。在本地訓練后,中央的訓練協調器通過獲取分布模型的更新獲得每個節點的訓練貢獻,但是不訪問實際的敏感數據。
聯邦學習本身并不能保證隱私(稍后我們將討論聯邦學習系統中的隱私破壞和修復),但它確實使隱私成為可能。
聯邦學習的用例:
- 手機輸入法的下一個詞預測(e.g. McMahan et al. 2017, Hard et al. 2019)
- 健康研究(e.g. Kaissis et al. 2020, Sadilek et al. 2021)
- 汽車自動駕駛(e.g. Zeng et al. 2021, OpenMined 的文章)
- “智能家居”系統(e.g. Matchi et al. 2019, Wu et al. 2020)
因為隱私的問題所以對于個人來說,人們寧愿放棄他們的個人數據,也不會將數據提供給平臺(平臺有時候也想著白嫖😉),所以聯邦學習幾乎涵蓋了所有以個人為單位進行預測的所有場景。
隨著公眾和政策制定者越來越意識到隱私的重要性,數據實踐中對保護隱私的機器學習的需求也正在上升,對于數據的訪問受到越來越多的審查,對聯邦學習等尊重隱私的工具的研究也越來越活躍。 在理想情況下,聯邦學習可以在保護個人和機構的隱私的前提下,使數據利益相關者之間的合作成為可能,因為以前商業機密、私人健康信息或數據泄露風險的通常使這種合作變得困難甚至無法進行。
歐盟《通用數據保護條例》或《加利福尼亞消費者隱私法》等政府法規使聯邦學習等隱私保護策略成為希望保持合法運營的企業的有用工具。與此同時,在保持模型性能和效率的同時獲得所需的隱私和安全程度,這本身就帶來了大量技術挑戰。
從個人數據生產者(我們都是其中的一員)的日常角度來看,至少在理論上是可以在私人健康和財務數據之間放置一些東西來屏蔽那種跟蹤你在網上行為設置暴露你的個人隱私的所謂的大雜燴生態系統。
如果這些問題中的任何一個引起你的共鳴,請繼續閱讀以了解更多關于聯邦學習的復雜性以及它可以為使用敏感數據的機器學習做了哪些工作。
聯邦學習簡介
聯邦學習的目的是訓練來自多個數據源的單個模型,其約束條件是數據停留在數據源上,而不是由數據源(也稱為節點、客戶端)交換,也不是由中央服務器進行編排訓練(如果存在的話)。
在典型的聯邦學習方案中,中央服務器將模型參數發送到各節點(也稱為客戶端、終端或工作器)。節點針對本地數據的一些訓練初始模型,并將新訓練的權重發送回中央服務器,中央服務器對新模型參數求平均值(通常與在每個節點上執行的訓練量有關)。在這種情況下,中央服務器或其他節點永遠不會直接看到任何一個其他節點上的數據,并使用安全聚合等附加技術進一步增強隱私。
該框架內有許多變體。例如,在本文中主要關注由中央服務器管理的聯邦學習方案,該方案在多個相同類型的設備上編排訓練,節點上每次訓練都使用自己的本地數據并將結果上傳到中央服務器,這是在 2017 年由 McMahan 等人描述的基本方案。但是某些情況下可能需要取消訓練的集中控制,當單個節點分配中央管理器的角色時,它就變成了去中心化的聯邦學習,這是針對特殊的醫療數據訓練模型的一種有效的解決方案。
典型的聯邦學習場景可能涉及大量的設備(例如手機),所有手機的計算能力大致相似,訓練相同的模型架構。但有一些方案,例如Diao等人2021年提出的HeteroFL允許在具有巨大差異的通信和計算能力的各種設備上訓練一個單一的推理模型,甚至可以訓練具有不同架構和參數數量的局部模型,然后將訓練的參數聚集到一個全局推理模型中。
聯邦學習還有一個優勢是數據保存在產生數據的設備上,訓練數據集通常比模型要大得多,因此發送后者而不是前者可以節省帶寬。
但在這些優勢中最重要的還是隱私保護,雖然有可能僅通過模型參數更新就推斷出關于私有數據集內容的某些內容。McMahan等人在2017年使用了一個簡單的例子來解釋該漏洞,即使用一個“詞袋”輸入向量訓練的語言模型,其中每個輸入向量具體對應于一個大詞匯表中的一個單詞。對于相應單詞的每個非零梯度更新將為竊聽者提供一個關于該單詞在私有數據集中存在(反之亦然)的線索,還有更復雜的攻擊也被證實了。為了解決這個問題,可以將多種隱私增強技術整合到聯邦學習中,從安全的更新聚合到使用完全同態加密進行訓練。下面我們將簡要介紹聯邦學習中對隱私的最突出的威脅及其緩解措施。
國家對數據隱私的監管是一個新興的政策領域,但是要比基于個人數據收集和分析的發展要晚10到20年。2016年頒布的《歐洲一般數據保護條例》(European General data Protection regulation,簡稱GDPR)是最重要的關于公眾個人數據的法規,這可能會有些奇怪,因為類似的保護限制企業監測和數據收集的措施施尚處于起步階段甚至是沒有。
兩年后的2018年,加州消費者隱私法案緊隨歐盟的GDPR成為法律。作為一項州法律,與GDPR相比,CCPA在地理范圍上明顯受到限制,該法案對個人信息的定義更窄。
聯邦學習的名字是由McMahan等人在2017年的一篇論文中引入的,用來描述分散數據模型的訓練。作者根據2012年白宮關于消費者數據隱私的報告為他們的系統制定了設計策略。他們提出了聯邦學習的兩個主要用例:圖像分類和用于語音識別或下一個單詞/短語預測的語言模型。
不久以后與分布式訓練相關的潛在攻擊就相繼的出現了。Phong et al. 2017和Bhowmick et al. 2018等人的工作表明,即使只訪問從聯邦學習客戶端返回到服務器的梯度更新或部分訓練的模型,也可以推斷出一些描述私人數據的細節。在inphero的文章中,你可以看到關于隱私問題的總結和解決方法。
在聯邦學習中,隱私、有效性和效率之間的平衡涉及廣泛的領域。服務器和客戶機之間的通信(或者僅僅是去中心化客戶機之間的通信)可以在傳輸時進行加密,但還有一個更健壯的選項即在訓練期間數據和模型也保持加密。同態加密可用于對加密的數據執行計算,因此(在理想情況下)輸出只能由持有密鑰的涉眾解密。OpenMined的PySyft、Microsoft的SEAL或TensorFlow Encrypted等庫為加密的深度學習提供了工具,這些工具可以應用到聯邦學習系統中。
關于聯邦學習的介紹到此為止,接下來我們將在教程部分中設置一個簡單的聯邦學習演示。
聯邦學習代碼實現
既然我們已經知道在何處以及為什么要使用聯邦學習,那么讓我們動手看看我們如何這樣做,在這里我們使用鳶尾花數據集進行聯邦學習。
有許多聯邦學習庫可供選擇,從在 GitHub 上擁有超過 1700 顆星的更主流的 Tensorflow Federated 到流行且注重隱私的 PySyft,再到面向研究的 FedJAX。 下面表中包含流行的聯邦學習存儲庫的參考列表。
在我們的演示中將使用 Flower 庫。 我們選擇這個庫的部分原因是它以一種可訪問的方式舉例說明了基本的聯邦學習概念并且它與框架無關,Flower 可以整合任何構建模型的深度學習工具包(他們在文檔中有 TensorFlow、PyTorch、MXNet 和 SciKit-Learn 的示例)所以我們將使用 SciKit-Learn 中包含的“iris”數據集和Pytorch來驗證它所說的與框架無關的這個特性。 從高層的角度來看我們需要設置一個服務器和一個客戶端,對于客戶端我們使用不同的訓練數據集。 首先就是設置中央協調器。
設置協調器的第一步就是定義一個評估策略并將其傳遞給 Flower 中的默認配置服務器。 但首先讓我們確保設置了一個虛擬環境,其中包含需要的所有依賴項。 在 Unix 命令行上:
virtualenv flower_env python==python3 source flower_env/bin/activate pip install flwr==0.17.0# I'm running this example on a laptop (no gpu) # so I am installing the cpu only version of PyTorch # follow the instructions at https://pytorch.org/get-started/locally/ # if you want the gpu optionpip install torch==1.9.1+cpu torchvision==0.10.1+cpu \-f https://download.pytorch.org/whl/torch_stable.htmlpip install scikit-learn==0.24.0隨著我們的虛擬環境啟動并運行,我們可以編寫一個模塊來啟動 Flower 服務器來處理聯邦學習。 在下面的代碼中,我們包含了 argparse,以便在從命令行調用服務器模塊時更容易地試驗不同數量的訓練輪次。 我們還定義了一個生成評估函數的函數,這是我們添加到 Flower 服務器默認配置使用的策略中的唯一其他內容。
以下我們的服務器模塊文件的內容:
import argparse import flwr as fl import torch from pt_client import get_data, PTMLPClientdef get_eval_fn(model):# This `evaluate` function will be called after every rounddef evaluate(parameters: fl.common.Weights):loss, _, accuracy_dict = model.evaluate(parameters)return loss, accuracy_dictreturn evaluateif __name__ == "__main__":parser = argparse.ArgumentParser()parser.add_argument("-r", "--rounds", type=int, default=3,\help="number of rounds to train")args = parser.parse_args()torch.random.manual_seed(42)model = PTMLPClient(split="val")strategy = fl.server.strategy.FedAvg( \eval_fn=get_eval_fn(model),\)fl.server.start_server("[::]:8080", strategy=strategy, \config={"num_rounds": args.rounds})注意上面代碼中調用的 PTMLPClient。這個是server模塊用來定義評估函數的,這個類也是用于訓練的模型類并兼作聯邦學習客戶端。接下來我們將定義 PTMLPClient,并繼承Flower 的 NumPyClient 類和 torch.nn.Module 類,如果您使用 PyTorch,你肯定就熟悉它們。
NumPyClient 類處理與服務器的通信,我們需要實現4個抽象函數 set_parameters、get_parameters、fit 和evaluate。 torch.nn.Module 類為我們提供了 PyTorch 模型,還有就是使用 PyTorch Adam 優化器進行訓練的能力。我們的 PTMLPClient 類只有 100 多行代碼,所以我們將從 init 開始依次介紹每個類的函數。
請注意,我們從兩個類繼承。從 nn.Module 繼承意味著我們必須確保使用 super 命令從 nn.Module 調用 init,但是如果您忘記這樣做,Python 會立即通知你。除此之外,我們將三個線性層初始化為矩陣(torch.tensor 數據類型),并將一些關于訓練分割和模型維度的信息存儲為類變量。
class PTMLPClient(fl.client.NumPyClient, nn.Module):def __init__(self, dim_in=4, dim_h=32, \num_classes=3, lr=3e-4, split="alice"):super(PTMLPClient, self).__init__()self.dim_in = dim_inself.dim_h = dim_hself.num_classes = num_classesself.split = splitself.w_xh = nn.Parameter(torch.tensor(\torch.randn(self.dim_in, self.dim_h) \/ np.sqrt(self.dim_in * self.dim_h))\)self.w_hh = nn.Parameter(torch.tensor(\torch.randn(self.dim_h, self.dim_h) \/ np.sqrt(self.dim_h * self.dim_h))\)self.w_hy = nn.Parameter(torch.tensor(\torch.randn(self.dim_h, self.num_classes) \/ np.sqrt(self.dim_h * self.num_classes))\)self.lr = lr接下來我們將定義 PTMLPClient 類的 get_parameters 和 set_parameters 函數。 這些函數將所有模型參數連接為一個扁平的 numpy 數組,這是 Flower 的 NumPyClient 類預期返回和接收的數據類型。 這符合聯邦學習方案,因為服務器將向每個客戶端發送初始參數(使用 set_parameters)并期望返回一組部分訓練的權重(來自 get_parameters)。 這種模式在訓練的每輪出現一次。 我們還在 set_parameters 中初始化優化器和損失函數。
def get_parameters(self):my_parameters = np.append(\self.w_xh.reshape(-1).detach().numpy(), \self.w_hh.reshape(-1).detach().numpy() \)my_parameters = np.append(\my_parameters, \self.w_hy.reshape(-1).detach().numpy() \)return my_parametersdef set_parameters(self, parameters):parameters = np.array(parameters)total_params = reduce(lambda a,b: a*b,\np.array(parameters).shape)expected_params = self.dim_in * self.dim_h \+ self.dim_h**2 \+ self.dim_h * self.num_classesstart = 0stop = self.dim_in * self.dim_hself.w_xh = nn.Parameter(torch.tensor(\parameters[start:stop])\.reshape(self.dim_in, self.dim_h).float() \)start = stopstop += self.dim_h**2self.w_hh = nn.Parameter(torch.tensor(\parameters[start:stop])\.reshape(self.dim_h, self.dim_h).float() \)start = stopstop += self.dim_h * self.num_classesself.w_hy = nn.Parameter(torch.tensor(\parameters[start:stop])\.reshape(self.dim_h, self.num_classes).float()\)self.act = torch.reluself.optimizer = torch.optim.Adam(self.parameters())self.loss_fn = nn.CrossEntropyLoss()接下來,我們將定義我們的前向傳遞和一個用于獲取損失標量的函數。
def forward(self, x):x = self.act(torch.matmul(x, self.w_xh))x = self.act(torch.matmul(x, self.w_hh))x = torch.matmul(x, self.w_hy)return xdef get_loss(self, x, y):prediction = self.forward(x)loss = self.loss_fn(prediction, y)return loss我們客戶端還需要的最后幾個函數是fit和evaluate。 對于每一輪,每個客戶端在進行幾個階段的訓練之前使用提供給fit方法的參數初始化它的參數(在本例中默認為10)。evaluate方法在計算訓練數據驗證的損失和準確性之前設置參數。
def fit(self, parameters, config=None, epochs=10):self.set_parameters(parameters)x, y = get_data(split=self.split)x, y = torch.tensor(x).float(), torch.tensor(y).long()self.train()for ii in range(epochs):self.optimizer.zero_grad()loss = self.get_loss(x, y)loss.backward()self.optimizer.step()loss, _, accuracy_dict = self.evaluate(self.get_parameters())return self.get_parameters(), len(y), \{"loss": loss, "accuracy": \accuracy_dict["accuracy"]}def evaluate(self, parameters, config=None):self.set_parameters(parameters)val_x, val_y = get_data(split="val")val_x = torch.tensor(val_x).float()val_y = torch.tensor(val_y).long()self.eval()prediction = self.forward(val_x)loss = self.loss_fn(prediction, val_y).detach().numpy()prediction_class = np.argmax(\prediction.detach().numpy(), axis=-1)accuracy = sklearn.metrics.accuracy_score(\val_y.numpy(), prediction_class)return float(loss), len(val_y), \{"accuracy":float(accuracy)}我們的客戶端類中的 fit 和evaluate都調用了一個函數 get_data,它只是 SciKit-Learn iris 數據集的包裝器。 它還將數據拆分為訓練集和驗證集,并進一步拆分訓練數據集(我們稱為“alice”和“bob”)以模擬聯邦學習,因為聯邦學習的客戶端都有自己的數據。
def get_data(split="all"):x, y = sklearn.datasets.load_iris(return_X_y=True)np.random.seed(42); np.random.shuffle(x)np.random.seed(42); np.random.shuffle(y)val_split = int(0.2 * x.shape[0])train_split = (x.shape[0] - val_split) // 2eval_x, eval_y = x[:val_split], y[:val_split] alice_x, alice_y = x[val_split:val_split + train_split], y[val_split:val_split + train_split]bob_x, bob_y = x[val_split + train_split:], y[val_split + train_split:]train_x, train_y = x[val_split:], y[val_split:]if split == "all":return train_x, train_yelif split == "alice":return alice_x, alice_yelif split == "bob":return bob_x, bob_yelif split == "val":return eval_x, eval_yelse:print("error: split not recognized.")return None現在我們只需要在文件底部填充一個 if name == “main”: 方法,以便我們可以從命令行將我們的客戶端代碼作為模塊運行。
if __name__ == "__main__":parser = argparse.ArgumentParser()parser.add_argument("-s", "--split", type=str, default="alice",\help="The training split to use, options are 'alice', 'bob', or 'all'")args = parser.parse_args()torch.random.manual_seed(42)fl.client.start_numpy_client("localhost:8080", client=PTMLPClient(split=args.split))最后,確保在客戶端模塊的頂部導入所需的所有內容。
import argparse import numpy as np import sklearn import sklearn.datasets import sklearn.metrics import torch import torch.nn as nn from functools import reduce import flwr as fl這就是我們使用 Flower 運行聯邦訓練演示所需實現的全部代碼!
要開始運行聯邦訓練,首先在其命令行終端中啟動服務器。 我們將我們的服務器保存為 pt_server.py,我們的客戶端模塊保存為 pt_client.py,兩者都在我們正在工作的目錄的根目錄中,所以為了啟動一個服務器并告訴它進行40 輪聯邦學習,我們使用以下命令。
python -m pt_server -r 40接下來打開一個新的終端,用“alice”訓練分組啟動你的第一個客戶端:
python -m pt_client -s alice啟動“bob”訓練分組的第二個客戶端。
python -m pt_client -s bob如果一切正常,在運行服務器進程的終端中看到訓練啟動和信息滾動。
這個演示在 20 輪訓練中達到了 96% 以上的準確率。 訓練運行的損失和準確度曲線如下所示:
聯邦學習的未來
人們可能會相信“再也沒有隱私這種東西了”。這些聲明主要是針對互聯網的(這樣的聲明至少從1999年就開始了),但隨著智能家居設備和愛管閑事的家用機器人的迅速普及,你可能覺得這些言論是正確的。
但是請注意是誰在做這些聲明,你會發現他們中的許多人在你的數據被竊取的過程中是能夠獲得既得利益的。這種“沒有隱私”的失敗主義態度不僅是錯誤的,而且是危險的:失去隱私會使個人和團體以他們可能不會注意到或承認的方式被巧妙地操縱。
聯邦學習是伴隨著不斷擴大的數據量而生的,數據無處不在,聯邦學習的優勢因此獲得了政府、企業等各界的關注。聯邦學習能夠有效解決數據孤島和數據隱私保護的兩難問題。這將會為未來人工智能協作,從而實現跨越式發展奠定良好基礎,在多行業、多領域都有廣泛的應用前景。
作者:James Montantes
總結
以上是生活随笔為你收集整理的联邦学习(Federated Learning)详解以及示例代码的全部內容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: 虚拟服务器 dmz区别,dmz主机和虚
- 下一篇: Exploring Word Vexto