Vision Transformers

Vision transformers are quickly becoming popular in the space of computer vision. Where convolution neural networks (CNNs) once reigned supreme, vision transformers are quickly giving CNNs a run for their money. A quick look at paperswithcode SOTA for imagenet image classification shows that 4 out of the top 5 models incorporate some form of a transformer.

A step-by-step walkthrough

To understand transformers for vision, it helps to have a working understanding of the original Attention is all you Need Paper and ideally a little about CNNs. In this walkthrough, I’ll go through the paper An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale. For a full copy of the notebook, click here. The code here will be simplified for educational purposes.

Embeddings

Transformers for vision work almost identical to a standard transformer. Vision transformers split an image into smaller patches, where each patch is passed into the transfomer as a token. In the paper, they used 16x16 patches. Here patches refers to the number of patches that make up an image, not the number of pixels per patch. For example, Imagenet pictures are 224 x 224 pixels, so when we want our image to be split up into 16x16 patches, we divide $\frac{224}{16}$ which tells us each patch is going to be made up of 14 x 14 pixels.

Since colored images are 3-dimensional (depth(R,G,B channels) x height x width), we need to use a convolution in order to generate a feature map which can act as our patches.

class PatchEmbeddings(nn.Module):
    def __init__(self, img_size, patch_size, embed_dim, input_channels):
        super().__init__()
        self.img_size = img_size
        self.patch_size = patch_size
        assert img_size % patch_size == 0
        self.num_patches = (img_size // patch_size) ** 2
        
        # we use convolution to create the projections of input image(3d matrix) to a vector
        # each filter is responsible for creating one patch 
        self.proj = nn.Conv2d(
            in_channels = input_channels,
            out_channels = embed_dim,
            stride = patch_size,
            kernel_size = patch_size
        )

We then take an input tensor $x$ of shape (batch_size x input_channels x height x width). When then run our tensor through the convolution, which gives us a shape (batch_size x embed_dim x num_patches ** .5 x num_patches ** .5). We then flatten the tensor so that the last two dimensions so that our tensor has shape (batch_size x embed_dim x num_patches). We then swap the last two dimensions by transposing the matrix. We get an output tensor of shape (batch_size, num_patches, embed_dim).

    def forward(self, x):
        """
        converts a tensor of (batch_size x input_channels x height x width)
        to one of tensor of shape (batch_size, num_patches, embed_dim)
        """
        x = self.proj(x)
        x = x.flatten(2)
        x = x.transpose(-2, -1)
        return x
         

Note that this is only to get the patch embeddings. If we were to feed these embeddings into a transformer, it would have no sense of positional encodings, and each patch embedding would be treated as a bag of words.

Attention

Now that we have the patch embeddings done, the rest of the ViT is basically just a standard transformer. We first compute self-attention which is defined as softmax($\frac{QK^T}{\sqrt{d_k}})V$. In practice, we compute self-attention in batches, which is why we make use of torch.bmm for batch matrix multiplication.

def attention(query, key, value):
    """
    calculate scaled dot product attention given a q, k and v
    Params:
    query -> a given query tensor
    key -> a given key tensor
    value -> a given value tensor
    """
    dim = query.shape[-1]
    # (Query * tranpose(key)) / sqrt(dim)
    scores = torch.bmm(query, key.transpose(-2, -1)) / math.sqrt(dim)
    weights = F.softmax(scores, dim = -1)
    return torch.bmm(weights, value)

We generate the query, key, and value vectors by applying a separate learnable linear transformation to our input embedding $x$. To do this we make a separate class AttentionHead which will generate the separate Q, K, V vectors and compute the self-attention.

class AttentionHead(nn.Module):
    """
    Generates the Q, K, V vectors from a given input embedding x 
    and calculates the attention for one head
    Params:
    embed_dim -> embedding dimension of input vector x
    head_dim -> dimension that embed_dim gets transformed into from the qkv transformation
    """
    def __init__(self, embed_dim, head_dim):
        super().__init__()
        self.to_q = nn.Linear(embed_dim, head_dim)
        self.to_k = nn.Linear(embed_dim, head_dim)
        self.to_v = nn.Linear(embed_dim, head_dim)
    def forward(self, x):
        attn = attention(self.to_q(x), self.to_k(x), self.to_v(x))
        return attn

Now, we simply need to calculate multi-head attention. In multi-head attention, we downscale the dimensionality of each head proportional to the number of heads we have.

class MultiHeadAttention(nn.Module):
    """
    Calculates the Multi-Headed attention
    Params:
    num_heads -> number of heads to use, each head_dim is calculated as embed_dim // num_heads 
    embed_dim -> embedding dimension of vector x
    """
    def __init__(self, num_heads, embed_dim):
        super().__init__()
        self.embed_dim = embed_dim
        self.num_heads = num_heads
        self.head_dim = embed_dim // num_heads 

        self.heads = nn.ModuleList(
            AttentionHead(self.embed_dim, self.head_dim) 
            for _ in range(num_heads)
        )
        self.to_out = nn.Linear(embed_dim, embed_dim)

We then concatenate the outputs of each head together. We then take the concatenated matrix through a linear layer.

    def forward(self, x):
        # calculate attention for each head and concatenate tensor
        x = torch.cat([head(x) for head in self.heads], dim = -1)
        x = self.to_out(x)
        return x

MLP

The multi-layer perceptron is simply a feed forward neural network with an activation function between each linear layer.

class FeedForward(nn.Module):
    def __init__(self, embed_dim, hidden_dim, dropout = .1):
        """
        A feed forward network after scaled dot product attention
        Params:
        embed_dim = embedding dimension of vector
        hidden_dim = hidden dimension in the FF network, generally 2-4x larger than embed_dim
        dropout = % dropout for training
        """
        super().__init__()
        self.FFN = nn.Sequential(
            nn.Linear(embed_dim, hidden_dim),
            nn.GELU(),
            nn.Linear(hidden_dim, embed_dim),
            nn.Dropout(dropout)
        )
    def forward(self, x):
        x = self.FFN(x)
        return x

Dropout randomly zeroes out some elements from an input tensor with some probability $p$. This methods helps regularize the network which makes it more robust and avoid overfitting. Note that dropout only occurs during the training phase, when we validate the model, dropout is automatically set to 0. We also generally make the hidden dimension of the network 2-4x larger than the embedding dimension.

Putting the Encoder Together

class ViTBlock(nn.Module):
    """
    A full ViT encoder block, with a pre-layer normalization -> MHA -> layernorm ->FF
    Params:
    num_heads -> number of heads to use, each head_dim is calculated as embed_dim // num_heads
    embed_dim -> embedding dimension of vector x
    hidden_dim -> hidden dimension of the Feed Forward network, generally 2-4x larger than embed_dim
    """
    def __init__(self, num_heads, embed_dim, hidden_dim):
        super().__init__()
        self.MHA = MultiHeadAttention(num_heads, embed_dim)
        self.FFN = FeedForward(embed_dim, hidden_dim)

    def forward(self, x):
        x = self.MHA(x)
        x = self.FFN(x)
        return x

Great, we know have a working transformer encoder! Now all we need to do is stack as many encoders on top of each other and SOTA here we come!. Is it really that simple? Well it’s not that simple. If we just stack as many of these encoders on top of each other, we deal with the vanishing gradient problem. This happens during backpropagation when our activation functions (such as ReLU) will essentially “kill” neurons, which in turn gives us very small gradients which makes it impossible for our network to converge.

Luckily, we can solve this problem by adding residual connections. Residual connections help “smooth out” the loss landscape and help with the vanishing gradient problem. We can define a residual connection as $X^i = Layer_i(X) + X^{(i - 1)}$. Our output tensor $X$ after layer $i$ is equal to the output of $Layer_i(X)$ plus the value of $X$ from a previous layer.

We also use layer-normalization within each encoder. Layer normalization sets the value of our tensors to unit mean and standard deviation. This helps our network converge toward a local minima faster.

Okay, that sounds like a lot, but it’s actually quite simple to implement that code wise. Here’s how it looks.

class ViTBlock(nn.Module):
    def __init__(self, num_heads, embed_dim, hidden_dim):
        super().__init__()
        #layernorm(prenorm) -> MHA ->layernorm -> FF
        self.layernorm1 = nn.LayerNorm(embed_dim)
        self.MHA = MultiHeadAttention(num_heads, embed_dim)
        self.layernorm2 = nn.LayerNorm(embed_dim)
        self.FFN = FeedForward(embed_dim, hidden_dim)

    def forward(self, x):

        x = self.layernorm1(x)
        # residual connection
        x = x + self.MHA(x) # layernorm
        x = self.layernorm2(x) # layernorm
        #residual connection
        x = x + self.FFN(x)
        return x

Putting it all together

We first initialize our transformer. Here we use pytorch lightning: pytorch lightning is a wrapper that sits on top of pytorch that helps us in writing and training our models by having to write less code. We first initalize our patch embeddings.

class ViTTransformer(pl.LightningModule):
    def __init__(self,img_size, patch_size, input_channels, num_heads, 
        embed_dim, hidden_dim, num_layers, num_classes, dropout = .1):
        super().__init__()
        # embedding -> (num_layers * ViTBlock) -> ->layernorm -> linear-head
        self.embedding = Embedding(
            img_size= img_size,
            patch_size = patch_size, 
            embed_dim = embed_dim, 
            input_channels = input_channels
        )

Remember earlier that we only create the our patch embeddings and that our tokens have no sense of position. We can create positional encodings and add them to our patch embeddings to inject a sense of position to each token.

self.pos_embeddings = nn.Parameter(torch.zeros(1,1 + self.embedding.num_patches, embed_dim))

We also introduce a [cls] token. This token is a classification token that we prepend to each of sequence of embedding patches. After our [cls] token goes through each layer, its vector representation gets altered, so that hopefully, by the last MLP head, it contains the right prediction of what class object we’re hoping to predict. Note we take only the [cls] token, push it through a linear layer which returns logits, and we take the softmax of the logits outputs to generate a probability distribution for our prediction.

self.cls_token = nn.Parameter(torch.zeros(1,1, embed_dim))

We also do one last layer normalization before we push our tensor through the final linear layer. Here it is all together.

class ViTTransformer(pl.LightningModule):
    def __init__(self,img_size, patch_size, input_channels, num_heads, 
        embed_dim, hidden_dim, num_layers, num_classes, dropout = .1):
        super().__init__()
        # embedding -> (num_layers * ViTBlock) -> ->layernorm -> linear-head
        self.embedding = Embedding(
            img_size= img_size,
            patch_size = patch_size, 
            embed_dim = embed_dim, 
            input_channels = input_channels
        )
        
        # we create classification token and append it to the beginning of each sequence
        self.cls_token = nn.Parameter(torch.zeros(1,1, embed_dim))
        self.pos_embeddings = nn.Parameter(torch.zeros(1,1 + self.embedding.num_patches, embed_dim))
        self.layers = nn.ModuleList(
            [ViTBlock(
                num_heads = num_heads,
                embed_dim = embed_dim,
                hidden_dim = hidden_dim,
            ) for _ in range(num_layers)] 
        )
        
        self.layernorm = nn.LayerNorm(embed_dim)
        # linear head for classification
        self.to_out = nn.Linear(embed_dim, num_classes)
    
    def forward(self, x):

        batch_size = x.shape[0]
        x = self.embedding(x)

        # add cls token and positional embeddings
        cls_token = self.cls_token.expand(batch_size, -1, -1) # (bs, 1, embed_dim)
        x = torch.cat((cls_token, x), dim = 1) #(bs, 1 + num_patches, embed_dim)
        x = x + self.pos_embeddings

        # go through model
        for layer in self.layers: 
            x = layer(x)

        x = self.layernorm(x)
        cls_token_only = x[:, 0]  # we want only the cls token
        x = self.to_out(cls_token_only)   # linear head for classification
        return x

Things to note

While transformers are an amazing architecture, they’re also not perfect. Transformers lack inductive bias when compared to CNNs, meaning that when we start training, our model has no knowledge of positional embeddings, and we need to learn them from scratch. This can also be a problem when we fine-tune with images of different dimensions than what the model sees during training.

Secondly, transformers also require significantly more training data to acheive performance on par of CNNs, which can be an issue if we’re training transformers from scratch.

Last thing to note is that vision transformers have really started to bring out self-supervised training, which the NLP field has been using for decades. The authors do this by randomly masking out 50% of patches where they try to predict the 3-bit mean color of each masked patch. Self-supervised training allows us to work well with unlabelled data and allows us to only need to label data for fine-tuning. This paper called Dino goes much further into depth about self-supervision, but that’s beyond the scope of this introduction.

Please check out the original notebook for the full code which even trains on the CIFAR-10 dataset.