|
@ -3,7 +3,7 @@ |
|
|
|
|
|
|
|
|
import os |
|
|
import os |
|
|
import torch |
|
|
import torch |
|
|
from torchvision import datasets, transforms |
|
|
|
|
|
|
|
|
from torchvision import transforms |
|
|
import torch.utils.data as data |
|
|
import torch.utils.data as data |
|
|
from torch.utils.data.sampler import SubsetRandomSampler |
|
|
from torch.utils.data.sampler import SubsetRandomSampler |
|
|
import numpy as np |
|
|
import numpy as np |
|
@ -31,7 +31,6 @@ class GTSRB(data.Dataset): |
|
|
def __len__(self): |
|
|
def __len__(self): |
|
|
return self.n_data |
|
|
return self.n_data |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def get_gtsrb(dataset_root, batch_size, train): |
|
|
def get_gtsrb(dataset_root, batch_size, train): |
|
|
"""Get GTSRB datasets loader.""" |
|
|
"""Get GTSRB datasets loader.""" |
|
|
shuffle_dataset = True |
|
|
shuffle_dataset = True |
|
@ -51,7 +50,8 @@ def get_gtsrb(dataset_root, batch_size, train): |
|
|
dataset_size = len(gtsrb_dataset) |
|
|
dataset_size = len(gtsrb_dataset) |
|
|
indices = list(range(dataset_size)) |
|
|
indices = list(range(dataset_size)) |
|
|
if shuffle_dataset: |
|
|
if shuffle_dataset: |
|
|
np.random.seed(random_seed) |
|
|
|
|
|
|
|
|
#np.random.seed(random_seed) |
|
|
|
|
|
np.random.seed() |
|
|
np.random.shuffle(indices) |
|
|
np.random.shuffle(indices) |
|
|
train_indices, val_indices = indices[:train_size], indices[train_size:] |
|
|
train_indices, val_indices = indices[:train_size], indices[train_size:] |
|
|
|
|
|
|
|
|