본문 바로가기

논문리뷰

Medical SAM Adapter : Adapting Segment Anything Model for Medical Image Segmentation

 

https://arxiv.org/abs/2304.12620

 

Medical SAM Adapter: Adapting Segment Anything Model for Medical Image Segmentation

The Segment Anything Model (SAM) has recently gained popularity in the field of image segmentation due to its impressive capabilities in various segmentation tasks and its prompt-based interface. However, recent studies and individual experiments have show

arxiv.org

 

Abstract 

Segment Anything Model (SAM)은 최근에 다양한 segmentation task에 활용할 수 있고 prompt를 사용할 수 있는 prompt-based interface이기 때문에 image segmentation 분야에서 많은 인기를 얻고 있습니다. 그러나 SAM은 general한 범용적인 모델이기 때문에 아무래도 medical에 특화된 지식은 부족하기 때문에 medical image segmentation에선 우수한 성능을 내는 데에 한계가 있습니다. 그렇다면 medical image에 대한 SAM의 segmentation 능력을 키우기 위해선 어떻게 해야 할까요? SAM model을 전체적으로 fine-tuning하는 대신, 간단한 adaptation 기술을 segmentation model에 적용해 domian에 특화된 medical knowledge를 추가해줄 수 있는 Medical SAM Adpater (Med-SA)를 제시했습니다. Med-SA에선 크게 두가지 technique이 소개됩니다. 2D SAM을 3D medical image에 적용하기 위한 Spase-Depth Transpose (SD-Trans)와 prompt 조건의 adaption을 하기 위한 Hyp-Adpt가 존재합니다. 다양한 modality에서 17개의 medical image segmentation task를 통해 실험을 진행하고 효과를 입증했습니다. Med-SA는 기존의 medical image segmentation method에 존재하는 SOTA를 뛰어넘었고, parameter는 단지 2%만 추가했습니다. 

 

Introduction

아주 최근에, Segmentation Anything Model은 다양한 곳에 사용될 수 있고 강력한 성능으로 인해 vision segmentation model에서 많은 주목을 받고 있습니다. SAM은 user prompt를 기반으로 다양하고 디테일한 segmentation mask를 생성할 수 있습니다. natural한 image에 대해선 좋은 성능을 내지만, 최근 연구에 의하면 medical image segmentation에선 평균 이하의 저조한 성능을 낸다는 결과가 있습니다. SAM 처럼 medical image segmentation을 interactive하게 (prompt를 통한) 만드는 것은 엄청난 clinical한 가치를 지니고 있습니다. interactive system은 clinician인 전문가에 의해 알게되는 관심있는 부분에 대해 우선순위를 둘 수 있게 하기 때문에, 보다 몰입감 있고 개인화된 경험을 제공할 수 있습니다.  예를 들어, 하나의 fundus 이미지에는 혈관, optic disc, optic cup, macula와 같이 중첩되고 복잡하게 얽혀 있는 구조가 있는 경우가 많습니다. interactive segmentation은 이러한 복잡한 구조로부터 효과적으로 target tissue를 구별해 clinicial을 보조할 수 있습니다. 대규모의 annotated 데이터셋을 얻는 건 어렵기 때문에, clinical에 활용을 하기 위해선 SAM과 같이 foundational한 interactive model을 사용하는 것이 중요해졌습니다. 

 

SAM에 medical image를 사용할 때 성능이 저조한 이유

1. SAM은 natural image로 학습된 모델이기 때문에 의학적 특수 지식이 부족함

2. medical image (MRI, CT)는 일반적인 이미지와 달리 조직 간의 대조가 약해 구분하는 것이 더 어려움

3. tissue boundary가 애매한 경우가 많음 

4. 보통 우리가 구분하고자 하는 lesion region은 크기가 작아 segmentation 하기 더 어려움 

 

SAM에 medical image가 잘 적용될 수 있도록 제안한 SOTA 논문인 MedSAM(Segment Anything in Medical Images)은 vanilla SAM model을 fully fine-tuning했습니다. SAM model 전체를 fine-tuning하는 것은 computation 측면에서도 memory 측면에서도 많은 비용이 소모 됩니다. 그리고 기존의 pre-trained visual model을 medical image로 transfer가 가능한데 SAM을 전체적으로 fine-tuning하는 것이 꼭 필요할까요? 

이 논문에선, 아주 적은 변형으로 학습된 SAM을 medical image segmentation에 적용하는 방법을 제시합니다. 엄밀히 말하면, Adaption이라고 불리는 parameter-efficient fine-tunign (PEFT)를 사용해 사전 학습된 SAM을 fine-tune하도록 합니다. Adaption은 NLP에서 기초적인 사전 학습된 모델을 다양한 downstream task에 적용하도록 fine-tune하는 널리사용되는 인기있는 기술입니다. Adaption의 main idea는 기존 모델에 부분적인 parameter를 갖고 있는 Adapter module을 삽입해 기존의 사전 학습된 큰 모델은 frozen시키고 추가적인 적은 양의 Adapter parameter만 업데이트 하는 것입니다. 

 

 Adaption technique을 medical 상황에 바로 적용하는 것은 그렇게 간단한 건 아닙니다.

1.  NLP가 아닌 image modality를 사용

natural image와 달리, 많은 medical image는 CT와 MRI처럼 3D의 형태입니다. 2D로 이루어진 SAM model을 3D medical image segmentation에 맞게 적용해야 합니다. 

2. Adaption은 NLP에서는 잘 적용되지만 visual model에, 특히 SAM과 같은 interactive visual model에 잘 적용되는지에 대한 연구는 많이 이루어지지 않음

interactive visual model에선 유저가 제공하는 visual prompt가 최종 예측에 상당히 중요한 역할을 합니다.  이러한 중요한 visual prompt에 맞게 합치는 방법은 아직 밝혀지지 않았습니다.

 

이런 문제를 극복하기 위해, Medical SAM Adapter (Med-SA)로 불리는 새로운 adaptation framework를 제안합니다.

Med-SA에선, 2D에서 3D로의 적용시키기 위해 Space-Depth Transpose(SD-Trans) 기법을 소개합니다. SD-Trans에선, input embedding의 spatial dimension을 depth dimension으로 transpose 시켜, 같은 self-attention block이 다른 input이 주어졌을 때 다른 dimensioanl information도 처리할 수 있도록 합니다. 그 다음 Hyper-Prompting Adapter (HyP-Adpt)를 통해 prompt에 맞는 adpation을 할 수 있게 합니다. 여기서 visual prompt를 통해 weight series를 생성하고 이 weight를 통해 adaptation embedding에 효과적으로 적용하고, prompt-adaptation ineraction을 편하게 할 수 있게 됩니다.

 

 다양한 modality (CT, MRI, ultrasound image, fundus image, dermoscopic image)로 이루어진 17개의 medical image segmentation task를 통해 실험을 진행해 성능을 평가했습니다. Med-SA가 당연히 기본 SAM과 fully fine-tuned SAM(MedSAM)에 비해 훨씬 좋은 성능을 냈습니다. Med-SA는 SAM에서 뿐만 아니라 nnUet, TransUNet,UNetr, Swin-UNetr과 같이 직접 수정한 medical image segmentation 방법으로 이루어진 여러 SOTA도 넘겼습니다. 더 중요한 건 Med-SA는 전체 SAM parameter의 2%에 해당되는 추가 parameter만 update해서 이런 우수한 성능을 얻었다는 점입니다. 

 

Related Work

Interactive Segmentation

interactive segmentation은 초기에 optimization 기술로 다뤄졌습니다. DIOS 연구는 deep learning과 병합하고 distance map으로 positive와 negative click을 합침으로써 interactive segmentation으로 발전했습니다. (2019)이후엔 여러 potential result를 예측하고 selection network로 선택을 하거나 user가 직접 선택하도록 함으로써 'uncertainty'를 다루는 데에 초점을 맞췄습니다. CDNet(2021)은 나아가 더 일정된 prediction을 생성하기 위해 self-attention을 사용해 interactive segmentation을 발전시켰습니다. RITM(2022)과 AccuracyNet(2020)은 prediction을 조금 더 robust하고 정확하게 하기 위해 이전의 mask를 input으로 사용했습니다. 최근엔, SAM(2023)은 zero shot segmentation에서 interactive segmentation은 많은 영향을 끼치는 것을 보였고 visaul foundation  model의 중요성을 강조했습니다. 그러나 interactive medical image segmentation은 의학적으로 중요한 역할을 함에도 불구하고 비교적 많은 관심을 얻지 못했습니다. 예를 들어, 하나의 fundus image는 다양한 상황에 따라 여러 target (vessel, optic disc, optic cup, macula)의 segmentation을 필요로 할 것입니다. Med-SA는 interactive medical image segmentation의 훌륭한 출발점을 제공하고 이후 연구에 많은 영감을 줄 것입니다.

 

Parameter-Efficient Fine-Tuning

PEFT는 특정 상황에서 사용했을 때 거대하고 기본이 되는 model을 fine-tuning하는 것이 좋은 방법이 될 수 있다는 것을 보였습니다. 전체를 fine-tuning하는 것과 비교했을 때, 대부분의 parameter는 동결시키고 매우 적은 parameter를 학습시키게 됩니다. 이 때 매우 적은 건 보통 전체 parameter의 5%도 안되는 수입니다. 그렇기 때문에 훨씬 더 빨리 update해서 효율적으로 학습할 수 있습니다. PEFT를 사용한 방법은 catastrophic forgetting을 하지 않고 out-of-domain scenario (특히 low-data regime)에도 더 잘 일반화할 수 있기 때문에 full fine-tuning보다 더 좋은 성능을 낼 수 있습니다. 여러 PEFT 방법 중, Adaption은 NLP뿐만 아니라 computer vision에서도 대규모의 기본 vision model을 downstrea task에 적용하기 위해 fine-tuning을 할 수 있는 효과적은 도구입니다. 최근 연구들은 Adaption이 다양한 downstream computer vision task에서 쉽게 적용될 수 있다는 것을 보였습니다. 그렇기 때문에 Adaption은 SAM을 medical domain에 사용할 수 있는 가장 적합한 기술이라 생각했고, 이 Med-SA는 foundationaml medical model의 발전에 굉장한 가능성을 열어줄 것이라 생각합니다.

 

 

Method

Preliminary : SAM architecture

시작하기에 앞서, SAM의 구조에 대해 알아보겠습니다. SAM은 크게 3개의 main component로 이루어져 있습니다. 

image encoder, prompt encoder, mask decoder

image encoder는 MAE로 사전 학습된 표준 Vision Transformer(ViT)를 기반으로 합니다. 구체적으로, ViT-H/16의 variant를 사용하고, 이는 14x14 windowed attention과 four equally-spaced global attention block을 사용합니다.

image encoder의 output은 16x donwsample된 input image의 embedding입니다. prompt encoder는 sprase(point,box)와 dense(mask)가 될 수 있습니다. 여기선 sparse encoder에만 다룰 예정이고, 이는 각 prompt type에 대해 학습된 embedding과 positional encoding으로 합쳐지도록 표현됩니다. mask decoder는 dynamic mask prediction head가 추가되도록 수정된 Transformer decoder block입니다. decoder는 prompt와 image embedding간의 interaction을 학습하기 위해 two-way cross-attention을 사용합니다. 이후, SAM은 image embedding을 upsample하고, MLP은 dynamic linear classifier에 output token을 mapping하고, 결국 이는 주어진 image에 대해 target mask를 예측하게 됩니다. 

 

Med-SA architecture

우리의 목표는 SAM 구조가 medical image segmentation task에 대해 fine-tuning을 통해 medical capability를 향상시키는 것입니다. 모든 parameter를 전체적으로 조정하는 것 보다, 사전 학습된 SAM parameter는 고정시키고 Adapter module을 고안해 설계된 위치에 통합합니다. Adapter는 bottleneck model로 제공하고, 순차적으로 down-projection, ReLU activation, up-projection으로 이루어져 있습니다.

down-projection은 심플한  MLP layer를 사용해 주어진 embedding을 압축해 더 낮은 차원으로 바꿔줍니다. up-projection은 또다른 MLP layer를 사용해 압축된 embedding을 다시 original dimension으로 확장시켜줍니다. 

SAM encoder에서, 각각의 ViT block에 대해 2개의 adapter를 사용합니다. standard ViT block에 대해, 첫 Adapter는  multi-head attention 이후와 residual connection 이전에 위치합니다. 두번째 Adapter는 multi-head attention 이후에 MLP layer의 residual path에 위치합니다. 두번째 adapter 이후에 scale factor s로 embedding을 scaling해줬다고 합니다. 

SAM decoder에서, 각 ViT block에 대해 3개의 adapter를 사용합니다. 첫번째 Adapter는 prompt embedding을 통합하는 역할을 하고, 이를 위해 Hyper-Prompting Adapter라 불리는 novel structure를 제안했습니다. 두번째 Adapter는 encoder에서와 정확히 일치하고, MLP-enhanced embedding을 사용합니다. 세번째 Adapter는 image embedding-to-prompt cross-attention의 residual connection이후에 위치합니다. 다른 residual connection과 layer normalization은 adaption 이후에 마지막 결과를 도출하기 위해 연결됩니다. 

 

SD-Trans

SAM을 medical image segmentation에 적용하는 것은 2D image간의 dimension 차이와 MRI 나 CT는 3D modality로 이루어져 있기 때문에 다소 어렵습니다. 임상적으로 봤을 때, slice간의 상관관계를 이해하는 것은 정확한 결정을 내리는 데에 매우 중요합니다.  최종 segmentation을 얻기 위해 한 volume에서 각각의 slice가 SAM에 사용될 수 있겠지만, 이 방법은 3D medical imag esegmentation에서 가까운 volume 관점에서의 correlation을 고려할 수 없게 됩니다. 이런 문제를 극복하기 위해, SD-Trans를 제시했고 이는 image를 video에 적용하는 방법을 제시한 "Deep image to-video adaptation and fusion networks for action recognition" 논문의 영감을 받은 아이디어 입니다. 

위의 그림에서 볼 수 있듯이, 각 block에서 attention operation을 depth branch, space branch 이렇게 두 branch로 나뉘어 집니다. depth가 D인 3D sample이 주어졌을 때, spcae branch의 multi-head attention에 .DxNxL을 입력합니다. 여기서 N은 embedding의 갯수, L은 embedding 길이가 됩니다. 여기서 D는 NxL에 적용되는 interaction에 대해 operation의 갯수라고 볼 수 있습니다. 이 때 NxL은 embedding으로 spatial correlation을 알아내고 요약하게 됩니다. depth branch에선, 기존의 input matrix를 NxDxL를 얻을 수 있게 transpose를 시켜주고 똑같이 multi-head attention에 넣어주게 됩니다. 같은 attention mechanism을 사용했지만, interaction은 DxL에서 발생하기 때문에, depth correlation을 학습하고 요약할 수 있게 됩니다. 마지막으로 NxDxL로 바꾼 depth branch를 다시 original shape으로 바꿔 space branch의 결과와 더해주어 최종적으로 깊이 정보도 포함하게 됩니다. 

 

HyP-Adpt 

이전 몇몇의 연구에서 visual model에 adaptation 기술을 적용해왔지만, interactive한 visual model에 적용하는 건 여전히 많이 연구되지 않았습니다. source task에서의 interactive behavior와 donwstream task의 interactive behavior은 다를 수 밖에 없습니다. 그렇기 때문에, interactive model에서 중요한 역할을 하는 visual prompt 정보를 adapter에 합하고자 했습니다. 이와 관련해, HyP-Adpt라는 solution을 제공해 prompt-conditioned adaptation을 완성하고자 했습니다. 

HyP-Adpt의 아이디어는 HyperNetworks로부터 영감을 얻었고, 이는 knowledge conditioning을 위해 하나의 network가 다른 network의 weight를 생성하는 것입니다. 이런 HyperNetworks의 핵심 개념을 사용하되 이를 feature level에서 더 효율적으로 적용될 수 있도록 수정했습니다. prompt embedding을 바탕으로 weight map의 sequence를 생성하기 위해 간단히 projection과 operation reshaping만 사용합니다. 이렇게 생성된 weight map은 adapter embedding에 바로 사용됩니다. 이 방법은 wide하고 deep한 feature-level interaction도 가능하고 학습에 필요한 parameter의 수도 줄일 수 있습니다.

Adapter (e(down))의 embedding에 대해 hyper-prompting을 수행합니다. 그러는 동안 prompt information은 concat되고 prompt embedding (e(prompt))로 줄어듭니다. 그 다음 e(prompt)를 weight의 sequence를 생성하는 데에 사용합니다. 

여기서 Re는 reshape을 나타내고, M은 MLP layer로 NxL shape을 가진 e(prompt)를 Nx(L(in)*L(out))의 shape을 가지도록 project합니다.  첫번째 weight의 L(in)은 e(down)의 길이가 되고, 마지막 weight의 L(out)은 output의 target length가 됩니다. 이 후, e(prompt)의 shape을 1D embedding에서 2D weight w(prompt) (NxL(in)xL(out))로 reshape하고 e(down)에 적용합니다. 

 

Training Strategy 

interactive segmentation을 위해, model이 학습하는 동안 click prompt와 bounding box를 이용합니다. BBox prompt를 생성하기 위해, SAM과 같은 방법을 적용합니다. 그러나 original SAM에선 click prompt generation에 대해 자세히 나와있지 않기 때문에, 다른 방법을 고안했습니다. 

click prompt generation 과정의 근본적인 개념은 foreground region을 나타내기 위해 positive click을 사용하고 background region을 나타내기 위해 negative click을 사용합니다. 

 

 

Experiments 

Dataset

2가지 타입으로 분류될 수 있는 5개의 distinct medical image segmentation dataset에 대해 실험을 진행했습니다. 첫번째 type은 일반적인 segmentation performance를 평가할 수 있도록 초점이 맞춰졌습니다. benchmark dataset으로 12개의 anatomy로 이루어진 널리 사용되는 BTCV dataset을 사용했습니다. 

4개의 task에 대해선 다른 modality간의 실험을 통해 model의 generalization을 확인했습니다.

(fundus image에 대해 optic disc, optic cup segmentation)

(brain MRI에 대해 brain tumor segmentation)

(ultrasound image에 대해 thyroid nodule segementation)

(dermoscopic image에 대해 melanoma or nevus segmentation)

 

 

Implementation Details

interactive model에서, 4개의 다른 prompt setting에 대해 실험했습니다. 

(1) "1-point"로 명시된 임의의 1개의 positive point

(2) "3-point"로 명시된 3개의 positive point

(3) "BBox 0.5"로 명시된 target의 50% overlapping된 bounding box

(4) "BBox 0.75"로 명시된 target의 75% overlappiung된 bounding box

 

Comparing with SOTA on Abdominal Multi-organ Segmentation

Med-SA model의 general performance를 입증하기 위해 BTCB dataset에 대해 SOTA segmentation method와 비교를 진행했습니다. 이 중 잘 알려진 nnUNet. TransUNet, UNetr, EnsDiff, SegDiff, vanilla SAM, MedSAM 이 존재합니다. 그리고 segmentation performance는 Dice score를 통해 평가했습니다. 

1-point prompt를 사용할 때 SAM보다 Med-SA 성능이 뛰어났고, BTCV dataset에 대해 one-point Med-SA는 12개의 organ에 대해 SOTA performance를 이뤘습니다. 그리고 더 정제된 prompt를 줄수록 더 좋은 결과가 나왔고, BBox 0.75에 대해 89.8%의 Dice 성능이 나왔습니다. 

 

 

Conclusion

이 논문에선 강력한 general segmentation model인 SAM을 확장해 medical image segmentation을 다룰 수 있도록 했습니다. 간단하지만 효과적인 parameter-efficient adaptation인 SD-Trans와 HyP-Adpt를 적용해, 기존 SAM model에 대해 상당히 성능이 향상됐습니다. 5개의 다른 modality를 가진 17개의 medical image segmentation task에 대해 SOTA를 얻었습니다. 이 연구를 통해 foundation medical image segmentation에 대한 디딤돌이 될 수 있습니다.