본문 바로가기

Others

Embedding Hook에 대해 알아보자 (+PointLLM을 사용해보며)

 

우선 embedding hook에 대해 설명하기 위해선 huggingface transformers framework를 알아야 한다. 

HuggingFace Transformers

: LLM,Vision 모델, Audio 모델을 일관된 인터페이스로 로딩,학습,추론하게 해주는 프레임 워크

 

ex)

[추론(generate) 관리]

: beam search, sampling, decoding, attention mask 생성, past_key_values 관리, stopping criteria

[tokenizer 관리]

: 텍스트를 토큰 ID로 변환, special tokens 등록, padding,truncation 제어

[모델 구조 통일]

: LLaMA, GPT, OPT, T5 등 서로 다른 모델을 같은 API로 호출

[forward 구조 표준화]

: 입력 어떻게 처리할지, attention mask 계산, embedding lookup

 

 

지금 나의 문제 상황

: chat_gradio.py를 이용해 pointLLM을 데모를 하려고 하는데 point cloud sample은 잘 등록이 되는데 llm의 output이 매우 엉뚱하다. 현재 LLM이 point feature를 못 받는 상태. 

즉, point_clouds -> encoder (O) -> projector (O) -> concat into LLM embedding (X) 여기서 문제 발생

그렇기 때문에 LLM이 point 정보 없이 텍스트만 보고 일반 답변을 생성하게 된다.

 

그렇다면 왜 이런 일이 발생할까

 

1] Transformers 버전이 모델이 테스트된 버전과 다르면 'forward hook'이 깨짐

: transformers 버전의 차이. embedding hook의 위치는 transformers 버전마다 다름 

=> generate()에서 point embedding 삽입 코드가 수행되지 않음 

2] generate() 경로에 맞춰진 LamaWrapper가 최신 transformer 구조와 맞지 않음 

: pointLLM은 generate() 내부에 custom input hook을 넣ㄴ는데 transformers가 내부 generatge() 구조를 바뀌면서 hook이 무시됨

3] pointLLM v1.2의 vision-to-LLM projector가 text-side embedding과 dimension만 맞고 위치가 달라짐

: 즉 전달은 되지만 llm에게는 먹히지 않는 위치에 들어감

 

 

embedding hook / forward hook이 무엇인가? 

: llm이 입력 처리할 때 (text->token IDs->embedding layer->transformer blocks->logits) 등 여러 단계 거침.

근데 multimodal 모델은 llm 입력 사이에 '추가 정보'를 넣어줘야 함. 예를 들어 image embedding ,point cloud embedding 등..

이걸 삽입하기 위해서 필요한 것이 'embedding hook'

 

Embedding hook

: llm의 input embedding을 가로채는 동작. (text embedding + 추가 embedding을 concat하거나 replace하는 단계)

 

[원래 LLM]

input_ids -> model.input_embeddings(input_ids)

 

[pointLLM]

point_embeddings = PointEncoder(point_cloud)

projected_points=projector(point_embeddings)

text_embeddings=embedding(input_ids)

final_input=concat(projected_points, text_embeddings)

이 단계가 "embedding hook"

 

Forward hook

forward() 함수가 호출될 때 중간 텐서를 바꾸거나 추가 데이터를 흘려보내는 동작

ex) generate() 실행 중 attention mask 확장, past_key_values 수정, multimodal token 삽입.

pointLLM의 generate()는 내부적으로 forward hook을 사용해서 point clouds->point encoder -> llm input으로 들어가게 함. 

 

그렇다면 embedding hook과 forward hook은 뭐가 어떻게 다른걸까

embedding hook(input 단계)은 llm 입력을 변형하는 지점이고, forward hook(model 내부 단계)은 llm 내부 forward 단계를 조절하는 지점이다. 

 

 

 

LLaMAWrapper

pointLLM 구조 보면 LLaMA 모델 상속해서 커스텀 기능 추가한 래퍼 클래스 존재

class PointLLMLlamaForCausalLM(LlamaForCausalLM)

- point encoder 붙이기

- projector 붙이기

- multimodal token 관리

- generate() override (generate 호툴할 때 point feature를 LLM에 묶어서 전달하는 로직)

- 텍스트-포인트 sequence 정렬

즉, llamwrapper는 llama를 그대로 쓴게 아니라 3d point embedding을 넣을 수 있게 만든 수정된 llama 엔진.

 

여기서 hook들이 transformers 구조에 의존하게 되고, transformers 버전이 바뀌면 generate() 내부 구조도 바뀌어서 hook 연결이 깨질 수 있다. (시스템은 멀쩡한데 fusion만 죽은 상태)