diff --git a/core/train.py b/core/train.py index 6ad0b6f..8da0c46 100644 --- a/core/train.py +++ b/core/train.py @@ -21,7 +21,7 @@ def train_dann(model, params, src_data_loader, tgt_data_loader, tgt_data_loader_ if not params.finetune_flag: print("training non-office task") - optimizer = optim.SGD(model.parameters(), lr=params.lr, momentum=params.momentum) + optimizer = optim.SGD(model.parameters(), lr=params.lr, momentum=params.momentum, weight_decay=params.weight_decay) else: print("training office task") parameter_list = [{