핵심 요약
LLM 추론 가속의 핵심인 Speculative Decoding에서 보조 모델의 성능을 결정짓는 '수락률'을 직접 최적화하는 새로운 학습 방식을 제시했다. 기존 KL Divergence 방식의 한계를 수학적으로 증명하고, 추가 연산 비용 없이 다양한 모델 아키텍처에서 실질적인 속도 향상을 달성했다는 점에서 실무적 가치가 높다.
왜 중요한가
LLM 추론 가속의 핵심인 Speculative Decoding에서 보조 모델의 성능을 결정짓는 '수락률'을 직접 최적화하는 새로운 학습 방식을 제시했다. 기존 KL Divergence 방식의 한계를 수학적으로 증명하고, 추가 연산 비용 없이 다양한 모델 아키텍처에서 실질적인 속도 향상을 달성했다는 점에서 실무적 가치가 높다.
핵심 기여
LK 손실 함수(LK Losses) 제안
추측 제어 디코딩의 핵심 지표인 수락률을 직접 타겟팅하는 로그 수락률(Log-acceptance) 및 하이브리드 손실 함수를 개발했다.
적응형 블렌딩 스케줄링 도입
학습 초기에는 KL Divergence로 안정성을 확보하고, 수락률이 개선됨에 따라 TV(Total Variation) 거리 최적화로 부드럽게 전환하는 스케줄링 기법을 적용했다.
광범위한 모델 및 아키텍처 검증
Llama-3.1, Qwen3, DeepSeek-V3 등 8B에서 685B 규모의 모델과 EAGLE-3, MEDUSA 등 다양한 보조 모델 구조에서 일관된 성능 향상을 입증했다.
핵심 아이디어 이해하기
Transformer 기반 LLM 추론은 메모리 대역폭 제한으로 인해 한 번에 하나의 토큰만 생성하는 과정에서 자원이 낭비된다. Speculative Decoding은 작은 모델이 미리 여러 토큰을 추측하고 큰 모델이 이를 한 번에 검증하여 이 문제를 해결한다. 기존에는 작은 모델이 큰 모델의 확률 분포를 그대로 흉내 내도록 KL Divergence를 최소화하는 방식을 썼지만, 작은 모델은 용량 한계로 인해 큰 모델을 완벽히 복제할 수 없다.
이 논문의 핵심 아이디어는 작은 모델이 큰 모델의 모든 분포를 맞추려 애쓰는 대신, 큰 모델이 실제로 수락할 가능성이 높은 '정답' 토큰들에만 집중하게 만드는 것이다. 이를 위해 수학적으로 수락률과 직접 연결된 TV(Total Variation) 거리를 손실 함수에 도입했다.
학습 초기에는 TV 거리의 그래디언트(Gradient)가 너무 작아 학습이 안 되는 문제가 발생하는데, 이를 해결하기 위해 수락률의 역수를 가중치로 곱해 신호를 증폭시키거나 KL Divergence와 섞어서 사용하는 전략을 취했다. 결과적으로 보조 모델은 자신이 가진 적은 파라미터로도 타겟 모델이 수락할 확률이 가장 높은 토큰을 더 정확하게 골라내게 된다.
방법론
수락률 를 직접 최적화하기 위해 두 가지 접근 방식을 사용한다. 첫 번째는 로 정의되는 로그 수락률 손실이다. 타겟 확률 와 보조 모델 확률 의 중첩 영역인 를 입력으로 받아 음의 로그를 계산하여 수락률을 높인다. 이 방식은 가 작을 때 그래디언트를 배로 증폭시켜 초기 학습의 수렴 속도를 높이는 효과가 있다.
두 번째는 하이브리드 목적 함수 이다. 여기서 식을 통해 현재 수락률 를 입력으로 받아 가중치 를 산출한다. 수락률이 낮을 때는 가 1에 가까워져 KL Divergence 위주로 학습하고, 수락률이 높아지면 TV 거리 최적화 비중을 높여 실제 성능을 극대화한다.
학습 시에는 660K개의 프롬프트를 사용하여 타겟 모델의 응답 분포를 먼저 생성하고, 이를 보조 모델이 학습하는 지식 증류(Knowledge Distillation) 구조를 취한다. 특히 EAGLE-3와 같은 최신 아키텍처에서 어휘 사전 절단(Vocabulary Truncation)이 발생해도 LK 손실은 KL과 달리 무한대 값이 발생하지 않아 안정적인 학습이 가능하다.
주요 결과
Llama-3.1-8B 모델을 타겟으로 한 실험에서, EAGLE-3 보조 모델에 LK 손실을 적용했을 때 기존 KL 방식 대비 평균 수락 길이()가 MT-bench 기준 3.39에서 3.48로 향상되었다. 특히 모델 용량이 매우 작은 MLP Speculator 구조에서는 수락률이 8.3%까지 개선되는 등 저용량 모델에서 더 큰 효과를 보였다.
DeepSeek-V3(685B) 모델의 기본 MTP(Multi-Token Prediction) 모듈을 LK 손실로 파인튜닝한 결과, 기존 KL 방식 대비 수락률이 추가로 5.6% 향상되었다. 이는 거대 모델에서도 직접적인 수락률 최적화가 유효함을 입증하는 결과이다.
다양한 도메인(일반 대화, 코딩, 수학) 벤치마크에서 일관된 성능 향상을 보였으며, 특히 수락 길이가 길어질수록(K가 커질수록) 기존 방식과의 격차가 더 벌어지는 경향을 확인했다. 이는 LK 손실이 장기적인 토큰 시퀀스 예측 정확도를 높이는 데 기여함을 시사한다.
기술 상세
보조 모델 가 타겟 모델 를 모사할 때, 수락률 는 수학적으로 와 동일하다. 의 그래디언트 분석 결과 임이 도출되었으며, 이는 수락률이 낮을 때 TV 거리의 그래디언트를 자동으로 증폭시켜 학습 소실 문제를 해결함을 의미한다.
KL Divergence는 분포의 모든 토큰에 가중치를 두어 작은 모델의 용량을 낭비하게 만들지만, TV 거리는 확률 질량이 큰 주요 토큰들에만 집중하게 만든다. 실험 결과, 보조 모델의 파라미터 수가 타겟 모델의 1~5% 수준으로 극히 제한적인 상황에서 이러한 선택적 최적화가 성능 향상의 핵심 요인으로 작용했다.
구현 측면에서 EAGLE-3 아키텍처를 사용할 경우, 타겟 모델의 중간 레이어 히든 스테이트를 입력으로 받아 단일 Transformer 레이어를 통해 예측을 수행한다. 이때 LK 손실은 타겟 모델의 로짓(Logit) 분포를 그대로 사용하면서도 보조 모델의 예측 정확도를 직접적으로 높이는 Trust-region 최적화와 유사한 효과를 낸다.
한계점
보조 모델의 초기 분포가 타겟 모델과 너무 동떨어져 수락률이 극도로 낮은 경우, TV 거리 기반의 그래디언트 방향이 부정확할 수 있다. 따라서 학습 초기에는 KL Divergence를 병행하는 하이브리드 방식이 필수적이다.
실무 활용
이 연구에서 제안한 LK 손실 함수는 기존의 추측 제어 디코딩 학습 파이프라인에 코드 몇 줄만 수정하여 즉시 적용할 수 있는 높은 실무 범용성을 갖추고 있다.
- vLLM, SGLang 등 추론 엔진의 Speculative Decoding 보조 모델 학습 및 최적화
- 모바일이나 엣지 디바이스에서 실행되는 경량 LLM의 응답 속도 가속
- 실시간 고객 상담 에이전트 등 지연 시간(Latency)이 중요한 AI 서비스의 처리량 개선
- DeepSeek-V3와 같이 MTP 모듈을 내장한 모델의 추론 효율 극대화
코드 공개 여부: 공개
코드 저장소 보기키워드
AI 요약 · 북마크 · 개인 피드 설정 — 무료
출처 · 인용 안내
인용 시 "요약 출처: AI Trends (aitrends.kr)"를 표기하고, 사실 확인은 원문 보기 기준으로 진행해 주세요. 자세한 기준은 운영 정책을 참고해 주세요.