数据集来源
可使用torch.utils.data.Dataset
来检索数据集中单个数据项,使用torch.utils.data.DataLoader
来定义数据集迭代器(单次批量batch个)。
这里用到来自torchvision
的Fashion-MNIST数据集,其中包含60000个训练图像和10000个测试图像。每个图像都是28x28的灰度图像,其中每个像素值是0到255之间的整数,共有10个类别。示例:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43
| import torch from torch.utils.data import Dataset from torchvision import datasets from torchvision.transforms import ToTensor, Lambda import matplotlib.pyplot as plt %matplotlib inline
training_data = datasets.FashionMNIST( root="data", train=True, download=True, transform=ToTensor() )
test_data = datasets.FashionMNIST( root="data", train=False, download=True, transform=ToTensor() )
labels_map = { 0: "T-Shirt", 1: "Trouser", 2: "Pullover", 3: "Dress", 4: "Coat", 5: "Sandal", 6: "Shirt", 7: "Sneaker", 8: "Bag", 9: "Ankle Boot", } figure = plt.figure(figsize=(8, 8)) cols, rows = 3, 3 for i in range(1, cols * rows + 1): sample_idx = torch.randint(len(training_data), size=(1,)).item() img, label = training_data[sample_idx] figure.add_subplot(rows, cols, i) plt.title(labels_map[label]) plt.axis("off") plt.imshow(img.squeeze(), cmap="gray") plt.show()
|
数据集样本如下图所示:
数据集规范化
首先将数据加载到torch的DataLoader
迭代器中,batch_size用于指定每次迭代的样本数量(并不是一次输入所有样本进行训练,而是“循序渐进”),shuffle为True代表每次迭代时打乱样本顺序(不放回,放回方法可参考Pytorch采样器)。
1 2 3 4 5 6 7 8 9
| train_features, train_labels = next(iter(train_dataloader)) print(f"Feature batch shape: {train_features.size()}") print(f"Labels batch shape: {train_labels.size()}") img = train_features[0].squeeze() label = train_labels[0] plt.imshow(img, cmap="gray") plt.show() print(f"Label: {label}")
|
在多特征情况下,规范化的目的是“让不同特征在同一起跑线上”,举个例子:特征A的取值是0~0.5,而特征B的取值是1000~10000,但假设两者对结果的实际影响程度是相同的,则势必造成A和B的权重数量级相差极大,B的权重仅需变化一点,A的权重就需要变化B的成百上千倍来抵消影响,不利于模型训练。
1 2 3 4 5 6 7 8 9 10 11 12 13
| from torchvision import datasets from torchvision.transforms import ToTensor, Lambda
ds = datasets.FashionMNIST( root="data", train=True, download=True, transform=ToTensor(), target_transform=Lambda(lambda y: torch.zeros(10, dtype=torch.float).scatter_(0, torch.tensor(y), value=1)) )
target_transform = Lambda(lambda y: torch.zeros( 10, dtype=torch.float).scatter_(dim=0, index=torch.tensor(y), value=1))
|
torchvision.datasets数据集提供了transform
和target_transform
参数,用于对数据进行“非即时”的转换(只是将转换用的函数封装进数据结构中),transform
用于修改特征,target_transform
用于修改标签。ToTensor
是一个类对象,用于对数据进行转换,将数据转换为tensor,并将数据转换为[0,1]的浮点型(规范化)。
此处Lambda简单理解为自定义函数,x.scatter_(dim, index, src)
此处没看懂,函数是将src中的元素按照index中的索引放到x中,dim是指定放置的维度,index是指定的索引,src是指定的数据。然后用Lambda
类将这个自定义lambda函数进行封装为一个transform
,但这里若y指的是标签,那么得到的是一个1×10的值全为1的tensor;若这里y指的是单个标签元素,则得到一个6000×10的tensor,每行相当于一个独热编码,比如以[0,1,0,0,0,0,0,0,0,0]
的形式标明属于第二类别。 查阅资料推测这里应该是后者,同时实验验证得到此处target_transform应该是为数据集每张图片附加一个独热编码标签,见下图
[{"url":"https://img.chen0495.top/img2022/202208202154679.png","alt":"经过target_transform的数据"},{"url":"https://img.chen0495.top/img2022/202208202155842.png","alt":"原始数据"}]