Загрузка...

Coding a Vision Transformer from scratch using PyTorch

Interested in learning Computer Vision hands-on? Check this: https://hands-on-cv.vizuara.ai/

The basic idea behind transformers, such as those used in ChatGPT, is to split a sentence into words (or tokens) and then convert these tokens into vector representations. Those vector representations are used to predict the next word. Vision transformers are very similar in many aspects. In vision transformers, you divide an image into sub-images (also called patches). Then you create vector representations of each of these patches and add an additional vector called CLS token and pass them through something called multi-head attention (which is very similar to the attention used in transformers for text-based data). Finally, you use CLS token from the input for performing classification.

In this lecture, we code a vision transformer completely from scratch using PyTorch. We divide the transformer into three parts.

Part one is the embedding where the image is split into patches. Patches are converted into vector embeddings on top of these vector embeddings, at the beginning, we concatenate a CLS token and all of these tokens are added with the position embedding.

Part two is the transformer encoder which repeats itself in series. So we can define four transformer encoders in series. We can just define one transformer encoder. It depends on what we want to do. Inside this transformer encoder: - There are two layer normalizations - There are two places where there are residual connections which help us minimize vanishing gradient - There is also one multi-head attention Within the multi-head attention there are parallel layers of attention which process only a fraction of the dimension of the input patch embeddings. At the output of the multi-head attention, these partial vectors are concatenated back to get the original input embedding dimension, and then these vectors are passed through a multi-layer perceptron to increase the representation capability.

Part three is the multi-layer perceptron (MLP) head, which performs the classification. Unlike text-based transformers, we are not going to predict the next batch. Typically, transformers perform next word prediction, but here we are trying to classify the image into one of the 10 or 5 (how many of our number of classes we want) classes we are using. For performing the classification, we only use the CLS tokens. So only the CLS token is passed through the MLP head for classification.

In this lecture, we build a vision transformer from scratch using PyTorch. We don't use any pre-trained model. We train the built vision transformer from scratch, and for the dataset, we use the famous MNIST dataset so that we are performing a 10-class classification. We get 96%+ training accuracy and similar validation accuracy which is on par with what many CNN's can do. With further hyperparameter tuning, we can even push this to 99% accuracy or beyond.

We code the entire architecture in Google Colab so that you can also start with very minimal setup. I hope you enjoy this lecture and learn the theory behind vision transformers and understand how to code every single part of a vision transformer pretty much from scratch.

Here is the Google Colab code file link: https://colab.research.google.com/drive/19zAnmFvBU-vx64yriADswlZHbaU2NN2s?usp=sharing

Видео Coding a Vision Transformer from scratch using PyTorch канала Vizuara
Яндекс.Метрика
Все заметки Новая заметка Страницу в заметки
Страницу в закладки Мои закладки
На информационно-развлекательном портале SALDA.WS применяются cookie-файлы. Нажимая кнопку Принять, вы подтверждаете свое согласие на их использование.
О CookiesНапомнить позжеПринять