Browse Source

update datasets

master
wogong 5 years ago
parent
commit
62bc1b6445
  1. 11
      datasets/gtsrb.py
  2. 3
      datasets/synsigns.py

11
datasets/gtsrb.py

@ -7,7 +7,6 @@ import torch.utils.data as data
from torch.utils.data.sampler import SubsetRandomSampler from torch.utils.data.sampler import SubsetRandomSampler
import numpy as np import numpy as np
def get_gtsrb(dataset_root, batch_size, train): def get_gtsrb(dataset_root, batch_size, train):
"""Get GTSRB datasets loader.""" """Get GTSRB datasets loader."""
shuffle_dataset = True shuffle_dataset = True
@ -36,9 +35,11 @@ def get_gtsrb(dataset_root, batch_size, train):
train_sampler = SubsetRandomSampler(train_indices) train_sampler = SubsetRandomSampler(train_indices)
valid_sampler = SubsetRandomSampler(val_indices) valid_sampler = SubsetRandomSampler(val_indices)
gtsrb_dataloader_train = torch.utils.data.DataLoader(gtsrb_dataset, batch_size=batch_size,
if train:
gtsrb_dataloader_train = torch.utils.data.DataLoader(gtsrb_dataset, batch_size=batch_size,
sampler=train_sampler) sampler=train_sampler)
gtsrb_dataloader_test = torch.utils.data.DataLoader(gtsrb_dataset, batch_size=batch_size,
return gtsrb_dataloader_train
else:
gtsrb_dataloader_test = torch.utils.data.DataLoader(gtsrb_dataset, batch_size=batch_size,
sampler=valid_sampler) sampler=valid_sampler)
return gtsrb_dataloader_train, gtsrb_dataloader_test
return gtsrb_dataloader_test

3
datasets/synsigns.py

@ -49,7 +49,6 @@ def get_synsigns(dataset_root, batch_size, train):
]) ])
# datasets and data_loader # datasets and data_loader
# using first 90K samples as training set
train_list = os.path.join(dataset_root, 'train_labelling.txt') train_list = os.path.join(dataset_root, 'train_labelling.txt')
synsigns_dataset = GetLoader( synsigns_dataset = GetLoader(
data_root=os.path.join(dataset_root), data_root=os.path.join(dataset_root),
@ -57,6 +56,6 @@ def get_synsigns(dataset_root, batch_size, train):
transform=pre_process) transform=pre_process)
synsigns_dataloader = torch.utils.data.DataLoader( synsigns_dataloader = torch.utils.data.DataLoader(
dataset=synsigns_dataset, batch_size=batch_size, shuffle=True, num_workers=0)
dataset=synsigns_dataset, batch_size=batch_size, shuffle=True, num_workers=8)
return synsigns_dataloader return synsigns_dataloader
Loading…
Cancel
Save