How to use pre-trained models in PyTorch?

Published on Aug. 22, 2023, 12:19 p.m.

To use pre-trained models in PyTorch, you can use the torchvision.models module, which provides a collection of popular pre-trained models for image classification, object detection, and other computer vision tasks. Here is an example of how to load a pre-trained ResNet-50 model for image classification:

import torch
import torchvision.models as models

model = models.resnet50(pretrained=True)

In this example, we are using the resnet50() function to load the ResNet-50 model pre-trained on the ImageNet dataset. The pretrained=True argument tells PyTorch to download the pre-trained weights for the model.

Once the pre-trained model is loaded, you can use it to make predictions on new data by passing the input through the model’s forward() method. For example:

# Load an image and preprocess it 
image = ...
preprocessed_image = ...

# Make a prediction using the loaded model
logits = model(preprocessed_image)

In this example, we are passing the preprocessed image through the loaded ResNet-50 model to get the logits (raw output) of the model for the given input.

I hope this helps! Let me know if you have any other questions or if there’s anything else I can help you with.

Tags: