最高加速9倍!字節跳動開源8比特混合精度Transformer引擎

近年來,Transformer 已經成為了 NLP 和 CV 等領域的主流模型,但龐大的模型參數限制了它的高效訓練和推理。於是字節跳動在 2019 年 12 月和 2021 年 6 月分別推出了高效推理和訓練引擎 LightSeq,大大加速了 Transformer 系列模型的訓練和推理,也打通了 Transformer 從訓練到推理的整個流程,極大最佳化了使用者使用體驗。最近,LightSeq 訓練引擎相關論文[1],被錄用難度極高的超算領域國際頂會 SC22 接收,得到了學術界的廣泛認可!

  • SC22 接收論文:https://sc22.supercomputing.org/presentation/?id=pap211&sess=sess154
  • 程式碼地址:https://github.com/bytedance/lightseq

如何繼續提升速度?降低計算精度是比較直接的方法。2017 年以來,fp16 混合精度技術 [2] 獲得了廣泛應用。在對模型效果無損的前提下,將模型訓練和推理的速度提升了 50% 以上。而為了維持模型效果,更低精度的方法(例如 int8)通常需要使用如下傳統方案:

  1. 首先使用 fp16 混合精度將模型訓練至收斂;
  2. 然後在模型計算密集型運算元的權重、輸入和輸出位置處,插入偽量化結點,進行量化感知訓練;
  3. 最後將帶有偽量化結點的模型計算圖轉換到專用的 int8 推理引擎中,進行服務部署和模型推理。

雖然在多數任務上,上述方案可以實現模型效果無損,但還是存在以下問題:

  1. 使用方法複雜。例如要多一次量化感知訓練 [4] 的過程,並且帶有偽量化節點的計算圖轉換複雜。
  2. 訓練速度慢。由於目前流行的深度學習框架不支持 int8 精度,所以量化感知訓練需要插入 fp16 的偽量化結點來模擬 int8 量化,導致量化感知訓練反而比 fp16 混合精度訓練慢 2-3 倍。
  3. 推理部署難且加速比低。對比 fp32、fp16 等類型,int8 硬體和底層軟體庫最佳化相對滯後。例如在 NVIDIA GPU 上,int8 矩陣乘法加速受限於硬體架構和特定 shape,實際加速比遠遠低於理論值。

在下文中,如無特殊說明,量化都是指的 int8 精度的量化

針對這些問題,字節跳動推出了全新版本的 LightSeq GPU 量化訓練與推理引擎。支持 Transformer 系列模型的量化訓練與推理,並做到了開箱即用,使用者友好。LightSeq 快準狠地實現了 int8 精度的量化訓練和推理:

  1. 快:A100 多卡訓練最高加速 5.2 倍,T4 單卡推理最高加速 8.9 倍。
  2. 準:訓練和推理效果基本無損。
  3. 狠:相同資料量下,視訊記憶體佔用最高減少 68%,模型儲存空間減少 75%。

總體來說,LightSeq 新版量化訓練與推理引擎具有如下幾個優點:

1. 豐富的支持

支持完整的 Transformer 模組和多種解碼演算法,支持 Transformer、BERT、GPT、BART、ViT 等多種模型結構,支持 Fairseq、Hugging Face、NeurST 等多種訓練框架接入量化訓練、匯出模型以及量化推理,提供了豐富的樣例供使用者參考。

2. 卓越的性能

相比於 fp16 精度的 LightSeq 推理引擎,int8 量化還可以進一步加速最高 70%,相比於 PyTorch 推理更是達到了最高 8.9 倍的加速比。同時視訊記憶體佔用相比 fp16 推理引擎降低了 30% 左右,模型儲存空間只需要原來的四分之一。最後經過多個任務的驗證,推理效果幾乎無損。

3. 便捷的使用

LightSeq 已經針對多個訓練庫進行了量化支持,可以一鍵開啟量化訓練,然後輕鬆匯出為 LightSeq 支持的模型格式,最後實現量化推理。除此之外,LightSeq 還支持訓練後量化,無需額外訓練即可體驗量化推理。

使用方法

使用方法

如上圖所示,為了最大程度減小量化帶來的損失,首先需要用 fp16 精度訓練一個浮點數模型,將模型效果訓到最好。然後開啟量化進行 finetune,得到微調過的量化模型,此時模型效果已經基本恢復到浮點數模型的水平。接著將量化模型轉換為 LightSeq 支持的 PB 或者 HDF5 模型格式,最後用 LightSeq 進行量化推理。

安裝方法

LightSeq 安裝非常簡單,只需要一行命令即可:

pip install lightseq

量化訓練

LightSeq 支持 Fairseq、Hugging Face、NeurST 等訓練框架的量化接入,同時也可以自定義模型並開啟量化訓練。以 encoder 層為例,只需要先定義浮點數模型,然後開啟量化即可:

from lightseq.training import LSTransformerEncoderLayerfrom lightseq.training.ops.pytorch.quantization import enable_quantconfig = LSTransformerEncoderLayer.get_config(model="bert-base",max_batch_tokens=4096,max_seq_len=512,fp16=True,local_rank=0,)layer = LSTransformerEncoderLayer(config)# 開啟量化layer.apply(enable_quant)

量化推理

LightSeq 提供了便捷的 python 推理接口,只需要三行程式碼即可實現快速的量化推理:

import lightseq.inference as lsimodel = lsi.QuantTransformer(pb_path, batch_size)result = model.infer(input)

此外 LightSeq 還提供了 BERT、GPT、ViT 等模型的 python 接口,分別調用 QuantBert、QuantGpt 和 QuanVit 即可體驗。

梯度通訊量化

LightSeq 支持 Transformer 模型的梯度通訊量化[5],使用 Fairseq 或者 Hugging Face 即可輕鬆開啟分散式量化訓練,並同時支持浮點數模型和量化模型。在構建模型後,只需要為模型註冊一個 communication hook 即可開啟梯度通訊量化,再開始訓練過程。

from lightseq.training.gradient_comm_quantization import encode_and_decode, GCQStatefrom torch.nn.parallel import DistributedDataParallel# model could be from Fairseq or Hugging Face, wrapped by DDPmodel = DistributedDataParallel(model)state =  GCQState(process_group)# register hookmodel.register_comm_hook(state=state, hook=encode_and_decode)

性能測試

LightSeq 在多個任務上測試了量化訓練、量化推理和梯度通訊量化的速度,並且分析了視訊記憶體佔用情況和量化模型的效果。

量化訓練速度

量化訓練速度

LightSeq 在 8 張 A100 顯示卡上進行了訓練實驗,主要對比對象是 Fairseq 的 Transformer、Hugging Face 的 BERT、GPT2 和 ViT。

可以看出,四種模型結構加速趨勢都是類似的,加速比都會隨著資料量的增大而減小,原因有三點:

  1. 隨著資料量的增大,矩陣乘法 GEMM 的佔比會明顯增加,因此 PyTorch QAT 增加的額外的偽量化結點時間佔比會逐漸減小,最後速度會和 PyTorch fp16 無限接近。
  2. 與此同時,隨著 GEMM 佔比升高,LightSeq fp16 自定義運算元的提速效果也逐漸減小,因此時間上也會和 PyTorch fp16 無限接近。
  3. 由於 Ampere 架構顯示卡上 int8 GEMM 在 shape 較小時甚至不如 fp16 GEMM 快,在大 shape 下才能稍快一點,因此隨著資料量增大,LightSeq int8 也會無限接近 LightSeq fp16 的速度。

量化推理速度

量化推理速度

LightSeq 在單張 T4 顯示卡上進行了推理實驗,主要對比對象是 Hugging Face 的 Transformer、BERT、GPT2 和 ViT。

可以看出,隨著輸入資料量的增大,LightSeq 與 PyTorch 的差距會逐漸減小,這也是 GEMM 佔比升高造成的。比較 LightSeq fp16 和 LightSeq int8,可以看出隨著資料量的增大,LightSeq int8 越來越快。這是因為在 T4 顯示卡上,int8 GEMM 的加速會隨著 shape 的增大而有明顯增加。因此在 T4 顯示卡上進行量化推理時,輸入資料量越大,加速效果越好。

LightSeq 還針對機器翻譯多個語向和多個測試集,測試了不同 batch size 下,LightSeq int8 推理相對於 LightSeq fp16 推理的加速比,實驗同樣是在單張 T4 顯示卡上進行的,採用的模型都是標準的 Transformer-Big。

可以得到和上文中相同的結論,隨著 batch size 的增大,量化推理的加速比會逐漸升高。相比於 LightSeq fp16,最高還可以再加速近 70%,這極大地縮短了線上翻譯模型的推理延時。

最後如上圖所示,為了展示自動 GEMM 調優技術的效果,LightSeq 測試對比了 A100 顯示卡上 Transformer 和 BERT 模型 fp16、int8 調優前和 int8 調優後的延時。可以看出調優前某些 shape 的 int8 GEMM 速度甚至比 fp16 還要慢,而調優後全面超越了 fp16。

視訊記憶體佔用

視訊記憶體佔用

LightSeq 分析了不同 batch size 下,量化模型相對於浮點數模型視訊記憶體佔用的加速比。可以看出隨著 batch size 的增大,量化模型的視訊記憶體佔用優勢更明顯,最高可以減少 30% 左右。而 LightSeq fp16 引擎相對於 PyTorch 模型也極大程度減少了視訊記憶體佔用,因此 LightSeq int8 引擎最終能夠減少最多 68% 左右的視訊記憶體。

量化模型效果

量化模型效果

針對機器翻譯多個語向和多個測試集,LightSeq 測試了量化模型推理相對於浮點數模型 BLEU 的損失,採用的模型都是標準的 Transformer-Big。

在資料量較大的語向 en2zh 上,LightSeq int8 相對 BLEU 損失較大些,最大達到了 – 0.4。而在資料量較小的語向 en2es 上,LightSeq int8 不僅沒有任何效果損失,反而比浮點數模型更好。總體而言,int8 量化模型的平均 BLEU 相比浮點數模型基本無損。在 GLUE 和 SQuAD 等多個任務上,LightSeq 也驗證了量化模型的效果。

梯度通訊量化

梯度通訊量化

由於在多機多卡場景下通訊瓶頸更加明顯,所以梯度通訊量化主要應用在分散式訓練場景。因此 LightSeq 在 2 機 8 卡的 A100 上進行了分散式訓練的速度測試。

可以看出,梯度通訊量化的訓練加速效果整體上隨著輸入資料的增大而減弱。這主要是因為隨著輸入資料的增大,計算時間佔比升高,梯度通訊時間佔比減少,梯度量化的收益也隨之減小。

LightSeq 還額外增加了不同數量網路卡(NIC)下的訓練速度測試。可以看到使用梯度通訊量化的分散式訓練速度相比原始的 LightSeq fp16 有大幅度提升。

量化技術

int8 量化的加速收益主要來自如下幾個方面:

  1. GEMM 精度從 fp16 降低到 int8 後,計算時間縮短;
  2. 自定義運算元採用 int8 輸入輸出後,資料讀寫時間縮短;
  3. 梯度採用 int8 儲存後,多機之間通訊時間縮短。

以 Transformer 模型為例,經過 LightSeq fp16 引擎加速後,自定義運算元時間大大縮短,而 GEMM 時間佔比提升到了 90% 左右,因此最佳化的重點轉移到了 GEMM 提速。將 fp16 GEMM 替換為 int8 GEMM 不僅可以縮短 GEMM 時間,還可以減小前後運算元的輸入輸出位寬,從而減小讀寫資料的時間。最後多機訓練的瓶頸主要在梯度的通訊,將梯度量化為 int8 精度可以大大加快分散式訓練的速度。

量化原理

量化原理

為了彌補量化帶來的精度損失,通常需要用量化感知訓練來模擬量化過程。如上圖所示,量化感知訓練就是將 float GEMM 的兩個 float 輸入分別做一遍量化和反量化(稱之為偽量化結點),離散化成分段的浮點數輸入,然後進行 float GEMM 運算。得到結果後再次進行量化與反量化,得到最終的浮點數結果。而量化的過程是不可導的,因此需要用 STE 方法來估計量化參數的梯度。之所以量化感知訓練中需要插入偽量化結點,然後用 float GEMM 去模擬量化過程,是因為 TensorFlow 和 PyTorch 等訓練框架不支持 int8 GEMM。

而 LightSeq 量化訓練直接採用 int8 GEMM 來真實還原量化過程,因此相比傳統的實現要更快,且更加節省視訊記憶體。在推理的時候,同樣採用離散化後的整數進行 int8 GEMM 運算,最後再反量化回浮點數結果。量化推理過程和量化訓練完全一致,並且和傳統的量化感知訓練是完全等價的。

量化位置

量化位置

整個量化 Transformer 的網路結構如上圖所示,紅色箭頭表示需要加上量化和反量化結點的位置。

首先所有 int8 GEMM 的輸入和輸出都需要進行量化。由於 int8 GEMM 的 shape 限制,部分 GEMM(例如注意力分數的計算)仍然採用 float GEMM。此外第二層 FFN 的 GEMM 採用的是 int32 的輸出,因為它的 GEMM 輸入是 ReLU 激活函數的輸出結果,只包含正數,非對稱,因此如果採用 int8 輸出的 GEMM,將無法反量化為正確的浮點數結果。

然後所有的模型權重 weight 都需要儲存為 int8 類型,因此需要對 weight 做量化。而權重 bias 參數量較小,無需量化,保留 float 精度反而可以提升模型效果。

最後需要對 decoder 端的 cache 進行量化。因為在推理時,decoder 端的 cache 需要頻繁進行讀寫,因此將 cache 量化為 int8 可以大大加快解碼的速度。

量化策略

量化策略

將一個浮點數矩陣量化為 int8 整數矩陣有很多方法,LightSeq 採用的是對稱量化,即將正負數範圍對稱的浮點數區間等比例地對映到整數區間 [-127, 127] 上。

而實際上浮點數矩陣的數值範圍通常並不對稱,存在極少的離群值。如果直接按照離群值的範圍來量化矩陣,會影響到量化後的精度,所以需要先對矩陣進行數值截斷。

LightSeq 採用 PACT 方法進行截斷[6],將截斷的範圍當作模型可學習的參數,然後利用 STE 演算法去估計參數的梯度,並進行反向傳播最佳化。根據實踐經驗,權重 weight 的初始截斷範圍設為[-1, 1],中間結果的初始截斷範圍設為[-16, 16],可以在大部分任務上達到最好的效果。最後經過截斷範圍和其他模型參數的聯合最佳化,量化模型的效果可以達到基本無損。

梯度通訊量化

針對分散式訓練場景,LightSeq 推出了梯度量化壓縮技術。即對浮點精度的梯度進行 int8 量化,以減少梯度通訊的時間消耗,從而加速訓練,這就是梯度通訊量化(GCQ)。

如上圖所示,梯度通訊量化的主要流程如下:

  1. 計算每張卡上各自梯度的截斷範圍;
  2. 對截斷範圍執行 all-reduce max 操作;
  3. 每張卡使用統一的截斷範圍對各自梯度進行 int8 量化;
  4. 對 int8 梯度執行 all-reduce sum 操作;
  5. 每張卡對 all-reduce 後的梯度進行反量化,還原為浮點數梯度,並進行參數更新。

為了解決 int8 梯度在 all-reduce 過程中溢出的問題,LightSeq 首先將每張卡上的浮點數梯度除以卡數,再使用除之前的截斷範圍進行量化,最後進行 all-reduce 操作。這樣每張卡上量化後的 int8 整數 all-reduce 完就不會溢出,但是單卡實際用於量化的比特數也因此而減少,所以目前方案在 2 機 8 卡效果幾乎無損,但隨著卡數的上漲,訓練效果會有所下降。以 en2de 和 en2fr 翻譯任務為例,在 4 機 8 卡上進行分散式量化訓練,BLEU 值分別會下降 0.4 和 1.5 左右。未來 LightSeq 將會持續探索更好的方法來解決這一問題。

通用技術

除了上一章節中提到的量化技術以外,此次更新 LightSeq 還提出了幾種通用的最佳化技術,不僅可以應用在量化模型中,也適用於其它所有精度模型的訓練與推理。

運算元融合

運算元融合

上圖是 encoder 模組量化訓練的計算圖,LightSeq 將兩次 GEMM 運算之間的所有操作融合成一個運算元[7],減少了 kernel 調用的次數,因此減少了總的計算時間。

圖中黃色矩形表示 int8 GEMM,綠色矩形表示 float GEMM。這裡採用 float GEMM 是由於 shape 的限制,不適合使用 int8 GEMM 加速。紅色箭頭表示流動資料的類型是 int8,綠色箭頭表示第二層 FFN 的 GEMM 輸出是 int32 資料類型。int8 GEMM 輸入輸出的量化與反量化操作都被融合到了前後 kernel 裡,這不僅可以減少資料搬運,還可以減小視訊記憶體佔用。

在推理時,LightSeq 還針對 decoder 做了最佳化。如上圖所示,在計算 self-attention 時,注意力得分的維度是(batch size, 1, sequence length)。因此在計算 value 乘積時,可以不採用 GEMM 運算,而直接手寫加權求和的運算元,從而將圖中虛線框中的計算融合成一個 kernel。

自動視訊記憶體管理

自動視訊記憶體管理

模型量化引入了更復雜的張量類型和張量依賴關係,這給視訊記憶體管理帶來新的挑戰。為此,LightSeq 設計了新的視訊記憶體管理機制。如上圖所示,主要包括以下過程:

  1. 訓練啟動前,根據每個運算元的拓撲依賴關係,自動計算每個張量的生命週期及視訊記憶體空間大小。其中,包含動態維度的張量按照此維度的最大量進行計算,例如機器翻譯任務中的最大句長和最大 batch 句子數量。這些最大量在訓練前已被指定;
  2. 張量確定生命週期和大小後,分析視訊記憶體複用關係。其中,無生命週期重合的張量可以共用一片視訊記憶體空間,所有視訊記憶體空間都是無資料類型的,可以被分配到任意資料類型的張量上;
  3. 根據張量視訊記憶體複用關係,申請多段視訊記憶體空間,為每個張量分配實際的視訊記憶體起止地址。

張量視訊記憶體複用的分析,LightSeq 借鑑了論文 [3] 中提出的 Greedy by Size for Offset Calculation 方法,做了三個改進:

  1. 支持了整個訓練過程的視訊記憶體複用(forward/backward);
  2. 不同資料類型能做到視訊記憶體複用(int8/fp16/fp32);
  3. 在多段視訊記憶體空間上容納所有張量,而非一段非常大的視訊記憶體空間,這樣能有效提升視訊記憶體利用率。

自動 GEMM 調優

LightSeq 的 int8 GEMM 採用了 NVIDIA 的 cuBLASLt 庫,這也是目前 NVIDIA 顯示卡上最為高效的矩陣運算庫。但是輸入資料的 shape 或者顯示卡不同的話,GEMM 所採用的最優配置(例如資料排布、GEMM 演算法等等)也可能不同,因此需要進行自動選取。LightSeq 採取的自動調優方案如下:

  1. 在多種型號顯示卡上(例如 T4 和 A100)進行不同 shape 的 GEMM 最優配置搜尋,並將結果保存到配置檔案中,使用者只需要下載即可;
  2. 模型初始化時,載入對應型號顯示卡的配置檔案,解析並保存到鍵值對為 (shape, 最優配置) 的字典中。如果沒有對應型號顯示卡的配置檔案,或者沒有需要的 GEMM shape,那麼使用者可以選擇自己搜尋並保存,或者直接使用默認配置;
  3. 模型前向或後向計算時,根據輸入的 shape 在字典中尋找最優配置,然後進行 GEMM 計算。如果沒有找到對應的 shape,那麼直接採用默認的配置。

未來工作

未來 LightSeq 還將繼續探索移動端的低精度量化、反向傳播中梯度的量化、大模型量化等方向。

引用

[1] Wang, Xiaohui, et al. “LightSeq2: Accelerated training for transformer-based models on gpus.” arXiv preprint arXiv:2110.05722 (2021).

[2] Micikevicius, Paulius, et al. “Mixed precision training.” arXiv preprint arXiv:1710.03740 (2017).

[3] Pisarchyk, Yury, and Juhyun Lee. “Efficient memory management for deep neural net inference.” arXiv preprint arXiv:2001.03288 (2020).

[4] Jacob, Benoit, et al. “Quantization and training of neural networks for efficient integer-arithmetic-only inference.” Proceedings of the IEEE conference on computer vision and pattern recognition. 2018.

[5] Alistarh, Dan, et al. “QSGD: Communication-efficient SGD via gradient quantization and encoding.” Advances in neural information processing systems 30 (2017).

[6] Choi, Jungwook, et al. “Pact: Parameterized clipping activation for quantized neural networks.” arXiv preprint arXiv:1805.06085 (2018).

[7] Wang, Xiaohui, et al. “LightSeq: A high performance inference library for transformers.” arXiv preprint arXiv:2010.13887 (2020).

相關文章

Transformer,ChatGPT 幕後的真正大佬

Transformer,ChatGPT 幕後的真正大佬

ChatGPT的背後 ChatGPT紅得發紫,強得讓人類心悸。 但在它的背後,還隱藏著一位真正的大佬。 可以說,與它相比,ChatGPT其實...