Implementation of "Adversarial Discriminative Domain Adaptation" in PyTorch
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

32 lines
798 B

import os
import shutil
import torch
def save(log_dir, state_dict, is_best):
checkpoint_path = os.path.join(log_dir, ''), checkpoint_path)
if is_best:
best_model_path = os.path.join(log_dir, '')
shutil.copyfile(checkpoint_path, best_model_path)
class AverageMeter(object):
"""Computes and stores the average and current value
def __init__(self):
def reset(self):
self.val = 0
self.avg = 0
self.sum = 0
self.count = 0
def update(self, val, n=1):
self.val = val
self.sum += val * n
self.count += n
self.avg = self.sum / self.count