ABOUT ME

-

Today
-
Yesterday
-
Total
-
  • 논문 리뷰) LLM2Vec: Large Language Models Are Secretly Powerful Text Encoders
    AI 논문 리뷰 2024. 9. 24. 22:28

    본 글은 [Open-Up] 오픈소스 소프트웨어 통합지원센터로부터 지원받아 작성하였습니다.

    [총평]

    - 24년 4월에 axriv에 등록된, LLM 모델을 임베딩 모델로 변환하는 방법론 제안
    - Last Hidden State를 Mean Pooling한 Representation이 좋은 임베딩 표현력을 갖추도록 크게 3가지 기법을 적용.
    - Bi-directional attention, Masked next token prediction(MNTP), SimCSE
    - 실험 당시에는 SOTA였을지 모르나, 24년 4월 기준으로 MTEB 9위 달성 (24년 9월 기준으로 MTEB 25위) 
    - 전형적인 LLM 기반 임베딩 모델답게 4096 차원의 높은 차원수가 아쉬움

    ※ 본 논문을 읽기전에, 먼저 아래 두 논문을 읽어보는 것이 좋다. 
    SimCSE 논문 리뷰 링크 
    - Improving Text Embeddings with Large Language Model 논문 리뷰 링크

    - Last_hidden_state와 Logit


    [주요 학습 포인트]


     1. Bi-directional attention 

    - 디코더 전용 LLM의 casual attention mask를 전부 1로 구성된 행렬로 대체. 즉, self-attention으로 변경
    - 각 토큰이 시퀀스 내의 다른 모든 토큰에 접근할 수 있게 되어 양방향 LLM으로 변환하고자 의도
    - 하지만, LLM은 uni-directional 혹은 Masked Attention을 적용하여, 학습되었으므로 이를 위한 ‘조정’이 필요함.  
       =>  next token prediction(MNTP)를 제안

    더보기

    ※LLM2VEC 모듈에서는 어떻게 bi-directional 파라미터를 처리할까?

    출처 : https://github.com/McGill-NLP/llm2vec/blob/main/llm2vec/llm2vec.py
     

       @classmethod
        def _get_model_class(cls, config_class_name, enable_bidirectional):
            if not enable_bidirectional:
                return AutoModel
            if config_class_name == "MistralConfig":
                return MistralBiModel
            elif config_class_name == "LlamaConfig":
                return LlamaBiModel
            elif config_class_name == "GemmaConfig":
                return GemmaBiModel
            elif config_class_name == "Qwen2Config":
                return Qwen2BiModel
            else:
                raise ValueError(
                    f"{config_class_name} is not supported yet with bidirectional models."


    enable_bidirectional이 False인 경우, AutoModel을 반환시킴.
    즉, causal attention만을 사용하는 모델을 선택하게 됨.


    반면에, enable_bidirectional이 True인 경우, 각기 다른 모델 설정(config_class_name)에 맞춰 parameter를 지정함. 
    (이때, Config 세팅에는 bidirectional : True로 지정되어있음)

     

    더보기

    출처 : https://github.com/McGill-NLP/llm2vec/blob/main/train_configs/simcse/Mistral.json
    MistralConfig 세팅값 확인
    {
        "model_name_or_path": "mistralai/Mistral-7B-Instruct-v0.2",
        "peft_model_name_or_path": "McGill-NLP/LLM2Vec-Mistral-7B-Instruct-v2-mntp",
        "simcse_dropout": 0.3,
        "bidirectional": true,
        "pooling_mode": "mean",
        "dataset_name": "Wiki1M",
        "dataset_file_path": "cache/wiki1m_for_simcse.txt",
        "remove_unused_columns": false,
        "learning_rate": 3e-5,
        "loss_scale": 20,
        "per_device_train_batch_size": 128,
        "gradient_accumulation_steps": 1,
        "do_train": true,
        "disable_tqdm": false,
        "max_seq_length": 128,
        "overwrite_output_dir": true,
        "output_dir": "output/mntp-simcse/Mistral-7B-Instruct-v0.2",
        "logging_steps": 50,
        "save_steps": 200,
        "save_only_model": true,
        "stop_after_n_steps": 1000,
        "lora_r": 16,
        "gradient_checkpointing": true,
        "torch_dtype": "bfloat16",
        "attn_implementation": "flash_attention_2",
        "seed": 42
    }

     

     

    2. Mask next token prediction(MNTP) 


    - MNTP는 Next Token Prediction(NTP)와 Masked Language Modeling(MLM)을 결합한 방식
    - 모델이 과거뿐만 아니라, 미래 context를 기반으로 Masked Token을 예측하도록 훈련시키고자 의도.
    - Bi-Directional Attention에 적응시키기 위한 장치
    - 중요한 것은, Figure 1에서 위치 iii의 마스킹된 토큰을 예측할 때, 마스킹된 위치에서의 Logit(파랑색 원)이 아닌 
      이전 위치 ‘ii’에서 얻은 Logit(노랑색 원)을 기반으로 Loss를 계산한다는 점
    - 이는 decoder-only LLM의 pretraing 방식(NTP)과의 align을 위해, masked token 위치의 representation 값을 쓰지 않고, 
      그 직전 token의 representation으로 masked token 맞추도록한 것.

     

     

    3. Unsupervised contrastive learning : SimCSE 

    - 이전 두 단계는 디코더 전용 LLM을 단어 수준 작업을 위한 인코더로 변환할 수 있지만, 시퀀스 표현에 대해서는 충분하지 않을 수 있음.
    - LLM 모델은 전체 시퀀스의 맥락을 캡처하도록 explicitly하게 훈련되지 않았음.
    - 이 격차를 메우기 위해 SimCSE(Gao et al., 2021)를 통해 비지도 대조 학습을 적용함.
    - SimCSE는 drop-out을 활용하여 하나의 문장을 Positive Pair(유사 문장)로 구성함, 
      그리고 학습 단위인 Batch 내에서 다른 문장은 Negative로 취급한 것이 핵심 아이디어
    - SimCSE는 drop-out을 활용해, 동일 문장을 2번 Encoder에 입력하는 것이 마치 유사한 문장 2개가 입력되는 것처럼 활용. 

     

     

     

    [실제 모델 변환(LLM to Vec)을 위한 구체적 방법과 Spec]

    1. 대상 모델 (1.3B ~ 8B의 LLM)
     - Sheared-LLaMA-1.3B (S-LLaMA-1.3B)
     - Llama-2-7B-chat (LLaMA-2-7B)
     - Mistral-7B-Instruct-v0.2 (Mistral-7B)
     - 추가적으로, Meta-Llama-3-8B-Instruct (Meta-LLaMA-3-8B)

    2. 학습 데이터
     - MNTP 단계: Wikitext-103 데이터셋
     - SimCSE 단계: Gao et al. (2021)이 공개한 위키피디아 문장 서브셋
     - MNTP 수행시, 파인튜닝을 위해서 Lora 방식을 적용함
     - 1000 step, 32 batch로 1개의 A100(80GB)가지고 학습 
        => 7B/8B 모델에 대해 100분 밖에 안걸림

    3. Unsupervised contrastive learning : SimCSE 
     - SimCSE 학습 이전에, MNTP 모델로 학습한 Lora adaptor 가중치를 그냥 기본모델로 통합시킴.
     - 그리고 새로운 Lora 파라미터로 SimCSE를 학습 진행함.
     - 1000 step, 128 batch로 A100(80GB) 1개로 학습
       => 7B/8B 모델에 대해 180분 밖에 안걸림

     

     

     

    [Evaluation on Unsupervised Learning]

    3.2 Evaluation on sequence-level tasks

     

    <Setup>
     - MTEB로 평가
     - E5-7B-instruct 임베딩 모델에 제안된 방법(Improving Text Embeddings with Large Language Model)을 따라,

       task-specific instruct를 수행
     - E5-7B-instruct와 동일한 Instruction을 사용 (Table 10 참고, E5-7B-instruct과 동일)
     - query 앞에 instruction을 붙여서 활용
     - mean pooling시에는, instruction token을 제외하고 수행

     

    E5-7B-instruct에서 수행한 방식과 Instruction들

     

     

    <Result1> 

     

     Figure 3에서 확인된 결과는 아래와 같다.
     - 학습없이 Bi-directional만 적용하면 성능이 당연히 떨어지지만 (LLM이 학습했던 방법이 아니므로) 

       MNTP까지 추가되면 크게 좋아지는 것 확인됨
     - EOS token에 대한 Last-hidden-state를 활용하는 것은 성능이 떨어짐
     - weighted mean pooling도 좋으나, mean pooling이 제일 성능이 좋은 것 확인.
     - Mistral-7B의 경우, 학습없이 Bi-directional만 적용해도 성능이 증가가 확인됨. 
     - 이후에, 나오지만 LLM2Vec 방식이 LLM에 영향을 주는 것에 대한 실험중 Mistral-7B의 경우, 

       Bi-directional과 같은 CLM 이외의 학습 방법이 있을것으로 예상됨.

     

     

    <Result2>
     Table 1에서 확인된 결과는 아래와 같다.
     - 대부분의 경우에서 MNTP와 SimCSE를 적용한 것이 효과가 좋았음
     - 전반적인 성능이 크게 향상됨 확인

     

     

     

    [Evaluation on Supervised Learning]

    5.1 LLM2Vec leads to strong performance on the MTEB leaderboard

     

    <Setup>
     - Lora 활용하여 MNTP / SimCSE 학습
     - Public Data로 학습(E5 dataset, 150만개 샘플)
     - Hard Negative와 in-batch negative로 contrastive learning 실시
     - 512 batch, 1000 step 수행

    <Result>
     Table2에서 확인된 결과는 아래와 같다.
     - 예상되었듯, SimCSE 효과 적음
     - SimCSE 없이 MNTP만 수행한 경우가 더 성능 좋은 경우가 많음 
     - 하지만, 샘플 대비 성능 효율성에서 SimCSE가 좋은 역할을 한다 (다음 5.2절) 

     

     

     

    5.2 LLM2Vec leads to more sample-efficient training

     

     - LLM2VEC 방법론을 쓴 경우, 빠르게 높은 성능으로 수렴하는 모습을 볼수 있다.
     - MNTP에 SimCSE까지 수행한 경우가, 가장 빠르게 높은 성능으로 수렴한다.
       => 즉, data sample-efficient하다.
     - 지도 학습에서 최종 성능은 MNTP만 하는 것이 좋았더라도, 
       SimCSE까지 수행하는 것을 고려해봐야하는 이유다.


     

     

    [ LLM에서 벡터를 취하는 위치 ]

    본 논문에서는 어디서 어떻게 임베딩 벡터를 취했다는 확실한 워딩이 없다( 못찾은 걸수도...)
    코드를 찾아보니, E5-7B instruct와 동일하게 last_hidden_state에서 임베딩 벡터를 취했다.

     

    # 출처 : https://github.com/McGill-NLP/llm2vec/blob/main/llm2vec/llm2vec.py
    
    class LLM2Vec(nn.Module):
    
         ....
    
       def get_pooling(self, features, last_hidden_states):  # All models padded from left
            assert (
                self.tokenizer.padding_side == "left"
            ), "Pooling modes are implemented for padding from left."
            if self.skip_instruction:
                self._skip_instruction(features)
            seq_lengths = features["attention_mask"].sum(dim=-1)
            if self.pooling_mode == "mean":
                return torch.stack(
                    [
                        last_hidden_states[i, -length:, :].mean(dim=0)
                        for i, length in enumerate(seq_lengths)
                    ],
                    dim=0,
                )
            elif self.pooling_mode == "weighted_mean":
                bs, l, _ = last_hidden_states.shape
                complete_weights = torch.zeros(bs, l, device=last_hidden_states.device)
                for i, seq_l in enumerate(seq_lengths):
                    if seq_l > 0:
                        complete_weights[i, -seq_l:] = torch.arange(seq_l) + 1
                        complete_weights[i] /= torch.clamp(
                            complete_weights[i].sum(), min=1e-9
                        )
                return torch.sum(last_hidden_states * complete_weights.unsqueeze(-1), dim=1)
            elif self.pooling_mode == "eos_token" or self.pooling_mode == "last_token":
                return last_hidden_states[:, -1]
            elif self.pooling_mode == "bos_token":
                return last_hidden_states[
                    features["input_ids"] == self.tokenizer.bos_token_id
                ]

     

     

    아래는 Mistral 7B 모델을 Huggingface에서 지원하는 AutoModelForCausalLM을 통해 다운로드 받아,

    모델 구조를 확인한 모습이다.

    Mistral 7B 기준으로, LLM 모델에서 어느 위치로 벡터를 뽑는지 확인해보자. 

     

    MistralForCausalLM(
      (model): MistralModel(
        (embed_tokens): Embedding(32000, 4096)
        (layers): ModuleList(
          (0-31): 32 x MistralDecoderLayer(
            (self_attn): MistralSdpaAttention(
              (q_proj): Linear(in_features=4096, out_features=4096, bias=False)
              (k_proj): Linear(in_features=4096, out_features=1024, bias=False)
              (v_proj): Linear(in_features=4096, out_features=1024, bias=False)
              (o_proj): Linear(in_features=4096, out_features=4096, bias=False)
              (rotary_emb): MistralRotaryEmbedding()
            )
            (mlp): MistralMLP(
              (gate_proj): Linear(in_features=4096, out_features=14336, bias=False)
              (up_proj): Linear(in_features=4096, out_features=14336, bias=False)
              (down_proj): Linear(in_features=14336, out_features=4096, bias=False)
              (act_fn): SiLU()
            )
            (input_layernorm): MistralRMSNorm()
            (post_attention_layernorm): MistralRMSNorm()
          )
        )
        (norm): MistralRMSNorm()         => 임베딩 벡터를 뽑는 위치

                                                                 해당 위치에서 나오는 출력을 last_hidden_state라고 한다.

                                                                  (Last_hidden_state는 마지막 트랜스포머 층(layer)의 출력)

                                                                 10개의 token이 Embedding_size(벡터 차원)가 4096인 모델에 입력된다고 해보자.

                                                                  해당 위치에서 출력되는 값의 형태는 (batch_size, 10, 4096)이 된다.

                                                                  Last_hidden_state 중 [EOS] token에 대한  Embedding을 취하고 싶다면,

                                                                    => Vector[ :, -1, : ] 

                                                                  Last_hidden_state의 모든 토큰에 대해 mean pooling 적용하고 싶다면

                                                                    => 10개 Vector 값에 대해 평균을 취함

                                                                  어떤 경우든, 출력 사이즈는 (batch_size, 4096)이 됨.
      )
      (lm_head): Linear(in_features=4096, out_features=32000, bias=False)     => LLM2VEC을 수행하기 위해서는

                                                                                                                                해당 부분을 제거 or 제거된 상태로 호출
    )                             

    자세한건 Last_hidden_state와 Logit 글에서 확인

    728x90
Designed by Tistory.