FlashAttention

Материал из MachineLearning.

Версия от 07:00, 2 июля 2026; Mihail Mishin (Обсуждение | вклад)
(разн.) ← Предыдущая | Текущая версия (разн.) | Следующая → (разн.)
Перейти к: навигация, поиск
Статья написана с использованием LLM DeepSeek-V3 и проверена участником М. Мишин 10:00, 2 июля 2026 (MSD)

Промпт приводится полностью в Обсуждение:FlashAttention


Содержание

FlashAttention — семейство IO‑aware алгоритмов для вычисления механизма внимания в трансформерах, позволяющее значительно ускорить обучение и инференс больших языковых моделей и снизить потребление памяти с квадратичного до линейного относительно длины последовательности. Впервые предложена в 2022 году группой исследователей из Стэнфордского университета. Ключевая инновация — переосмысление вычислений с учётом иерархии памяти GPU, что позволяет минимизировать дорогостоящие операции чтения/записи между медленной глобальной памятью (HBM) и быстрой кэш-памятью (SRAM).

В отличие от приближённых методов, FlashAttention вычисляет точное внимание без потери качества, но при этом работает в 2–4 раза быстрее оптимизированных реализаций и сокращает объём используемой памяти. Благодаря FlashAttention появилась возможность создавать модели с контекстным окном в сотни тысяч и миллионы токенов.

Мотивировка: проблема стандартного внимания

Механизм самовнимания (self-attention) является вычислительным ядром трансформеров. Для входных последовательностей Q, K, V ∈ ℝ^{n×d} (где n — длина последовательности, d — размерность представления) стандартное внимание вычисляется как:


\text{Attention}(Q,K,V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d}}\right) V.

Наивная реализация требует материализации матрицы оценок внимания S = QK^T / √d размера n×n, что приводит к двум фундаментальным проблемам:

  1. Квадратичная сложность по памяти — O(n²). При длинных последовательностях (n > 4096) хранение полной матрицы S становится невозможным даже на самых современных GPU.
  2. IO-узкое место — основным ограничением производительности оказывается не число арифметических операций, а постоянные чтения/записи в медленную глобальную память GPU (HBM).

Иерархия памяти GPU

Современные GPU имеют два основных уровня памяти:

  • HBM (High Bandwidth Memory) — большая (40–80 ГБ), но медленная память с пропускной способностью ~1,5 ТБ/с.
  • SRAM (Static RAM) — небольшая (∼20 МБ), но чрезвычайно быстрая кэш-память на кристалле с пропускной способностью до 19 ТБ/с.

Стандартный алгоритм внимания постоянно читает и записывает промежуточные матрицы (размером n×n) из HBM, что и становится главным источником задержек. Арифметические операции (матричные умножения) в этом случае простаивают в ожидании данных — скорость работы ограничена пропускной способностью памяти.

Основная идея FlashAttention

FlashAttention переформулирует вычисление внимания как IO‑aware алгоритм, минимизирующий количество обращений к HBM. Это достигается за счёт трёх ключевых приёмов.

1. Разбиение на блоки (Tiling)

Входные матрицы Q, K, V разбиваются на небольшие блоки (tiles), которые полностью помещаются в быструю SRAM. Алгоритм последовательно загружает эти блоки из HBM в SRAM, выполняет все необходимые вычисления для данного блока и обновляет результат, никогда не материализуя полную матрицу внимания в глобальной памяти.

2. Онлайн‑softmax

Ключевое техническое новшество — модифицированный алгоритм softmax, который может вычисляться по частям. Стандартный softmax требует знания всех элементов вектора для нормализации:


\text{softmax}(x_i) = \frac{\exp(x_i - m)}{\sum_j \exp(x_j - m)}, \quad m = \max_j x_j.

FlashAttention поддерживает два скалярных состояния при обходе блоков: текущий максимум m и сумму экспонент ℓ. При обработке каждого нового блока эти состояния обновляются, что позволяет получить точный результат без повторного чтения ранее обработанных данных.

3. Перевычисление в обратном проходе

Для обратного распространения ошибки стандартное внимание сохраняет промежуточную матрицу S (размером n×n). FlashAttention вместо этого перевычисляет необходимые промежуточные значения из сохранённых блоков Q, K, V и статистик softmax (m, ℓ), экономя память ценой дополнительных вычислений.

IO-сложность

Авторы показали, что FlashAttention требует O(n² d² / M) чтений/записей HBM для некоторых конфигураций, где M — размер SRAM, что значительно меньше, чем O(n² + nd) у стандартного внимания. Для большинства практических размеров блоков количество обращений к HBM сокращается в 10–100 раз.

Эволюция версий

Семейство FlashAttention активно развивается, каждая новая версия адаптируется к возможностям современных GPU и вводит дополнительные оптимизации.

FlashAttention (2022)

Первая версия:

  • Снижение памяти с O(n²) до O(n).
  • Ускорение в 2–4× по сравнению с оптимизированными реализациями.
  • На BERT-large (seq. length 512) — ускорение на 15% end‑to‑end.
  • На GPT-2 (seq. length 1K) — ускорение в 3×.
  • Первые трансформеры, показавшие результат лучше случайного на Path‑X (seq. length 16K) и Path‑256 (seq. length 64K).

Однако на NVIDIA A100 использование GPU составляло лишь 25–40% от теоретического пика FLOPs — основная причина заключалась в неоптимальном распределении работы между потоками.

FlashAttention‑2 (2023)

Вторая версия устранила недостатки первой за счёт:

  • Улучшенного распараллеливания: вычисления для одной головы внимания распределяются между разными блоками потоков.
  • Оптимизации распределения работы между варпами внутри блока, что сократило обмен через разделяемую память.
  • Сокращения числа операций, отличных от матричного умножения (non‑matmul FLOPs).

Результат:

  • Ускорение ∼2× относительно FlashAttention.
  • Использование A100 достигло 50–73% от теоретического пика.
  • При обучении GPT‑подобных моделей — до 225 TFLOPs/s на GPU A100 (72% utilisation).

FlashAttention‑3 (2024)

Третья версия ориентирована на архитектуру NVIDIA Hopper (H100) и использует новые аппаратные возможности:

  • WGMMA (Warpgroup Matrix Multiply‑Accumulate) — новый тип инструкций для тензорных ядер, почти вдвое быстрее, чем в Ampere.
  • TMA (Tensor Memory Accelerator) — аппаратный ускоритель для асинхронной передачи данных между HBM и разделяемой памятью, освобождающий вычислительные ядра.
  • Поддержка FP8 — низкоточные вычисления с сохранением точности.

Результаты:

  • Ускорение 1,5–2,0× относительно FlashAttention‑2 в прямом проходе и 1,5–1,75× в обратном.
  • До 740 TFLOPS (75% utilisation H100) для FP16.


Современные направления развития

Помимо основной линии FlashAttention‑1/2/3, активно развиваются специализированные расширения и адаптации.

Адаптация для RISC‑V векторных процессоров

Оригинальные реализации FlashAttention заточены под GPU NVIDIA. Однако растёт интерес к развёртыванию LLM на открытой архитектуре RISC‑V. В работе 2025 года предложена первая векторизованная реализация FlashAttention для RISC‑V векторных процессоров. Основные особенности:

  • Минимизация скалярного кода и упрощение вычисления экспонент через низкозатратную аппроксимацию.
  • Исследование стратегий разбиения на блоки для улучшения локальности памяти.
  • Значительный прирост производительности при обработке слоёв внимания в практических приложениях.

Другие исследования показывают, что аппаратная реализация FlashAttention на RISC‑V может обеспечивать на 10³ меньшее энергопотребление и задержку по сравнению с CPU, синтезированным по той же технологии.

Гибридные вычисления с плавающей и логарифмической точностью

Аппаратная реализация FlashAttention сталкивается с двумя вызовами: дорогими операциями деления и экспоненты в softmax. Работа H‑FA (Hybrid Floating‑point and Logarithmic Approach) предлагает вычислять attention, используя смесь представлений:

  • Оценки внимания (scores) вычисляются в арифметике с плавающей запятой.
  • Fused softmax и умножение на V выполняются в логарифмической области с фиксированной точкой, где умножение и деление заменяются на сложение и вычитание.
  • Операции экспоненты эффективно сливаются с остальными вычислениями.

Результаты на 28‑нм технологии: сокращение площади на 26,5% и энергопотребления на 23,4% по сравнению с чисто floating‑point реализациями без потери производительности.

Эффективная работа с масками внимания

Стандартный FlashAttention оптимизирован для полных (dense) и причинных (causal) масок. Однако многие приложения используют разреженные или частично заполненные маски (LongFormer, BigBird, tree‑masking для MEDUSA, упаковка последовательностей). Наивное применение FlashAttention к таким маскам сохраняет квадратичную сложность.

Предложены два подхода:

  • Binary Block Masking (BinBlkMsk) — расширение FlashAttention, поддерживающее произвольные маски через обработку только блоков, содержащих хотя бы один ненулевой элемент маски. Дополнительные оптимизации для масок с непрерывными ненулевыми паттернами и для крайне разреженных масок. Эксперименты показывают ускорение до 9× на реальных сценариях.
  • FlashMask — вводит столбцовое разреженное представление масок, эффективно поддерживающее широкий спектр типов масок и обеспечивающее линейную сложность по памяти.

Практические аспекты использования

Реализации и фреймворки

Официальный репозиторий FlashAttention доступен на GitHub. Библиотека широко интегрирована в экосистему машинного обучения:

  • PyTorch — функция `torch.nn.functional.scaled_dot_product_attention` использует FlashAttention в качестве бэкенда при наличии совместимого GPU.
  • Hugging Face Transformers — многие модели автоматически применяют FlashAttention при установленной библиотеке.
  • FlashInfer — библиотека ядер для инференса LLM, включающая оптимизированные версии FlashAttention.
  • Поддержка AMD GPU (MI300, RDNA) через реализацию для fp16.


Актуальные научные подходы

Исследования вокруг FlashAttention продолжаются по нескольким направлениям.

Теоретический анализ IO‑сложности

Работы показывают, что FlashAttention является оптимальным по числу обращений к HBM для широкого диапазона размеров SRAM. Дальнейшие исследования уточняют границы оптимальности для различных конфигураций памяти и типов матриц.

FlashAttention‑4 и новые архитектуры

Уже анонсирована FlashAttention‑4, написанная на CuTe и оптимизированная для Hopper и Blackwell (H100, B200). Ожидается дальнейшее использование аппаратных возможностей новых GPU.

Связь с разреженным вниманием

Блочно‑разреженное расширение FlashAttention (block‑sparse) позволяет работать с разреженными паттернами внимания, достигая ещё большего ускорения. Это направление активно развивается в контексте моделей с длинным контекстом и специализированных архитектур (например, FlashSFA для работы с разреженными перекрытиями).

Интеграция с квантованием и низкоточными вычислениями

FlashAttention‑3 с FP8 демонстрирует, как низкая точность может быть эффективно использована без потери качества. Исследуются также комбинации с 4‑битным квантованием KV‑кэша и другими техниками сжатия.

Заключение

FlashAttention представляет собой существенный шаг в развитии эффективных алгоритмов вычисления механизма внимания, знаменуя сдвиг от парадигмы, ориентированной исключительно на пиковую производительность арифметических устройств, к IO‑aware подходам, учитывающим иерархию памяти современных GPU. Данный алгоритм преодолевает фундаментальное ограничение стандартного внимания — квадратичную сложность по объёму требуемой памяти — и обеспечивает практическую возможность работы с последовательностями длины, ранее недоступной для точных вычислений. Теоретический анализ показывает, что FlashAttention достигает почти оптимального числа обращений к медленной глобальной памяти при заданном объёме быстрой кэш-памяти, что подтверждается эмпирическими результатами на широком спектре моделей.

См. также

Примечания

Литература