最近,Mamba 團隊的研究令人矚目:來自康奈爾和普林斯頓等高校的研究者們成功將 Llama 這一大型 Transformer 模型 “蒸餾” 成了 Mamba,並設計了一種新型的推理解碼算法,顯著提高了模型的推理速度。
研究人員的目標是讓 Llama 變成 Mamba。爲什麼這麼做呢?因爲從零開始訓練一個大型模型代價高昂,而 Mamba 自問世以來受到了廣泛關注,但實際上很少有團隊自己訓練大規模的 Mamba 模型。雖然市面上有一些名聲在外的變種,比如 AI21的 Jamba 和 NVIDIA 的 Hybrid Mamba2,但衆多成功的 Transformer 模型中蘊藏了豐富的知識。如果我們能夠鎖住這些知識,同時將 Transformer 微調爲 Mamba,那問題就迎刃而解了。
研究團隊結合了漸進式蒸餾、監督微調和定向偏好優化等多種方法,成功達成了這個目標。值得注意的是,在保證性能不打折的前提下,速度也顯得至關重要。Mamba 在長序列推理中的優勢非常明顯,而 Transformer 也有推理加速方案,比如推測解碼。由於 Mamba 的獨特結構無法直接應用這些方案,研究者們特意設計了一種全新的算法,並結合硬件特性來實現基於 Mamba 的推測解碼。
最終,研究人員將 Zephyr-7B 和 Llama-38B 成功轉換爲線性 RNN 模型,且性能與蒸餾前的標準模型相當。整個訓練過程僅使用了20B 的 token,結果與使用1.2T 個 token 從頭訓練的 Mamba7B 模型及3.5T 個 token 訓練的 NVIDIA Hybrid Mamba2模型不相上下。
在技術細節方面,線性 RNN 與線性注意力是相通的,因此研究者能夠直接複用注意力機制中的投影矩陣,並通過參數初始化完成模型構建。此外,研究團隊凍結了 Transformer 中 MLP 層的參數,逐步用線性 RNN 層(即 Mamba)替換掉注意力頭,並對跨頭共享鍵和值的分組查詢注意力進行了處理。
在蒸餾過程中,採用了逐步替換注意力層的策略。監督微調包括兩種主要方法:一種是基於 word-level 的 KL 散度,另一種是序列級知識蒸餾。針對用戶偏好的調優階段,團隊利用了直接偏好優化(DPO)的方法,通過與老師模型的輸出進行對比,確保模型在生成內容時能更好地符合用戶的期望。
接下來,研究者們開始將 Transformer 的推測解碼應用到 Mamba 模型中。推測解碼可以簡單理解爲用一個小模型生成多個輸出,然後使用大模型對這些輸出進行驗證。小模型運行迅速,可以快速生成多個輸出向量,而大模型則負責評估這些輸出的準確性,從而提升整體推理速度。
爲了實現這一過程,研究者們設計了一套算法,每次使用小模型生成 K 個草稿輸出,隨後大模型通過驗證返回最終的輸出和中間狀態的緩存。這一方法在 GPU 上得到了很好的效果,Mamba2.8B 實現了1.5倍的推理加速,且接受率達到了60%。儘管在不同架構的 GPU 上效果有所差異,研究團隊通過融合內核和調整實現方式進行進一步優化,最終達成了理想的加速效果。
在實驗階段,研究人員利用 Zephyr-7B 和 Llama-3Instruct8B 進行了三階段的蒸餾訓練,最終僅需在8卡80G A100上運行3到4天,便成功復現了研究成果。這項研究不僅展示了 Mamba 與 Llama 之間的轉變之路,也爲未來模型的推理速度和性能提升提供了新的思路。
論文地址:https://arxiv.org/pdf/2408.15237