From 62bc1b6445a09c7495f8c358e5bddd1b5c3898b1 Mon Sep 17 00:00:00 2001 From: wogong Date: Wed, 4 Sep 2019 11:20:33 +0800 Subject: [PATCH] update datasets --- datasets/gtsrb.py | 11 ++++++----- datasets/synsigns.py | 3 +-- 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/datasets/gtsrb.py b/datasets/gtsrb.py index c97d2d9..92fa9ba 100644 --- a/datasets/gtsrb.py +++ b/datasets/gtsrb.py @@ -7,7 +7,6 @@ import torch.utils.data as data from torch.utils.data.sampler import SubsetRandomSampler import numpy as np - def get_gtsrb(dataset_root, batch_size, train): """Get GTSRB datasets loader.""" shuffle_dataset = True @@ -36,9 +35,11 @@ def get_gtsrb(dataset_root, batch_size, train): train_sampler = SubsetRandomSampler(train_indices) valid_sampler = SubsetRandomSampler(val_indices) - gtsrb_dataloader_train = torch.utils.data.DataLoader(gtsrb_dataset, batch_size=batch_size, + if train: + gtsrb_dataloader_train = torch.utils.data.DataLoader(gtsrb_dataset, batch_size=batch_size, sampler=train_sampler) - gtsrb_dataloader_test = torch.utils.data.DataLoader(gtsrb_dataset, batch_size=batch_size, + return gtsrb_dataloader_train + else: + gtsrb_dataloader_test = torch.utils.data.DataLoader(gtsrb_dataset, batch_size=batch_size, sampler=valid_sampler) - - return gtsrb_dataloader_train, gtsrb_dataloader_test \ No newline at end of file + return gtsrb_dataloader_test \ No newline at end of file diff --git a/datasets/synsigns.py b/datasets/synsigns.py index f8bd07f..5d314a7 100644 --- a/datasets/synsigns.py +++ b/datasets/synsigns.py @@ -49,7 +49,6 @@ def get_synsigns(dataset_root, batch_size, train): ]) # datasets and data_loader - # using first 90K samples as training set train_list = os.path.join(dataset_root, 'train_labelling.txt') synsigns_dataset = GetLoader( data_root=os.path.join(dataset_root), @@ -57,6 +56,6 @@ def get_synsigns(dataset_root, batch_size, train): transform=pre_process) synsigns_dataloader = torch.utils.data.DataLoader( - dataset=synsigns_dataset, batch_size=batch_size, shuffle=True, num_workers=0) + dataset=synsigns_dataset, batch_size=batch_size, shuffle=True, num_workers=8) return synsigns_dataloader \ No newline at end of file