Browse Source

using random seed.

master
wogong 5 years ago
parent
commit
793969a3ca
  1. 6
      datasets/gtsrb.py

6
datasets/gtsrb.py

@ -3,7 +3,7 @@
import os
import torch
from torchvision import datasets, transforms
from torchvision import transforms
import torch.utils.data as data
from torch.utils.data.sampler import SubsetRandomSampler
import numpy as np
@ -31,7 +31,6 @@ class GTSRB(data.Dataset):
def __len__(self):
return self.n_data
def get_gtsrb(dataset_root, batch_size, train):
"""Get GTSRB datasets loader."""
shuffle_dataset = True
@ -51,7 +50,8 @@ def get_gtsrb(dataset_root, batch_size, train):
dataset_size = len(gtsrb_dataset)
indices = list(range(dataset_size))
if shuffle_dataset:
np.random.seed(random_seed)
#np.random.seed(random_seed)
np.random.seed()
np.random.shuffle(indices)
train_indices, val_indices = indices[:train_size], indices[train_size:]

Loading…
Cancel
Save