Vision Transformer from scratch

a step-by-step implementation of Vision Transformer (ViT) model introduced in the paper An Image is Worth 16x16 Wordd Transformers for Image Recognition at Scale (ICLR 2021) with PyTorch.

1 Background

Basic Structures of ViT

  • Convert an image into a sequence of patches.
  • Map the patches into \(D\) dimension with a trainable linear projection \(\mathbf{E}\).
  • Prepend a learnable class token [class] to the patch embeddings.
  • Add positional encoding into the patch embeddings.
  • ViT encoder consists of \(N\) basic ViT blocks;
    • Multi-Head Self-Attention layer
    • MLP layer (called FeedForward layer in the original Transformer)
    • residual connection
    • layer normalization
  • A classification head was added to \(\mathbf{z}^{0}_{L}\), the first output token on layer \(L\).

2 Implementation with PyTorch

2-1 Convert image into patches

The standard Transformer receives a 1D sequence of token embeddings as the input. In ViT, the image \(\mathbf{x}\in\mathbb{R}^{H \times W \times C}\) is reshaped into flatten 2D patches \(\mathbf{x} \in \mathbb{R}^{N \times (P^2\cdot C)}\), where \((H,W)\) is the resolution of the image, \((P, P)\) is the resolution of the image patches, \(N=HW/P^2\) is the number of patches, which also serves as the sequence length of the input patch sequence.

This step can be realized with a sliding window extraction on the image, with kernel_size=P and stride=P. We can implement this operation with torch.nn.Unfold or torch.nn.functional.unfold in PyTorch.

class Patcher(nn.Module):
    def __init__(self, patch_size=16):
        super(Patcher, self).__init__()
        self.patch_size = patch_size

    def forward(self, x):
        B, C, H, W = x.shape
        assert (H % self.patch_size) == 0 and (W % self.patch_size) == 0
        # (B,C,H,W)-->(B, C*(P**2), num_patches)
        flat_patches = F.unfold(x, kernel_size=self.patch_size, stride=self.patch_size)
        # (B, C*(P**2), N)-> (B, N, C*(P**2))
        flat_patches = flat_patches.permute(0, 2, 1)
        return flat_patches

We can test Patcher with the following function:

def test_Patcher():
    import cv2

    image = cv2.imread("test.jpg")
    image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
    image = image.astype("float32") / 255.0
    H, W, C = image.shape
    image = torch.from_numpy(image)
    image = image.unsqueeze(0).permute(0, 3, 1, 2)
    patch_size = 32

    patcher = Patcher(patch_size=patch_size)
    patches = patcher(image)
    B, N, CPP = patches.shape
    patches = patches.view(B, N, 3, patch_size, patch_size)

    print(patches.shape)
    N_h, N_w = H // patch_size, W // patch_size

    #
    patches = patches[0]
    import matplotlib.pyplot as plt
    from mpl_toolkits.axes_grid1.axes_grid import ImageGrid

    fig = plt.figure(figsize=(8, 8))
    grid = ImageGrid(fig, 111, nrows_ncols=(N_h, N_w), axes_pad=0.1)
    for i, ax in enumerate(grid):
        patch = patches[i].permute(1, 2, 0)
        ax.imshow(patch)
        ax.axis("off")

    plt.show()
    fig.savefig("show-patches.jpg")

Then the original image and the patched images should look like:

Remark: Actually, the convolution operation can also be considered as the combination of unfold (torch.nn.functional.unfold), matrix multiplication and fold (torch.nn.functional.fold).

2-2 Map the flatten patches into \(D\) dimensions

This is a simple linear layer self.embed.

class VisionTransformer(nn.Module):
    def __init__(
        self,
        patch_size=16,
        hidden_dim=768,
        in_channels=3,
    ):
        super(VisionTransformer, self).__init__()
        self.patch_size = patch_size
        self.hidden_dim = hidden_dim

        # 1) image patcher
        self.patcher = Patcher(patch_size=patch_size)

        # 2) Linear projection on the patches
        self.embed = nn.Linear(in_channels * (patch_size**2), hidden_dim)

Remark: An alternative (and equivalent) way to implement Patch + Embedding Projection is directly applying a nn.Conv2D operation to the original 2D image, with out_channels=hidden_dim, kernel_size=patch_size and stride=patch_size:

patch_embed = nn.Conv2d(
    in_channels=in_channels, # 3
    out_channels=hidden_dim,
    kernel_size=patch_size,
    stride=patch_size,
)

2-3 Prepend a trainable [class] token to the flatten patches

The self.class_token is a set of trainable parameters, which can be create with nn.Parameters(). This [class] token will be prepend to the patch sequences in the forward() function.

class VisionTransformer(nn.Module):
    def __init__(
        self,
        patch_size=16,
        hidden_dim=768,
        in_channels=3,
    ):
        super(VisionTransformer, self).__init__()
        self.patch_size = patch_size
        self.hidden_dim = hidden_dim

        # 1) image patcher
        self.patcher = Patcher(patch_size=patch_size)

        # 2) Linear projection on the patches
        self.embed = nn.Linear(in_channels * (patch_size**2), hidden_dim)

        # 3) Prepend a learnable [class] token
        self.class_token = nn.Parameter(
            data=torch.rand(size=(1, 1, hidden_dim)), requires_grad=True
        )

2-4 Add positional encoding into the prepended sequences

Same with Transformer in Natural Language Processing (NLP), we also need to add positional encoding to the embeddings after linear projection.

\[PE(pos, 2i) = \sin(\frac{pos}{10000^{\frac{2i}{D}}})\] \[PE(pos, 2i+1) = \cos(\frac{pos}{10000^{\frac{2i}{D}}})\]

where \(D\) is the hidden_dim.

The positional encoder can be implemented as follows:

class PositionalEncoder(nn.Module):
    def __init__(self, num_patches, embed_dim, learnable=False):
        super(PositionalEncoder, self).__init__()
        self.num_patches = num_patches
        self.embed_dim = embed_dim
        self.learnable = learnable
        self.pos_embed = nn.Parameter(
            torch.zeros(size=(1, num_patches, embed_dim)), requires_grad=learnable
        )
        if not learnable:
            ### Method 1: for-loop
            # for p in range(num_patches):
            #     for i in range(0, embed_dim, 2):
            #         self.pos_embed.data[0, p, i] = torch.from_numpy(
            #             np.sin(p / np.power(10000, i / embed_dim))
            #         )
            #         self.pos_embed.data[0, p, i + 1] = torch.from_numpy(
            #             np.sin(p / np.power(10000, i / embed_dim))
            #         )

            # Method 2: meshgrid
            rowv, colv = torch.meshgrid(
                torch.arange(num_patches), torch.arange(0, embed_dim, 2), indexing="ij"
            )
            self.pos_embed.data[0, rowv, colv] = torch.from_numpy(
                np.sin(rowv / np.power(10000, colv / embed_dim))
            )
            self.pos_embed.data[0, rowv, colv + 1] = torch.from_numpy(
                np.cos(rowv / np.power(10000, colv / embed_dim))
            )

    def forward(self, x):
        return x + self.pos_embed

Please note that the number of patches need to be inferred:

class VisionTransformer(nn.Module):
    def __init__(
        self,
        image_size=224,
        patch_size=16,
        num_layers=12,
        hidden_dim=768,
        mlp_hidden_dim=None,
        mlp_ratio=4,
        n_heads=12,
        num_classes=1000,
        in_channels=3,
        is_pretrain=False,
    ):
        super(VisionTransformer, self).__init__()
        self.image_size = image_size
        assert image_size % patch_size == 0
        self.num_patches = (image_size / patch_size) ** 2
        self.patch_size = patch_size
        self.num_layers = num_layers
        self.hidden_dim = hidden_dim
        if mlp_hidden_dim is not None:
            self.mlp_hidden_dim = mlp_hidden_dim
        elif mlp_ratio is not None:
            self.mlp_hidden_dim = hidden_dim * mlp_ratio
        else:
            raise ValueError(
                "Must assign either one of 'mlp_hidden_dim' and 'mlp_ratio' arguments!"
            )
        self.n_heads = n_heads
        self.num_classes = num_classes
        self.in_channels = in_channels
        self.is_pretrain = is_pretrain

        # 1) image patcher
        self.patcher = Patcher(patch_size=patch_size)

        # 2) Linear projection on the patches
        self.embed = nn.Linear(in_channels * (patch_size**2), hidden_dim)

        # 3) Prepend a learnable [class] token
        self.class_token = nn.Parameter(
            data=torch.rand(size=(1, 1, hidden_dim)), requires_grad=True
        )

        # 4) positional encoder, length + 1 for the [class] token
        self.pos_encoder = PositionalEncoder(
            num_patches=self.num_patches + 1, embed_dim=hidden_dim
        )

2-5 Build a basic ViT block

In this part, we firstly build a Multi-Head Self-Attention block, then a basic can be built with other sublayers (e.g., FeedForward MLP layer, layer normalization, residual connection)

A basic ViT block

2-5-1 Multi-Head Self-Attention block

To simplify, we denote n_heads=h and hidden_dim=D according to notation in the paper.

Even though the linear projection for queries, keys and values are mapped into \(\frac{D}{h}\) within each head (\(h\) heads in total), we can still implement the \(h\) projections for all heads within a single linear projection and then split the associate \(QW_{i}^{Q}, KW_{i}^{K}, VW_{i}^{V}\) for the scaled dot-producted attention operation within each head \(i\). This is to better leverage the parallel computation optimization in PyTorch and CUDA.

class MultiHeadSelfAttention(nn.Module):
    def __init__(self, n_heads, hidden_dim):
        super(MultiHeadSelfAttention, self).__init__()
        self.n_heads = n_heads
        self.hidden_dim = hidden_dim
        assert hidden_dim % n_heads == 0
        self.dim_per_head = hidden_dim // n_heads

        self.qkv_proj = nn.Linear(hidden_dim, hidden_dim * 3)
        self.final_proj = nn.Linear(hidden_dim, hidden_dim)

    def forward(self, x):
        B, N, D = x.shape
        qkv = self.qkv_proj(x).view(B, N, 3, self.n_heads, self.dim_per_head)
        queries, keys, values = torch.unbind(qkv, dim=2)  # 3*(B,N,h,D/h)
        queries = queries.permute(0, 2, 1, 3)  # (B,h,N,D/h)
        keys = keys.permute(0, 2, 3, 1)  # (B,h,D/h,N)
        values = values.permute(0, 2, 1, 3)  # (B,h,N,D/h)

        out = torch.matmul(
            torch.softmax(
                torch.matmul(queries, keys) / math.sqrt(self.dim_per_head),
                dim=-1,
            ),
            values,
        )  # (B,h,N,D/h)
        out = out.permute(0, 2, 1, 3).reshape(B, N, D)
        out = self.final_proj(out)
        return out

2-5-2 A basic ViT block

Then, we can add other layers to the basic block. mlp_hidden_dim and mlp_ratio are used to specify the hidden size of the FeedForward MLP layer.

In the following code, if mlp_hidden_dim is not specified in the arguments, the default mlp_hidden_dim = mlp_ratio * D will be adopted. This is consistent with the original paper.

class ViTBlock(nn.Module):
    def __init__(self, n_heads, hidden_dim, mlp_hidden_dim=None, mlp_ratio=4):
        super(ViTBlock, self).__init__()
        self.n_heads = n_heads
        self.hidden_dim = hidden_dim
        if mlp_hidden_dim is not None:
            self.mlp_hidden_dim = mlp_hidden_dim
        elif mlp_ratio is not None:
            self.mlp_hidden_dim = hidden_dim * mlp_ratio
        else:
            raise ValueError(
                "Must assign either one of 'mlp_hidden_dim' and 'mlp_ratio' arguments!"
            )

        self.mhsa = MultiHeadSelfAttention(n_heads=n_heads, hidden_dim=hidden_dim)
        self.feedforward = nn.Sequential(
            nn.Linear(hidden_dim, mlp_hidden_dim),
            nn.GELU(),
            nn.Linear(mlp_hidden_dim, hidden_dim),
        )
        self.layernorm1 = nn.LayerNorm(normalized_shape=hidden_dim)
        self.layernorm2 = nn.LayerNorm(normalized_shape=hidden_dim)

    def forward(self, x):
        out = self.mhsa(self.layernorm1(x)) + x
        out = self.feedforward(self.layernorm2(x)) + out
        return out

2-5-3 Add ViT blocks

Add \(L\) basic blocks

class VisionTransformer(nn.Module):
    def __init__(
        self,
        image_size=224,
        patch_size=16,
        num_layers=12,
        hidden_dim=768,
        mlp_hidden_dim=None,
        mlp_ratio=4,
        n_heads=12,
        num_classes=1000,
        in_channels=3,
        is_pretrain=False,
    ):
        super(VisionTransformer, self).__init__()
        self.image_size = image_size
        assert image_size % patch_size == 0
        self.num_patches = (image_size / patch_size) ** 2
        self.patch_size = patch_size
        self.num_layers = num_layers
        self.hidden_dim = hidden_dim
        if mlp_hidden_dim is not None:
            self.mlp_hidden_dim = mlp_hidden_dim
        elif mlp_ratio is not None:
            self.mlp_hidden_dim = hidden_dim * mlp_ratio
        else:
            raise ValueError(
                "Must assign either one of 'mlp_hidden_dim' and 'mlp_ratio' arguments!"
            )
        self.n_heads = n_heads
        self.num_classes = num_classes
        self.in_channels = in_channels
        self.is_pretrain = is_pretrain

        # 1) image patcher
        self.patcher = Patcher(patch_size=patch_size)

        # 2) Linear projection on the patches
        self.embed = nn.Linear(in_channels * (patch_size**2), hidden_dim)

        # 3) Prepend a learnable [class] token
        self.class_token = nn.Parameter(
            data=torch.rand(size=(1, 1, hidden_dim)), requires_grad=True
        )

        # 4) positional encoder, length + 1 for the [class] token
        self.pos_encoder = PositionalEncoder(
            num_patches=self.num_patches + 1, embed_dim=hidden_dim
        )

        # 5) add multiple ViT blocks
        self.layers = nn.ModuleList(
            [
                ViTBlock(
                    n_heads=n_heads,
                    hidden_dim=hidden_dim,
                    mlp_hidden_dim=mlp_hidden_dim,
                    mlp_ratio=mlp_ratio,
                )
                for _ in range(num_layers)
            ]
        )

2-6 Add classification head

A classification head is attached to \(\mathbf{z}^{0}_{L}\), the first output embedding in the layer \(L\). As described in the original paper,

  • In pre-training mode, the classification head is a MLP with one hidden layer;
  • In fine-tuning mode, the classification head is a single linear layer.
class VisionTransformer(nn.Module):
    def __init__(
        self,
        image_size=224,
        patch_size=16,
        num_layers=12,
        hidden_dim=768,
        mlp_hidden_dim=None,
        mlp_ratio=4,
        n_heads=12,
        num_classes=1000,
        in_channels=3,
        is_pretrain=False,
    ):
        super(VisionTransformer, self).__init__()
        self.image_size = image_size
        assert image_size % patch_size == 0
        self.num_patches = (image_size / patch_size) ** 2
        self.patch_size = patch_size
        self.num_layers = num_layers
        self.hidden_dim = hidden_dim
        if mlp_hidden_dim is not None:
            self.mlp_hidden_dim = mlp_hidden_dim
        elif mlp_ratio is not None:
            self.mlp_hidden_dim = hidden_dim * mlp_ratio
        else:
            raise ValueError(
                "Must assign either one of 'mlp_hidden_dim' and 'mlp_ratio' arguments!"
            )
        self.n_heads = n_heads
        self.num_classes = num_classes
        self.in_channels = in_channels
        self.is_pretrain = is_pretrain

        # 1) image patcher
        self.patcher = Patcher(patch_size=patch_size)

        # 2) Linear projection on the patches
        self.embed = nn.Linear(in_channels * (patch_size**2), hidden_dim)

        # 3) Prepend a learnable [class] token
        self.class_token = nn.Parameter(
            data=torch.rand(size=(1, 1, hidden_dim)), requires_grad=True
        )

        # 4) positional encoder, length + 1 for the [class] token
        self.pos_encoder = PositionalEncoder(
            num_patches=self.num_patches + 1, embed_dim=hidden_dim
        )

        # 5) add multiple ViT blocks
        self.layers = nn.ModuleList(
            [
                ViTBlock(
                    n_heads=n_heads,
                    hidden_dim=hidden_dim,
                    mlp_hidden_dim=mlp_hidden_dim,
                    mlp_ratio=mlp_ratio,
                )
                for _ in range(num_layers)
            ]
        )
        # 6) classification head
        if is_pretrain:
            self.classifier = nn.Sequential(
                nn.Linear(hidden_dim, self.mlp_hidden_dim),
                nn.Linear(self.mlp_hidden_dim, self.mlp_hidden_dim),
                nn.Linear(self.mlp_hidden_dim, num_classes),
            )
        else:
            self.classifier = nn.Linear(hidden_dim, num_classes)

2-7 Finish the forward() function

The forward process should be very clear. Two points are worth noting:

  • Prepend the [class] token into the sequence of the image patches
  • Add positional encoding into the prepended sequence
  • The classification head is attached to \(\mathbf{z}^{0}_{L}\), the first output embedding at the last layer \(L\)
class VisionTransformer(nn.Module):
    def __init__(
        self,
        image_size=224,
        patch_size=16,
        num_layers=12,
        hidden_dim=768,
        mlp_hidden_dim=None,
        mlp_ratio=4,
        n_heads=12,
        num_classes=1000,
        in_channels=3,
        is_pretrain=False,
    ):
        super(VisionTransformer, self).__init__()
        self.image_size = image_size
        assert image_size % patch_size == 0
        self.num_patches = (image_size / patch_size) ** 2
        self.patch_size = patch_size
        self.num_layers = num_layers
        self.hidden_dim = hidden_dim
        if mlp_hidden_dim is not None:
            self.mlp_hidden_dim = mlp_hidden_dim
        elif mlp_ratio is not None:
            self.mlp_hidden_dim = hidden_dim * mlp_ratio
        else:
            raise ValueError(
                "Must assign either one of 'mlp_hidden_dim' and 'mlp_ratio' arguments!"
            )
        self.n_heads = n_heads
        self.num_classes = num_classes
        self.in_channels = in_channels
        self.is_pretrain = is_pretrain

        # 1) image patcher
        self.patcher = Patcher(patch_size=patch_size)

        # 2) Linear projection on the patches
        self.embed = nn.Linear(in_channels * (patch_size**2), hidden_dim)

        # 3) Prepend a learnable [class] token
        self.class_token = nn.Parameter(
            data=torch.rand(size=(1, 1, hidden_dim)), requires_grad=True
        )

        # 4) positional encoder, length + 1 for the [class] token
        self.pos_encoder = PositionalEncoder(
            num_patches=self.num_patches + 1, embed_dim=hidden_dim
        )

        # 5) add multiple ViT blocks
        self.layers = nn.ModuleList(
            [
                ViTBlock(
                    n_heads=n_heads,
                    hidden_dim=hidden_dim,
                    mlp_hidden_dim=mlp_hidden_dim,
                    mlp_ratio=mlp_ratio,
                )
                for _ in range(num_layers)
            ]
        )
        # 6) classification head
        if is_pretrain:
            self.classifier = nn.Sequential(
                nn.Linear(hidden_dim, self.mlp_hidden_dim),
                nn.Linear(self.mlp_hidden_dim, self.mlp_hidden_dim),
                nn.Linear(self.mlp_hidden_dim, num_classes),
            )
        else:
            self.classifier = nn.Linear(hidden_dim, num_classes)

    def forward(self, x):
        # image -> (B, N, CPP)
        B = x.shape[0]
        flat_patches = self.patcher(x)
        embeddings = self.embed(flat_patches)
        ### add the `[class]` token, (B, N, D) -> (B, N+1, D)
        embeddings = torch.cat(
            [self.class_token.expand(size=(B, -1, -1)), embeddings], dim=1
        )
        out = self.pos_encoder(embeddings)
        for layer in self.layers:
            out = layer(out)
        ### attached the head on $\mathbf{z}^{0}_{L}$
        out = self.classifier(out[:, 0, :])
        return out

2-8 Three variants of ViT

In the original ViT paper, three ViT variants were introduced, Vit-Base, ViT-Large and ViT-Huge. The differences between them are:

  • The number of stacks of basic block (or layers) \(L\)
  • The hidden size of the model \(D\)
  • The number of heads used in Multi-Head Self-Attention module \(h\)
  • The hidden size of the FeedForward MLP layer mlp_hidden_dim

Here is a detailed table:

ViT variants

Change the num_layers, hidden_dim, mlp_hidden_size, n_heads to build the three variants of ViT, ViT-B, ViT-L and ViT-H.

def ViT_Base(
    patch_size=16,
    num_layers=12,
    hidden_dim=768,
    mlp_hidden_dim=3072,
    n_heads=12,
    in_channels=3,
):
    return VisionTransformer(
        patch_size=patch_size,
        num_layers=num_layers,
        hidden_dim=hidden_dim,
        mlp_hidden_dim=mlp_hidden_dim,
        n_heads=n_heads,
        in_channels=in_channels,
    )


def ViT_Large(
    patch_size=16,
    num_layers=24,
    hidden_dim=1024,
    mlp_hidden_dim=4096,
    n_heads=16,
    in_channels=3,
):
    return VisionTransformer(
        patch_size=patch_size,
        num_layers=num_layers,
        hidden_dim=hidden_dim,
        mlp_hidden_dim=mlp_hidden_dim,
        n_heads=n_heads,
        in_channels=in_channels,
    )


def ViT_Huge(
    patch_size=16,
    num_layers=32,
    hidden_dim=1280,
    mlp_hidden_dim=4096,
    n_heads=16,
    in_channels=3,
):
    return VisionTransformer(
        patch_size=patch_size,
        num_layers=num_layers,
        hidden_dim=hidden_dim,
        mlp_hidden_dim=mlp_hidden_dim,
        n_heads=n_heads,
        in_channels=in_channels,
    )



    Enjoy Reading This Article?

    Here are some more articles you might like to read next:

  • Container for Deep Learning Environment
  • Transformer model from scratch
  • Distributed Training with PyTorch