Vision Transformer from scratch
a step-by-step implementation of Vision Transformer (ViT) model introduced in the paper An Image is Worth 16x16 Wordd Transformers for Image Recognition at Scale (ICLR 2021) with PyTorch.
1 Background
Basic Structures of ViT
- Convert an image into a sequence of patches.
- Map the patches into \(D\) dimension with a trainable linear projection \(\mathbf{E}\).
- Prepend a learnable class token
[class]
to the patch embeddings. - Add positional encoding into the patch embeddings.
- ViT encoder consists of \(N\) basic ViT blocks;
- Multi-Head Self-Attention layer
- MLP layer (called FeedForward layer in the original Transformer)
- residual connection
- layer normalization
- A classification head was added to \(\mathbf{z}^{0}_{L}\), the first output token on layer \(L\).
2 Implementation with PyTorch
2-1 Convert image into patches
The standard Transformer receives a 1D sequence of token embeddings as the input. In ViT, the image \(\mathbf{x}\in\mathbb{R}^{H \times W \times C}\) is reshaped into flatten 2D patches \(\mathbf{x} \in \mathbb{R}^{N \times (P^2\cdot C)}\), where \((H,W)\) is the resolution of the image, \((P, P)\) is the resolution of the image patches, \(N=HW/P^2\) is the number of patches, which also serves as the sequence length of the input patch sequence.
This step can be realized with a sliding window extraction on the image, with kernel_size=P
and stride=P
. We can implement this operation with torch.nn.Unfold or torch.nn.functional.unfold in PyTorch.
class Patcher(nn.Module):
def __init__(self, patch_size=16):
super(Patcher, self).__init__()
self.patch_size = patch_size
def forward(self, x):
B, C, H, W = x.shape
assert (H % self.patch_size) == 0 and (W % self.patch_size) == 0
# (B,C,H,W)-->(B, C*(P**2), num_patches)
flat_patches = F.unfold(x, kernel_size=self.patch_size, stride=self.patch_size)
# (B, C*(P**2), N)-> (B, N, C*(P**2))
flat_patches = flat_patches.permute(0, 2, 1)
return flat_patches
We can test Patcher
with the following function:
def test_Patcher():
import cv2
image = cv2.imread("test.jpg")
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
image = image.astype("float32") / 255.0
H, W, C = image.shape
image = torch.from_numpy(image)
image = image.unsqueeze(0).permute(0, 3, 1, 2)
patch_size = 32
patcher = Patcher(patch_size=patch_size)
patches = patcher(image)
B, N, CPP = patches.shape
patches = patches.view(B, N, 3, patch_size, patch_size)
print(patches.shape)
N_h, N_w = H // patch_size, W // patch_size
#
patches = patches[0]
import matplotlib.pyplot as plt
from mpl_toolkits.axes_grid1.axes_grid import ImageGrid
fig = plt.figure(figsize=(8, 8))
grid = ImageGrid(fig, 111, nrows_ncols=(N_h, N_w), axes_pad=0.1)
for i, ax in enumerate(grid):
patch = patches[i].permute(1, 2, 0)
ax.imshow(patch)
ax.axis("off")
plt.show()
fig.savefig("show-patches.jpg")
Then the original image and the patched images should look like:
Remark: Actually, the convolution operation can also be considered as the combination of unfold
(torch.nn.functional.unfold), matrix multiplication
and fold
(torch.nn.functional.fold).
2-2 Map the flatten patches into \(D\) dimensions
This is a simple linear layer self.embed
.
class VisionTransformer(nn.Module):
def __init__(
self,
patch_size=16,
hidden_dim=768,
in_channels=3,
):
super(VisionTransformer, self).__init__()
self.patch_size = patch_size
self.hidden_dim = hidden_dim
# 1) image patcher
self.patcher = Patcher(patch_size=patch_size)
# 2) Linear projection on the patches
self.embed = nn.Linear(in_channels * (patch_size**2), hidden_dim)
Remark: An alternative (and equivalent) way to implement Patch + Embedding Projection
is directly applying a nn.Conv2D
operation to the original 2D image, with out_channels=hidden_dim
, kernel_size=patch_size
and stride=patch_size
:
patch_embed = nn.Conv2d(
in_channels=in_channels, # 3
out_channels=hidden_dim,
kernel_size=patch_size,
stride=patch_size,
)
2-3 Prepend a trainable [class] token to the flatten patches
The self.class_token
is a set of trainable parameters, which can be create with nn.Parameters()
. This [class]
token will be prepend to the patch sequences in the forward()
function.
class VisionTransformer(nn.Module):
def __init__(
self,
patch_size=16,
hidden_dim=768,
in_channels=3,
):
super(VisionTransformer, self).__init__()
self.patch_size = patch_size
self.hidden_dim = hidden_dim
# 1) image patcher
self.patcher = Patcher(patch_size=patch_size)
# 2) Linear projection on the patches
self.embed = nn.Linear(in_channels * (patch_size**2), hidden_dim)
# 3) Prepend a learnable [class] token
self.class_token = nn.Parameter(
data=torch.rand(size=(1, 1, hidden_dim)), requires_grad=True
)
2-4 Add positional encoding into the prepended sequences
Same with Transformer in Natural Language Processing (NLP), we also need to add positional encoding to the embeddings after linear projection.
\[PE(pos, 2i) = \sin(\frac{pos}{10000^{\frac{2i}{D}}})\] \[PE(pos, 2i+1) = \cos(\frac{pos}{10000^{\frac{2i}{D}}})\]where \(D\) is the hidden_dim
.
The positional encoder can be implemented as follows:
class PositionalEncoder(nn.Module):
def __init__(self, num_patches, embed_dim, learnable=False):
super(PositionalEncoder, self).__init__()
self.num_patches = num_patches
self.embed_dim = embed_dim
self.learnable = learnable
self.pos_embed = nn.Parameter(
torch.zeros(size=(1, num_patches, embed_dim)), requires_grad=learnable
)
if not learnable:
### Method 1: for-loop
# for p in range(num_patches):
# for i in range(0, embed_dim, 2):
# self.pos_embed.data[0, p, i] = torch.from_numpy(
# np.sin(p / np.power(10000, i / embed_dim))
# )
# self.pos_embed.data[0, p, i + 1] = torch.from_numpy(
# np.sin(p / np.power(10000, i / embed_dim))
# )
# Method 2: meshgrid
rowv, colv = torch.meshgrid(
torch.arange(num_patches), torch.arange(0, embed_dim, 2), indexing="ij"
)
self.pos_embed.data[0, rowv, colv] = torch.from_numpy(
np.sin(rowv / np.power(10000, colv / embed_dim))
)
self.pos_embed.data[0, rowv, colv + 1] = torch.from_numpy(
np.cos(rowv / np.power(10000, colv / embed_dim))
)
def forward(self, x):
return x + self.pos_embed
Please note that the number of patches need to be inferred:
class VisionTransformer(nn.Module):
def __init__(
self,
image_size=224,
patch_size=16,
num_layers=12,
hidden_dim=768,
mlp_hidden_dim=None,
mlp_ratio=4,
n_heads=12,
num_classes=1000,
in_channels=3,
is_pretrain=False,
):
super(VisionTransformer, self).__init__()
self.image_size = image_size
assert image_size % patch_size == 0
self.num_patches = (image_size / patch_size) ** 2
self.patch_size = patch_size
self.num_layers = num_layers
self.hidden_dim = hidden_dim
if mlp_hidden_dim is not None:
self.mlp_hidden_dim = mlp_hidden_dim
elif mlp_ratio is not None:
self.mlp_hidden_dim = hidden_dim * mlp_ratio
else:
raise ValueError(
"Must assign either one of 'mlp_hidden_dim' and 'mlp_ratio' arguments!"
)
self.n_heads = n_heads
self.num_classes = num_classes
self.in_channels = in_channels
self.is_pretrain = is_pretrain
# 1) image patcher
self.patcher = Patcher(patch_size=patch_size)
# 2) Linear projection on the patches
self.embed = nn.Linear(in_channels * (patch_size**2), hidden_dim)
# 3) Prepend a learnable [class] token
self.class_token = nn.Parameter(
data=torch.rand(size=(1, 1, hidden_dim)), requires_grad=True
)
# 4) positional encoder, length + 1 for the [class] token
self.pos_encoder = PositionalEncoder(
num_patches=self.num_patches + 1, embed_dim=hidden_dim
)
2-5 Build a basic ViT block
In this part, we firstly build a Multi-Head Self-Attention block, then a basic can be built with other sublayers (e.g., FeedForward MLP layer, layer normalization, residual connection)
2-5-1 Multi-Head Self-Attention block
To simplify, we denote n_heads=h
and hidden_dim=D
according to notation in the paper.
Even though the linear projection for queries, keys and values are mapped into \(\frac{D}{h}\) within each head (\(h\) heads in total), we can still implement the \(h\) projections for all heads within a single linear projection and then split the associate \(QW_{i}^{Q}, KW_{i}^{K}, VW_{i}^{V}\) for the scaled dot-producted attention operation within each head \(i\). This is to better leverage the parallel computation optimization in PyTorch and CUDA.
class MultiHeadSelfAttention(nn.Module):
def __init__(self, n_heads, hidden_dim):
super(MultiHeadSelfAttention, self).__init__()
self.n_heads = n_heads
self.hidden_dim = hidden_dim
assert hidden_dim % n_heads == 0
self.dim_per_head = hidden_dim // n_heads
self.qkv_proj = nn.Linear(hidden_dim, hidden_dim * 3)
self.final_proj = nn.Linear(hidden_dim, hidden_dim)
def forward(self, x):
B, N, D = x.shape
qkv = self.qkv_proj(x).view(B, N, 3, self.n_heads, self.dim_per_head)
queries, keys, values = torch.unbind(qkv, dim=2) # 3*(B,N,h,D/h)
queries = queries.permute(0, 2, 1, 3) # (B,h,N,D/h)
keys = keys.permute(0, 2, 3, 1) # (B,h,D/h,N)
values = values.permute(0, 2, 1, 3) # (B,h,N,D/h)
out = torch.matmul(
torch.softmax(
torch.matmul(queries, keys) / math.sqrt(self.dim_per_head),
dim=-1,
),
values,
) # (B,h,N,D/h)
out = out.permute(0, 2, 1, 3).reshape(B, N, D)
out = self.final_proj(out)
return out
2-5-2 A basic ViT block
Then, we can add other layers to the basic block. mlp_hidden_dim
and mlp_ratio
are used to specify the hidden size of the FeedForward
MLP layer.
In the following code, if mlp_hidden_dim
is not specified in the arguments, the default mlp_hidden_dim = mlp_ratio * D
will be adopted. This is consistent with the original paper.
class ViTBlock(nn.Module):
def __init__(self, n_heads, hidden_dim, mlp_hidden_dim=None, mlp_ratio=4):
super(ViTBlock, self).__init__()
self.n_heads = n_heads
self.hidden_dim = hidden_dim
if mlp_hidden_dim is not None:
self.mlp_hidden_dim = mlp_hidden_dim
elif mlp_ratio is not None:
self.mlp_hidden_dim = hidden_dim * mlp_ratio
else:
raise ValueError(
"Must assign either one of 'mlp_hidden_dim' and 'mlp_ratio' arguments!"
)
self.mhsa = MultiHeadSelfAttention(n_heads=n_heads, hidden_dim=hidden_dim)
self.feedforward = nn.Sequential(
nn.Linear(hidden_dim, mlp_hidden_dim),
nn.GELU(),
nn.Linear(mlp_hidden_dim, hidden_dim),
)
self.layernorm1 = nn.LayerNorm(normalized_shape=hidden_dim)
self.layernorm2 = nn.LayerNorm(normalized_shape=hidden_dim)
def forward(self, x):
out = self.mhsa(self.layernorm1(x)) + x
out = self.feedforward(self.layernorm2(x)) + out
return out
2-5-3 Add ViT blocks
Add \(L\) basic blocks
class VisionTransformer(nn.Module):
def __init__(
self,
image_size=224,
patch_size=16,
num_layers=12,
hidden_dim=768,
mlp_hidden_dim=None,
mlp_ratio=4,
n_heads=12,
num_classes=1000,
in_channels=3,
is_pretrain=False,
):
super(VisionTransformer, self).__init__()
self.image_size = image_size
assert image_size % patch_size == 0
self.num_patches = (image_size / patch_size) ** 2
self.patch_size = patch_size
self.num_layers = num_layers
self.hidden_dim = hidden_dim
if mlp_hidden_dim is not None:
self.mlp_hidden_dim = mlp_hidden_dim
elif mlp_ratio is not None:
self.mlp_hidden_dim = hidden_dim * mlp_ratio
else:
raise ValueError(
"Must assign either one of 'mlp_hidden_dim' and 'mlp_ratio' arguments!"
)
self.n_heads = n_heads
self.num_classes = num_classes
self.in_channels = in_channels
self.is_pretrain = is_pretrain
# 1) image patcher
self.patcher = Patcher(patch_size=patch_size)
# 2) Linear projection on the patches
self.embed = nn.Linear(in_channels * (patch_size**2), hidden_dim)
# 3) Prepend a learnable [class] token
self.class_token = nn.Parameter(
data=torch.rand(size=(1, 1, hidden_dim)), requires_grad=True
)
# 4) positional encoder, length + 1 for the [class] token
self.pos_encoder = PositionalEncoder(
num_patches=self.num_patches + 1, embed_dim=hidden_dim
)
# 5) add multiple ViT blocks
self.layers = nn.ModuleList(
[
ViTBlock(
n_heads=n_heads,
hidden_dim=hidden_dim,
mlp_hidden_dim=mlp_hidden_dim,
mlp_ratio=mlp_ratio,
)
for _ in range(num_layers)
]
)
2-6 Add classification head
A classification head is attached to \(\mathbf{z}^{0}_{L}\), the first output embedding in the layer \(L\). As described in the original paper,
- In pre-training mode, the classification head is a MLP with one hidden layer;
- In fine-tuning mode, the classification head is a single linear layer.
class VisionTransformer(nn.Module):
def __init__(
self,
image_size=224,
patch_size=16,
num_layers=12,
hidden_dim=768,
mlp_hidden_dim=None,
mlp_ratio=4,
n_heads=12,
num_classes=1000,
in_channels=3,
is_pretrain=False,
):
super(VisionTransformer, self).__init__()
self.image_size = image_size
assert image_size % patch_size == 0
self.num_patches = (image_size / patch_size) ** 2
self.patch_size = patch_size
self.num_layers = num_layers
self.hidden_dim = hidden_dim
if mlp_hidden_dim is not None:
self.mlp_hidden_dim = mlp_hidden_dim
elif mlp_ratio is not None:
self.mlp_hidden_dim = hidden_dim * mlp_ratio
else:
raise ValueError(
"Must assign either one of 'mlp_hidden_dim' and 'mlp_ratio' arguments!"
)
self.n_heads = n_heads
self.num_classes = num_classes
self.in_channels = in_channels
self.is_pretrain = is_pretrain
# 1) image patcher
self.patcher = Patcher(patch_size=patch_size)
# 2) Linear projection on the patches
self.embed = nn.Linear(in_channels * (patch_size**2), hidden_dim)
# 3) Prepend a learnable [class] token
self.class_token = nn.Parameter(
data=torch.rand(size=(1, 1, hidden_dim)), requires_grad=True
)
# 4) positional encoder, length + 1 for the [class] token
self.pos_encoder = PositionalEncoder(
num_patches=self.num_patches + 1, embed_dim=hidden_dim
)
# 5) add multiple ViT blocks
self.layers = nn.ModuleList(
[
ViTBlock(
n_heads=n_heads,
hidden_dim=hidden_dim,
mlp_hidden_dim=mlp_hidden_dim,
mlp_ratio=mlp_ratio,
)
for _ in range(num_layers)
]
)
# 6) classification head
if is_pretrain:
self.classifier = nn.Sequential(
nn.Linear(hidden_dim, self.mlp_hidden_dim),
nn.Linear(self.mlp_hidden_dim, self.mlp_hidden_dim),
nn.Linear(self.mlp_hidden_dim, num_classes),
)
else:
self.classifier = nn.Linear(hidden_dim, num_classes)
2-7 Finish the forward()
function
The forward
process should be very clear. Two points are worth noting:
- Prepend the
[class]
token into the sequence of the image patches - Add positional encoding into the prepended sequence
- The classification head is attached to \(\mathbf{z}^{0}_{L}\), the first output embedding at the last layer \(L\)
class VisionTransformer(nn.Module):
def __init__(
self,
image_size=224,
patch_size=16,
num_layers=12,
hidden_dim=768,
mlp_hidden_dim=None,
mlp_ratio=4,
n_heads=12,
num_classes=1000,
in_channels=3,
is_pretrain=False,
):
super(VisionTransformer, self).__init__()
self.image_size = image_size
assert image_size % patch_size == 0
self.num_patches = (image_size / patch_size) ** 2
self.patch_size = patch_size
self.num_layers = num_layers
self.hidden_dim = hidden_dim
if mlp_hidden_dim is not None:
self.mlp_hidden_dim = mlp_hidden_dim
elif mlp_ratio is not None:
self.mlp_hidden_dim = hidden_dim * mlp_ratio
else:
raise ValueError(
"Must assign either one of 'mlp_hidden_dim' and 'mlp_ratio' arguments!"
)
self.n_heads = n_heads
self.num_classes = num_classes
self.in_channels = in_channels
self.is_pretrain = is_pretrain
# 1) image patcher
self.patcher = Patcher(patch_size=patch_size)
# 2) Linear projection on the patches
self.embed = nn.Linear(in_channels * (patch_size**2), hidden_dim)
# 3) Prepend a learnable [class] token
self.class_token = nn.Parameter(
data=torch.rand(size=(1, 1, hidden_dim)), requires_grad=True
)
# 4) positional encoder, length + 1 for the [class] token
self.pos_encoder = PositionalEncoder(
num_patches=self.num_patches + 1, embed_dim=hidden_dim
)
# 5) add multiple ViT blocks
self.layers = nn.ModuleList(
[
ViTBlock(
n_heads=n_heads,
hidden_dim=hidden_dim,
mlp_hidden_dim=mlp_hidden_dim,
mlp_ratio=mlp_ratio,
)
for _ in range(num_layers)
]
)
# 6) classification head
if is_pretrain:
self.classifier = nn.Sequential(
nn.Linear(hidden_dim, self.mlp_hidden_dim),
nn.Linear(self.mlp_hidden_dim, self.mlp_hidden_dim),
nn.Linear(self.mlp_hidden_dim, num_classes),
)
else:
self.classifier = nn.Linear(hidden_dim, num_classes)
def forward(self, x):
# image -> (B, N, CPP)
B = x.shape[0]
flat_patches = self.patcher(x)
embeddings = self.embed(flat_patches)
### add the `[class]` token, (B, N, D) -> (B, N+1, D)
embeddings = torch.cat(
[self.class_token.expand(size=(B, -1, -1)), embeddings], dim=1
)
out = self.pos_encoder(embeddings)
for layer in self.layers:
out = layer(out)
### attached the head on $\mathbf{z}^{0}_{L}$
out = self.classifier(out[:, 0, :])
return out
2-8 Three variants of ViT
In the original ViT paper, three ViT variants were introduced, Vit-Base
, ViT-Large
and ViT-Huge
. The differences between them are:
- The number of stacks of basic block (or layers) \(L\)
- The hidden size of the model \(D\)
- The number of heads used in Multi-Head Self-Attention module \(h\)
- The hidden size of the
FeedForward
MLP layermlp_hidden_dim
Here is a detailed table:
Change the num_layers
, hidden_dim
, mlp_hidden_size
, n_heads
to build the three variants of ViT, ViT-B
, ViT-L
and ViT-H
.
def ViT_Base(
patch_size=16,
num_layers=12,
hidden_dim=768,
mlp_hidden_dim=3072,
n_heads=12,
in_channels=3,
):
return VisionTransformer(
patch_size=patch_size,
num_layers=num_layers,
hidden_dim=hidden_dim,
mlp_hidden_dim=mlp_hidden_dim,
n_heads=n_heads,
in_channels=in_channels,
)
def ViT_Large(
patch_size=16,
num_layers=24,
hidden_dim=1024,
mlp_hidden_dim=4096,
n_heads=16,
in_channels=3,
):
return VisionTransformer(
patch_size=patch_size,
num_layers=num_layers,
hidden_dim=hidden_dim,
mlp_hidden_dim=mlp_hidden_dim,
n_heads=n_heads,
in_channels=in_channels,
)
def ViT_Huge(
patch_size=16,
num_layers=32,
hidden_dim=1280,
mlp_hidden_dim=4096,
n_heads=16,
in_channels=3,
):
return VisionTransformer(
patch_size=patch_size,
num_layers=num_layers,
hidden_dim=hidden_dim,
mlp_hidden_dim=mlp_hidden_dim,
n_heads=n_heads,
in_channels=in_channels,
)
Enjoy Reading This Article?
Here are some more articles you might like to read next: