Pytorch搭建GoogLeNet网络(奥特曼分类)

1 爬取奥特曼

get_data.py

import requests
import urllib.parse as up
import json
import time
import os

major_url = 'https://image.baidu.com/search/index?'
headers = {'User-Agent' : 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/84.0.4147.135 Safari/537.36'}
def pic_spider(kw, path, page = 10):
    path = os.path.join(path, kw)
    if not os.path.exists(path):
        os.mkdir(path)
    if kw != '':
        for num in range(page):
            data = {
                "tn": "resultjson_com",
                "logid": "11587207680030063767",
                "ipn": "rj",
                "ct": "201326592",
                "is": "",
                "fp": "result",
                "queryWord": kw,
                "cl": "2",
                "lm": "-1",
                "ie": "utf-8",
                "oe": "utf-8",
                "adpicid": "",
                "st": "-1",
                "z": "",
                "ic": "0",
                "hd": "",
                "latest": "",
                "copyright": "",
                "word": kw,
                "s": "",
                "se": "",
                "tab": "",
                "width": "",
                "height": "",
                "face": "0",
                "istype": "2",
                "qc": "",
                "nc": "1",
                "fr": "",
                "expermode": "",
                "force": "",
                "pn": num*30,
                "rn": "30",
                "gsm": oct(num*30),
                "1602481599433": ""
            }
            url = major_url + up.urlencode(data)
            i = 0
            pic_list = []
            while i < 5:
                try:
                    pic_list = requests.get(url=url, headers=headers).json().get('data')
                    break
                except:
                    print('网络不好,正在重试...')
                    i += 1
                    time.sleep(1.3)
            for pic in pic_list:
                url = pic.get('thumbURL', '') # 有的没有图片链接,就设置成空
                if url == '':
                    continue
                name = pic.get('fromPageTitleEnc')
                for char in ['?', '\\', '/', '*', '"', '|', ':', '<', '>']:
                    name = name.replace(char, '') # 将所有不能出现在文件名中的字符去除掉
                type = pic.get('type', 'jpg') # 找到图片的类型,若没有找到,默认为 jpg
                pic_path = (os.path.join(path, '%s.%s') % (name, type))
                print(name, '已完成下载')
                if not os.path.exists(pic_path):
                    with open(pic_path, 'wb') as f:
                        f.write(requests.get(url = url, headers = headers).content)
cwd = os.getcwd() # 当前路径        
file1 = 'flower_data/flower_photos'
file2 = '数据/下载数据'
save_path = os.path.join(cwd,file2)
#flower_class = [cla for cla in os.listdir(file1) if ".txt" not in cla]
#lists = ['猫','哈士奇','燕子','恐龙','鹦鹉','老鹰','柴犬','田园犬','咖啡猫','老虎','狮子','哥斯拉','奥特曼']
lists = ['佐菲','初代','赛文','杰克','艾斯','泰罗','奥特之父','奥特之母','爱迪','尤莉安','雷欧','阿斯特拉','奥特之王','葛雷','帕瓦特','奈克斯特','奈克瑟斯','哉阿斯','迪加','戴拿','盖亚(大地)','阿古茹(海洋)','高斯(慈爱)','杰斯提斯(正义)','雷杰多(高斯与杰斯提斯的合体)','诺亚(奈克斯特的最终形态)','撒加','奈欧斯','赛文21','麦克斯','杰诺','梦比优斯','希卡利','赛罗','赛文X']


for list in lists:
    if not os.path.exists(save_path):
        os.mkdir(save_path)
    pic_spider(list+'奥特曼',save_path, page = 10)

print("lists_len: ",len(lists))

2 划分数据集

训练集 train :80%
验证集 val :10%
测试集 predict :10%

spile_data.py

import os
from shutil import copy
import random
import cv2


def mkfile(file):
    if not os.path.exists(file):
        os.makedirs(file)


#file = 'flower_data/flower_photos'
file = '数据/下载数据'
flower_class = [cla for cla in os.listdir(file) if ".txt" not in cla]
#mkfile('flower_data/train')
mkfile('数据/train')
for cla in flower_class:
    #mkfile('flower_data/train/'+cla)
    mkfile('数据/train/'+cla)

#mkfile('flower_data/val')
mkfile('数据/val')
for cla in flower_class:
    #mkfile('flower_data/val/'+cla)
    mkfile('数据/val/'+cla)

mkfile('数据/predict')
for cla in flower_class:
    #mkfile('flower_data/predict/'+cla)
    mkfile('数据/predict/'+cla)

split_rate = 0.1

for cla in flower_class:
    images = []
    cla_path = file + '/' + cla + '/'

    # 过滤jpg和png
    images1 = [cla1 for cla1 in os.listdir(cla_path) if ".jpg" in cla1]
    images2 = [cla1 for cla1 in os.listdir(cla_path) if ".png" in cla1]+images1

    # 去掉小于256的图
    for image in images2:
        img = cv2.imread(cla_path+image)
        if img.shape[0]>255 and img.shape[1]>255:
            images.append(image)

    #images = os.listdir(cla_path) 
    num = len(images)

            

    #eval_index = random.sample(images, k=int(num*split_rate))

    for index, image in enumerate(images):
        if index<0.1*num:
            image_path = cla_path +'/'+ image
            new_path = '数据/val/' + cla
            copy(image_path, new_path)
        elif index<0.2*num:
            image_path = cla_path + image
            new_path = '数据/predict/' + cla
            copy(image_path, new_path)
        else:
            image_path = cla_path + image
            new_path = '数据/train/' + cla
            copy(image_path, new_path)
        print("\r[{}] processing [{}/{}]".format(cla, index+1, num), end="")  # processing bar
    print()

print("processing done!")

3 GoogLeNet model

https://github.com/pytorch/vision/blob/master/torchvision/models/googlenet.py

model.py

import warnings
from collections import namedtuple
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import Tensor
#from .utils import load_state_dict_from_url
from typing import Optional, Tuple, List, Callable, Any

__all__ = ['GoogLeNet', 'googlenet', "GoogLeNetOutputs", "_GoogLeNetOutputs"]

model_urls = {
    # GoogLeNet ported from TensorFlow
    'googlenet': 'https://download.pytorch.org/models/googlenet-1378be20.pth',
}

GoogLeNetOutputs = namedtuple('GoogLeNetOutputs', ['logits', 'aux_logits2', 'aux_logits1'])
GoogLeNetOutputs.__annotations__ = {'logits': Tensor, 'aux_logits2': Optional[Tensor],
                                    'aux_logits1': Optional[Tensor]}

# Script annotations failed with _GoogleNetOutputs = namedtuple ...
# _GoogLeNetOutputs set here for backwards compat
_GoogLeNetOutputs = GoogLeNetOutputs


def googlenet(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> "GoogLeNet":
    r"""GoogLeNet (Inception v1) model architecture from
    `"Going Deeper with Convolutions" <http://arxiv.org/abs/1409.4842>`_.
    Args:
        pretrained (bool): If True, returns a model pre-trained on ImageNet
        progress (bool): If True, displays a progress bar of the download to stderr
        aux_logits (bool): If True, adds two auxiliary branches that can improve training.
            Default: *False* when pretrained is True otherwise *True*
        transform_input (bool): If True, preprocesses the input according to the method with which it
            was trained on ImageNet. Default: *False*
    """
    if pretrained:
        if 'transform_input' not in kwargs:
            kwargs['transform_input'] = True
        if 'aux_logits' not in kwargs:
            kwargs['aux_logits'] = False
        if kwargs['aux_logits']:
            warnings.warn('auxiliary heads in the pretrained googlenet model are NOT pretrained, '
                          'so make sure to train them')
        original_aux_logits = kwargs['aux_logits']
        kwargs['aux_logits'] = True
        kwargs['init_weights'] = False
        model = GoogLeNet(**kwargs)
        #state_dict = load_state_dict_from_url(model_urls['googlenet'],progress=progress)
        #model.load_state_dict(state_dict)
        
        model_weight_path = "./googlenet-1378be20.pth"
        model.load_state_dict(torch.load(model_weight_path))

        if not original_aux_logits:
            model.aux_logits = False
            model.aux1 = None  # type: ignore[assignment]
            model.aux2 = None  # type: ignore[assignment]
        return model

    return GoogLeNet(**kwargs)


class GoogLeNet(nn.Module):
    __constants__ = ['aux_logits', 'transform_input']

    def __init__(
        self,
        num_classes: int = 1000,
        aux_logits: bool = True,
        transform_input: bool = False,
        init_weights: Optional[bool] = None,
        blocks: Optional[List[Callable[..., nn.Module]]] = None
    ) -> None:
        super(GoogLeNet, self).__init__()
        if blocks is None:
            blocks = [BasicConv2d, Inception, InceptionAux]
        if init_weights is None:
            warnings.warn('The default weight initialization of GoogleNet will be changed in future releases of '
                          'torchvision. If you wish to keep the old behavior (which leads to long initialization times'
                          ' due to scipy/scipy#11299), please set init_weights=True.', FutureWarning)
            init_weights = True
        assert len(blocks) == 3
        conv_block = blocks[0]
        inception_block = blocks[1]
        inception_aux_block = blocks[2]

        self.aux_logits = aux_logits
        self.transform_input = transform_input

        self.conv1 = conv_block(3, 64, kernel_size=7, stride=2, padding=3)
        self.maxpool1 = nn.MaxPool2d(3, stride=2, ceil_mode=True)
        self.conv2 = conv_block(64, 64, kernel_size=1)
        self.conv3 = conv_block(64, 192, kernel_size=3, padding=1)
        self.maxpool2 = nn.MaxPool2d(3, stride=2, ceil_mode=True)

        self.inception3a = inception_block(192, 64, 96, 128, 16, 32, 32)
        self.inception3b = inception_block(256, 128, 128, 192, 32, 96, 64)
        self.maxpool3 = nn.MaxPool2d(3, stride=2, ceil_mode=True)

        self.inception4a = inception_block(480, 192, 96, 208, 16, 48, 64)
        self.inception4b = inception_block(512, 160, 112, 224, 24, 64, 64)
        self.inception4c = inception_block(512, 128, 128, 256, 24, 64, 64)
        self.inception4d = inception_block(512, 112, 144, 288, 32, 64, 64)
        self.inception4e = inception_block(528, 256, 160, 320, 32, 128, 128)
        self.maxpool4 = nn.MaxPool2d(2, stride=2, ceil_mode=True)

        self.inception5a = inception_block(832, 256, 160, 320, 32, 128, 128)
        self.inception5b = inception_block(832, 384, 192, 384, 48, 128, 128)

        if aux_logits:
            self.aux1 = inception_aux_block(512, num_classes)
            self.aux2 = inception_aux_block(528, num_classes)
        else:
            self.aux1 = None  # type: ignore[assignment]
            self.aux2 = None  # type: ignore[assignment]

        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
        self.dropout = nn.Dropout(0.2)
        self.fc = nn.Linear(1024, num_classes)

        if init_weights:
            self._initialize_weights()

    def _initialize_weights(self) -> None:
        for m in self.modules():
            if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear):
                import scipy.stats as stats
                X = stats.truncnorm(-2, 2, scale=0.01)
                values = torch.as_tensor(X.rvs(m.weight.numel()), dtype=m.weight.dtype)
                values = values.view(m.weight.size())
                with torch.no_grad():
                    m.weight.copy_(values)
            elif isinstance(m, nn.BatchNorm2d):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)

    def _transform_input(self, x: Tensor) -> Tensor:
        if self.transform_input:
            x_ch0 = torch.unsqueeze(x[:, 0], 1) * (0.229 / 0.5) + (0.485 - 0.5) / 0.5
            x_ch1 = torch.unsqueeze(x[:, 1], 1) * (0.224 / 0.5) + (0.456 - 0.5) / 0.5
            x_ch2 = torch.unsqueeze(x[:, 2], 1) * (0.225 / 0.5) + (0.406 - 0.5) / 0.5
            x = torch.cat((x_ch0, x_ch1, x_ch2), 1)
        return x

    def _forward(self, x: Tensor) -> Tuple[Tensor, Optional[Tensor], Optional[Tensor]]:
        # N x 3 x 224 x 224
        x = self.conv1(x)
        # N x 64 x 112 x 112
        x = self.maxpool1(x)
        # N x 64 x 56 x 56
        x = self.conv2(x)
        # N x 64 x 56 x 56
        x = self.conv3(x)
        # N x 192 x 56 x 56
        x = self.maxpool2(x)

        # N x 192 x 28 x 28
        x = self.inception3a(x)
        # N x 256 x 28 x 28
        x = self.inception3b(x)
        # N x 480 x 28 x 28
        x = self.maxpool3(x)
        # N x 480 x 14 x 14
        x = self.inception4a(x)
        # N x 512 x 14 x 14
        aux1: Optional[Tensor] = None
        if self.aux1 is not None:
            if self.training:
                aux1 = self.aux1(x)

        x = self.inception4b(x)
        # N x 512 x 14 x 14
        x = self.inception4c(x)
        # N x 512 x 14 x 14
        x = self.inception4d(x)
        # N x 528 x 14 x 14
        aux2: Optional[Tensor] = None
        if self.aux2 is not None:
            if self.training:
                aux2 = self.aux2(x)

        x = self.inception4e(x)
        # N x 832 x 14 x 14
        x = self.maxpool4(x)
        # N x 832 x 7 x 7
        x = self.inception5a(x)
        # N x 832 x 7 x 7
        x = self.inception5b(x)
        # N x 1024 x 7 x 7

        x = self.avgpool(x)
        # N x 1024 x 1 x 1
        x = torch.flatten(x, 1)
        # N x 1024
        x = self.dropout(x)
        x = self.fc(x)
        # N x 1000 (num_classes)
        return x, aux2, aux1

    @torch.jit.unused
    def eager_outputs(self, x: Tensor, aux2: Tensor, aux1: Optional[Tensor]) -> GoogLeNetOutputs:
        if self.training and self.aux_logits:
            return _GoogLeNetOutputs(x, aux2, aux1)
        else:
            return x   # type: ignore[return-value]

    def forward(self, x: Tensor) -> GoogLeNetOutputs:
        x = self._transform_input(x)
        x, aux1, aux2 = self._forward(x)
        aux_defined = self.training and self.aux_logits
        if torch.jit.is_scripting():
            if not aux_defined:
                warnings.warn("Scripted GoogleNet always returns GoogleNetOutputs Tuple")
            return GoogLeNetOutputs(x, aux2, aux1)
        else:
            return self.eager_outputs(x, aux2, aux1)


class Inception(nn.Module):

    def __init__(
        self,
        in_channels: int,
        ch1x1: int,
        ch3x3red: int,
        ch3x3: int,
        ch5x5red: int,
        ch5x5: int,
        pool_proj: int,
        conv_block: Optional[Callable[..., nn.Module]] = None
    ) -> None:
        super(Inception, self).__init__()
        if conv_block is None:
            conv_block = BasicConv2d
        self.branch1 = conv_block(in_channels, ch1x1, kernel_size=1)

        self.branch2 = nn.Sequential(
            conv_block(in_channels, ch3x3red, kernel_size=1),
            conv_block(ch3x3red, ch3x3, kernel_size=3, padding=1)
        )

        self.branch3 = nn.Sequential(
            conv_block(in_channels, ch5x5red, kernel_size=1),
            # Here, kernel_size=3 instead of kernel_size=5 is a known bug.
            # Please see https://github.com/pytorch/vision/issues/906 for details.
            conv_block(ch5x5red, ch5x5, kernel_size=3, padding=1)
        )

        self.branch4 = nn.Sequential(
            nn.MaxPool2d(kernel_size=3, stride=1, padding=1, ceil_mode=True),
            conv_block(in_channels, pool_proj, kernel_size=1)
        )

    def _forward(self, x: Tensor) -> List[Tensor]:
        branch1 = self.branch1(x)
        branch2 = self.branch2(x)
        branch3 = self.branch3(x)
        branch4 = self.branch4(x)

        outputs = [branch1, branch2, branch3, branch4]
        return outputs

    def forward(self, x: Tensor) -> Tensor:
        outputs = self._forward(x)
        return torch.cat(outputs, 1)


class InceptionAux(nn.Module):

    def __init__(
        self,
        in_channels: int,
        num_classes: int,
        conv_block: Optional[Callable[..., nn.Module]] = None
    ) -> None:
        super(InceptionAux, self).__init__()
        if conv_block is None:
            conv_block = BasicConv2d
        self.conv = conv_block(in_channels, 128, kernel_size=1)

        self.fc1 = nn.Linear(2048, 1024)
        self.fc2 = nn.Linear(1024, num_classes)

    def forward(self, x: Tensor) -> Tensor:
        # aux1: N x 512 x 14 x 14, aux2: N x 528 x 14 x 14
        x = F.adaptive_avg_pool2d(x, (4, 4))
        # aux1: N x 512 x 4 x 4, aux2: N x 528 x 4 x 4
        x = self.conv(x)
        # N x 128 x 4 x 4
        x = torch.flatten(x, 1)
        # N x 2048
        x = F.relu(self.fc1(x), inplace=True)
        # N x 1024
        x = F.dropout(x, 0.7, training=self.training)
        # N x 1024
        x = self.fc2(x)
        # N x 1000 (num_classes)

        return x


class BasicConv2d(nn.Module):

    def __init__(
        self,
        in_channels: int,
        out_channels: int,
        **kwargs: Any
    ) -> None:
        super(BasicConv2d, self).__init__()
        self.conv = nn.Conv2d(in_channels, out_channels, bias=False, **kwargs)
        self.bn = nn.BatchNorm2d(out_channels, eps=0.001)

    def forward(self, x: Tensor) -> Tensor:
        x = self.conv(x)
        x = self.bn(x)
        return F.relu(x, inplace=True)

4 训练及验证

train.py

import torch
import torch.nn as nn
from torchvision import transforms, datasets, utils
import matplotlib.pyplot as plt
import numpy as np
import torch.optim as optim
from model import googlenet
import os
import json
import time
import torchvision


#device : GPU 或 CPU
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)


#数据预处理
data_transform = {
    "train": transforms.Compose([transforms.RandomResizedCrop(224), # 随机裁剪为224x224
                                 transforms.RandomHorizontalFlip(), # 水平翻转
                                 transforms.ToTensor(), # 转为张量
                                 transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]),# 均值和方差为0.5
    "val": transforms.Compose([transforms.Resize((224, 224)), # 重置大小
                               transforms.ToTensor(),
                               transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])}


batch_size = 4 # 批次大小
data_root = os.getcwd() # 获取当前路径
image_path = data_root + "/数据"  # 数据路径 
#image_path = data_root + "/数据/flower_data"  # 数据路径 
train_dataset = datasets.ImageFolder(root=image_path + "/train",
                                     transform=data_transform["train"]) # 加载训练数据集并预处理
train_num = len(train_dataset) # 训练数据集大小

train_loader = torch.utils.data.DataLoader(train_dataset,
                                           batch_size=batch_size, shuffle=True,
                                           num_workers=2) # 训练加载器

validate_dataset = datasets.ImageFolder(root=image_path + "/val",
                                        transform=data_transform["val"]) # 验证数据集
val_num = len(validate_dataset) # 验证数据集大小
validate_loader = torch.utils.data.DataLoader(validate_dataset,
                                              batch_size=batch_size, shuffle=True,
                                              num_workers=2) # 验证加载器

print("训练数据集大小: ",train_num,"\n") # 28218
print("验证数据集大小: ",val_num,"\n") # 308

def imshow(img):
    img = img / 2 + 0.5     # unnormalize
    npimg = img.numpy()
    plt.imshow(np.transpose(npimg, (1, 2, 0)))
    plt.show()


num_classes=int(input("输入分类数目:"))
net = googlenet(num_classes, init_weights=True) # 调用模型

net.to(device)

loss_function = nn.CrossEntropyLoss() # 损失函数:交叉熵
optimizer = optim.Adam(net.parameters(), lr=0.0002) #优化器 Adam
save_path = './googlenet.pth' # 训练参数保存路径
best_acc = 0.0 # 训练过程中最高准确率

#开始进行训练和测试,训练一轮,测试一轮
for epoch in range(10):
    # 训练部分
    print(">>开始训练: ",epoch+1)
    net.train()    #训练dropout
    running_loss = 0.0
    t1 = time.perf_counter()
    for step, data in enumerate(train_loader, start=0):
        images, labels = data
        #print("\nlabels: ",labels)
        #imshow(torchvision.utils.make_grid(images))

        optimizer.zero_grad() # 梯度置0
        outputs = net(images.to(device)) 
        loss = loss_function(outputs, labels.to(device))
        loss.backward() # 反向传播
        optimizer.step()

        
        running_loss += loss.item() # 累加损失
        rate = (step + 1) / len(train_loader) # 训练进度
        a = "*" * int(rate * 50) # *数
        b = "." * int((1 - rate) * 50) # .数
        print("\rtrain loss: {:^3.0f}%[{}->{}]{:.3f}".format(int(rate * 100), a, b, loss), end="")
    print()
    print(time.perf_counter()-t1) # 一个epoch花费的时间

    # 验证部分
    print(">>开始验证: ",epoch+1)
    net.eval()    #验证不需要dropout
    acc = 0.0  # 一个批次中分类正确个数
    with torch.no_grad():
        for val_data in validate_loader:
            val_images, val_labels = val_data
            outputs = net(val_images.to(device))
            #print("outputs: \n",outputs,"\n")

            predict_y = torch.max(outputs, dim=1)[1]
            #print("predict_y: \n",predict_y,"\n")

            acc += (predict_y == val_labels.to(device)).sum().item() # 预测和标签一致,累加
        val_accurate = acc / val_num # 一个批次的准确率
        if val_accurate > best_acc:
            best_acc = val_accurate
            torch.save(net.state_dict(), save_path) # 更新准确率最高的网络参数
        print('[epoch %d] train_loss: %.3f  test_accuracy: %.3f' %
              (epoch + 1, running_loss / step, val_accurate))

print('Finished Training')

# {'daisy':0, 'dandelion':1, 'roses':2, 'sunflower':3, 'tulips':4}
flower_list = train_dataset.class_to_idx 

#  {0: 'daisy', 1: 'dandelion', 2: 'roses', 3: 'sunflowers', 4: 'tulips'}
cla_dict = dict((val, key) for key, val in flower_list.items())

# 将字典写入 json 文件
json_str = json.dumps(cla_dict, indent=4) # 字典转json
with open('class_indices.json', 'w') as json_file: # 对class_indices.json写入操作
    json_file.write(json_str) # 写入class_indices.json

5 测试

predict.py

import torch
from model import googlenet
from PIL import Image
from torchvision import transforms
import matplotlib.pyplot as plt
import json
import os

data_transform = transforms.Compose(
    [transforms.Resize((224, 224)),
     transforms.ToTensor(),
     transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

cwd = os.getcwd() # 获取当前路径
predict = '数据/predict'
predict_path = os.path.join(cwd,predict)


try:
    json_file = open('./class_indices.json', 'r')
    class_indict = json.load(json_file)
except Exception as e:
    print(e)
    exit(-1)

num_classes=len(class_indict)

# cla 为类
for j,cla in class_indict.items():

    print(">>测试: ",cla)
    #print("类别\t","概率") 

    path = os.path.join(predict_path,cla)
    images = [f1 for f1 in os.listdir(path) if ".gif" not in f1] # 过滤gif动图


    acc_ =  [0 for x in range(0,num_classes)] # 统计类别数
    for image in images:

        # 加载图片
        img = Image.open(path+'/'+image).convert('RGB')
        # RuntimeError: The size of tensor a (4) must match the size of tensor b (3) at non-singleton dimension 0
        # .convert('RGB')

        #plt.imshow(img)

        # [N, C, H, W]
        img = data_transform(img)
        # expand batch dimension
        img = torch.unsqueeze(img, dim=0)

        # read class_indict
        try:
            json_file = open('./class_indices.json', 'r')
            class_indict = json.load(json_file)
        except Exception as e:
            print(e)
            exit(-1)
        
        # create model
        model = googlenet(num_classes)
        # load model weights
        model_weight_path = "./googlenet.pth"
        model.load_state_dict(torch.load(model_weight_path))
        model.eval()
        with torch.no_grad():
            # predict class
            output = torch.squeeze(model(img))
            predict = torch.softmax(output, dim=0)
            predict_cla = torch.argmax(predict).numpy()
        print(class_indict[str(predict_cla)],'\t', predict[predict_cla].item())

        acc_[predict_cla]+=1
        #print(acc_)
        #plt.show()
    #print("acc_: ",acc_)
    if len(images) == 0:
        print("{}文件夹为空".format(cla))
    else:
        print("{}总共有{}张图片 \n其中,".format(cla,len(images)))
    #print(class_indict.values(),'\n',str(acc_))
        print("{}准确率为:{}%".format(cla,100*acc_[int(j)]/len(images)))
    print("\n")
print(">>测试完毕!")

Anaconda 3
python 3.6
pytorch 1.3
torchvision 0.4

|——GoogLeNet
|————数据
|————————下载数据
|————————train
|————————val
|————————preditct
|————get_data.py
|————spile_data.py
|————model.py
|————train.py
|————predict.py

在GoogLeNet文件夹,右键打开终端:

python get_data.py
python spile_data.py
python train.py
python predict.py
上一篇:MATLAB||清除指令clear,clear all,clc,clf,cla


下一篇:Ubuntu18.04(乌班图)环境下安装ROS(Robot Operating System)步骤总结