数据集来源

可使用torch.utils.data.Dataset来检索数据集中单个数据项,使用torch.utils.data.DataLoader来定义数据集迭代器(单次批量batch个)。

这里用到来自torchvisionFashion-MNIST数据集,其中包含60000个训练图像和10000个测试图像。每个图像都是28x28的灰度图像,其中每个像素值是0到255之间的整数,共有10个类别。示例

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
import torch
from torch.utils.data import Dataset
from torchvision import datasets
from torchvision.transforms import ToTensor, Lambda
import matplotlib.pyplot as plt
%matplotlib inline

training_data = datasets.FashionMNIST(
root="data",
train=True,
download=True,
transform=ToTensor()
)

test_data = datasets.FashionMNIST(
root="data",
train=False,
download=True,
transform=ToTensor()
)

labels_map = { # 标签映射表
0: "T-Shirt",
1: "Trouser",
2: "Pullover",
3: "Dress",
4: "Coat",
5: "Sandal",
6: "Shirt",
7: "Sneaker",
8: "Bag",
9: "Ankle Boot",
}
figure = plt.figure(figsize=(8, 8))
cols, rows = 3, 3
for i in range(1, cols * rows + 1):
sample_idx = torch.randint(len(training_data), size=(1,)).item() # 随机选择一个样本
img, label = training_data[sample_idx]
figure.add_subplot(rows, cols, i)
plt.title(labels_map[label])
plt.axis("off")
plt.imshow(img.squeeze(), cmap="gray")
plt.show()

数据集样本如下图所示:

数据集规范化

首先将数据加载到torch的DataLoader迭代器中,batch_size用于指定每次迭代的样本数量(并不是一次输入所有样本进行训练,而是“循序渐进”),shuffle为True代表每次迭代时打乱样本顺序(不放回,放回方法可参考Pytorch采样器)。

1
2
3
4
5
6
7
8
9
# Display image and label.
train_features, train_labels = next(iter(train_dataloader))
print(f"Feature batch shape: {train_features.size()}")
print(f"Labels batch shape: {train_labels.size()}")
img = train_features[0].squeeze() # 去除batch维度
label = train_labels[0]
plt.imshow(img, cmap="gray")
plt.show()
print(f"Label: {label}")

在多特征情况下,规范化的目的是“让不同特征在同一起跑线上”,举个例子:特征A的取值是0~0.5,而特征B的取值是1000~10000,但假设两者对结果的实际影响程度是相同的,则势必造成A和B的权重数量级相差极大,B的权重仅需变化一点,A的权重就需要变化B的成百上千倍来抵消影响,不利于模型训练。

1
2
3
4
5
6
7
8
9
10
11
12
13
from torchvision import datasets
from torchvision.transforms import ToTensor, Lambda

ds = datasets.FashionMNIST(
root="data",
train=True,
download=True,
transform=ToTensor(),
target_transform=Lambda(lambda y: torch.zeros(10, dtype=torch.float).scatter_(0, torch.tensor(y), value=1))
)

target_transform = Lambda(lambda y: torch.zeros(
10, dtype=torch.float).scatter_(dim=0, index=torch.tensor(y), value=1))

torchvision.datasets数据集提供了transformtarget_transform参数,用于对数据进行“非即时”的转换(只是将转换用的函数封装进数据结构中),transform用于修改特征,target_transform用于修改标签。ToTensor是一个类对象,用于对数据进行转换,将数据转换为tensor,并将数据转换为[0,1]的浮点型(规范化)。
此处Lambda简单理解为自定义函数,x.scatter_(dim, index, src)此处没看懂,函数是将src中的元素按照index中的索引放到x中,dim是指定放置的维度,index是指定的索引,src是指定的数据。然后用Lambda类将这个自定义lambda函数进行封装为一个transform,但这里若y指的是标签,那么得到的是一个1×10的值全为1的tensor;若这里y指的是单个标签元素,则得到一个6000×10的tensor,每行相当于一个独热编码,比如以[0,1,0,0,0,0,0,0,0,0]的形式标明属于第二类别。 查阅资料推测这里应该是后者,同时实验验证得到此处target_transform应该是为数据集每张图片附加一个独热编码标签,见下图