AI

[논문 리뷰] LISA: Reasoning Segmentation via Large Language Model

jyseo-me 2024. 12. 30. 01:28
제목: LISA: Reasoning Segmentation via Large Language Model
저자: Xin Lai, Zhuotao Tian, Yukang Chen, Yanwei Li, Yuhui Yuan, Shu Liu, Jiaya Jia
출처: CVPR 2024
연도: 1 Aug 2023
링크: https://arxiv.org/abs/2308.00692
Github: https://github.com/dvlab-research/LISA?tab=readme-ov-file
 

LISA: Reasoning Segmentation via Large Language Model

Although perception systems have made remarkable advancements in recent years, they still rely on explicit human instruction or pre-defined categories to identify the target objects before executing visual recognition tasks. Such systems cannot actively re

arxiv.org


Introduction

이 논문에서는 Reasoning segmentation이라는 새로운 task를 제시한다.
Reasoning segmentation
Input Query에서 Binary segmenatation mask를 generate하는 task이다. 여기서 Input query는 "the orange"와 같이 명확한 text가 아닌, "food with high vitamin C"와 같은 implicit text이다. 이를 수행하기 위해서 model은 1) reasoning complex, implicit text  2) segmentation mask 생성을 해야 한다.

 

여기서는 LLM의 reasoning ability를 활용하여 이를 해결하였다. Vision input에 LLM의 reasoning ability를 결합한 시도는 이전 연구에도 있었지만, 이들은 대부분 text generation에 초점을 맞추었으며, segmentation mask와 같은 fine-grained output에는 활용되지 못하였다.

 

LISA는 segmenatation mask를 생성하는 multimodal LLM로, existing vocabulary로 <SEG> 토큰을 생성한다. 이때 LLM의 hidden embedding은 해당하는 segmentation mask로 decode 된다.

LLM이 human intention을 이해하는 것과 같이 implicit한 text를 segmentation하는 것이 놀랍다.

 

나아가 논문에서는 1000장 이상의 image-instruction set으로 구성된 ReasonSeg benchmark dataset을 공개하였다.


Related Work

1. Image Segmentation

  • Semantic Segmentation
    encoder-decoder, dilated convolution, pyramid pooling module, non-local operator
  • Instance Segmentation
    DETR-based structures, mask attention, dynamic convolution
  • Referring Segmentation: text에 해당하는 물체 segmentation
    SAM, X-Decoder, SEEM

Segmentation task에는 위와 같이 다양한 architectural innovation이 있었음
LISA는 이전 연구에서 다뤄지지 않은 reasoning ability를 지니고 있음

2. Multimodal Large Language Model

이전에도 LLM을 vision domian으로 transfer하는 Multimodal LLM 관련 연구가 있었다.

  • Flamingo: Cross-attention으로 visual contexts에 attend (Visual in-context learning)
  • BLIP-2 / mPLUG-OWL: Visual encoder로 뽑은 image feature를 LLM의 text embeddings으로 넣음
  • Otter: In-context instruction tuning (MIMIC-IT dataset)
  • LLaVA / MiniGPT-4: 처음으로 instruction tuning을 통한 image-text feature alignment
  • Grounding Language Models to Images for Multimodal Inputs and Outputs: Image retrieval for LLMs

최근에는 multimodal LLM과 vision tasks의 intersection 관련 연구들도 나오고 있음 

  • VisionLLM: Instruction tuning 통해 interaction interface 제공 (complex reasoning X)
  • Kosmos-2: Large scale grounded image-text pairs dataset을 만들어 LLM에 grounding capability 주입
  • DetGPT: Multimodal LLM + open-vocabulary detector (instruction 기반 detection)
  • GPT4RoI: Spatial box를 input으로 region-text pair에 model을 train 

Reasoning Segmentation

정의: Image, Text → Binary Segmentation Mask

 

Referring segmentation과 유사하지만 query text의 complexity에 차이가 있다.

Reasoning segmenatation에서는 straight forward한 문장 대신 더 복잡하거나 긴 문장을 사용한다.

(Complex reasoning, World knowledge를 필요로 한다.)

 

Referring segmentation 
“the trash can”
Reasoning segmentation
“something that the garbage should be put into”
“After cooking, consuming food, preparing for food, where can we throw away the rest of the food and scraps?”

 

Benchmark

논문에서는 reasoning segmentation benchmark dataset인 ReasonSeg를 제시하였다.

OpenImages, ScanNetv2에서 가져온 image set에 implicit text instruction과 target mask로 annotation하였다.

 

다양한 상황을 위해서 text instruction은 2가지로 구성하였다. (왼쪽 그림이 짧은 phrase, 오른쪽이 긴 문장)

1) short phrases 2) long sentences

 

ReasonSeg는 1218개의 image-instruction-mask pair로 구성되어 있다. Train, Validation, Test로 각각 239/200/779으로 나누어져 있다. (Benchmark의 목적이 training이 아닌 evaluation이기 때문에 valid, test에 더 많은 수의 data가 있다.)


Method

1. Model architecture

최신 multimodal LLM들 (LLaVA, Flamingo, BLIP-2, Otter, etc) input으로 image와 text를 받을 수 있지만, output으로 text만 내보낼 수 있다. 즉, segmentation mask를 직접 내보내는 것은 불가능하다.

VisionLLM은 segmentation mask를 polygon sequence로 parsing하여, segmentation mask를 plain text로 바꾸어서 end-to-end training을 가능하게 하였다. 하지만, polygon sequence를 end-to-end training하는 것은 optimization이 어렵고, generalization을 위해서는 data와 computational resource가 매우 많이 필요하다.

cf) VisionLLM 7B: 4X8 80G A100, 50 epoch LISA 7B: 8 NVIDIA 24G 3090, 3 일 이내

 

LISA에서는 embedding-as-mask paradigm을 사용하여 multimodal segmentation을 한다. 

 

먼저 original LLM의 vocabulary를 <SEG> token으로 확장한다. 이는 segmentation output request를 의미한다.

Text instruction과 input image가 주어지면 multimodal LLM은 text response y를 output으로 내보낸다. 

 

LLM이 binary segmentation mask를 생성하고자 할 때, output y에는 <SEG> token이 포함된다.

이때 <SEG> token에 해당하는 LLM의 last-layer embedding을 extract하고, MLP layer로 projection하여 hseg를 얻는다.

동시에 vision backbone은 input image에서 dense visual feature f를 extract한다.

마지막으로 hsegf는 decoder를 통과하여 최종 segmentation mask M을 내보낸다. (Decoder는 SAM 구조, 변경 가능)

Training Objective

Model은 end-to-end로 학습되며, 2가지 loss를 사용한다. (각각에 가중치 부여)

1) Text generation loss (Autoregressive CE loss)

2) Segmentation mask loss (Dice + BCE loss)

2. Training

Training data는 크게 3 부분으로 구성되었으며, 모두 public dataset에서 가져왔다.

1) Semantic Segmentation Dataset
ADE20K, COCO-Stuff, LVIS-PACO
각 image에 대해 random하게 category 선택, QA template으로 변경 (다양한 template 사용)

USER: <IMAGE> Can you segment the {class name} in this image?
ASSISTANT: It is <SEG>.
(<IMAGE>는 image patch token을 위한 placeholder, {class name}은 chosen category)

 

2) Vanilla Referring Segmentation Dataset
refCOCO, refCOCO+, refCOCOg, refCLEF
Referring segmentation은 input image-short description으로 이루어져 있으므로 QA pair로 변환하기 쉬움.

USER: <IMAGE> Can you segment {description} in this image? ({description}은 explicit description)
ASSISTANT: Sure, it is <SEG>.

 

3) Visual Question Answering Dataset
LLaVA-Instruct-150k, LLaVA-v1.5-mix665k

Training dataset에는 reasoning segmentation datasample이 없다는 것이 특징적이다. 이렇게 LISA는 complex reasoning training data 없이도 높은 zero-shot 능력을 보여준다. 또, 239개의 trainin set에 finetuning 했을 때 성능이 더 올라간다.

 

Trainable Parameters

Pretrained multimodal LLM의 knowledge를 보존하기 위해 (이 실험에서는 LLaVA ) LoRA로 fine-tuning하고 vision backbone은 freeze한다. Decoder는 fully fine-tuned 되었다. 아래 parameter도 모두 trainable하다.
- LLM token embeddings (embed tokens)

- LLM head (lm head)

- projection layer γ

 

학습된 model은 catastrophic forgetting (original text generation capability)를 방지하며 conversation 능력을 보존하였다.
이것의 potential reason은

1) LoRA fine-tuning으로 학습되는 parameter 줄임
2) Fine-tuning 과정에 VQA dataset을 사용


Experiment

1. Experiment Setting

Network architecture

  • LLM: LLaVA-7B-v1-1, LLaVA-13B-v1-1
  • vision backbone: ViT-H SAM
  • Projection layer: MLP [256, 4096, 4096]

Implementation Details

  • 8개의 NVIDIA 24G 3090, Deepspeed 라이브러리 사용
  • AdamW, lr=3e-4, weight decay=0, WarmupDecayLR (warmup iteration=100)
  • λtxt=1.0, λmask=0, λbce=2.0, λdice=0.5
  • Batch size (per device)=2
  • Gradient accumulation step=10
  • semantic segmentation dataset에서는 각 image 당 3개의 category 선택

Evaluation Metrics

Referring segmentation task와 동일한 metric 사용

1. gIoU: Image 당 IoU (Intersection-over-Unions)의 평균

2. cIoU: cumulative IoU, 누적된 image, mask의 IoU를 계산

* cIoU는 large-area object에 bias되어 있으며, fluctuation이 심하기 때문에 gIoU를 선호한다.

 

2. Experiment 결과

1. Reasoning Segmentation

 

위 표에서 reasoning segmentation task는 거의 20%의 gIoU 차이로 LLAVA보다 LISA가 성능이 좋은 것을 볼 수 있다.

 

LLaVA1.5 + OVSeg는 2 stage method를 사용한 것을 의미한다.

1) LLAVA v1.5로 text output 생성

2) OVSeg와 같은 open vocabulary segmentation mask를 생성
(너무 길어질 경우 GPT-3.5로 요약 후 OVSeg에 input)

* End-to-End로 training 했다는 점, text를 input으로 사용하는 대신 embedding을 사용했다는 점이 차이점

 

2. Vanilla Referring Segmentation

Reasoning segmentation이 아닌 원래 referring segmentation에서는 어떤지도 실험하였다.

 

refCOCO, refCOCO+, refCOCOg에서 실험한 결과, LISA가 가장 성능이 좋았다.

3. Ablation Study

  • Vision backbone: 다양한 선택지가 있지만 pretrained SAM이 가장 성능이 좋았다.
  • Instruction을 GPT-3.5로 재구성하는 것이 gIoU를 2.2% 높였다.
  • Semantic segmentation dataset은 성능에 필수적이며, reasoning segmentation sample이 많을수록 결과가 좋았다.

자세한 ablation은 원 논문 참고

4. Qualiative result

 

5. Conclusion

본 연구에서는 reasoning segmentation이라는 새로운 task를 정의하고, LLM의 hidden embedding을 활용하여 복잡한 text에 대해서도 segmentation mask를 generation할 수 있는 모델인 LISA를 제시하였다.

[Summary]
LISA는 complex query를 이해하여 해당하는 segmentation mask를 생성하는 model
- Explicit text 대신 complex text를 이해 (LLM의 reasoning ability transfer)
- Text를 output하는 것이 아니라 segmentation mask를 output
- Vision 정보를 text, 좌표 등이 아닌 embedding을 input으로 넣는 것이 인상적!
- Embedding은 <SEG> 토큰을 생성하고, LLM의 last layer를 가져옴