Browse Source

minor update, remove some params

master
wogong 7 years ago
parent
commit
837812da0c
  1. 2
      .gitignore
  2. 5
      datasets/mnist.py
  3. 7
      datasets/mnistm.py
  4. 7
      datasets/office.py
  5. 7
      datasets/officecaltech.py
  6. 5
      datasets/svhn.py
  7. 43
      models/model.py
  8. 9
      params.py

2
.gitignore

@ -3,6 +3,7 @@ __pycache__/
*.py[cod] *.py[cod]
*$py.class *$py.class
_gsdata_ _gsdata_
*.ipynb
# C extensions # C extensions
*.so *.so
@ -105,3 +106,4 @@ ENV/
.idea .idea
.DS_Store .DS_Store
main_legacy.py main_legacy.py
test.ipynb

5
datasets/mnist.py

@ -12,8 +12,9 @@ def get_mnist(train):
# image pre-processing # image pre-processing
pre_process = transforms.Compose([transforms.ToTensor(), pre_process = transforms.Compose([transforms.ToTensor(),
transforms.Normalize( transforms.Normalize(
mean=params.dataset_mean,
std=params.dataset_std)])
mean=(0.5, 0.5, 0.5),
std=(0.5, 0.5, 0.5)
)])
# datasets and data loader # datasets and data loader
mnist_dataset = datasets.MNIST(root=os.path.join(params.dataset_root,'mnist'), mnist_dataset = datasets.MNIST(root=os.path.join(params.dataset_root,'mnist'),

7
datasets/mnistm.py

@ -41,11 +41,12 @@ class GetLoader(data.Dataset):
def get_mnistm(train): def get_mnistm(train):
"""Get MNISTM datasets loader.""" """Get MNISTM datasets loader."""
# image pre-processing # image pre-processing
pre_process = transforms.Compose([transforms.Resize(params.digit_image_size),
pre_process = transforms.Compose([transforms.Resize(28),
transforms.ToTensor(), transforms.ToTensor(),
transforms.Normalize( transforms.Normalize(
mean=params.dataset_mean,
std=params.dataset_std)])
mean=(0.5, 0.5, 0.5),
std=(0.5, 0.5, 0.5)
)])
# datasets and data_loader # datasets and data_loader
if train: if train:

7
datasets/office.py

@ -13,8 +13,9 @@ def get_office(train, category):
pre_process = transforms.Compose([transforms.Resize(params.office_image_size), pre_process = transforms.Compose([transforms.Resize(params.office_image_size),
transforms.ToTensor(), transforms.ToTensor(),
transforms.Normalize( transforms.Normalize(
mean=params.imagenet_dataset_mean,
std=params.imagenet_dataset_mean)])
mean=(0.485, 0.456, 0.406),
std=(0.229, 0.224, 0.225)
)])
# datasets and data_loader # datasets and data_loader
office_dataset = datasets.ImageFolder( office_dataset = datasets.ImageFolder(
@ -25,6 +26,6 @@ def get_office(train, category):
dataset=office_dataset, dataset=office_dataset,
batch_size=params.batch_size, batch_size=params.batch_size,
shuffle=True, shuffle=True,
num_workers=8)
num_workers=4)
return office_dataloader return office_dataloader

7
datasets/officecaltech.py

@ -13,8 +13,9 @@ def get_officecaltech(train, category):
pre_process = transforms.Compose([transforms.Resize(params.office_image_size), pre_process = transforms.Compose([transforms.Resize(params.office_image_size),
transforms.ToTensor(), transforms.ToTensor(),
transforms.Normalize( transforms.Normalize(
mean=params.imagenet_dataset_mean,
std=params.imagenet_dataset_mean)])
mean=(0.485, 0.456, 0.406),
std=(0.229, 0.224, 0.225)
)])
# datasets and data_loader # datasets and data_loader
officecaltech_dataset = datasets.ImageFolder( officecaltech_dataset = datasets.ImageFolder(
@ -25,6 +26,6 @@ def get_officecaltech(train, category):
dataset=officecaltech_dataset, dataset=officecaltech_dataset,
batch_size=params.batch_size, batch_size=params.batch_size,
shuffle=True, shuffle=True,
num_workers=8)
num_workers=4)
return officecaltech_dataloader return officecaltech_dataloader

5
datasets/svhn.py

@ -15,8 +15,9 @@ def get_svhn(train):
transforms.Resize(params.digit_image_size), transforms.Resize(params.digit_image_size),
transforms.ToTensor(), transforms.ToTensor(),
transforms.Normalize( transforms.Normalize(
mean=params.dataset_mean,
std=params.dataset_std)])
mean=(0.5, 0.5, 0.5),
std=(0.5, 0.5, 0.5)
)])
# datasets and data loader # datasets and data loader
if train: if train:

43
models/model.py

@ -3,7 +3,6 @@
import torch.nn as nn import torch.nn as nn
from .functions import ReverseLayerF from .functions import ReverseLayerF
from torchvision import models from torchvision import models
import params
class Classifier(nn.Module): class Classifier(nn.Module):
@ -43,6 +42,45 @@ class Classifier(nn.Module):
return class_output return class_output
class MNISTmodel(nn.Module):
""" MNIST architecture"""
def __init__(self):
super(MNISTmodel, self).__init__()
self.restored = False
self.feature = nn.Sequential(
nn.Conv2d(in_channels=1, out_channels=32, kernel_size=(5, 5)), # 1 28 28, 32 24 24
nn.ReLU(inplace=True),
nn.MaxPool2d(kernel_size=(2, 2)), # 32 12 12
nn.Conv2d(in_channels=32, out_channels=48, kernel_size=(5, 5)), # 48 8 8
nn.ReLU(inplace=True),
nn.MaxPool2d(kernel_size=(2, 2)), # 48 4 4
)
self.classifier = nn.Sequential(
nn.Linear(48*4*4, 100),
nn.ReLU(inplace=True),
nn.Linear(100, 100),
nn.ReLU(inplace=True),
nn.Linear(100, 10),
)
self.discriminator = nn.Sequential(
nn.Linear(48*4*4, 100),
nn.ReLU(inplace=True),
nn.Linear(100, 2),
)
def forward(self, input_data, alpha):
input_data = input_data.expand(input_data.data.shape[0], 1, 28, 28)
feature = self.feature(input_data)
feature = feature.view(-1, 48 * 4 * 4)
reverse_feature = ReverseLayerF.apply(feature, alpha)
class_output = self.classifier(feature)
domain_output = self.discriminator(reverse_feature)
return class_output, domain_output
class SVHNmodel(nn.Module): class SVHNmodel(nn.Module):
""" SVHN architecture""" """ SVHN architecture"""
@ -90,7 +128,6 @@ class SVHNmodel(nn.Module):
return class_output, domain_output return class_output, domain_output
class AlexModel(nn.Module): class AlexModel(nn.Module):
""" AlexNet pretrained on imagenet for Office dataset""" """ AlexNet pretrained on imagenet for Office dataset"""
@ -119,7 +156,7 @@ class AlexModel(nn.Module):
nn.Dropout(0.5), nn.Dropout(0.5),
nn.Linear(4096, 256), nn.Linear(4096, 256),
nn.ReLU(inplace=True), nn.ReLU(inplace=True),
nn.Linear(256, params.class_num_src),
nn.Linear(256, 31),
) )
self.discriminator = nn.Sequential( self.discriminator = nn.Sequential(

9
params.py

@ -7,16 +7,9 @@ dataset_root = os.path.expanduser(os.path.join('~', 'Datasets'))
model_root = os.path.expanduser(os.path.join('~', 'Models', 'pytorch-DANN')) model_root = os.path.expanduser(os.path.join('~', 'Models', 'pytorch-DANN'))
# params for datasets and data loader # params for datasets and data loader
dataset_mean_value = 0.5
dataset_std_value = 0.5
dataset_mean = (dataset_mean_value, dataset_mean_value, dataset_mean_value)
dataset_std = (dataset_std_value, dataset_std_value, dataset_std_value)
imagenet_dataset_mean = (0.485, 0.456, 0.406)
imagenet_dataset_std = (0.229, 0.224, 0.225)
batch_size = 64 batch_size = 64
digit_image_size = 28
office_image_size = 227 office_image_size = 227
# params for source dataset # params for source dataset

Loading…
Cancel
Save