Background

CLIP SAM
Training text-image pairs mask labels
Strength zero-shot visual recognition highly adaptable to a wide range of downstream tasks through interactive prompts
Drawback suboptimal results for dense prediction tasks, e.g., segmentation lacking the capability to recognize the segments it identifies

Two Naïve Baselines

Preceding this work, there have been some naïve attempts to combine CLIP and SAM.

Image Cropping Baseline
Image Cropping Baseline

Removal of uninterested part of an image also removes the contexts surrounding the object to recognize.

Feature Cropping Baseline
Feature Cropping Baseline
Subpar results for small-scale objects.

Abstract

The CLIP and Segment Anything Model (SAM) are remarkable vision foundation models (VFMs).

SAM excels in segmentation tasks across diverse domains, whereas CLIP is renowned for its zero-shot recognition capabilities.

This paper presents an in-depth exploration of integrating these two models into a unified framework.

Specifically, we introduce the Open-Vocabulary SAM, a SAM-inspired model designed for simultaneous interactive segmentation and recognition, leveraging two unique knowledge transfer modules: SAM2CLIP and CLIP2SAM.

The former adapts SAM’s knowledge into the CLIP via distillation and learnable transformer adapters, while the latter transfers CLIP knowledge into SAM, enhancing its recognition capabilities.

Extensive experiments on various datasets and detectors show the effectiveness of Open-Vocabulary SAM in both segmentation and recognition tasks, significantly outperforming the naïve baselines of simply combining SAM and CLIP.

Furthermore, aided with image classification data training, our method can segment and recognize approximately 22,000 classes.

Problems

Problem Solution
1 computational costs a unified architechture
2 knowledge transfer between different architecture SAM2CLIP & CLIP2SAM
3 recognition of small objects Lightweight Feature Pyramid Network (FPN)
4 integrating Open-Vocabulary capabilities frozen CLIP backbone

The Unified Architecture

Working pipeline
Working pipeline
Open-Vocabulary SAM
Open-Vocabulary SAM

SAM2CLIP

Overview

SAM2CLIP plays the role of a student network and aligns the knowledge of SAM into CLIP.

Aim: Aligning CLIP features with SAM’s representation.

Method: Adaptation plus distillation.

SAM2CLIP
SAM2CLIP

SAM2CLIP has a symetrical architecture, so as to enable bi-directional knowledge transfer.

  • SAM’s feature map: $F_\mathrm{SAM}$
  • CLIP’s feature map: pyramid CLIP features $E_{I}^i \quad (i = 1, 2, 3)$ (high-resolution and semantic information)
  • multi-scale adapter: $A_\mathrm{SAM2CLIP}$ (several Transformer layers)
  • distillation loss: $\mathcal{L}_\mathrm{distill} = \mathrm{MSE}\left(F_\mathrm{SAM}, A_\mathrm{SAM2CLIP}\left(\mathrm{Fusion}\left(E_\mathrm{I}^i\right)\right)\right)$ , where $\mathrm{Fusion}$ is achieved by bilinear upsampling.

Multi-scale processing in the adapter:

Implementation Details

Below is the implementation of SAM2CLIP. backbone_teacher refers to a frozen SAM backbone, backbone_student refers to a frozen CLIP backbone. neck_teacher and neck_student serve to align the feature representations from the teacher and student models, making it possible to compute a meaningful distillation loss (such as MSE) between them.

model = dict(
    type=BackboneDistillation,
    use_cache=True,
    data_preprocessor=data_preprocessor,
    backbone_teacher=dict( # 冻结 SAM 骨干作为教师模型
        type=SAMBackbone,
        model_name='vit_h',
        fix=True,
        init_cfg=dict(
            type='sam_pretrain',
            checkpoint='vit_h'
        )
    ),
    backbone_student=dict( # 冻结 CLIP 骨干作为学生模型
        type=OpenCLIPBackbone,
        model_name='RN50x16',
        fix=True,
        init_cfg=dict(
            type='clip_pretrain',
            checkpoint='openai'
        )
    ),
    neck_teacher=dict(type=LastLayerNeck), # SAM 的最后一层作为 neck_teacher
    neck_student=dict( # CLIP 的多层 Transformer 作为 neck_student
        type=MultiLayerTransformerNeck,
        input_size=(1024, 1024),
        in_channels=[384, 768, 1536, 3072],
        strides=[4, 8, 16, 32],
        layer_ids=(0, 1, 2, 3), # 特征图融合:第 0 层(原图)和第 1 ~ 3 层(下采样的特征图)
        embed_channels=1280,
        out_channels=256,
        embedding_path='sam_vit_h'
    ),
    loss_distill=dict( # MSE 蒸馏损失
        type=MSELoss,
        reduction='mean',
        loss_weight=1.
    )
)

CLIP2SAM

Overview

Aim: leveraging CLIP’s knowledge to enhance the reocognition capabilities of the SAM encoder.

Method: SAM uses $Q_\mathrm{mask}$ and $Q_\mathrm{IoU}$ tokens to do segmentation. On this basis, appending $Q_\mathrm{label}$ tokens to store class label information.

Copilot: A Two-Way Transformer is a specialized transformer architecture designed to enable bidirectional attention flow between two different types of inputs—typically sparse queries (like points, prompts, or tokens) and dense inputs (like image features or text embeddings).

Now we have:

  • Tokens: $Q_\mathrm{mask}$ , $Q_\mathrm{IoU}$ , and $Q_\mathrm{label}$ ;
  • Features: output of the prompt encoder, CLIP features.

A two-way transformer is a good choice for better alignment, but this design falls short of recognizing small objects (Problem 3) because the adaptation process only involves single-scale feature (only the output feature is used).

Multi-Scale Feature Fusion

To better recognize small objects, it is proposed to use Feature Pyramid Network (FPN, proposed in CVPR’2017), which can “fuse” multiple-scale features, as the adapter.

CLIP2SAM
CLIP2SAM

Point Prompts vs. Region Prompts

For point prompts, first obtain masks fron SAM decoder and then obtain bounding box via the masks.

For region prompts, directly send them to the FPN.

Open Vocabulary

Fuse the learned class scores with those from the frozen CLIP via a geometric mean to leverage information from both the CLIP and CLIP2SAM.

alpha: float = .1,
beta: float = .9,
# ...
clip_logit = clip_logit.softmax(-1) # CLIP logits
query_logit = query_logit.softmax(-1) # CLIP2SAM logits

cls_logits_seen = (
    (query_logit ** (1 - alpha) * clip_logit ** alpha).log() * overlapping_mask
)
cls_logits_unseen = (
    (query_logit ** (1 - beta) * clip_logit ** beta).log() * (1 - overlapping_mask)
)
cls_results = cls_logits_seen + cls_logits_unseen

Training

  • Training SAM2CLIP: $\mathcal{L}_\mathrm{distill}$

  • Jointly training CLIP2SAM and mask decoder: $\mathcal{L} = \lambda_\mathrm{cls}\mathcal{L}_\mathrm{cls} + \lambda_\mathrm{ce}\mathcal{L}_\mathrm{t\_ce} + \lambda_\mathrm{dice}\mathcal{L}_\mathrm{t\_dice},$ where $\mathcal{L}_\mathrm{t\_\star}$ is segmentation loss.

Since the features are updated, so SAM decoder must be re-trained. During the training of CLIP2SAM, SAM2CLIP is kept frozen.

model = dict(
    type=CLIP2SAM,
    # ...
    neck=dict(
        type=MultiLayerTransformerNeck,
        input_size=(1024, 1024),
        in_channels=[384, 768, 1536, 3072],
        strides=[4, 8, 16, 32],
        layer_ids=(0, 1, 2, 3),
        embed_channels=1280,
        out_channels=256,
        fix=True, # 冻结 neck_student
        init_cfg=dict(
            type='Pretrained',
            checkpoint='./models/sam2clip_vith_rn50x16.pth',
            prefix='neck_student',
        )
    ),
    mask_decoder=dict( # 在冻结 SAM2CLIP 的情况下训练 mask decoder
        type=OVSAMHead,
        model_name='vit_h',
        with_label_token=True,
        ov_classifier_name='RN50x16_LVISV1Dataset',
        roi_extractor=dict(
            type=SingleRoIExtractor,
            roi_layer=dict(type=RoIAlign, output_size=12, sampling_ratio=0),
            out_channels=256,
            featmap_strides=[4, 8, 16, 32]
        ),
        fix=False,
        init_cfg=dict(
            type='sam_pretrain',
            checkpoint='vit_h'
        ),
        loss_cls=dict( # 分类损失
            type=CrossEntropyLoss,
            use_sigmoid=False,
            loss_weight=2.0,
            reduction='mean'
        ),
        loss_mask=dict( # 交叉熵分割损失
            type=CrossEntropyLoss,
            use_sigmoid=True,
            reduction='mean',
            loss_weight=5.0
        ),
        loss_dice=dict( # Dice 损失
            type=DiceLoss,
            use_sigmoid=True,
            activate=True,
            reduction='mean',
            naive_dice=True,
            eps=1.0,
            loss_weight=5.0
        )
    )
)

Comparison

Comparison with baselines

Using Detic [1] as the detector to generate box prompts.

Comparison with baseline
Comparison with baseline
Comparison with baseline
Comparison with baseline

Small Object Recognition

Comparison on small object recognition
Comparison on small object recognition

Comparison with CLIP & SAM

Comparison on instance segmentation
Comparison on instance segmentation
Comparison of mask quality
Comparison of mask quality

OVSAM maintains the segmentation quality of SAM after distillation and alignment.

Ablation

Effectiveness of SAM2CLIP & CLIP2SAM
Effectiveness of SAM2CLIP & CLIP2SAM
Different Design of SAM2CLIP & CLIP2SAM
Different Design of SAM2CLIP & CLIP2SAM
Different CLIP backbones & SAM variants
Different CLIP backbones & SAM variants

Summary

  • Explores interactive open-vocaulary segmentation for the first time.

  • Two efficient modules: SAM2CLIP and CLIP2SAM.

[1]
X. Zhou, R. Girdhar, A. Joulin, P. Krähenbühl, and I. Misra, “Detecting twenty-thousand classes using image-level supervision,” in European conference on computer vision, Springer, 2022, pp. 350–368.