Vision Transformers — attention for vision task.

--

Recently there’s paper “An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale” on open-review. It uses pretrained transformers at scale for vision tasks. Transformers are highly successful for language tasks, but haven’t seen that much success for vision. In vision, transformers are either applied in conjunction with Convolutional Neural Networks(CNN) or to replace some components of CNN. Recently transformers has shown good results on object detection (End-to-End Object Detection with Transformers). This paper applies transformers to vision task without using CNN and shows that state-of-art results can be obtained without CNN. The Cost of attention is quadratic. So for images, every pixel needs to attend to every other pixel which is costly. Usually, this is solved using local attention, where you attend to local area around. This paper divides image into patches and unrolls them into sequence and still achieves global attention.

Model Architecture

Architecture

The architecture follows very closely the transformers. This is done to use transformer architecture that has scaled well for NLP tasks and optimised implementation of the architecture can be used out of box from different libraries. The difference came from how images are fed as sequence of patches to transformers.

Patch Embedding

Transformer receives 1D embedding as input. To handle 2D image input., the image is divided into sequence of flattened 2D fix size image patches. So , image of size H*W*C is divided into sequence of patches of size N*(P2*C), where P*P is size of patch.

Trending AI Articles:

1. How to automatically deskew (straighten) a text image using OpenCV

2. Explanation of YOLO V4 a one stage detector

3. 5 Best Artificial Intelligence Online Courses for Beginners in 2020

4. A Non Mathematical guide to the mathematics behind Machine Learning

Before passing the patches to transformer , Paper suggest them to put them through linear projection to get patch embedding. The official jax implementation uses conv layer for the same.(can be done by simple linear layer but its costly). Below is snippet of code from my pytorch implementation for the same.

As with BERT’s [class] token, learnable class token is concatenated to patch embedding, which serves as class representation.

To retain positional information of patches, positional embedding are added to patch embedding. Paper have explored 2D-aware variant as well as standard 1D embedding for position , but haven’t seen much advantage of one over the other.

Hybrid Architecture.

Alternative can be to use intermediate feature maps of a ResNet instead of image patches as input to transformers. The 2D feature map from earlier layers of resnet are flattened and projected to transformer dimension and fed to transformer. class token and positional embedding are added as mentioned.

Artificial Intelligence Jobs

Results

Vision transformer is pretrained on large datasets like Imagenet-1k, Imagenet-21k, JFT-300M. And based on task, it’s fine tuned on the task dataset. The table below shows the results of fine-tuning on vision transformer pretrained on JFT-300M.

What Model Learns

Left: Filters of the initial linear embedding of RGB values of ViT-L/32. Center: Similarity of position embeddings of ViT-L/32. Tiles show the cosine similarity between the position embedding of the patch with the indicated row and column and the position embeddings of all other patches. Right: Size of attended area by head and network depth. Each dot shows the mean attention distance across images for one of 16 heads at one layer. See Appendix D.6 for details. (Reference: https://openreview.net/pdf?id=YicbFdNTTy)

As can be seen from above image(left), the filter that model learns are analogous to what convolution network learns. Model also learns positional embeddings well(center). As can be seen from image(right), as the number of layers increase, the span of attention increases, and after certain layers (mid), it almost spans the whole image. This is quite similar to CNN where in case of CNN we’ll a curve , the advantage with transformers is we can attend to farther elements even in earlier layers too.

You can find my repo for pytorch implementation here. I have used Imagenet-1k pretrained weights from https://github.com/rwightman/pytorch-image-models/ and updated checkpoint for my implementation. The checkpoint can be found here.

You can also find pytorch Kaggle Kernel for fine tuning vision transformer on tpu here.

If you like my work do consider sponsoring me, it’ll help me put out more such work.

References:

Don’t forget to give us your 👏 !

--

--