diff --git a/models/model.py b/models/model.py index 7fd3f20..3d3b651 100644 --- a/models/model.py +++ b/models/model.py @@ -141,7 +141,7 @@ class SVHNmodel(nn.Module): nn.Linear(256, 2), ) - def forward(self, input_data, alpha): + def forward(self, input_data, alpha = 1.0): input_data = input_data.expand(input_data.data.shape[0], 3, 28, 28) feature = self.feature(input_data) feature = feature.view(-1, 64 * 4 * 4)