Core Concepts
Linear Attention Sequence Parallelism (LASP)は、線形アテンションベースの言語モデルにおいて、長シーケンスを効率的に並列処理するための新しい手法である。LASPは、線形アテンションの特性を活かした効率的な通信メカニズムと、ハードウェア最適化により、既存の並列手法よりも高速で長いシーケンスを処理できる。
Abstract
本論文は、線形アテンションベースの言語モデルにおいて長シーケンスを効率的に並列処理するためのLinear Attention Sequence Parallelism (LASP)を提案している。
LASP の主な特徴は以下の通り:
-
線形アテンションの特性を活かした効率的な通信メカニズム
- 線形アテンションの右積カーネルトリックを利用して、通信オーバーヘッドを大幅に削減
- 並列度に依存しない通信量を実現
-
ハードウェア最適化
- カーネルの融合
- 中間状態のキャッシング
- GPUクラスタ上での高速な実装
-
各種分散データ並列手法との互換性
- PyTorch DDP、FSDP、ZeRO最適化などと組み合わせ可能
- 大規模クラスタでの長シーケンス・大バッチ学習に適用可能
実験では、LASP が既存の並列手法よりも高速で長いシーケンスを処理できることを示している。具体的には、1B パラメータモデルで最大4096Kのシーケンス長まで拡張でき、既存手法の8倍長いシーケンスを扱えるようになった。また、収束性能も既存手法と同等であることを確認している。
Stats
1つのGPUの最大メモリ使用量は128 GPUsで4096Kシーケンス長まで拡張できる
LASPは既存手法と比べて、256Kシーケンス長で38%高速、136%高スループットを達成した
Quotes
"LASP scales sequence length up to 4096K using 128 A100 80G GPUs on 1B models, which is 8× longer than existing SP methods while being significantly faster."
"LASP demonstrates a notable enhancement in throughput for linear attention, primarily due to its efficient communication design that facilitates the exchange of linear attention intermediate states."