핵심 요약
Flash Attention은 데이터를 작은 블록으로 나누는 Tiling과 Online Softmax 기법을 통해 SRAM 내에서 연산을 완결함으로써 HBM 접근 횟수를 줄이고 연산 속도를 가속화한다.
배경
Transformer 모델의 핵심인 Attention 연산은 시퀀스 길이에 따라 메모리 사용량이 제곱으로 증가하며 GPU 메모리 I/O 병목 현상을 겪는다.
대상 독자
딥러닝 모델 최적화와 GPU 내부 동작 원리에 관심 있는 개발자 및 연구자
의미 / 영향
Flash Attention의 도입으로 Transformer 모델의 시퀀스 길이 확장이 가능해졌으며 동일 하드웨어에서 더 큰 모델을 더 빠르게 학습시킬 수 있게 되었다. 이는 긴 문맥을 처리해야 하는 최신 LLM의 성능 향상에 결정적인 역할을 한다.
챕터별 상세
Attention 연산의 기본 과정
- •Query-Key 행렬 곱 후 Softmax 적용
- •수치 안정성을 위한 최댓값 차감 기법
- •Value와의 가중합을 통한 최종 출력 산출
표준적인 Attention 연산은 중간 결과물인 Attention Score 행렬을 메모리에 저장해야 하므로 시퀀스 길이에 비례해 메모리 점유율이 급증한다.
GPU 메모리 구조와 I/O 병목
- •HBM(느린 대용량)과 SRAM(빠른 소용량)의 계층 구조
- •중간 결과의 빈번한 HBM 읽기/쓰기로 인한 병목
- •연산 속도보다 메모리 대역폭이 성능을 제한하는 상황
메모리 대역폭 제한(Memory-bound) 문제는 현대 GPU 연산에서 연산 능력 자체보다 더 큰 성능 저하 요인이 된다.
Flash Attention의 핵심 아이디어: Tiling
- •데이터를 블록 단위로 나누는 Tiling 기법 적용
- •SRAM 내에서 연산 전체를 융합하여 수행
- •HBM 접근 횟수 최소화를 통한 속도 향상
Tiling은 커널 퓨전(Kernel Fusion)의 일종으로 여러 연산을 하나의 GPU 커널로 합쳐 메모리 접근을 줄이는 기법이다.
Online Softmax를 통한 블록 단위 계산
- •전체 데이터 없이 부분 합과 최댓값을 갱신하는 Online Softmax
- •지수 함수 스케일링을 통한 점진적 결과 보정
- •메모리 효율성과 수학적 정확성을 동시에 확보
Online Softmax는 수학적으로 표준 Softmax와 동일한 결과를 내면서도 메모리 효율성을 확보하는 핵심 알고리즘이다.
Flash Attention의 전체 알고리즘 구조
- •Query와 Key-Value 조각을 순회하는 이중 루프 구조
- •SRAM 내 누적 연산을 통한 최종 출력 행렬 완성
- •논문 이론과 실제 코드 구현의 구조적 특징
이중 루프 구조는 GPU의 병렬 처리 특성을 활용하면서도 메모리 계층 구조를 최적으로 이용하도록 설계되었다.
실무 Takeaway
- Attention 연산 시 중간 결과물을 HBM에 저장하지 않고 SRAM에서 즉시 처리하여 I/O 병목을 제거해야 한다.
- Online Softmax 기법을 활용하면 전체 시퀀스 통계량 없이도 블록 단위로 정확한 가중합을 계산할 수 있다.
- 모델 학습 및 추론 속도를 높이기 위해 Flash Attention 2와 같은 최적화 라이브러리를 적극적으로 활용하는 것이 실무적으로 중요하다.
AI 요약 · 북마크 · 개인 피드 설정 — 무료