
Background
| CLIP | SAM | |
|---|---|---|
| Training | text-image pairs | mask labels |
| Strength | zero-shot visual recognition | highly adaptable to a wide range of downstream tasks through interactive prompts |
| Drawback | suboptimal results for dense prediction tasks, e.g., segmentation | lacking the capability to recognize the segments it identifies |
Two Naïve Baselines
Preceding this work, there have been some naïve attempts to combine CLIP and SAM.
Removal of uninterested part of an image also removes the contexts surrounding the object to recognize.
Abstract
The CLIP and Segment Anything Model (SAM) are remarkable vision foundation models (VFMs).
SAM excels in segmentation tasks across diverse domains, whereas CLIP is renowned for its zero-shot recognition capabilities.
This paper presents an in-depth exploration of integrating these two models into a unified framework.
Specifically, we introduce the Open-Vocabulary SAM, a SAM-inspired model designed for simultaneous interactive segmentation and recognition, leveraging two unique knowledge transfer modules: SAM2CLIP and CLIP2SAM.
The former adapts SAM’s knowledge into the CLIP via distillation and learnable transformer adapters, while the latter transfers CLIP knowledge into SAM, enhancing its recognition capabilities.
Extensive experiments on various datasets and detectors show the effectiveness of Open-Vocabulary SAM in both segmentation and recognition tasks, significantly outperforming the naïve baselines of simply combining SAM and CLIP.
Furthermore, aided with image classification data training, our method can segment and recognize approximately 22,000 classes.
Problems
| Problem | Solution | |
|---|---|---|
| 1 | computational costs | a unified architechture |
| 2 | knowledge transfer between different architecture | SAM2CLIP & CLIP2SAM |
| 3 | recognition of small objects | Lightweight Feature Pyramid Network (FPN) |
| 4 | integrating Open-Vocabulary capabilities | frozen CLIP backbone |
The Unified Architecture
SAM2CLIP
Overview
SAM2CLIP plays the role of a student network and aligns the knowledge of SAM into CLIP.
Aim: Aligning CLIP features with SAM’s representation.
Method: Adaptation plus distillation.
SAM2CLIP has a symetrical architecture, so as to enable bi-directional knowledge transfer.
- SAM’s feature map: $F_\mathrm{SAM}$
- CLIP’s feature map: pyramid CLIP features $E_{I}^i \quad (i = 1, 2, 3)$ (high-resolution and semantic information)
- multi-scale adapter: $A_\mathrm{SAM2CLIP}$ (several Transformer layers)
- distillation loss: $\mathcal{L}_\mathrm{distill} = \mathrm{MSE}\left(F_\mathrm{SAM}, A_\mathrm{SAM2CLIP}\left(\mathrm{Fusion}\left(E_\mathrm{I}^i\right)\right)\right)$ , where $\mathrm{Fusion}$ is achieved by bilinear upsampling.
Multi-scale processing in the adapter:
Implementation Details
Below is the implementation of SAM2CLIP. backbone_teacher refers to a frozen SAM backbone, backbone_student refers to a frozen CLIP backbone.
neck_teacher and neck_student serve to align the feature representations from the teacher and student models, making it possible to compute a meaningful distillation loss (such as MSE) between them.
model = dict(
type=BackboneDistillation,
use_cache=True,
data_preprocessor=data_preprocessor,
backbone_teacher=dict( # 冻结 SAM 骨干作为教师模型
type=SAMBackbone,
model_name='vit_h',
fix=True,
init_cfg=dict(
type='sam_pretrain',
checkpoint='vit_h'
)
),
backbone_student=dict( # 冻结 CLIP 骨干作为学生模型
type=OpenCLIPBackbone,
model_name='RN50x16',
fix=True,
init_cfg=dict(
type='clip_pretrain',
checkpoint='openai'
)
),
neck_teacher=dict(type=LastLayerNeck), # SAM 的最后一层作为 neck_teacher
neck_student=dict( # CLIP 的多层 Transformer 作为 neck_student
type=MultiLayerTransformerNeck,
input_size=(1024, 1024),
in_channels=[384, 768, 1536, 3072],
strides=[4, 8, 16, 32],
layer_ids=(0, 1, 2, 3), # 特征图融合:第 0 层(原图)和第 1 ~ 3 层(下采样的特征图)
embed_channels=1280,
out_channels=256,
embedding_path='sam_vit_h'
),
loss_distill=dict( # MSE 蒸馏损失
type=MSELoss,
reduction='mean',
loss_weight=1.
)
)
CLIP2SAM
Overview
Aim: leveraging CLIP’s knowledge to enhance the reocognition capabilities of the SAM encoder.
Method: SAM uses $Q_\mathrm{mask}$ and $Q_\mathrm{IoU}$ tokens to do segmentation. On this basis, appending $Q_\mathrm{label}$ tokens to store class label information.
Copilot: A Two-Way Transformer is a specialized transformer architecture designed to enable bidirectional attention flow between two different types of inputs—typically sparse queries (like points, prompts, or tokens) and dense inputs (like image features or text embeddings).
Now we have:
- Tokens: $Q_\mathrm{mask}$ , $Q_\mathrm{IoU}$ , and $Q_\mathrm{label}$ ;
- Features: output of the prompt encoder, CLIP features.
A two-way transformer is a good choice for better alignment, but this design falls short of recognizing small objects (Problem 3) because the adaptation process only involves single-scale feature (only the output feature is used).
Multi-Scale Feature Fusion
To better recognize small objects, it is proposed to use Feature Pyramid Network (FPN, proposed in CVPR’2017), which can “fuse” multiple-scale features, as the adapter.
Point Prompts vs. Region Prompts
For point prompts, first obtain masks fron SAM decoder and then obtain bounding box via the masks.
For region prompts, directly send them to the FPN.
Open Vocabulary
Fuse the learned class scores with those from the frozen CLIP via a geometric mean to leverage information from both the CLIP and CLIP2SAM.
alpha: float = .1,
beta: float = .9,
# ...
clip_logit = clip_logit.softmax(-1) # CLIP logits
query_logit = query_logit.softmax(-1) # CLIP2SAM logits
cls_logits_seen = (
(query_logit ** (1 - alpha) * clip_logit ** alpha).log() * overlapping_mask
)
cls_logits_unseen = (
(query_logit ** (1 - beta) * clip_logit ** beta).log() * (1 - overlapping_mask)
)
cls_results = cls_logits_seen + cls_logits_unseen
Training
-
Training SAM2CLIP: $\mathcal{L}_\mathrm{distill}$
-
Jointly training CLIP2SAM and mask decoder: $\mathcal{L} = \lambda_\mathrm{cls}\mathcal{L}_\mathrm{cls} + \lambda_\mathrm{ce}\mathcal{L}_\mathrm{t\_ce} + \lambda_\mathrm{dice}\mathcal{L}_\mathrm{t\_dice},$ where $\mathcal{L}_\mathrm{t\_\star}$ is segmentation loss.
Since the features are updated, so SAM decoder must be re-trained. During the training of CLIP2SAM, SAM2CLIP is kept frozen.
model = dict(
type=CLIP2SAM,
# ...
neck=dict(
type=MultiLayerTransformerNeck,
input_size=(1024, 1024),
in_channels=[384, 768, 1536, 3072],
strides=[4, 8, 16, 32],
layer_ids=(0, 1, 2, 3),
embed_channels=1280,
out_channels=256,
fix=True, # 冻结 neck_student
init_cfg=dict(
type='Pretrained',
checkpoint='./models/sam2clip_vith_rn50x16.pth',
prefix='neck_student',
)
),
mask_decoder=dict( # 在冻结 SAM2CLIP 的情况下训练 mask decoder
type=OVSAMHead,
model_name='vit_h',
with_label_token=True,
ov_classifier_name='RN50x16_LVISV1Dataset',
roi_extractor=dict(
type=SingleRoIExtractor,
roi_layer=dict(type=RoIAlign, output_size=12, sampling_ratio=0),
out_channels=256,
featmap_strides=[4, 8, 16, 32]
),
fix=False,
init_cfg=dict(
type='sam_pretrain',
checkpoint='vit_h'
),
loss_cls=dict( # 分类损失
type=CrossEntropyLoss,
use_sigmoid=False,
loss_weight=2.0,
reduction='mean'
),
loss_mask=dict( # 交叉熵分割损失
type=CrossEntropyLoss,
use_sigmoid=True,
reduction='mean',
loss_weight=5.0
),
loss_dice=dict( # Dice 损失
type=DiceLoss,
use_sigmoid=True,
activate=True,
reduction='mean',
naive_dice=True,
eps=1.0,
loss_weight=5.0
)
)
)
Comparison
Comparison with baselines
Using Detic [1] as the detector to generate box prompts.
Small Object Recognition
Comparison with CLIP & SAM
OVSAM maintains the segmentation quality of SAM after distillation and alignment.
Ablation
Summary
-
Explores interactive open-vocaulary segmentation for the first time.
-
Two efficient modules: SAM2CLIP and CLIP2SAM.