使用推测解码 (Speculative Decoding) 使 Whisper 实现 2 倍的推理加速
Open AI 推出的 Whisper 是一個通用語音轉錄模型,在各種基準和音頻條件下都取得了非常棒的結果。最新的 large-v3 模型登頂了 OpenASR 排行榜,被評為最佳的開源英語語音轉錄模型。該模型在 Common Voice 15 數據集的 58 種語言中也展現出了強大的多語言性能,在 42 種語言上的單詞錯誤率 (WER) 低于 30%。
盡管轉錄準確度非常優秀,但推理速度非常緩慢。即使利用 flash attention 、半精度和 分塊 等優化推理技術,1 小時長度的音頻在 16GB T4 GPU 上也需要超過 6 分鐘的轉錄時間。
在本文中,我們將演示如何運用推測解碼將 Whisper 的推理時間縮減 2 倍,同時在數學上確保完全取得與原模型 相同的輸出。因此,這種方法可以完美地替換現有的 Whisper 流水線,因為它可以在不降低準確性的情況下免費獲得 2 倍的加速。想要看附帶有更簡潔解釋的全部代碼,請參閱配套的 Google Colab。
推測解碼
推測解碼由 Yaniv Leviathan 等人在 Fast Inference from Transformers via Speculative Decoding 中提出。其思想是,一個更快的 輔助模型 通常會生成和更大的 主模型 相同的 token。
首先,輔助模型會通過自回歸生成 \(N\) 個 候選 token 序列: \(\hat{\boldsymbol{y}}_{1:N}\)。在下圖中,輔助模型生成了一個包含 5 個候選 token 的序列: The quick brown sock jumps 。
盡管這些候選 token 可以快速生成,但它們可能與主模型預測的 token 不同。因此,在第二步中,候選 token 被傳入主模型以進行“驗證”。主模型將候選 token 作為輸入,并執行 單次前饋傳播。主模型的輸出是每個步驟中“正確”token 的序列 $ \boldsymbol{y}_{1:N}$。
在上圖中,我們看到主模型預測的前三個 token 與輔助模型的 token 一致: <span style="color:green"> The quick brown 但是,輔助模型的第四個候選 token: “ <span style="color:red"> sock”與主模型的正確 token: “ <span style="color:green"> fox”不一致。
我們知道,所有候選 token 一直到第一個不匹配之前都是正確的 ( <span style="color:green"> The quick brown),因為這些與主模型的預測一致。但是,在第一個不匹配之后,候選 token 開始偏離主模型實際預測的 token。因此,我們可以用主模型的正確 token ( <span style="color:green"> fox) 替換第一個不正確的候選 token ( <span style="color:red"> sock),并放棄之后所有預測的 token,因為這些已經逐漸偏離主模型的預測。經過校正的序列 The quick brown fox 現在成為輔助模型的新輸入:
然后,輔助模型再次通過自回歸推理,生成一組新的 \(N\) 個候選 token,這些 token 再次通過主模型的單次前饋傳播進行驗證。
由于我們在生成的時候使用的快速的輔助模型進行自回歸,并且緩慢的主模型僅用于驗證前饋傳播,解碼過程將大大加快。此外,經過主模型前饋傳播驗證后可以確保與僅使用主模型時獲得完全相同的輸出。這使得推測解碼可以完美地替換現有的 Whisper 流水線,因為我們可以確定會取得相同質量的輸出。
為了最大限度地減少延遲,輔助模型應該比主模型快得多,同時盡可能頻繁地預測相同的 token 分布。實際上,這兩個屬性之間需要權衡: 模型越快,其準確度越低。然而,由于所有預測 token 中的 70-80% 往往是“較易”的 token,此權衡傾向于選擇一個更快的模型,而不是一個更準確的模型。因此,輔助模型應該至少比主模型快 3 倍 (越快越好),同時在示例中正確預測所有較“易”token。剩余的 20-30% 更“難”的 token 可以由更大的主模型進行驗證。
選擇輔助模型的唯一約束是它必須與主模型使用相同的詞匯表。也就是說,輔助模型必須使用與主模型完全一對一相同的分詞器。因此,如果我們想對諸如 large-v2 (多語言) 的 Whisper 多語言版本使用推測解碼,我們需要選擇諸如 tiny 的 Whisper 多語言版本作為輔助模型。而如果我們想對諸如 medium.en 的 Whisper 英文版本使用推測解碼,我們需要選擇諸如 tiny.en 的 Whisper 英文版本作為輔助模型。目前,large-v3 是唯一一個擴展了詞匯量的 Whisper 檢查點,因此與以前的 Whisper 檢查點不兼容。
現在我們已經了解了推測解碼背后的原理,我們準備實際實現它。在 ?? Transformers 庫中,推測解碼被實現為“輔助生成 (Assisted Generation)”推理策略。欲了解更多實現細節,建議讀者閱讀 Joao Gante 關于 輔助生成 的精彩博文。
英文語音轉錄
基準實現
我們首先使用 Whisper large-v2 進行基準測試,以獲得推理速度的基準數值。我們可以通過便捷的 AutoModelForSpeechSeq2Seq 和 AutoProcessor 類加載主模型及其對應的處理器。我們將以 float16 精度加載模型,并通過傳遞 low_cpu_mem_usage=True 確保加載時間盡可能少。此外,我們要確保模型以 safetensors 格式加載,方法是傳遞 use_safetensors=True。最后,我們將傳遞參數 attn_implementation="sdpa" ,以通過 PyTorch 的 SDPA 注意力內核 進行 Flash 注意力加速。
import torch
from transformers import AutoModelForSpeechSeq2Seq, AutoProcessor
device = "cuda:0" if torch.cuda.is_available() else "cpu"
torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32
model_id = "openai/whisper-large-v2"
model = AutoModelForSpeechSeq2Seq.from_pretrained(
model_id,
torch_dtype=torch_dtype,
low_cpu_mem_usage=True,
use_safetensors=True,
attn_implementation="sdpa",
)
model.to(device)
processor = AutoProcessor.from_pretrained(model_id)
讓我們加載將用于基準測試的英語語音轉錄數據集。我們將加載 LibriSpeech ASR 中驗證數據集的 clean 分組中的 73 個樣本組成的小型數據集。這大約有 9MB 的數據,因此非常輕量且可以快速下載到設備上。
from datasets import load_dataset
dataset = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
對于基準測試,我們只想測量生成時間,所以讓我們編寫一個簡短的輔助函數來測量此步驟運行的時間。下面的函數將同時返回解碼的 token 和運行模型所需的時間:
import time
def generate_with_time(model, inputs, **kwargs):
start_time = time.time()
outputs = model.generate(**inputs, **kwargs)
generation_time = time.time() - start_time
return outputs, generation_time
現在我們可以迭代語音數據集中的音頻樣本,并統計整體生成時間:
from tqdm import tqdm
all_time = 0
predictions = []
references = []
for sample in tqdm(dataset):
audio = sample["audio"]
inputs = processor(audio["array"], sampling_rate=audio["sampling_rate"], return_tensors="pt")
inputs = inputs.to(device=device, dtype=torch.float16)
output, gen_time = generate_with_time(model, inputs)
all_time += gen_time
predictions.append(processor.batch_decode(output, skip_special_tokens=True, normalize=True)[0])
references.append(processor.tokenizer._normalize(sample["text"]))
print(all_time)
Output:
100%|██████████| 73/73 [01:37<00:00, 1.33s/it]
72.99542546272278
很好!我們看到轉錄 73 個樣本花了 73 秒。讓我們檢查一下預測的 WER:
from evaluate import load
wer = load("wer")
print(wer.compute(predictions=predictions, references=references))
Output:
0.03507271171941831
我們的最終基準數值為 73 秒,WER 為 3.5%。
推測解碼
現在讓我們加載推測解碼的輔助模型。在此示例中,我們將使用 Whisper 蒸餾后的版本 distil-large-v2。蒸餾模型只使用了 Whisper 中 32 個解碼器層中的 2 個編碼器。因此,它比 Whisper 快 6 倍,同時在分布測試集上的 WER 性能相比于蒸餾前僅下降了 1%。這使其成為理想的輔助模型,因為它在轉錄準確性和生成速度方面都非常優秀\({}^1\)。
\({}^1\) 我們即將發布 Distil-Whisper 的改進版本,在 token 分布中具有更佳的對齊性,這將進一步提高推測解碼性能。關注 Distil-Whisper 存儲庫 來追蹤最新的更新信息。
由于 Distil-Whisper 使用與 Whisper 模型完全相同的編碼器,我們可以在主模型和輔助模型之間共享編碼器。然后,我們只需要從 Distil-Whisper 加載 2 層解碼器作為“僅解碼器”模型。我們可以通過便捷的 AutoModelForCausalLM 自動類實現這一點。在實踐中,相比于僅使用主模型,這僅增加了 8%的 VRAM 占用量。
from transformers import AutoModelForCausalLM
assistant_model_id = "distil-whisper/distil-large-v2"
assistant_model = AutoModelForCausalLM.from_pretrained(
assistant_model_id,
torch_dtype=torch_dtype,
low_cpu_mem_usage=True,
use_safetensors=True,
attn_implementation="sdpa",
)
assistant_model.to(device)
我們可以為推測解碼的基準測試定義一個新的函數。與前面的函數唯一的區別是,我們在對 .generate 的調用中傳遞輔助模型:
def assisted_generate_with_time(model, inputs, **kwargs):
start_time = time.time()
outputs = model.generate(**inputs, assistant_model=assistant_model, **kwargs)
generation_time = time.time() - start_time
return outputs, generation_time
讓我們使用 Distil-Whisper 作為 Whisper 的助手運行推測解碼的基準測試:
all_time = 0
predictions = []
references = []
for sample in tqdm(dataset):
audio = sample["audio"]
inputs = processor(audio["array"], sampling_rate=audio["sampling_rate"], return_tensors="pt")
inputs = inputs.to(device=device, dtype=torch.float16)
output, gen_time = assisted_generate_with_time(model, inputs)
all_time += gen_time
predictions.append(processor.batch_decode(output, skip_special_tokens=True, normalize=True)[0])
references.append(processor.tokenizer._normalize(sample["text"]))
print(all_time)
Outputs:
100%|██████████| 73/73 [00:38<00:00, 1.88it/s]
32.69683289527893
使用推測解碼,推理時間僅為 33 秒,比之前快 2.2 倍!讓我們驗證一下 WER 是否相同:
print(wer.compute(predictions=predictions, references=references))
Outputs:
0.03507271171941831
太完美了!再次達到 3.5%的 WER,因為我們的輸出與僅使用主模型的時候完全相同。
推測解碼也可以與基礎的 ?? Transformers pipeline API 一起用于推理。下面,我們使用模型和處理器實例化管道,然后使用它來轉錄測試數據集中的第一個樣本。這可以擴展為轉錄任意長度的音頻樣本,包括進行批處理:
from transformers import pipeline
pipe = pipeline(
"automatic-speech-recognition",
model=model,
tokenizer=processor.tokenizer,
feature_extractor=processor.feature_extractor,
max_new_tokens=128,
chunk_length_s=15,
batch_size=4,
generate_kwargs={"assistant_model": assistant_model},
torch_dtype=torch_dtype,
device=device,
)
sample = dataset[0]["audio"]
result = pipe(sample)
print(result["text"])
Outputs:
Mr. Quilter is the apostle of the middle classes and we are glad to welcome his gospel.
使用 Whisper 和 Distil-Whisper 運行推測解碼的端到端代碼示例可在 Distil-Whisper 模型卡 中找到。它將本文中涵蓋的推理階段組合成一個代碼示例。
多語言語音轉錄
Distil-Whisper 是英語語音轉錄的最佳輔助模型,因為它與原始 Whisper 模型的 WER 誤差率僅相差 1%,而對短長語音樣本的推理速度提高了 6 倍。然而,官方的 Distil-Whisper 檢查點僅支持英語,這意味著它們無法用于多語言語音轉錄。
要使用推測解碼進行多語言語音轉錄,您可以使用 官方 Whisper 多語言檢查點 之一,或者 Whisper 的微調版本。在撰寫本文時,Hugging Face Hub 上已有超過 5000 個微調過的 Whisper 檢查點,支持超過 100 種語言。這些為選擇表現出色的輔助模型提供了極好的起點。在此示例中,我們將使用最小的官方多語言檢查點 Whisper tiny。您可以使用任意一個您的語言中微調過的不同檢查點!
讓我們為新的輔助模型 Whisper tiny 加載權重。由于 Whisper tiny 的編碼器與 large-v2 不同,這次我們將使用 AutoModelForSpeechSeq2Seq 類同時加載編碼器和解碼器:
assistant_model_id = "openai/whisper-tiny"
assistant_model = AutoModelForSpeechSeq2Seq.from_pretrained(
assistant_model_id,
torch_dtype=torch_dtype,
low_cpu_mem_usage=True,
use_safetensors=True,
attn_implementation="sdpa",
)
assistant_model.to(device);
我們的基準數據集,將從 VoxPopuli 數據集的荷蘭語 (“nl”) 部分中加載 73 個樣本:
dataset = load_dataset("sanchit-gandhi/voxpopuli_dummy", "nl", split="validation")
非常好!現在我們可以像前面一樣重新運行我們的 Whisper large-v2 模型的基準測試。我們所做的唯一更改是在 generate 函數中傳遞語言和任務參數,以確保執行語音轉錄 (而不是語音翻譯)。推測解碼完全兼容語音轉錄和翻譯任務。只需如下所示設置任務參數即可:
all_time = 0
predictions = []
references = []
for sample in tqdm(dataset):
audio = sample["audio"]
inputs = processor(audio["array"], sampling_rate=audio["sampling_rate"], return_tensors="pt")
inputs = inputs.to(device=device, dtype=torch.float16)
output, gen_time = generate_with_time(model, inputs, language="nl", task="transcribe")
all_time += gen_time
predictions.append(processor.batch_decode(output, skip_special_tokens=True, normalize=True)[0])
references.append(processor.tokenizer._normalize(sample["normalized_text"]))
wer_result = wer.compute(predictions=predictions, references=references)
print("Time:", all_time)
print("WER:", wer_result)
Outputs:
100%|██████████| 73/73 [02:05<00:00, 1.72s/it]
Time: 116.50992178916931
WER: 0.127190136275146
沒錯!我們的基準時間為 117 秒,WER 為 12.8%。讓我們使用推測解碼重新運行生成過程:
all_time = 0
predictions = []
references = []
for sample in tqdm(dataset):
audio = sample["audio"]
inputs = processor(audio["array"], sampling_rate=audio["sampling_rate"], return_tensors="pt")
inputs = inputs.to(device=device, dtype=torch.float16)
output, gen_time = assisted_generate_with_time(model, inputs, language="nl", task="transcribe")
all_time += gen_time
predictions.append(processor.batch_decode(output, skip_special_tokens=True, normalize=True)[0])
references.append(processor.tokenizer._normalize(sample["normalized_text"]))
wer_result = wer.compute(predictions=predictions, references=references)
print("Time:", all_time)
print("WER:", wer_result)
Outputs:
100%|██████████| 73/73 [01:08<00:00, 1.06it/s]
Time: 62.10229682922363
WER: 0.127190136275146
Nice!我們達到了 12.8% 的 WER,但這次的推理時間只有 62 秒,表示速度提高了 1.9 倍。考慮到加載輔助模型的低開銷和確保獲得完全相同輸出的數學證明,推測解碼為現有的 Whisper 管道提供了完美的即插即用的替代方案。
高效推測解碼的策略
在本最終部分,我們將介紹兩種策略,以確保使用推測解碼時獲得可能最快的推理時間。
輔助模型
我們的目標是選擇一個至少比主模型快 3 倍 并且 正確轉錄至少 70-80% 的預測 token (通常是示例中的“更簡單”token) 的輔助模型。如果您想要轉錄某種特定語言,一種有效的策略是訓練兩個不同大小的 Whisper 模型,并將其中一個用作另一個的輔助模型:
- 首先,微調 Whisper large-v3 以用作主模型
- 其次,在同一數據集上蒸餾 Whisper large-v3 以用作快速的輔助模型
微調和蒸餾都可以提高主模型和輔助模型在您選擇的語言上的 WER 性能,同時最大化 token 分布的對齊。有關 Whisper 微調的完整指南,請參閱 此處,有關蒸餾的指南請參閱 此處。
批次大小
值得注意的是,使用推測解碼獲得的最大速度提升來自批次大小為 1。對于批處理推測解碼,批處理中的所有候選 token 必須與驗證 token 相匹配,才能被接受。如果批處理中給定位置的 token 不一致,則所有在該位置之前的候選 token 將被丟棄。因此,推測解碼更傾向于較小的批次大小。在實踐中,我們發現推測解碼可以提供速度提升,直到批次大小達到 4 為止。當批次大小超過 4 時,推測解碼的推理速度比僅用主模型還要慢。有關完整結果,請參閱 Distil-Whisper 論文 的第 D.3 節。
結論
在本博文中,我們介紹了推測解碼的推理策略,以及如何將其應用于語音轉錄的 Whisper 模型。我們展示了如何實現 2 倍的速度提升,同時數學上確保獲得與僅使用原始模型相同的輸出。我們鼓勵您嘗試將推測解碼用作現有 Whisper 管道的即插即用替代方案,因為使用額外的輔助模型的開銷很小,并且可以保證獲得相同的轉錄結果。
致謝
本博客由 Sanchit Gandhi 撰寫。非常感謝 Patrick von Platen 和 Pedro Cuenca 的建設性意見,以及 Joao Gante 在 ?? Transformers 中實現輔助生成的貢獻。
英文原文: https://hf.co/blog/whisper-speculative-decoding
作者: Sanchit Gandhi
譯者: Hu Yaoqi (yaoqi)
總結
以上是生活随笔為你收集整理的使用推测解码 (Speculative Decoding) 使 Whisper 实现 2 倍的推理加速的全部內容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: 介绍一个prometheus监控数据生成
- 下一篇: P1990-覆盖墙壁