-
논문 리뷰) LLM2Vec: Large Language Models Are Secretly Powerful Text EncodersAI 논문 리뷰 2024. 9. 24. 22:28
본 글은 [Open-Up] 오픈소스 소프트웨어 통합지원센터로부터 지원받아 작성하였습니다.
[총평]
- 24년 4월에 axriv에 등록된, LLM 모델을 임베딩 모델로 변환하는 방법론 제안
- Last Hidden State를 Mean Pooling한 Representation이 좋은 임베딩 표현력을 갖추도록 크게 3가지 기법을 적용.
- Bi-directional attention, Masked next token prediction(MNTP), SimCSE
- 실험 당시에는 SOTA였을지 모르나, 24년 4월 기준으로 MTEB 9위 달성 (24년 9월 기준으로 MTEB 25위)
- 전형적인 LLM 기반 임베딩 모델답게 4096 차원의 높은 차원수가 아쉬움
※ 본 논문을 읽기전에, 먼저 아래 두 논문을 읽어보는 것이 좋다.
- SimCSE 논문 리뷰 링크
- Improving Text Embeddings with Large Language Model 논문 리뷰 링크
[주요 학습 포인트]
1. Bi-directional attention
- 디코더 전용 LLM의 casual attention mask를 전부 1로 구성된 행렬로 대체. 즉, self-attention으로 변경
- 각 토큰이 시퀀스 내의 다른 모든 토큰에 접근할 수 있게 되어 양방향 LLM으로 변환하고자 의도
- 하지만, LLM은 uni-directional 혹은 Masked Attention을 적용하여, 학습되었으므로 이를 위한 ‘조정’이 필요함.
=> next token prediction(MNTP)를 제안더보기※LLM2VEC 모듈에서는 어떻게 bi-directional 파라미터를 처리할까?
출처 : https://github.com/McGill-NLP/llm2vec/blob/main/llm2vec/llm2vec.py
@classmethod
def _get_model_class(cls, config_class_name, enable_bidirectional):
if not enable_bidirectional:
return AutoModel
if config_class_name == "MistralConfig":
return MistralBiModel
elif config_class_name == "LlamaConfig":
return LlamaBiModel
elif config_class_name == "GemmaConfig":
return GemmaBiModel
elif config_class_name == "Qwen2Config":
return Qwen2BiModel
else:
raise ValueError(
f"{config_class_name} is not supported yet with bidirectional models."
enable_bidirectional이 False인 경우, AutoModel을 반환시킴.
즉, causal attention만을 사용하는 모델을 선택하게 됨.
반면에, enable_bidirectional이 True인 경우, 각기 다른 모델 설정(config_class_name)에 맞춰 parameter를 지정함.
(이때, Config 세팅에는 bidirectional : True로 지정되어있음)더보기출처 : https://github.com/McGill-NLP/llm2vec/blob/main/train_configs/simcse/Mistral.json
MistralConfig 세팅값 확인
{
"model_name_or_path": "mistralai/Mistral-7B-Instruct-v0.2",
"peft_model_name_or_path": "McGill-NLP/LLM2Vec-Mistral-7B-Instruct-v2-mntp",
"simcse_dropout": 0.3,
"bidirectional": true,
"pooling_mode": "mean",
"dataset_name": "Wiki1M",
"dataset_file_path": "cache/wiki1m_for_simcse.txt",
"remove_unused_columns": false,
"learning_rate": 3e-5,
"loss_scale": 20,
"per_device_train_batch_size": 128,
"gradient_accumulation_steps": 1,
"do_train": true,
"disable_tqdm": false,
"max_seq_length": 128,
"overwrite_output_dir": true,
"output_dir": "output/mntp-simcse/Mistral-7B-Instruct-v0.2",
"logging_steps": 50,
"save_steps": 200,
"save_only_model": true,
"stop_after_n_steps": 1000,
"lora_r": 16,
"gradient_checkpointing": true,
"torch_dtype": "bfloat16",
"attn_implementation": "flash_attention_2",
"seed": 42
}2. Mask next token prediction(MNTP)
- MNTP는 Next Token Prediction(NTP)와 Masked Language Modeling(MLM)을 결합한 방식
- 모델이 과거뿐만 아니라, 미래 context를 기반으로 Masked Token을 예측하도록 훈련시키고자 의도.
- Bi-Directional Attention에 적응시키기 위한 장치
- 중요한 것은, Figure 1에서 위치 iii의 마스킹된 토큰을 예측할 때, 마스킹된 위치에서의 Logit(파랑색 원)이 아닌
이전 위치 ‘ii’에서 얻은 Logit(노랑색 원)을 기반으로 Loss를 계산한다는 점
- 이는 decoder-only LLM의 pretraing 방식(NTP)과의 align을 위해, masked token 위치의 representation 값을 쓰지 않고,
그 직전 token의 representation으로 masked token 맞추도록한 것.3. Unsupervised contrastive learning : SimCSE
- 이전 두 단계는 디코더 전용 LLM을 단어 수준 작업을 위한 인코더로 변환할 수 있지만, 시퀀스 표현에 대해서는 충분하지 않을 수 있음.
- LLM 모델은 전체 시퀀스의 맥락을 캡처하도록 explicitly하게 훈련되지 않았음.
- 이 격차를 메우기 위해 SimCSE(Gao et al., 2021)를 통해 비지도 대조 학습을 적용함.
- SimCSE는 drop-out을 활용하여 하나의 문장을 Positive Pair(유사 문장)로 구성함,
그리고 학습 단위인 Batch 내에서 다른 문장은 Negative로 취급한 것이 핵심 아이디어
- SimCSE는 drop-out을 활용해, 동일 문장을 2번 Encoder에 입력하는 것이 마치 유사한 문장 2개가 입력되는 것처럼 활용.[실제 모델 변환(LLM to Vec)을 위한 구체적 방법과 Spec]
1. 대상 모델 (1.3B ~ 8B의 LLM)
- Sheared-LLaMA-1.3B (S-LLaMA-1.3B)
- Llama-2-7B-chat (LLaMA-2-7B)
- Mistral-7B-Instruct-v0.2 (Mistral-7B)
- 추가적으로, Meta-Llama-3-8B-Instruct (Meta-LLaMA-3-8B)
2. 학습 데이터
- MNTP 단계: Wikitext-103 데이터셋
- SimCSE 단계: Gao et al. (2021)이 공개한 위키피디아 문장 서브셋
- MNTP 수행시, 파인튜닝을 위해서 Lora 방식을 적용함
- 1000 step, 32 batch로 1개의 A100(80GB)가지고 학습
=> 7B/8B 모델에 대해 100분 밖에 안걸림
3. Unsupervised contrastive learning : SimCSE
- SimCSE 학습 이전에, MNTP 모델로 학습한 Lora adaptor 가중치를 그냥 기본모델로 통합시킴.
- 그리고 새로운 Lora 파라미터로 SimCSE를 학습 진행함.
- 1000 step, 128 batch로 A100(80GB) 1개로 학습
=> 7B/8B 모델에 대해 180분 밖에 안걸림[Evaluation on Unsupervised Learning]
3.2 Evaluation on sequence-level tasks
<Setup>
- MTEB로 평가
- E5-7B-instruct 임베딩 모델에 제안된 방법(Improving Text Embeddings with Large Language Model)을 따라,task-specific instruct를 수행
- E5-7B-instruct와 동일한 Instruction을 사용 (Table 10 참고, E5-7B-instruct과 동일)
- query 앞에 instruction을 붙여서 활용
- mean pooling시에는, instruction token을 제외하고 수행<Result1>
Figure 3에서 확인된 결과는 아래와 같다.
- 학습없이 Bi-directional만 적용하면 성능이 당연히 떨어지지만 (LLM이 학습했던 방법이 아니므로)MNTP까지 추가되면 크게 좋아지는 것 확인됨
- EOS token에 대한 Last-hidden-state를 활용하는 것은 성능이 떨어짐
- weighted mean pooling도 좋으나, mean pooling이 제일 성능이 좋은 것 확인.
- Mistral-7B의 경우, 학습없이 Bi-directional만 적용해도 성능이 증가가 확인됨.
- 이후에, 나오지만 LLM2Vec 방식이 LLM에 영향을 주는 것에 대한 실험중 Mistral-7B의 경우,Bi-directional과 같은 CLM 이외의 학습 방법이 있을것으로 예상됨.
<Result2>
Table 1에서 확인된 결과는 아래와 같다.
- 대부분의 경우에서 MNTP와 SimCSE를 적용한 것이 효과가 좋았음
- 전반적인 성능이 크게 향상됨 확인[Evaluation on Supervised Learning]
5.1 LLM2Vec leads to strong performance on the MTEB leaderboard
<Setup>
- Lora 활용하여 MNTP / SimCSE 학습
- Public Data로 학습(E5 dataset, 150만개 샘플)
- Hard Negative와 in-batch negative로 contrastive learning 실시
- 512 batch, 1000 step 수행
<Result>
Table2에서 확인된 결과는 아래와 같다.
- 예상되었듯, SimCSE 효과 적음
- SimCSE 없이 MNTP만 수행한 경우가 더 성능 좋은 경우가 많음
- 하지만, 샘플 대비 성능 효율성에서 SimCSE가 좋은 역할을 한다 (다음 5.2절)5.2 LLM2Vec leads to more sample-efficient training
- LLM2VEC 방법론을 쓴 경우, 빠르게 높은 성능으로 수렴하는 모습을 볼수 있다.
- MNTP에 SimCSE까지 수행한 경우가, 가장 빠르게 높은 성능으로 수렴한다.
=> 즉, data sample-efficient하다.
- 지도 학습에서 최종 성능은 MNTP만 하는 것이 좋았더라도,
SimCSE까지 수행하는 것을 고려해봐야하는 이유다.[ LLM에서 벡터를 취하는 위치 ]
본 논문에서는 어디서 어떻게 임베딩 벡터를 취했다는 확실한 워딩이 없다( 못찾은 걸수도...)
코드를 찾아보니, E5-7B instruct와 동일하게 last_hidden_state에서 임베딩 벡터를 취했다.# 출처 : https://github.com/McGill-NLP/llm2vec/blob/main/llm2vec/llm2vec.py class LLM2Vec(nn.Module): .... def get_pooling(self, features, last_hidden_states): # All models padded from left assert ( self.tokenizer.padding_side == "left" ), "Pooling modes are implemented for padding from left." if self.skip_instruction: self._skip_instruction(features) seq_lengths = features["attention_mask"].sum(dim=-1) if self.pooling_mode == "mean": return torch.stack( [ last_hidden_states[i, -length:, :].mean(dim=0) for i, length in enumerate(seq_lengths) ], dim=0, ) elif self.pooling_mode == "weighted_mean": bs, l, _ = last_hidden_states.shape complete_weights = torch.zeros(bs, l, device=last_hidden_states.device) for i, seq_l in enumerate(seq_lengths): if seq_l > 0: complete_weights[i, -seq_l:] = torch.arange(seq_l) + 1 complete_weights[i] /= torch.clamp( complete_weights[i].sum(), min=1e-9 ) return torch.sum(last_hidden_states * complete_weights.unsqueeze(-1), dim=1) elif self.pooling_mode == "eos_token" or self.pooling_mode == "last_token": return last_hidden_states[:, -1] elif self.pooling_mode == "bos_token": return last_hidden_states[ features["input_ids"] == self.tokenizer.bos_token_id ]
아래는 Mistral 7B 모델을 Huggingface에서 지원하는 AutoModelForCausalLM을 통해 다운로드 받아,
모델 구조를 확인한 모습이다.
Mistral 7B 기준으로, LLM 모델에서 어느 위치로 벡터를 뽑는지 확인해보자.
MistralForCausalLM(
(model): MistralModel(
(embed_tokens): Embedding(32000, 4096)
(layers): ModuleList(
(0-31): 32 x MistralDecoderLayer(
(self_attn): MistralSdpaAttention(
(q_proj): Linear(in_features=4096, out_features=4096, bias=False)
(k_proj): Linear(in_features=4096, out_features=1024, bias=False)
(v_proj): Linear(in_features=4096, out_features=1024, bias=False)
(o_proj): Linear(in_features=4096, out_features=4096, bias=False)
(rotary_emb): MistralRotaryEmbedding()
)
(mlp): MistralMLP(
(gate_proj): Linear(in_features=4096, out_features=14336, bias=False)
(up_proj): Linear(in_features=4096, out_features=14336, bias=False)
(down_proj): Linear(in_features=14336, out_features=4096, bias=False)
(act_fn): SiLU()
)
(input_layernorm): MistralRMSNorm()
(post_attention_layernorm): MistralRMSNorm()
)
)
(norm): MistralRMSNorm() => 임베딩 벡터를 뽑는 위치해당 위치에서 나오는 출력을 last_hidden_state라고 한다.
(Last_hidden_state는 마지막 트랜스포머 층(layer)의 출력)
10개의 token이 Embedding_size(벡터 차원)가 4096인 모델에 입력된다고 해보자.
해당 위치에서 출력되는 값의 형태는 (batch_size, 10, 4096)이 된다.
Last_hidden_state 중 [EOS] token에 대한 Embedding을 취하고 싶다면,
=> Vector[ :, -1, : ]
Last_hidden_state의 모든 토큰에 대해 mean pooling 적용하고 싶다면
=> 10개 Vector 값에 대해 평균을 취함
어떤 경우든, 출력 사이즈는 (batch_size, 4096)이 됨.
)
(lm_head): Linear(in_features=4096, out_features=32000, bias=False) => LLM2VEC을 수행하기 위해서는해당 부분을 제거 or 제거된 상태로 호출
)자세한건 Last_hidden_state와 Logit 글에서 확인
728x90'AI 논문 리뷰' 카테고리의 다른 글