1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62
| import torchvision.transforms from PIL import Image from sympy.printing.pytorch import torch from torch import nn from torch.nn import Sequential, Conv2d, MaxPool2d, Flatten, Linear image_path = "../imgs/dog.png" image = Image.open(image_path) print(image)
image = image.convert('RGB') transform = torchvision.transforms.Compose([ torchvision.transforms.Resize((32, 32)), torchvision.transforms.ToTensor(), ]) image = transform(image) print(image.shape) class MyModel(nn.Module): def __init__(self): super(MyModel, self).__init__() self.model1 = Sequential( Conv2d(3, 32, 5, padding=2), MaxPool2d(2), Conv2d(32, 32, 5, padding=2), MaxPool2d(2), Conv2d(32, 64, 5, padding=2), MaxPool2d(2), Flatten(), Linear(1024, 64), Linear(64, 10) ) def forward(self, x): x = self.model1(x) return x model = MyModel() model.load_state_dict(torch.load('../model_mps/mymodel_9.pth')) print(model)
image = torch.reshape(image, (1, 3, 32, 32)) print("torch.reshape", image.shape) with torch.no_grad(): output = model(image) print(output)
print(output.argmax(1))
classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck') predicted_class = classes[output.argmax(1).item()] print(f"预测结果: {predicted_class}")
|