|
|
@ -49,7 +49,6 @@ def get_synsigns(dataset_root, batch_size, train): |
|
|
|
]) |
|
|
|
|
|
|
|
# datasets and data_loader |
|
|
|
# using first 90K samples as training set |
|
|
|
train_list = os.path.join(dataset_root, 'train_labelling.txt') |
|
|
|
synsigns_dataset = GetLoader( |
|
|
|
data_root=os.path.join(dataset_root), |
|
|
@ -57,6 +56,6 @@ def get_synsigns(dataset_root, batch_size, train): |
|
|
|
transform=pre_process) |
|
|
|
|
|
|
|
synsigns_dataloader = torch.utils.data.DataLoader( |
|
|
|
dataset=synsigns_dataset, batch_size=batch_size, shuffle=True, num_workers=0) |
|
|
|
dataset=synsigns_dataset, batch_size=batch_size, shuffle=True, num_workers=8) |
|
|
|
|
|
|
|
return synsigns_dataloader |