准备数据集和浮点模型

准备数据集和浮点模型#

从执行必要的导入、定义一些辅助函数和准备数据开始。这些步骤与 PyTorch 中的静态量化(急速模式)完全相同

要在本教程中使用整个 ImageNet 数据集运行代码,首先按照这里 ImageNet Data 的说明下载 ImageNet。

import torch
import torch.nn as nn
from torchvision.models.resnet import resnet18, ResNet18_Weights
from imagenet import ImageNet

train_batch_size = 30
eval_batch_size = 50
# data_path = 'data/imagenet'
data_path = "/media/pc/data/lxw/home/data/datasets/ILSVRC"
dataset = ImageNet(data_path)
data_loader = dataset.train_loader(train_batch_size)
data_loader_test = dataset.test_loader(eval_batch_size)
example_inputs = (next(iter(data_loader))[0])
criterion = nn.CrossEntropyLoss()
float_model = resnet18(weights=ResNet18_Weights.DEFAULT)
float_model = float_model.to("cpu").eval()