介绍
Tiny ImageNet Challenge 来源于斯坦福 CS231N 课程,共237M
Tiny Imagenet 有 200 个类。 每个类有 500 张训练图像、50 张验证图像和 50 张测试图像。
下载链接:
http://cs231n.stanford.edu/tiny-imagenet-200.zip
数据集使用
因为下载来的train跟val文件夹下图片存放位置不一样,所以路径需要一些变动
wnids.txt存放着标签
words.txt存放标签跟对应的描述,可以在few-shot或是zero-shot的时候用(下面的加载代码没有使用,只是做简单的分类任务)
train/label/xx/xx_boxes.txt与val/val_annotations.txt: 包括lable与boundingbox的标注,目标检测任务中使用(下面的加载代码没有使用,只是做简单的分类任务)
下面附上代码:
from typing import Any import torch import torchvision import torchvision.transforms as transforms from torch.utils.data import Dataset import glob import argparse from PIL import Image class TrainTinyImageNet(Dataset): def __init__(self, root, id, transform=None) -> None: super().__init__() self.filenames = glob.glob(root + "\train***.JPEG") self.transform = transform self.id_dict = id def __len__(self): return len(self.filenames) def __getitem__(self, idx: Any) -> Any: img_path = self.filenames[idx] image = Image.open(img_path) if image.mode == 'L': image = image.convert('RGB') label = self.id_dict[img_path.split('\')[-3]] if self.transform: image = self.transform(image) return image, label class ValTinyImageNet(Dataset): def __init__(self, root, id, transform=None): self.filenames = glob.glob(root + "\valimages*.JPEG") self.transform = transform self.id_dict = id self.cls_dic = {} for i, line in enumerate(open(root + '\val\val_annotations.txt', 'r')): a = line.split(' ') img, cls_id = a[0], a[1] self.cls_dic[img] = self.id_dict[cls_id] def __len__(self): return len(self.filenames) def __getitem__(self, idx): img_path = self.filenames[idx] image = Image.open(img_path) if image.mode == 'L': image = image.convert('RGB') label = self.cls_dic[img_path.split('\')[-1]] if self.transform: image = self.transform(image) return image, label def load_tinyimagenet(args): batch_size = args.batch_size nw = args.workers root = 'E:PythonProjectsdataset\tiny-imagenet-200' id_dic = {} for i, line in enumerate(open(root+'wnids.txt','r')): id_dic[line.replace(' ', '')] = i num_classes = len(id_dic) data_transform = { "train": transforms.Compose([transforms.Resize(224), transforms.RandomCrop(224, padding=4), transforms.RandomHorizontalFlip(), transforms.ToTensor(), transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]), "val": transforms.Compose([transforms.Resize(224), transforms.ToTensor(), transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])} train_dataset = TrainTinyImageNet(root, id=id_dic, transform=data_transform["train"]) val_dataset = ValTinyImageNet(root, id=id_dic, transform=data_transform["val"]) train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True, pin_memory=True, num_workers=nw) val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=batch_size, shuffle=False, pin_memory=True, num_workers=nw) print("TinyImageNet Loading SUCCESS"+ " len of train dataset: "+str(len(train_dataset))+ " len of val dataset: "+str(len(val_dataset))) return train_loader, val_loader, num_classes if __name__ == '__main__': parser = argparse.ArgumentParser("parameters") parser.add_argument("--batch-size", type=int, default=120, help="number of batch size, (default, 512)") parser.add_argument('--workers', type=int, default=7) parser.add_argument('--seed', default=42, type=int, nargs='+', help='seed for initializing training. ') args = parser.parse_args() train, val, num_classes = load_tinyimagenet(args)
workers是数据预加载的参数,可以根据cpu情况自行更改