1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76
| import torch import torch.nn as nn from torchvision import transforms from torchvision.datasets import CIFAR10 from torch.utils.data import DataLoader, Subset import numpy as np import sys
class CustomCNN(nn.Module): def __init__(self): super(CustomCNN, self).__init__() self.conv1 = nn.Conv2d(in_channels=3, out_channels=16, kernel_size=3, padding=1) self.conv2 = nn.Conv2d(in_channels=16, out_channels=32, kernel_size=3, padding=1) self.pool = nn.MaxPool2d(kernel_size=2, stride=2, padding=0) self.fc1 = nn.Linear(in_features=32 * 8 * 8, out_features=128) self.fc2 = nn.Linear(in_features=128, out_features=10) self.relu = nn.ReLU()
def forward(self, x): x = self.pool(self.relu(self.conv1(x))) x = self.pool(self.relu(self.conv2(x))) x = x.view(-1, 32 * 8 * 8) x = self.relu(self.fc1(x)) x = self.fc2(x) return x
def load_model(model_path): model = CustomCNN() state_dict = torch.load(model_path) model.load_state_dict(state_dict) model.eval() return model
def prepare_dataloader(batch_size=32): transform = transforms.Compose([ transforms.RandomHorizontalFlip(), transforms.RandomCrop(32, padding=4), transforms.ToTensor(), transforms.Normalize(mean=[0.4914, 0.4822, 0.4465], std=[0.2023, 0.1994, 0.2010]), ]) dataset = CIFAR10(root='/root/datasets/', train=False, download=False, transform=transform) subset = Subset(dataset, indices=np.random.choice(len(dataset), 64, replace=False)) dataloader = DataLoader(subset, batch_size=batch_size, shuffle=False) return dataloader
def evaluate_model(model, dataloader): correct = 0 total = 0 with torch.no_grad(): for images, labels in dataloader: outputs = model(images) _, predicted = torch.max(outputs.data, 1) total += labels.size(0) correct += (predicted == labels).sum().item() accuracy = 100 * correct / total print(f'[+] Accuracy of the model on the test dataset: {accuracy:.2f}%')
def main(model_path): model = load_model(model_path) print("[+] Loaded Model.") dataloader = prepare_dataloader() print("[+] Dataloader ready. Evaluating model...") evaluate_model(model, dataloader)
if __name__ == "__main__": if len(sys.argv) < 2: print("Usage: python script.py <path_to_model.pth>") else: model_path = sys.argv[1] main(model_path)
|