ABOUT ME

-

Today
-
Yesterday
-
Total
-
  • 논문 리뷰) Llama 2 : Open Foundation and Fine-Tuned Chat Models
    AI 논문 리뷰 2024. 6. 24. 23:14

    Llama 2_Open Foundation and Fine-Tuned Chat Models 논문은 베일에 쌓인 LLM 학습 방식을 A-Z까지 상세하게 기술한 한줄기의 빛과 같은 논문이다. META는 LLaMa2를 위해 500억 이상의 비용과 시간을 들였지만, 대부분의 학습 방식과 스킬을 공개하였다. 

     

    [총 평]
    - open source LLM SOTA를 달성,  gpt-3.5-turbo 버전과 성능 유사.

    - GPT-3.5(InstructGPT 논문)와 비교했을 때 방법론적으로 거의 유사하지만,
      GQA, Doubled Context, 많은 토큰수, 대화 장기기억을 위한 Ghost Attention, RLHF 반복 수행과 Distribution 맞춰주기, PPO+Reject Sampling Fine-tuning 등 신규 테크닉을 적용

    - 특히, 데이터 셋 구축 방법을 공개함.

    - Data Quality 또는 좋은 라벨러가 매우 중요. 
      (업체 인력을 통한 좋은 품질의 소량 데이터 > 크롤링 등을 통한 soso 품질의 다량 데이터)

    - pretraining LLM은 500억 이상의 비용과 시간 발생. 왠만한 대기업 아니면 꿈도 꾸지 마라.
      (A100 1개를 170만시간 돌린 시간만큼 학습, A100 170개여도 만 시간 필요)  

    - Meta는 Open-source로 LLaMA2를 개발하면서, 얻은 지식을 모두 개방.... (미친 것 같음). This is 교과서 of LLM

     

     

    [1. Introduction]

    - llama2 (7B, 13B, 70B), llama2-chat (7B, 13B, 70B) 총 6개의 버전이 릴리즈되었다.
    - llama2는 llama1 대비 publicly available data를 40% 가량 추가 사용했다.
    - context length가 2048에서 4096으로 증가했다.
    - 7B, 13B, 70B 버젼을 릴리즈했다. 34B도 학습했지만 safety 문제로 릴리즈하지 못했다.
    - llama2-chat은 llama2를 fine-tuning한 버전으로 dialog use case에 최적화되어 있다.

      이 모델 또한 7B, 13B, 70B 버젼으로 릴리즈했다.
    - 오픈소스로 개방하였지만, LLama2 chat 버젼은 모든 시나리오에 대한 safety를 보증해줄수 없다.
    - 따라서, 개발할때 safety test를 하고 하길 바란다. 다만, guide나 샘플 code는 제공할 것이다.



    [2. Pretraining]

    아래와 같이 6개의 큰 특징으로 학습함.
    - 약간의 변화를 준 optimized auto-regressive transformer 사용
    - robust data cleaning (data 클린징에 힘을 줌)
    - data mixes
    - GQA
    - context 길이를 2배로
    - 필터링을 거쳐 total data에서 40%만 학습

    ※ Data의 조건:     
    1. 메타의 프로덕트 데이터는 사용하지 않고 공개적으로 
        사용 가능한 데이터 소스만을 사용
    2. 개인 정보 데이터 사용 X
    3. 2조 개의 토큰 사용
    4. 할루시네이션 문제 감소위해 fact 기반의 데이터 비중 높임


    ※ Traing Detail

    - (llama1) standard transform architecture
    - (llama1) pre-normalization using RMSNorm
    - (llama1) SwiGLU activation function
    - (llama1) rotary positional embeddings
    (new) doubled the context length
    (new) used grouped-query attention (GQA) 
                (for improved inference scalability)
    - (llama 1) bytepair encoding algorithm using SentencePiece
    - (llama 1) split digit
    - (llama 1) use byte for Unknown UTF-8

     

    입력 가능한 context length는 4096이며, GQA는 34B, 70B에만 적용하였다.
    총 Token 수는 2.0조개인데, Public & online Data를 학습에 활용하였다. (토큰 수가 상당하다)

     

     

    토크나이저의 경우,
    2조개 토큰까지는 학습 Loss 감소, 그 이후로는 saturation되는 모습이다.

    토크나이저는 llama1에서와 같이 SentencePiece와 UTF-8로 Unknown을 인코딩해서 처리하였다.
    - (llama 1) bytepair encoding algorithm using SentencePiece
    - (llama 1) split digit
    - (llama 1) use byte for Unknown UTF-8

     

     

    GPU(A100-80GB) 1개를 사용했다고 했을때, 70B는 170만시간이나 소요된다.
    실제로는 GPU를 최대 2000개까지 풀가동했다고 한다.....(LLM은 규모의 싸움이다.)

    LLama2 학습 소요
    LLama2 학습 비용

     

     

    직접 다시 모델들을 돌려서, Benchmark에 평가해보았다.
    라마2 70B가 모든 오픈소스 모델보다 우위였다.

    LLama2 성능 비교표

     

     

    Closed-source 모델과도 성능 비교해봤는데
    GPT4와 Palm2-L 과는 성능 격차가 크게 있었지만, 
    GPT 3.5와 Palm보다는 우위였다고 한다.

    LLama2 성능 비교표

     

     

     

     

    [3. Fine-tuning]

    LLama2는 주옥같은 학습 방식을 공개하였다. 주요 학습 테크닉은 아래처럼 3개로 구분해 볼수 있겠다.

    - 기존의 instruction tuning, RLHF와 같은 테크닉들 적용
    - supervised fine-tuning(SFT), iterative reward modeling, RLHF와 같은 기존 테크닉에 

      Ghost Attention (GAtt)라는 자체 새로운 테크닉 적용
    - GAtt는 multiple turn 동안의 dialogue flow를 제어할 수 있게 돕는 테크닉임
      (Instructin에 대한 장기 기억을 보조하기 위한 테크닉)


    ※ Instruction Tuning과 Alignment Tuning의 차이점 :

    chat형 LLM을 만든다고 할때,
    Instruction Tuning은 지시사항에 적합하게 수행하도록 하는 것이기 때문에, 
    답변을 잘하도록 QA 데이터로 SFT하는 것이 그 예이고,
    Alignment Tuning은 답변의 퀄리티/선호도 정도를 튜닝하는 것이 목적이고
    RLHF/DPO(Preference optimziation)를 수행한다. 
    위 2과정을 합쳐서 Instruction Fine-tuning / 그냥 Fine-tuning이라고 한다.

     

    <라마2 학습 전체 과정>
    <라마2 학습 전체 과정>

     

     

     

    [3.1 SFT]
    다양한 소스에서 third-party SFT data를 긁어왔지만 다양성과 질이 부족했는데,
    업체 통해서 질 좋은 데이터셋을 쓰니까 양이 훨씬 적어도 결과가 더 좋았다고 한다.
    주석을 단 27540개만 수집하고 멈춰도 될 정도로 양질의 데이터는 학습효과가 좋았다.

    Training Detail.
    - cosine learning rate scheduler(init: 2e-5), weight decay: 0.1, batch size: 64, seq len: 4096
    - Prompt + special token + Answer의 형태로 데이터셋을 제작 (해당 형태가 바로 SFT 학습)
    - autoregressive objective(next token prediction)을 활용
    - Prompt에서는 loss를 항상 0으로 설정함으로써 Anaswer에 대한 loss만 backprop함
    - 2 epochs만 돌림


    나중에 실제 모델 활용할때
    사용자가 Prompt를 입력하면, '시스템'이 Prompt에 '스페셜 Token'을 붙여 모델에 입력. 
    => 모델은 스페셜 Token 이후로, Answer를 답변하는데 특화될 것이다.

    SFT 학습방식
    SFT 학습방식



     

    [3.2 RLHF]
    RLHF란 모델이 human preference와 instruction을 더 잘 따라갈 수 있도록 모델의 행동에 대해 alignment tuning하는 학습 stage이다. 쉽게 말해서, 사람의 선호도에 맞게 align하는 작업이다.

    모델의 응답을 평가해줄 reward model을 만들기 위해 human preference data를 수집하였다.

    2개의 답변을 주고 라벨러에게 무엇이 더 좋은지 고르게 하였다. 
    human이 의견을 낼 수 있는 제일 쉬운 형태(boolean)이므로 그만큼 수집할 수 있는 프롬프트의 다양성이 높아졌다. 
    (다른 방법도 많겠지만 그건 나중에 한다고 함)

    1. annotator에게 prompt를 쓰고, 
    SFT까지 완료된 모델에게 답변을 뽑아 2개 중에서 고르도록했다.
    이때, 다양성을 위해 2개의 답변을 만들 때 각각 다른 모델을 사용하고 temperature도 계속 바꿈.
    그리고 총 5개로 label.(아주 좋음/좋음/약간 좋음/무시 가능한 수준 정도로만 좋음/불확실함)
    A>B, A<B로 존재하므로 정확히는 라벨이 10개가 됨.

    2. annotation할때, helpfulness와 safety에 집중하였다. 
    helpfulness는 모델 response가 user request에 얼마나 도움되는지의 여부, safety는 모델 response의 안정성 여부를 체크했다.
    2개의 답변이 주어질 때 2개의 기준을 가지고 각각 평가함. 각각 평가해야 평가 기준이 더 명확해지기 때문.

    3. safety 단계에서는 safety 여부에 대한 라벨을 추가 수집했다.
    선호도에서 A>B일때, safety에서 B>A라면 학습데이터에서 제외시켰다.

    4. 선호도 data가 쌓일수록, reward model의 성능이 좋아졌다. 
    그리고 그 reward model을 통해 LLama2-chat 모델을 훈련하여, 모델의 답변도 좋아졌다.
    이때, 중요한 것은 LLama2-chat 모델이 훈련되면 답변의 distribution이 변경되므로, 
    reward model 역시 새로운 선호도 데이터(latest model로부터 나온)를 통해 학습시켜야한다.
    즉, 두 모델이 same distribution에 있도록 위 과정을 iteration 해줘야 한다. 이를 figure 4에서 Iterative reward modeling이라 함.
    (뒤에서 5번 반복했다고 나옴)

    5. Meta reward modeling data라는 많은 수의 dataset을 구축하였다.
    이는 helpfulness와 safetyness에 대한 label이 존재한다.
    우리 dataset이 더 대화 턴이 많고, 길다.

     

     

    ※  3.3 Reward Modeling
    helpfuness와 safetyness가 trade-off라고 주장한 논문도 있기때문에, 2개에 대하여 별개의 reward model을 만들어 훈련하였다.

    reward model은 사전학습된 (SFT 학습된) chat 모델을 사용하였다.
    왜냐하면, reward model은 chat model이 무엇을 하는 것인지 알아야하기 때문이다.
    모델의 아키텍쳐도 전부 똑같이 하였고, 모델의 맨 마지막 layer 즉, next token prediction을 위한 classification head를
    scalar reward를 출력하도록 regression head로 대체함.

    reward model 훈련을 위해 선호도 데이터를 ranking label로 바꿔주었다. 선호되는 답변이 더 높은 점수를 갖도록.

    목적 함수는 우측과 같이 loss function을 적용.
    y_c는 prefered, y_r은 rejected된 답변이다.

    여기에 sigmoid와 margin 값 m(r)을 추가함.
    preference rating은 4개로 구분되어 있다.(매우 좋음~불확실함)
    따라서, m(r)도 discrete 펑션이다. (Table 27과 28 참고)




    ※  Data Composition (데이터 조합 및 구축)
    open-source preference dataset과 meta custom preference dataset을 조합함.
    초기에는 open-source preference dataset을 단순히, annotation data 수집하는 동안에, 
    reward model을 bootstrap 하는 용도로 사용하였으나, 
    open-source preference dataset 중에 전이학습에 대해 부정적인 전이가 관찰되지 않아서, generalization 확보와 reward hacking 방지 차원에서 사용함

    **Reward hacking 이란?
    agent가 같은 의도하지 않은 편법을 통해 목표 달성 방법을 학습하는 것


    ※ Training Detail.
    - one epoch만 돌림 
    (더 돌리면 overfit됨)
    - 러닝레이트는 70B에 6e-6, 7B, 13B, 30B에 1e-5로 세팅
    - cosine learning rate scheduler 적용
    - warm-up of 3% (최대 5 에폭)
    - batch size 512 pairs
    - helpfulness와 safeness에 대해 각각 테스트해봄 (Table 7)

     

     

    Reward model Result.
    1000개 샘플을 뽑아, test 해봄
    GPT4랑도 비교해봤음.
    OpenAI의 API 활용하여,
    prompt로 'hoose the best answer between A and B라고 물음
    결론 : 자기네 데이터셋에서는 reward model 평가시, GPT4 이김 
    (학습한 데이터가 이거니까...더 유리할 것임, 이건 트릭이라과 봐야함...)

     

     

    모델이 크면 클수록 더 잘 이해하고, saturation 되지 않는 모습을 보였다.

     

     

    3.2.3 Iterative Fine-Tuning
    선호도 data가 쌓일수록, reward model의 성능이 좋아졌다. 
    그리고 그 reward model을 통해 LLama2-chat 모델을 훈련하여, 모델의 답변도 좋아졌다.
    이때, 중요한 것은 LLama2-chat 모델의 훈련되면 답변의 distribution이 변경되므로, 
    reward model 역시 새로운 선호도 데이터(latest model로부터 나온)를 통해 학습시켜야한다.
    즉, 두 모델이 same distribution에 있도록 위 과정을 iteration 해줘야 한다. 
    이를 figure 4에서 Iterative reward modeling이라 함. (5번 반복)

    다음 2가지 방법으로 RLHF 파인튜닝함.
    1. Proximal Policy Optimization (PPO) 
    2. Rejection Sampling fine-tuning
    - prompt를 주고 output을 K개씩 추출하여 reward model에서 제일 좋다고 선택된 output만 SFT 모델 가중치 업데이트에 사용

    위 2가지의 알고리즘은 아래 2가지의 차이점이 있음.
    Breadth - Rejection Sampling을 통해, 여러 K개 샘플 중 Best를 취하여 이동(가중치 업데이트)했으므로, Breadth를 넓힌 셈.
    Depth - PPO를 통해, t단계 동안 policy 탐험 수행.

    4번째까지는, Rejection Sampling 파인튜닝만 적용함. 그러고나서, RSF+PPO 순서로 적용.

     

    ※ Rejection Sampling.
    Rejection Sampling은 llama2-chat 70B에서만 수행. 
    그보다 작은 모델들은 70B모델때 나온 rejection sampled data로 fine-tuning 실시

    이전 iteration에서 나온 output 중에서만 top sample을 고르고 그걸로 PPO함.

    Rejection Sampling을 위한 샘플을 많이 뽑으면 그 중 top 샘플의 max reward는 당연히 높아짐. 
    computational cost를 고려해 RLHF 스텝마다 적절한 constant number만큼 샘플을 추출했음.


    ※ PPO.
    PPO 목적함수에 pi0라는 penalty 항목을 적용하였다.
    이는 여기에서 기존의 policy와 새로 업데이트되는 policy간의 차이가 커질수록, penalty를 주고자하는 목적이다.
    이렇게 하는 이유는, 사전학습된 LLM은 표현이 서투를뿐, 세상의 지식을 잘 내포하고 있는 모델이다.
    따라서, 무작적 reward에 따라 가중치가 업데이트되면, 지식없이 떠드는 바보상자가 될 수 있다
    실제로 실험결과, 이 항목을 통해 reward hacking을 줄여줄 수 있었다.

    따라서, Reward는 커지게 하면서도 사전학습된 가중치와는 너무 멀어지지 않게하도록 하였다.


    ※ Train Parameter.
    AdamW 옵티마이저를 썼고, beta1 = 0.9, beta2 = 0.95, eps= 0.00001이다.
    weight decay 0.1로 적용, gradient clipping 1.0, 러닝레이트는 0.000001 적용하였다.
    512 배치사이즈로 PPO iteration하였으며, PPO clip은 0.2였다. (PPO-clip 타입 적용한듯)
    모든 모델이 200~400 사이로 iteration하였으며, 얼리스탑 적용함.

     

     

    GAtt.
    RLHF 적용한 모델은 초기에 multi-turn에서 대화 내용를 잊어 버리는 문제가 있었음
    그래서, Ghost Attention (GATT)를 제안함

    Ghost Attention은 Dataset 구축이 메인 아이디어이다. 
    Instruction을 통해, 올바른 답변을 출력하게 하고, 데이터를 편집하여 대화 내에서 처음 Instruction만 남긴채, 중간에 존재하는 Instruction은 모두 지운다.
    이 대화를 통째로 SFT로 학습시키면, 모델은 처음 지시한 내용에 대해 가중치(기억)를 지속하여 지시사항을 계속 이행하게 될 것이다.
    이때, SFT를 위해서 이전 대화 내용을 입력했을때, 출력되는 Token에는 모두 loss값을 0으로 처리한다. 
    그리고 실제 답변 내용에 대해서만 loss 적용 (P+special token +A와 동일)
    (이를 통해 LLaMA2 계열은 모두 Instruction이 맨 앞에 있는 것이 유리하다고 볼수 있겠다)

     

    GAtt 데이터로 SFT했을때, 
    실제로 초기 앞부분의 Instruction에 대해 Attention이 높게 유지되는 것이 확인됨 (밝을수록 높은 attention)
    (이를 통해 LLaMA2 계열은 모두 Instruction이 맨 앞에 있는 것이 유리하다고 볼수 있겠다)

    Gatt의 시각화



    4. Safety
    left : Safety data 학습을 계속 진행하더라도, Hepfulness에 대한 영향성 없음 확인하였다.
    right : Safety data를 계속 학습할수록, Safety Score가 낮은 답변을 생성할 확률이 줄어드는 것 확인하였다.

    Safety 비교 확인

     

    728x90
Designed by Tistory.