|
@ -61,6 +61,12 @@ class Config(object): |
|
|
momentum = 0.9 |
|
|
momentum = 0.9 |
|
|
weight_decay = 1e-6 |
|
|
weight_decay = 1e-6 |
|
|
|
|
|
|
|
|
|
|
|
def __init__(self): |
|
|
|
|
|
public_props = (name for name in dir(self) if not name.startswith('_')) |
|
|
|
|
|
with open(self.config, 'w') as f: |
|
|
|
|
|
for name in public_props: |
|
|
|
|
|
f.write(name + ': ' + str(getattr(self, name)) + '\n') |
|
|
|
|
|
|
|
|
params = Config() |
|
|
params = Config() |
|
|
logger = SummaryWriter(params.model_root) |
|
|
logger = SummaryWriter(params.model_root) |
|
|
device = torch.device("cuda:" + params.gpu_id if torch.cuda.is_available() else "cpu") |
|
|
device = torch.device("cuda:" + params.gpu_id if torch.cuda.is_available() else "cpu") |
|
|