TL;DR
대형 언어 모델의 프리트레이닝은 계산 비용이 많이 들며 데이터 처리량이 병렬 처리의 한계에 도달하는 문제를 안고 있다. 본 연구는 모델 아키텍처나 토크나이저를 변경하지 않고, 입력 토큰을 s개로 묶어 처리량을 높이고, 다음 묶음 예측을 다중-핫 크로스 엔트로피로 학습하는 두 단계의 학습 흐름(TST)을 제시한다. 이를 통해 10B Mixture-of-Experts 규모에서도 equal-FLOPs 조건에서 전처리 속도가 크게 개선되며, 총 프리트레이닝 시간은 최대 2.5배까지 절감될 수 있다.
왜 중요한가
대형 언어 모델의 프리트레이닝은 계산 비용이 많이 들며 데이터 처리량이 병렬 처리의 한계에 도달하는 문제를 안고 있다. 본 연구는 모델 아키텍처나 토크나이저를 변경하지 않고, 입력 토큰을 s개로 묶어 처리량을 높이고, 다음 묶음 예측을 다중-핫 크로스 엔트로피로 학습하는 두 단계의 학습 흐름(TST)을 제시한다. 이를 통해 10B Mixture-of-Experts 규모에서도 equal-FLOPs 조건에서 전처리 속도가 크게 개선되며, 총 프리트레이닝 시간은 최대 2.5배까지 절감될 수 있다.
핵심 기여
Two-phase Token Superposition Training(TST)
입력 토큰의 연속 묶음을 평균으로 합성한 s-token embeddings를 사용하고, 출력은 next-bag-of-tokens에 대한 다중-핫 크로스 엔트로피를 적용하는 두 단계 학습 방식으로 프리트레이닝 throughput을 증가시킨다.
Input Superposition: Bag-of-Token-Embeddings
토큰 시퀀스를 비중복 비토큰 길이 l의 연속 bags로 분할하고, 각 bag의 임베딩을 평균하여 s-token 임베딩을 생성한다. 이렇게 하면 per-step 처리 데이터 토큰 수가 증가하고, 동일 FLOPs에서 더 많은 토큰을 학습에 사용한다.
Output Superposition: Multi-hot CE Loss
다수의 타깃 토큰으로 이루어진 bag에 대해 각 토큰의 확률을 1/|y|로 균등하게 분배하는 LMCE를 도입한다. 이는 일반적인 one-hot CE와 다르게, bag의 엔트로피를 보존하기 위해 KL(t || P) 형식으로 수식을 정리한다.
Two-phase Task Alignment
두 학습 phase 간 input embedding과 LM head의 공유를 유지해 representational alignment을 확보한다. phase 간 임의 재초기화(randomization)은 Gains를 완전히 제거한다.
Empirical Validation across Scales
270M, 600M, 3B, 10B MoE(Qwen3 계열) 모델에서 2.5x 속도 증가를 포함한 일관된 이득을 보이며, equal-FLOPs 및 equal-Loss 구간에서 Baseline 대비 성능 향상을 확인했다.
핵심 아이디어 이해하기
출발점: 토큰화와 데이터 처리량은 프리트레이닝 효율의 핵심 변수다. 표준 AR(next-token) 학습은 시퀀스 길이에 따라 매 스텝의 연산이 발생하지만, 이때 입력 표현의 해상도를 낮출수록 FLOPs는 유지한 채 더 많은 토큰 정보를 한 번에 처리할 수 있다. 해결 원리: 입력을 s개의 contiguous 토큰으로 묶어 s-token으로 표현하고, 출력은 next bag-of-tokens로 예측한다. 출력 손실은 다중-핫 크로스 엔트로피(LMCE)로 계산되며, 각 bag의 토큰들에 균등한 기여를 부여한다. 이로써 토큰 간의 상호작용 정보를 보존하면서도 샘플당 처리량을 높여 학습 속도를 증가시킨다. 달라지는 점: TST는 입력 granularity를 고르게 축소하는 input superposition과, 미래 토큰의 요약 표현을 예측하는 output superposition의 두 축을 결합한다. 두 축은 서로 다른 학습 신호를 제공하며, 함께 사용될 때 Baseline 대비 평균적으로 더 나은 downstream 성능을 보인다. 손실의 구성에서 power-law 가중치가 큰 s에서 더 안정적이고 성능이 우수한 경우가 관찰되었다.
방법론
전체 접근 방식: TST는 두 단계로 구성된다. (1) 입력 superposition 단계에서 연속 토큰을 s개 단위로 묶어 bag으로 처리하고, 각 bag의 임베딩을 평균하여 s-token 임베딩을 생성한다. 이때 shape는 B × l × s × V에서 B × l × d로 축소된다. (2) 출력 superposition 단계에서 다음 bag-of-tokens를 예측하기 위해 MCE를 사용한다. 라벨은 y로 정의되며, |y| = s이다. LMCE(z, y) = (1/|y|) Σ_{y∈y} LCE(z, y). 라벨의 절대 위치는 causality를 유지하기 위해 좌로 s−1만큼 시프트한다. 회복(recovery) 단계에서 저장된 체크포인트로 표준 토큰 예측으로 복귀한다. 회복 phase에서는 TST 코드가 제거되어 실험의 오염을 방지한다. 구현 세부사항: 학습은 AdamW, warmup-stable-decay 스케줄러를 사용하며 초기 warmup는 2000 steps, 마지막 10% steps 동안 감쇠한다. 시퀀스 길이 L은 s로 나눠진 형태로 처리되며, per-step FLOPs는 baseline과 동일하게 유지된다. 토큰 묶음의 수를 조정하는 superposition bag size s와 ratio r(전체 step 중 superposition 단계의 비율)을 통해 Loss-Throughput trade-off를 조정한다. 대규모 학습에서 64 GPU(B200)와 TorchTitan, FSDP를 사용하고, 작은 모델은 8 GPU에서 수행했다. 데이터셋은 DCLM을 주 데이터로 사용하고, SmolLM 계열의 설정을 차용하되 Embedding은 untyed 형태로 사용하여 모델의 파라미터 수를 증가시켰다. 실험은 270M/600M/3B/10B 파라미터 모델에서 수행되었고, 10B MoE 구성은 1.05T tokens까지 학습했다. 평가로는 ARC, BoolQ, Hellaswag, MMLU, OpenBookQA, PIQA, Winogrande를 Eleuther AI LM-Eval harness로 0-shot prompting 방식으로 진행했다.
주요 결과
주요 벤치마크에서 TST는 Baseline 대비 일관된 개선을 보였다. 3B 모델의 경우 20000 스텝에서 Final Loss가 Baseline 2.808에서 2.676으로 감소했고, HellaSwag/ARC-E/ARC-C/MMLU 지표가 각각 향상됐다. 10B MoE(1B 활성 파라미터, A1B 혼합)에서는 Final Loss가 2.252에서 2.236으로 감소했고, HellaSwag, ARC-E, ARC-C, MMLU에서 모두 상승했다. 270M의 경우 Equal-FLOPs 조건에서 Final Loss가 3.212에서 3.142로 감소했고, Equal-Loss 조건에서도 3.092에서 3.048로 감소했다. 더불어 10B A1B 스케일에서 2.5x의 총 프리트레이닝 시간 절감이 관찰되었으며, 동일한 손실에서 속도-향상이 확인되었다. ablation에서 Input+Output Full Superposition이 단독 구성보다 더 높은 성능을 보여 주된 기여를 확인했다. 또한 s=8 이상에서 Power-Law 가중치가 더 안정적이고 성능이 우수했다. Phase 간 representation alignment가 Gains의 핵심 요인임이 Table 2의 Randomization 실험으로 뒷받침된다.
기술 상세
아키텍처 구성: Token Superposition Training은 입력-표현의 coarse-to-fine 그레이데이션과 출력-예측의 다중-토큰 표적을 결합한다. 입력은 B×L×V에서 B×l×s×V로 확장되며, 임베딩 레이어에서 s-token 임베딩은 토큰 임베딩들의 평균으로 계산된다. h = (sum_{i=1}^{s} tok_embeddings(token_i)) / s로 정규화한다. 따라서 한 스텝당 처리되는 tokens의 수가 증가하되 per-step FLOPs는 baseline과 같다. 출력은 Next Bag-of-Tokens Prediction으로 바뀌며, 다중-핫 크로스 엔트로피(LMCE) 손실을 사용한다. LMCE(z,y) = (1/|y|) Σ_{y∈y} LCE(z,y)로 정의되며, y는 s개의 유효 타깃 토큰으로 구성된다. 좌측으로 s−1만큼 라벨을 시프트해 causality를 보존한다. 회복 phase에서는 TST 코드가 제거되고 표준 AR 학습으로 재개한다. 초깃값은 AdamW 최적학습률로 sweep를 수행하되, 270M/600M 모델은 학습률 스윕으로 최적값을 찾아 적용하고 3B/10B에 대해서는 권장 학습률(2e-4, 3e-4)을 사용한다. Warmup은 2000 steps, 마지막 10% step에서 decay를 적용한다. 데이터 파이프라인은 TorchTitan과 FSDP를 이용하고, 대형 모델은 64GPU, 소형은 8GPU에서 실행한다. 회귀 테스트로 0-shot 평가를 Eleuther AI LM-Eval harness를 사용한다. 밀집 Baseline 대비 TST의 성능은지표별로 향상되며, equal-FLOPs/ equal-Loss 구간에서 측정되었다.
한계점
초기 시나리오에서 데이터 소비를 더 늘려 학습을 수행하므로 compute-bound 제약이 기본 가정이다. 데이터-제약이 증가하는 시나리오에서 Output-Only Superposition 등의 대안 설정이 더 이점일 수 있다. 거시 규모(ablation 필요)이나 다중 identical run의 통계적 유의성은 추가 연구가 필요하다. 또한 코드/모듈의 공개 여부는 불확실하다.
실무 활용
고정된 모델 아키텍처와 토크나이저를 유지하면서 프리트레이닝 속도를 높이려는 상황에서 적용 가능하다. 두 가지 초점은 입력의 해상도 감소에 따른 처리량 증가와 출력 예측의 다중-타깃 신호로의 학습이다. recovery 단계에서 원래 토큰 예측으로 되돌아와 성능을 유지한다.
- Compute-bound LLM pretraining에 적용하여 데이터 처리량을 증가시키고 학습 시간 절감을 달성
- MoE 기반 대규모 모델의 프리트레이닝 효율성 개선
- HPC 자원 활용 효율화 및 에너지 소비 감소에 기여
코드 공개 여부: 미확인
키워드
코드 예제
Listing 1: Input folding in Pytorch
if superposition_bag_size is not None and superposition_bag_size > 1:
bs , seq = inputs .shape
inputs = inputs .reshape (bs , seq // superposition_bag_size ,
superposition_bag_size )
입력 폴딩은 PyTorch에서 bag-size로 입력 시퀀스를 재구성하는 예시 코드이다.
Listing 2: Bag-of-Token embeddings input in PyTorch
# Sum in float32 for better numerical precision
h = self.tok_embeddings (tokens[..., 0])
h = (h + self.tok_embeddings (tokens[..., 1]).float()) / superposition_bag_sizes-토큰 임베딩은 각 토큰의 임베딩을 합산해 하나의 임베딩으로 만든다.
Listing 3: Next bag-of-words prediction loss code in PyTorch
def cross_entropy_loss(pred, labels):
# pred: (bs, seq, dim), labels: (bs, seq)
loss = 0.
for i in range(superposition_bag_size):
target = labels[..., i].flatten(0,1)
loss += torch.nn.functional.cross_entropy(pred.flatten(0,1), target)
return loss / superposition_bag_size
다중-핫 손실(L_MCE) 계산 예시로, 각 bag의 각 토큰에 대한 크로스 엔트로피를 합산해 평균한다.
AI 요약 · 북마크 · 개인 피드 설정 — 무료
출처 · 인용 안내
인용 시 "요약 출처: AI Trends (aitrends.kr)"를 표기하고, 사실 확인은 원문 보기 기준으로 진행해 주세요. 자세한 기준은 운영 정책을 참고해 주세요.