Building a CIFAR-10 Image Classifier using DINOv2 and PyTorch Lightning
Introduction
In this blog post, we will explore how to build an image classifier using the DINOv2 (Data-efficient Image Transformer) model, which is a state-of-the-art self-supervised learning approach for image classification. We will leverage PyTorch Lightning, a popular deep learning library, to streamline the training process. Our goal is to provide a step-by-step explanation of the code and concepts involved in creating this image classifier.
Setting Up the Environment
Before we dive into the code, make sure you have PyTorch and other necessary libraries installed. You can install PyTorch Lightning with the following command:
!pip install pytorch_lightning -q
Loading the DINOv2 Model
We start by loading the DINOv2 model using the torch.hub.load
function from the Facebook Research GitHub repository . This pre-trained model will serve as our feature extractor.
For details, see the paper: DINOv2: Learning Robust Visual Features without Supervision.
import torch # Load the DINOv2 model
dinov2_vits14 = torch.hub.load("facebookresearch/dinov2", "dinov2_vits14")
Custom Dataset and DataModule
Next, we create a custom dataset class, CustomCIFAR10Dataset
, and a Lightning DataModule, CustomCIFAR10DataModule
, for handling the CIFAR-10 dataset. These classes help with data loading, splitting, and transformations.
class CustomCIFAR10Dataset(Dataset):
def __init__(self, file_paths, labels, transform=None):
self.file_paths = file_paths
self.labels = labels
self.transform = transform
def __len__(self):
return len(self.file_paths)
def __getitem__(self, idx):
image_path = self.file_paths[idx]
image = Image.open(image_path).convert("RGB")
if self.transform:
image = self.transform(image)
label = self.labels[idx]
return image, label
class CustomCIFAR10DataModule(pl.LightningDataModule):
def __init__(self, batch_size, transform):
super().__init__()
self.batch_size = batch_size
self.transform = transform
def prepare_data(self):
torchvision.datasets.CIFAR10(root="./data", train=True, download=True)
torchvision.datasets.CIFAR10(root="./data", train=False, download=True)
def setup(self, stage=None):
if stage == 'fit' or stage is None:
train_dataset = torchvision.datasets.CIFAR10(root="./data", train=True, transform=self.transform)
num_train = len(train_dataset)
self.train_dataset, self.val_dataset = torch.utils.data.random_split(train_dataset, [num_train - 5000, 5000])
if stage == 'test' or stage is None:
self.test_dataset = torchvision.datasets.CIFAR10(root="./data", train=False, transform=self.transform)
def train_dataloader(self):
return DataLoader(self.train_dataset, batch_size=self.batch_size, shuffle=True)
def val_dataloader(self):
return DataLoader(self.val_dataset, batch_size=self.batch_size, shuffle=False)
def test_dataloader(self):
return DataLoader(self.test_dataset, batch_size=self.batch_size, shuffle=False)
Linear Classifier Head and Custom Model
We define a linear classifier head, LinearClassifierHead
, which will take the features extracted by DINOv2 and map them to class predictions. Our main model, CustomModel
, combines the DINOv2 feature extractor and the linear classifier head. It also configures the loss function and optimizer.
class LinearClassifierHead(nn.Module):
def __init__(self, embed_dim, num_classes):
super().__init__()
self.head = nn.Linear(embed_dim, num_classes)
def forward(self, x):
return self.head(x)
class CustomModel(pl.LightningModule):
def __init__(self, embed_dim, num_classes, learning_rate=0.001):
super().__init__()
self.dinov2_vits14 = dinov2_vits14.eval()
self.linear_classifier_head = LinearClassifierHead(embed_dim, num_classes)
self.criterion = nn.CrossEntropyLoss()
self.learning_rate = learning_rate
self.validation_losses = []
def forward(self, x):
with torch.no_grad():
features = self.dinov2_vits14(x)
return self.linear_classifier_head(features)
def configure_optimizers(self):
optimizer = torch.optim.Adam(self.linear_classifier_head.parameters(), lr=self.learning_rate)
return optimizer
def training_step(self, batch, batch_idx):
images, labels = batch
outputs = self(images)
loss = self.criterion(outputs, labels)
self.log('train_loss', loss)
return loss
def validation_step(self, batch, batch_idx):
val_images, val_labels = batch
val_outputs = self(val_images)
val_loss = self.criterion(val_outputs, val_labels)
self.validation_losses.append(val_loss.item())
return val_loss
def on_validation_epoch_end(self):
avg_val_loss = sum(self.validation_losses) / len(self.validation_losses)
self.log('val_loss', avg_val_loss, prog_bar=True)
self.validation_losses = []
Configuration and Training
Here, we set configuration parameters such as batch size, the number of epochs, and learning rate. We also define data transformations for preprocessing the images.
# Configuration
batch_size = 1024
num_epochs = 2
learning_rate = 0.001
transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
data_module = CustomCIFAR10DataModule(batch_size=batch_size, transform=transform)
model = CustomModel(embed_dim=dinov2_vits14.embed_dim, num_classes=10, learning_rate=learning_rate)
trainer = pl.Trainer(
max_epochs=num_epochs,
# val_check_interval=0.1,
check_val_every_n_epoch=1,
)
%%time
trainer.fit(model, data_module)
trainer.validate(model, datamodule=data_module)
print("Training and testing complete.")
Model Evaluation and Conclusion
After training, we evaluate the model's performance on the validation dataset. We also provide a function to display images with true and predicted labels for visual inspection.
import matplotlib.pyplot as plt
import numpy as np
# Load the trained model
# model = CustomModel(embed_dim=dinov2_vits14.embed_dim, num_classes=10, learning_rate=learning_rate)
# model.load_state_dict(torch.load('path_to_your_saved_model.pth'))
model.eval()
# Initialize the validation dataloader
val_dataloader = data_module.val_dataloader()
# Define a function to display images with true and predicted labels as strings
def display_images_with_labels(images, true_labels, predicted_labels, class_names):
num_images = len(images)
fig, axes = plt.subplots(2, 5, figsize=(12, 6))
for i in range(10):
row = i // 5
col = i % 5
image = images[i].permute(1, 2, 0).cpu().numpy()
true_label = class_names[true_labels[i]]
predicted_label = class_names[predicted_labels[i]]
axes[row, col].imshow(image)
axes[row, col].set_title(f'True: {true_label}\nPredicted: {predicted_label}')
axes[row, col].axis('off')
plt.show()
# Make predictions on the validation data and visualize
true_labels = []
predicted_labels = []
images_to_display = []
with torch.no_grad():
for i, val_batch in enumerate(val_dataloader):
if i>1:
break
val_images, val_true_labels = val_batch
val_outputs = model(val_images)
val_predicted_labels = torch.argmax(val_outputs, dim=1)
true_labels.extend(val_true_labels.cpu().numpy())
predicted_labels.extend(val_predicted_labels.cpu().numpy())
images_to_display.extend(val_images)
# Load the CIFAR-10 class names
class_names = [
"airplane", "automobile", "bird", "cat", "deer",
"dog", "frog", "horse", "ship", "truck"
]
true_labels = np.array(true_labels)
predicted_labels = np.array(predicted_labels)
# Display the 10 validation images and labels with class names
display_images_with_labels(images_to_display[:10], true_labels[:10], predicted_labels[:10], class_names)
from sklearn.metrics import accuracy_score
# Compute accuracy
accuracy = accuracy_score(true_labels, predicted_labels)
print(f"Accuracy on the validation dataset: {accuracy * 100:.2f}%")
By following this blog post, you have learned how to build an image classifier using DINOv2 and PyTorch Lightning. You can use this knowledge to explore self-supervised learning and create your own image classification models. Feel free to adapt and share this code to introduce new technology to a wider audience and make the world a better place through education and innovation. Happy coding!