728x90
1. Introduction and Motivating Work
- 2개의 인코더(텍스트 인코더와 이미지 인코더)가 짝이 맞는 텍스트, 이미지가 유사도가 높아지도록 짝이 맞지 않는 쌍은 유사도가 낮아지도록 사전학습합니다.
- 예를 들면, 풀을 뜯고 있는 코끼리 이미지와 '코끼리가 풀을 뜯는다'라는 문장은 유사도가 높아지도록 학습을 하고, '비행기가 날아다니고 있다' 라는 문장과는 유사도가 낮아지도록 학습하는 식입니다.
- CLIP은 초기에 image captioning baseline을 썼는데 Transformer 기반의 모델보다 zeroshot 성능이 뛰어났습니다.
- image captioning baseline보다 Contrastive 모델이 훨씬 효율적인 것으로 나타났습니다.
2. Approach
2.1 Natural Language Supervision
- 자연어로 지도학습하여 모델을 학습.
- 라벨링할 필요가 없어서 스케일링이 쉬움.
- zeroshot을 가능하게 함.
2.2 Creating a Sufficiently Large Dataset
- 주요 멀티 모달 데이터셋 3개
- MS-COCO →약 10만장
- Visual Genome → 약 10만장
- YFCC100M → 약 1억장인데 meta data 정보가 빈약한 데이터가 많아 필터링하면 약 1500만장 정도 활용 가능. 이는 ImageNet과 비슷한 정도
- 많은 양의 데이터셋을 확보하기 위해서 4억장의 이미지, 텍스트 페어를 인터넷으로부터 수집합니다.
2.3 Selecting an Efficient Pre-Training Method
Contrastive Learning
N개의 이미지, 텍스트 페어가 있을 때 배치 내에서 N * N개의 페어가 만들어질 수 있다. N^2-N 개의 틀린 페어의 코사인 유사도를 최소화하고 N개의 올바른 페어의 코사인 유사도를 최대화하도록 한다.
2.4 Choosing and Scaling a Model
Image Encoder : ResNet-50, Vision Transformer
Text Encoder : Transformer(a 63M-parameter 12-layer 512-wide model with 8 attention heads)
2.5 Training
- 5 ResNets(ResNet-50, a ResNet-101, EfficientNet-style model scaling 4x, 6x, and 64x the compute of a ResNet-50)
- 3 Vision Transformers(ViT-B/32, ViT-B/16, ViT-L/14)
- train 환경 : 32 epochs, the Adam optimizer, cosine schedule
3. Experiments
3.1 Zero-shot Transfer
3.1.4 Prompt Engineering AND ENSEMBLING
- pretraining은 주로 문장으로 학습하기 때문에 단일한 단어보다는 어느 정도의 문맥을 살린 “A photo of a {label}.”, “A photo of a {label}, a type of pet.”과 같은 prompt 템플릿을 함께 사용하여 예측하는 것이 성능 향상에 좋습니다.
- “A photo of a big {label}” , “A photo of a small {label}” 과 같은 다른 문맥을 가진 prompt를 활용하여 앙상블을 수행합니다.
3.1.5 Analysis of zeroshot CLIP performance
- Fig5는 Zeroshot Clip과 각 데이터셋을 지도 학습한 ResNet-50을 비교한 것입니다. 초록색이 Zeroshot CLIP이 더 뛰어난 성능을 보인 데이터셋, 파란색이 ResNet-50이 더 뛰어난 성능을 보인 데이터셋입니다.
- CLIP은 전문적이고, 복잡하며, 추상적인 TASK(인공위성 이미지 분류, 림프절 종양 검출, 합성 이미지에서 개체 카운트, 자율주행, 가장 가까운 거리의 차 인식)에 대해서는 약하다는 것을 보여줍니다.
- CLIP의 zeroshot은 few shot 모델보다 훨씬 성능이 좋습니다. CLIP에 Logistic Regression Classifer를 달았을 때 4 shot(클래스당 학습 이미지 개수가 4개)의 성능이 zeroshot과 비슷한 것으로 나옵니다. BiT-M의 경우는 16shot일 때 zeroshot CLIP과 성능이 비슷하게 나옵니다.
- Fig 7은 CLIP의 zeroshot 성능과 비슷하게 유지하려면 CLIP의 Linear Probe가 필요한 클래스당 라벨링 데이터 개수를 정리해둔 것입니다. FER2013 데이터셋의 경우 zeroshot CLIP과 비슷한 성능이 나오려면 Linear Probe CLIP은 클래스당 184개의 데이터가 필요합니다.
- 1개 이하인 데이터 셋에서는 zeroshot의 성능보다 Linear Probe CLIP이 1개 이상의 데이터로 fine tuning 했을 때 성능이 더 좋습니다.
- 클래스당 필요한 라벨링 개수를 1, 2, 4, 8, 16개 등으로 실험하여 log-linear interpolation으로 그래프를 그린 것입니다.
- Zeroshot이 이상적으로 task-agnostic하려면 Fig8 그래프의 y=x에 수렴하는 직선이 나와야 합니다. 그렇게 되는 데이터셋은 5가지 정도(STL10, CIFAR10, ...)이고 대부분은 Linear Probe일 때 성능이 더 좋은 것으로 나옵니다.
3.2 Representation Learning
- CLIP의 Linear probe의 성능은 CV SOTA 모델들을 여러 데이터셋에서 비교했을 때 훨씬 뛰어납니다.
- CLIP의 linear probe는 Noisy Student EfficientNet-L2와 비교했을 때 27개 중 21개의 데이터셋에서 더 뛰어난 성능을 보여줍니다.
3.3 Robustness to Natural Distribution shift
- ImageNet에서 pretraining한 후 linear probe를 수행한 모델들보다 CLIP의 linear probe가 여러 데이터셋에서 성능이 더 뛰어난 점으로 보아 전자의 모델들은 ImageNet에 과적합 되어있음을 알 수 있습니다.
- 반면, CLIP은 다양한 TASK에 강건하다는 것을 알 수 있습니다.
- 강건성을 파악하기 위해 distribution shift가 있는 여러 데이터셋(ImageNetV2, ImageNet-R, ObjectNet, ImageNetSketch, ImageNet-A, Youtube-BB, ImageNetVid)을 평가합니다.
- 왼쪽 그림에서 dashed line은 이상적인 강건 모델을 뜻합니다. distribution shift가 있는 여러 데이터셋에서 정확도와 ImageNet에서의 정확도가 일치한다는 의미입니다.
- 오른쪽 표에서 ImageNet 모델인 ResNet101은 ImageNet이 아닌 다른 데이터셋에서 성능 저하를 보이는 반면 zeroshot CLIP은 우수한 성능을 보입니다.
- zeroshot CLIP은 distribution shift에도 다른 모델에 비해 약 75% 정도 강건함을 보입니다.
4. Comparision to Human Performance
- zeroshot CLIP은 few shot CLIP보다 성능이 좋고, 강건성이 뛰어납니다. Fig15는 16 shot CLIP 모델과 Zeroshot 모델의 ImageNet에서의 성능이 비슷하지만 강건성 면에서는 zeroshot 모델이 훨씬 뛰어납니다.
- Table2는 인간은 zeroshot보다 one, two shot에서 훨씬 더 높은 정확도를 보여주는 반면에 CLIP은 zeroshot이 few shot보다 더 뛰어난 것으로 볼 수 있습니다.
- CLIP이 어려워 애완동물 종류 구분은 인간도 어려워하는 것을 Fig16에서 볼 수 있습니다.
5. Data Overlap Analysis
인터넷으로 받은 큰 데이터셋을 사전 학습하는 것은 평가 데이터셋과 overlap할 수 있는 가능성이 있습니다. 이럴 경우 의미 있는 일반화 평가가 힘들어지는데 4억장의 데이터를 모두 검수하기란 힘들기 때문에 다음과 같은 절차를 통해 overlap이 얼마나 발생하고 그에 따른 성능 차이를 분석합니다.
- duplicate detector를 실행하여 유사도가 특정 threshold 이상이면 Overlap 집합에, 그 이하이면 Clean 집합에 포함시킵니다.
- All 과 Clean을 metric으로 비교하여 데이터셋이 얼마나 오염되어 있는지를 분석합니다.
- overlap 양이 적은 경우가 많기 때문에 Clean에 대한 정확도를 귀무 가설로 사용하고 Overlap 부분 집합에 대한 단일 꼬리(더 큰) p-값을 계산하는 이항 유의성 검정도 실행합니다. 또한 오염도에 대한 99.5% 클로퍼-피어슨 신뢰 구간을 또 다른 검사로 계산합니다.
- 합성 혹은 특수한 데이터 셋에서의 overlap은 검출되지 않았기 때문에 duplicate detector의 false positive의 비율이 낮다는 것을 알 수 있습니다.
- 왼쪽 그래프에서 보듯이 정확도 차이가 20% 정도로 분명히 나는 경우가 있지만, 대부분 overlap의 비율이 낮기 때문에 오른쪽 그래프에서 Clean 집합 대비 All 집합의 정확도 상승이 1% 미만임을 알 수 있습니다.
- Country211은 overlap 비율이 21.5%로 높지만 CLIP이 학습했던 데이터셋과 다릅니다. Country211은 geo-localization을 위한 데이터셋을 측정하는 것으로 학습용 텍스트가 CLIP이 학습한 것과는 다릅니다.
분석의 한계
- duplicate detector가 완벽하지 않고 4억장의 데이터를 모두 확인해본 것은 아닙니다.
- Overlap 과 Clean 의 데이터 분포에 차이가 있어 정확한 비교가 어려울 수 있습니다. 예를 들면 Kinetics-700의 overlap에는 검정 화면이 많이 들어있었기에 clean에 비해 정확도 하락이 20%가량 달했습니다.
6. Limitations
- zeroshot CLIP은 SOTA 모델에 비해 성능이 떨어집니다.
- task-specific, fine-grained classification(차 종류, 꽃 종류, 비행선 구별), abstract/systematic task(개체 개수 세기), novel task(사진 상에서 가장 가까운 차까지의 거리 분류) 등에서 CLIP은 다소 성능이 떨어집니다.
- MNIST와 유사한 데이터를 학습한 적이 없지만 Rendered SST2라는 OCR 데이터셋을 학습했습니다. CLIP은 MNIST에서 88%의 정확도가 나오는데 이는 단순한 logistic regression 지도 학습보다 떨어지는 성능입니다. CLIP은 일반화 문제에 취약하다는 것을 나타냅니다.
7.Broader Impacts
7.1 Bias
- FairFace dataset은 기존의 face dataset에서 백인의 비율을 줄이고 성별과 인종이 고루 분포하도록 수집한 데이터셋입니다.
- 백인 카테고리에서 LR CLIP(Logistic Regression)은 FairFace Model, Linear Probe Instagram보다 성능이 잘 나왔고, ZL CLIP(Zeroshot CLIP)은 성능이 떨어졌습니다.
- 다양한 인종을 평가하면서 ZL CLIP은 성능이 상승하였습니다.
- table 6은 남녀 각각에 7개의 인종, 3개의 범죄 관련 카테고리, 4개의 인간이 아닌 카테고리를 포함해 총 14개의 카테고리로 이미지를 분류하는 작업을 수행할 때 범죄 관련(’thief’, ‘criminal’, ‘suspicious person’) 혹은 인간이 아닌 카테고리(’animal’, ‘chimpanzee’, ‘gorilla’, ‘orangutan’)로 분류된 비율을 나타낸 것입니다.
- table 7은 table 6을 연령별로 구분하고 child 카테고리를 추가했을 때 범죄 관련 혹은 인간이 아닌 카테고리로 분류된 비율을 비교한 표입니다.
- Black 이 가장 높은 비율(14.4%)로 인간이 아닌 카테고리에 분류되었고, 남성의 16.5%, 여성의 9.8%가 범죄 관련 카테고리로 분류되었습니다.
- child 카테고리를 추가했을 때 20세 이하에서 오분류된 비율이 크게 줄어들었다는 점을 볼 수 있습니다.
- 이는 class design이 모델의 성능과 unwanted bias를 결정하는 중요한 요인임을 알 수 있습니다.
7.3 Future Work
- 연구 프로세스 초기에 모델의 잠재적으로 유익한 다운스트림 사용을 식별하여 다른 연구자가 응용 프로그램에 대해 생각할 수 있도록 합니다.
- 상당한 민감성과 정책 입안자들의 개입이 필요할 수 있는 많은 사회적 이해관계자들의 작업 표면화.
- 모델에서 편견을 더 잘 특성화하여 다른 연구자에게 관심 영역과 개입 영역을 경고합니다.
- CLIP과 같은 시스템을 평가하기 위한 테스트를 생성하여 개발 주기 초기에 모델 기능을 더 잘 특성화할 것.
9. Conclusion
CLIP은 폭넓은 task에 대해 사전학습을 하고, 자연어 prompting을 통해 많은 데이터셋에 대하여 zeroshot transfer를 가능하게 합니다. 성능 향상이 필요하지만 task specific한 모델들에 견주어 볼만합니다.
# CODE
Evaluation Code
Zeroshot evaluation
import os
import clip
import torch
from torchvision.datasets import CIFAR100
# Load the model
device = "cuda" if torch.cuda.is_available() else "cpu"
model, preprocess = clip.load('ViT-B/32', device)
# Download the dataset
cifar100 = CIFAR100(root=os.path.expanduser("~/.cache"), download=True, train=False)
# Prepare the inputs
image, class_id = cifar100[3637]
image_input = preprocess(image).unsqueeze(0).to(device)
text_inputs = torch.cat([clip.tokenize(f"a photo of a {c}") for c in cifar100.classes]).to(device)
# Calculate features
with torch.no_grad():
image_features = model.encode_image(image_input)
text_features = model.encode_text(text_inputs)
# Pick the top 5 most similar labels for the image
image_features /= image_features.norm(dim=-1, keepdim=True)
text_features /= text_features.norm(dim=-1, keepdim=True)
similarity = (100.0 * image_features @ text_features.T).softmax(dim=-1)
values, indices = similarity[0].topk(5)
# Print the result
print("\\nTop predictions:\\n")
for value, index in zip(values, indices):
print(f"{cifar100.classes[index]:>16s}: {100 * value.item():.2f}%")
- 분류하고자 하는 이미지 한 장을 전처리하고 모든 텍스트에 a photo of prompt를 붙입니다.
- 모델에 넣어 이미지 피처와 텍스트 피처를 뽑아낸 후 torch.norm 을 사용하여 normalization하고, 행렬 곱을 통해 유사도를 계산합니다.
- torch.topk 를 통해 유사도가 높은 5개를 뽑아 이미지와 유사한지 살핍니다.
Linear Probe evaluation
import os
import clip
import torch
import numpy as np
from sklearn.linear_model import LogisticRegression
from torch.utils.data import DataLoader
from torchvision.datasets import CIFAR100
from tqdm import tqdm
# Load the model
device = "cuda" if torch.cuda.is_available() else "cpu"
model, preprocess = clip.load('ViT-B/32', device)
# Load the dataset
root = os.path.expanduser("~/.cache")
train = CIFAR100(root, download=True, train=True, transform=preprocess)
test = CIFAR100(root, download=True, train=False, transform=preprocess)
def get_features(dataset):
all_features = []
all_labels = []
with torch.no_grad():
for images, labels in tqdm(DataLoader(dataset, batch_size=100)):
features = model.encode_image(images.to(device))
all_features.append(features)
all_labels.append(labels)
return torch.cat(all_features).cpu().numpy(), torch.cat(all_labels).cpu().numpy()
# Calculate the image features
train_features, train_labels = get_features(train)
test_features, test_labels = get_features(test)
# Perform logistic regression
classifier = LogisticRegression(random_state=0, C=0.316, max_iter=1000, verbose=1)
classifier.fit(train_features, train_labels)
# Evaluate using the logistic regression classifier
predictions = classifier.predict(test_features)
accuracy = np.mean((test_labels == predictions).astype(np.float)) * 100.
print(f"Accuracy = {accuracy:.3f}")
- train과 test 이미지의 피처, 텍스트를 리스트에 각각 저장합니다.
- Logistic Regression을 만들어 이미지 피처 리스트로 텍스트를 학습합니다.
- test 이미지로 평가합니다.
출처
728x90