728x90

1. Introduction

이 논문의 contribution

  1. vector-quantized knowledge distillation
  2. patch aggregation strategy
  3. extensive experiments on downstream tasks, such as ImageNet fine-tuning, linear probing, and semantic segmentation.

2. Method

2.1. Training Visual Tokenizer

Backbone : VIT(Vision Transformers)

224 x 224 images를 14 x 14 Grid image patch로 나눈다. 각 patch를 flatten하고, transformers input embedding으로 projection합니다.

 

VQ-KD(Vector Quntized Knowledge Distillation)은 2개의 모듈 Visual Tokenizer와 Decoder를 학습시킵니다. Visual Tokenizer는 vision Transformer encoder와 quantizer로 구성되어 있습니다. Tokenizer는 먼저 이미지를 벡터로 인코딩합니다. 다음에, vector quantizer는 codebook에서 각 patch의 인코딩된 벡터에 대해 nearest neighbor를 찾습니다.

이 distance는 cosine similarity로 codes를 찾는 것과 같습니다.

이미지를 visual tokens(z_i)로 quantizing 한 이후에 L2 normalized 된 codebook embeddings(l2norm(v_z_i))를 decoder에 넣어줍니다. decoder는 multi-layer transformers이고, output을 teacher model인 DINO와 CLIP의 semantic features를 재구성하도록 학습합니다. decoder output과 teacher guidance의 cosine similiarity를 최대화합니다.

quantization process는 미분 불가능하기 때문에 encoder에서 back-propagate를 수행하기 위해서 decoder의 input에서 gradients를 encoder의 output으로 복사합니다.

 

quantizer는 각각의 encoder output을 위한 nearest code를 찾기 때문에 gradients of codebook embeddings는 encoder의 유용한 최적화 방향을 일러줍니다.

 

Object function of VQKD

  • reconstruction loss(KD) + codebook loss + commitment loss
    • reconstruction loss : Decoder가 Teacher Model의 feature map을 모방하도록 학습
    • codebook loss : codebook vector가 encoder의 output을 잘 표현하도록 학습
    • commitment loss : encoder의 output이 codebook embedding과 값의 차이가 벌어지지 않도록 규제
  • D : tokenizer training에 사용된 이미지 개수
  • o_i : decoder output vector of i-th image patch
  • t_i : teacher model’s feature vector of i-th image patch
  • h_i : encoding vector of i-th image patch
  • v_z_i : quantized vector of i-th image patch

codebook 활용의 향상

  • codebook lookup을 할 때 차원 축소(32차원)와 l2 norm(equation 1)을 적용하였습니다.
  • 낮은 차원(32차원)의 codebook embeddings에서 높은 차원으로 mapping한 후 decoder에 넣어줍니다.
  • EMA(Exponential Moving Averages)를 적용하여 안정적인 실험 결과를 얻었습니다.

VQKD 코드

class VQKD(nn.Module):
    def __init__(self,
             encoder_config,
             decoder_config,
             n_embed=8192, 
             embed_dim=32,
             decay=0.99,
             process_type='default',
             quantize_kmeans_init=True,
             teacher_model_type='clip',
             decoder_out_dim=512,
             rec_loss_type='cosine',
             **kwargs
             ):
        super().__init__()

        # encoder & decode params
        print('Final encoder config', encoder_config)
        self.encoder = VisionTransformer(**encoder_config)

        print('Final decoder config', decoder_config)
        self.decoder = VisionTransformer(**decoder_config)

        self.quantize = NormEMAVectorQuantizer(
            n_embed=n_embed, embedding_dim=embed_dim, beta=1.0, kmeans_init=quantize_kmeans_init, decay=decay,
        )

        ## Teacher model setting
        self.teacher_model_type = teacher_model_type
        self.decoder_out_dim = decoder_out_dim
        if self.teacher_model_type == 'clip':
            self.scaling_layer = ScalingLayerForClip()
            self.teacher_model, _ = clip.load("ViT-B/16", device='cpu', jit=False)
            self.decoder_out_dim = 512

        elif self.teacher_model_type == 'dino':
            self.scaling_layer = ScalingLayerForIM()
            self.teacher_model = get_dino_vit_base()
            self.decoder_out_dim = 768

        else:
            self.teacher_model = None

        if self.teacher_model is not None:
            for param in self.teacher_model.parameters():
                param.requires_grad = False # fix teacher_model model

            self.teacher_model.eval()

        # task layer
        self.encode_task_layer = nn.Sequential(
            nn.Linear(encoder_config['embed_dim'], encoder_config['embed_dim']),
            nn.Tanh(),
            nn.Linear(encoder_config['embed_dim'], embed_dim) # for quantize
        )
        self.decode_task_layer = nn.Sequential(
            nn.Linear(decoder_config['embed_dim'], decoder_config['embed_dim']),
            nn.Tanh(),
            nn.Linear(decoder_config['embed_dim'], self.decoder_out_dim),
        ) # codebook dimension 32차원에서 512차원으로 차원 증가

        self.encode_task_layer.apply(self._init_weights)
        self.decode_task_layer.apply(self._init_weights)

    def get_tokens(self, data, **kwargs):

        data = self.pre_process(data)
        quantize, embed_ind, loss = self.encode(data)
        output = {}
        output['token'] = embed_ind.view(data.shape[0], -1)
        output['input_img'] = data

        return output

    def encode(self, x):
        encoder_features = self.encoder(x, return_patch_tokens=True)

        with torch.cuda.amp.autocast(enabled=False):
            to_quantizer_features = self.encode_task_layer(encoder_features.type_as(self.encode_task_layer[-1].weight))

        N = to_quantizer_features.shape[1]
        h, w = int(math.sqrt(N)), int(math.sqrt(N))

        to_quantizer_features = rearrange(to_quantizer_features, 'b (h w) c -> b c h w', h=h, w=w) # reshape for quantizer
        quantize, loss, embed_ind = self.quantize(to_quantizer_features)

        return quantize, embed_ind, loss
    
    def decode(self, quantize, **kwargs):
        # reshape tokens to feature maps for patch embed in decoder
        # quantize = rearrange(quantize, 'b (h w) c -> b c h w', h=self.token_shape[0], w=self.token_shape[1])
        decoder_features = self.decoder(quantize, return_patch_tokens=True)
        rec = self.decode_task_layer(decoder_features) # 32차원 channel에서 512차원으로 확장

        return rec
    
    def get_codebook_indices(self, x, **kwargs):
        # for beit pre-training
        return self.get_tokens(x, **kwargs)['token']

    @torch.no_grad()
    def get_regress_target(self, x, **kwargs): # teacher model의 feature map 뽑기

        norm_imgs = self.scaling_layer(x)
        if self.teacher_model_type == 'clip':
            target = self.teacher_model.encode_image(norm_imgs, return_all_tokens=True) @ self.teacher_model.visual.proj
        elif self.teacher_model_type == 'dino':
            target = self.teacher_model.forward(norm_imgs, return_patch_tokens=True)
        else:
            raise NotImplementedError

        return target

    def calculate_rec_loss(self, rec, target):
        if self.rec_loss_type == 'cosine':
            target = target / target.norm(dim=-1, keepdim=True)
            rec = rec / rec.norm(dim=-1, keepdim=True)
            rec_loss = (1 - (target * rec).sum(-1)).mean() # teacher model feature map과 decoder의 
        else:                                              # feature map의 cosine 유사도를 최대화함.
            raise NotImplementedError

        return rec_loss

	def get_tokens(self, data, **kwargs):      
        data = self.pre_process(data)
        quantize, embed_ind, loss = self.encode(data)
        output = {}
        output['token'] = embed_ind.view(data.shape[0], -1)
        output['input_img'] = data

        return output

    def forward(self, x, **kwargs):
        """
        x: shape [B, 3, H, W] in [0, 1]
        """
        x = self.pre_process(x) # rescale to [-1, 1]

        target = self.get_regress_target(x, **kwargs) # [B, H, W, 512]

        quantize, embed_ind, emb_loss = self.encode(x) # [B, C, H, W] -> [B, H, W, C]
        xrec = self.decode(quantize) # [B, H, W, 512]

        rec_loss = self.calculate_rec_loss(xrec, target)
        loss = emb_loss + rec_loss

        return loss

Codebook 코드

class NormEMAVectorQuantizer(nn.Module):
    def __init__(self, n_embed, embedding_dim, beta, decay=0.99, eps=1e-5, 
                statistic_code_usage=True, kmeans_init=False, codebook_init_path=''):
        super().__init__()
        self.codebook_dim = embedding_dim
        self.num_tokens = n_embed
        self.beta = beta
        self.decay = decay
        
        # learnable = True if orthogonal_reg_weight > 0 else False
        self.embedding = EmbeddingEMA(self.num_tokens, self.codebook_dim, decay, eps, kmeans_init, codebook_init_path)

    def forward(self, z):
        # reshape z -> (batch, height, width, channel) and flatten
        #z, 'b c h w -> b h w c'
        z = rearrange(z, 'b c h w -> b h w c')
        z = l2norm(z)
        z_flattened = z.reshape(-1, self.codebook_dim)
        
        self.embedding.init_embed_(z_flattened)
        
        d = z_flattened.pow(2).sum(dim=1, keepdim=True) + \\ # distance for looking up nearest neighbor vector 
            self.embedding.weight.pow(2).sum(dim=1) - 2 * \\
            torch.einsum('bd,nd->bn', z_flattened, self.embedding.weight) # 'n d -> d n'
        
        encoding_indices = torch.argmin(d, dim=1) # visual tokens

        z_q = self.embedding(encoding_indices).view(z.shape) # quantizing vector
        
        encodings = F.one_hot(encoding_indices, self.num_tokens).type(z.dtype)     
        

        if self.training and self.embedding.update:
            #EMA cluster size

            bins = encodings.sum(0)
            self.all_reduce_fn(bins)

            # self.embedding.cluster_size_ema_update(bins)
            ema_inplace(self.cluster_size, bins, self.decay)

            zero_mask = (bins == 0)
            bins = bins.masked_fill(zero_mask, 1.)

            embed_sum = z_flattened.t() @ encodings
            self.all_reduce_fn(embed_sum)
                        
            embed_normalized = (embed_sum / bins.unsqueeze(0)).t()
            embed_normalized = l2norm(embed_normalized)
            
            embed_normalized = torch.where(zero_mask[..., None], self.embedding.weight,
                                           embed_normalized)
            norm_ema_inplace(self.embedding.weight, embed_normalized, self.decay)

        # compute loss for embedding
        loss = self.beta * F.mse_loss(z_q.detach(), z) 
        
        # preserve gradients
        z_q = z + (z_q - z).detach()

        # reshape back to match original input shape
        #z_q, 'b h w c -> b c h w'
        z_q = rearrange(z_q, 'b h w c -> b c h w')
        
        return z_q, loss, encoding_indices # quantizing vector, embedding loss, visual tokens

 

2.2 Pretraining BEIT v2

1. L_MIM : BEIT2에서 MIM(Masked Image Modeling)은 이미지 패치의 40%를 마스킹하여 마스킹 된 패치가 원래 이미지 패치의 어떤 Visual Token이었는지 맞추도록 학습하는 과정입니다. softmax classifer인 MIM Head는 마스킹된 인코딩 벡터를 입력으로 받아 visual token을 예측합니다.

  • D : pretraining images의 개수
  • M : 마스킹 된 패치
  • z_i : 원래 이미지 패치의 visual token

2. L_MIM_C : global representation을 위한 CLS token pretraining입니다. 패치 수준의 pretraining과 이미지를 종합하는 것의 간극을 줄이기 위해 학습을 합니다. MIM에서 사용되는 encoder의 l번째 layer의 N개의 이미지 패치 토큰들과 마지막 레이어의 CLS를 concat하여 얕은 decoder의 input으로 feed합니다. 마스킹된 벡터는 Visual Token을 예측하도록 학습됩니다. MIM loss를 줄이기 위하여 l+1부터 L번째 layer에 있는 모든 파라미터를 활용하여 마지막 레이어의 CLS token에 global information을 집어넣게 됩니다. BEIT v2 pretraining 최종적인 loss는 L_MIM + L_MIM_C 이 됩니다

 

3. 코드

samples, images, bool_masked_pos = batch # 마스킹 된 이미지 패치, 이미지, 마스킹 여부
images = images.to(device, non_blocking=True)
samples = samples.to(device, non_blocking=True)
bool_masked_pos = bool_masked_pos.to(device, non_blocking=True)

with torch.no_grad():
    with torch.cuda.amp.autocast():
        input_ids = vqkd.get_codebook_indices(images) # 원본 이미지로부터 visual token 뽑기
    bool_masked_pos = bool_masked_pos.flatten(1).to(torch.bool)
    labels = input_ids[bool_masked_pos] # 마스킹 된 부분의 visual token만 뽑도록 필터링

with torch.cuda.amp.autocast(): # enabled=False
    outputs = model(samples, bool_masked_pos=bool_masked_pos)

    if isinstance(outputs, list):
        loss_1 = loss_fn(input=outputs[0], target=labels) # MIM Loss
        loss_2 = loss_fn(input=outputs[1], target=labels) # MIM_C Loss
        loss = loss_1 + loss_2 
    else:
        loss = loss_fn(input=outputs, target=labels)

loss_value = loss.item()

3. Experiment

3.1 Pretraining Setup

Visual Tokenizer training : VQKD는 ImageNet-1k 224 x 224 데이터셋, CLIP teacher model을 활용합니다. codebook size K는 8192개, dimension D는 32개를 사용합니다.

 

Masked image modeling : ImageNet-1k 224 x 224 데이터셋을 사용하여 self-supervised learning으로 pretraining합니다. CLS token pretraining에서 ViT-B를 쓸 때는 l을 9, ViT-L을 쓸 때는 l을 21로 두었고, decoder의 depth를 2로 설정합니다. 이미지의 40%를 마스킹합니다.

3.2 Downstream task

Fine-tuning

Fine-tuning은 전체 파라미터를 학습시키는 것을 의미합니다. BEIT와 동일한 방식으로 finetuning합니다. BEIT v2는 ImageNet-1k와 ADE20K에서 SOTA를 달성하였습니다.

 

Linear probing

Linear probing은 backbone의 파라미터는 동결시키고 뒤에 붙힌 linear layers만 학습시키는 것을 의미합니다. Dall-e의 visual tokenizer를 사용하는 BEIT와 CAE의 성능을 웃돌았습니다.

3.3 Ablation Studies

Visual tokenizer training - model architecture와 codebook size가 VQKD에 미치는 영향을 알아봅니다. Table4는 decoder의 깊이가 깊어질수록 reconstruction은 잘하지만, codebook usage와 downstream task의 성능을 감소시키는 것을 나타냅니다. codebook의 차원을 축소하는 것이 codebook utilization을 향상시킵니다.

CLS token pretraining - table 5에서 head depth가 얕은 모델이(head depth=2) 깊은 모델(head depth=3)보다 성능이 더 좋습니다.

VQKD targets - table 6에서 보듯이 teacher model의 성능보다 더 좋은 성능을 내면서 VQKD의 scalability를 증명합니다.

4. Conclusion

visual tokenizer를 학습하는 VQKD와 global image representation을 학습하는 CLS token pretraining을 제안하였습니다.

5. Reference

 

728x90

+ Recent posts