|
|
@ -7,7 +7,6 @@ import torch.utils.data as data |
|
|
|
from torch.utils.data.sampler import SubsetRandomSampler |
|
|
|
import numpy as np |
|
|
|
|
|
|
|
|
|
|
|
def get_gtsrb(dataset_root, batch_size, train): |
|
|
|
"""Get GTSRB datasets loader.""" |
|
|
|
shuffle_dataset = True |
|
|
@ -36,9 +35,11 @@ def get_gtsrb(dataset_root, batch_size, train): |
|
|
|
train_sampler = SubsetRandomSampler(train_indices) |
|
|
|
valid_sampler = SubsetRandomSampler(val_indices) |
|
|
|
|
|
|
|
gtsrb_dataloader_train = torch.utils.data.DataLoader(gtsrb_dataset, batch_size=batch_size, |
|
|
|
if train: |
|
|
|
gtsrb_dataloader_train = torch.utils.data.DataLoader(gtsrb_dataset, batch_size=batch_size, |
|
|
|
sampler=train_sampler) |
|
|
|
gtsrb_dataloader_test = torch.utils.data.DataLoader(gtsrb_dataset, batch_size=batch_size, |
|
|
|
return gtsrb_dataloader_train |
|
|
|
else: |
|
|
|
gtsrb_dataloader_test = torch.utils.data.DataLoader(gtsrb_dataset, batch_size=batch_size, |
|
|
|
sampler=valid_sampler) |
|
|
|
|
|
|
|
return gtsrb_dataloader_train, gtsrb_dataloader_test |
|
|
|
return gtsrb_dataloader_test |