Browse Source

using random seed.

master
wogong 5 years ago
parent
commit
793969a3ca
  1. 6
      datasets/gtsrb.py

6
datasets/gtsrb.py

@ -3,7 +3,7 @@
import os import os
import torch import torch
from torchvision import datasets, transforms
from torchvision import transforms
import torch.utils.data as data import torch.utils.data as data
from torch.utils.data.sampler import SubsetRandomSampler from torch.utils.data.sampler import SubsetRandomSampler
import numpy as np import numpy as np
@ -31,7 +31,6 @@ class GTSRB(data.Dataset):
def __len__(self): def __len__(self):
return self.n_data return self.n_data
def get_gtsrb(dataset_root, batch_size, train): def get_gtsrb(dataset_root, batch_size, train):
"""Get GTSRB datasets loader.""" """Get GTSRB datasets loader."""
shuffle_dataset = True shuffle_dataset = True
@ -51,7 +50,8 @@ def get_gtsrb(dataset_root, batch_size, train):
dataset_size = len(gtsrb_dataset) dataset_size = len(gtsrb_dataset)
indices = list(range(dataset_size)) indices = list(range(dataset_size))
if shuffle_dataset: if shuffle_dataset:
np.random.seed(random_seed)
#np.random.seed(random_seed)
np.random.seed()
np.random.shuffle(indices) np.random.shuffle(indices)
train_indices, val_indices = indices[:train_size], indices[train_size:] train_indices, val_indices = indices[:train_size], indices[train_size:]

Loading…
Cancel
Save