統治擴散模型的U-Net要被取代了,謝賽寧等引入Transformer提出DiT

來自 UC 柏克萊的 William Peebles 以及紐約大學的謝賽寧撰文揭秘擴散模型中架構選擇的意義,併為未來的生成模型研究提供經驗基線。

近幾年,在 Transformer 的推動下,機器學習正在經歷復興。過去五年中,用於自然語言處理、計算機視覺以及其他領域的神經架構在很大程度上已被 transformer 所佔據。

不過還有許多圖像級生成模型仍然不受這一趨勢的影響,例如過去一年擴散模型在圖像生成方面取得了驚人的成果,幾乎所有這些模型都使用卷積 U-Net 作為主幹。這有點令人驚訝!在過去的幾年中,深度學習的大事件一直是跨領域的 Transformer 的主導地位。U-Net 或卷積是否有什麼特別之處使它們在擴散模型中表現得如此出色?

將 U-Net 主幹網路首次引入擴散模型的研究可追溯到 Ho 等人,這種設計模式繼承了自迴歸生成模型 PixelCNN++,只是稍微進行了一些改動。而 PixelCNN++ 由卷積層組成,其包含許多的 ResNet 塊。其與標準的 U-Net 相比,PixelCNN++ 附加的空間自注意力塊成為 transformer 中的基本元件。不同於其他人的研究,Dhariwal 和 Nichol 等人消除了 U-Net 的幾種架構選擇,例如使用自適應歸一化層為卷積層注入條件資訊和通道計數。

本文中來自 UC 柏克萊的 William Peebles 以及紐約大學的謝賽寧撰文《 Scalable Diffusion Models with Transformers 》,目標是揭開擴散模型中架構選擇的意義,併為未來的生成模型研究提供經驗基線。該研究表明,U-Net 歸納偏置對擴散模型的性能不是至關重要的,並且可以很容易地用標準設計(如 transformer)取代。

這一發現表明,擴散模型可以從架構統一趨勢中受益,例如,擴散模型可以繼承其他領域的最佳實踐和訓練方法,保留這些模型的可擴展性、魯棒性和效率等有利特性。標準化架構也將為跨領域研究開闢新的可能性。

  • 論文地址:https://arxiv.org/pdf/2212.09748.pdf
  • 項目地址:https://github.com/facebookresearch/DiT
  • 論文主頁:https://www.wpeebles.com/DiT

該研究專注於一類新的基於 Transformer 的擴散模型:Diffusion Transformers(簡稱 DiTs)。DiTs 遵循 Vision Transformers (ViTs) 的最佳實踐,有一些小但重要的調整。DiT 已被證明比傳統的卷積網路(例如 ResNet )具有更有效地擴展性。

具體而言,本文研究了 Transformer 在網路複雜度與樣本質量方面的擴展行為。研究表明,通過在潛在擴散模型 (LDM) 框架下構建 DiT 設計空間並對其進行基準測試,其中擴散模型在 VAE 的潛在空間內進行訓練,可以成功地用 transformer 替換 U-Net 主幹。本文進一步表明 DiT 是擴散模型的可擴展架構:網路複雜性(由 Gflops 測量)與樣本質量(由 FID 測量)之間存在很強的相關性。通過簡單地擴展 DiT 並訓練具有高容量主幹(118.6 Gflops)的 LDM,可以在類條件 256 × 256 ImageNet 生成基準上實現 2.27 FID 的最新結果。

Diffusion Transformers

DiTs 是一種用於擴散模型的新架構,目標是儘可能忠實於標準 transformer 架構,以保留其可擴展性。DiT 保留了 ViT 的許多最佳實踐,圖 3 顯示了完整 DiT 體系架構。

DiT 的輸入為空間表示 z(對於 256 × 256 × 3 圖像,z 的形狀為 32 × 32 × 4)。DiT 的第一層是 patchify,該層通過將每個 patch 線性嵌入到輸入中,以此將空間輸入轉換為一個 T token 序列。patchify 之後,本文將標準的基於 ViT 頻率的位置嵌入應用於所有輸入 token。

patchify 創建的 token T 的數量由 patch 大小超參數 p 決定。如圖 4 所示,將 p 減半將使 T 翻四倍,因此至少能使 transformer Gflops 翻四倍。本文將 p = 2,4,8 添加到 DiT 設計空間。

DiT 塊設計:在 patchify 之後,輸入 token 由一系列 transformer 塊處理。除了噪聲圖像輸入之外,擴散模型有時還會處理額外的條件資訊,例如噪聲時間步長 t、類標籤 c、自然語言等。本文探索了四種以不同方式處理條件輸入的 transformer 塊變體。這些設計對標準 ViT 塊設計進行了微小但重要的修改。所有模組的設計如圖 3 所示。

本文嘗試了四種因模型深度和寬度而異的配置:DiT-S、DiT-B、DiT-L 和 DiT-XL。這些模型配置範圍從 33M 到 675M 參數,Gflops 從 0.4 到 119 。

實驗

研究者訓練了四個最高 Gflop 的 DiT-XL/2 模型,每個模型使用不同的 block 設計 ——in-context(119.4Gflops)、cross-attention(137.6Gflops)、adaptive layer norm(adaLN,118.6Gflops)或 adaLN-zero(118.6Gflops)。然後在訓練過程中測量 FID,圖 5 為結果。

擴展模型大小和 patch 大小。圖 2(左)給出了每個模型的 Gflops 和它們在 400K 訓練迭代時的 FID 概況。可以發現,增加模型大小和減少 patch 大小會對擴散模型產生相當大的改進。

圖 6(頂部)展示了 FID 是如何隨著模型大小的增加和 patch 大小保持不變而變化的。在四種設置中,通過使 Transformer 更深、更寬,訓練的所有階段都獲得了 FID 的明顯提升。同樣,圖 6(底部)展示了 patch 大小減少和模型大小保持不變時的 FID。研究者再次觀察到,在整個訓練過程中,通過簡單地擴大 DiT 處理的 token 數量,並保持參數的大致固定,FID 會得到相當大的改善。

圖 8 中展示了 FID-50K 在 400K 訓練步數下與模型 Gflops 的對比:

SOTA 擴散模型 256×256 ImageNet。在對擴展分析之後,研究者繼續訓練最高 Gflop 模型 DiT-XL/2,步數為 7M。圖 1 展示了該模型的樣本,並與類別條件生成 SOTA 模型進行比較,表 2 中展示了結果。

當使用無分類器指導時,DiT-XL/2 優於之前所有的擴散模型,將之前由 LDM 實現的 3.60 的最佳 FID-50K 降至 2.27。如圖 2(右)所示,相對於 LDM-4(103.6 Gflops)這樣的潛在空間 U-Net 模型來說,DiT-XL/2(118.6 Gflops)計算效率高得多,也比 ADM(1120 Gflops)或 ADM-U(742 Gflops)這樣的像素空間 U-Net 模型效率高很多。

表 3 展示了與 SOTA 方法的比較。XL/2 在這一解析度下再次勝過之前的所有擴散模型,將 ADM 之前取得的 3.85 的最佳 FID 提高到 3.04。

更多研究細節,可參考原論文

更多研究細節,可參考原論文。

相關文章

Transformer,ChatGPT 幕後的真正大佬

Transformer,ChatGPT 幕後的真正大佬

ChatGPT的背後 ChatGPT紅得發紫,強得讓人類心悸。 但在它的背後,還隱藏著一位真正的大佬。 可以說,與它相比,ChatGPT其實...