if __name__ == '__main__':
from utils import boolean_string
legal_models = ['resnet18', 'resnet34', 'mobilenet_v2', 'shufflenet_v2_x1_0', 'squeezenet1_1', 'densenet121', 'googlenet', 'resnext50_32x4d', 'vgg11']
parser = argparse.ArgumentParser(description='Hands-On GANs - Chapter 8')
parser.add_argument('--model', type=str, default='resnet18', help='one of {}'.format(legal_models))
parser.add_argument('--cuda', type=boolean_string, default=True, help='enable CUDA.')
parser.add_argument('--train_single', type=boolean_string, default=True, help='train single model.')
parser.add_argument('--train_ensemble', type=boolean_string, default=True, help='train final model.')
parser.add_argument('--model_dir', type=str, default='models', help='directory for trained models')
parser.add_argument('--data_dir', type=str, default='/media/john/FastData/cats-dogs-kaggle/train', help='Directory for dataset.')
parser.add_argument('--data_split', type=float, default=0.8, help='split ratio for train and val data')
parser.add_argument('--cutout', type=boolean_string, default=False, help='use cutout')
parser.add_argument('--cutout_length', type=int, default=64, help='number of epochs')
parser.add_argument('--out_dir', type=str, default='output', help='Directory for output.')
parser.add_argument('--epochs', type=int, default=60, help='number of epochs')
parser.add_argument('--batch_size', type=int, default=128, help='size of batches')
parser.add_argument('--lr', type=float, default=0.01, help='learning rate')
parser.add_argument('--classes', type=int, default=2, help='number of classes')
parser.add_argument('--img_size', type=int, default=224, help='size of images')
parser.add_argument('--channels', type=int, default=3, help='number of image channels')
parser.add_argument('--log_interval', type=int, default=50, help='interval between logging and image sampling')
parser.add_argument('--pretrained_epoch', type=int, default=60, help='epoch number of pretrained generator')
parser.add_argument('--seed', type=int, default=1, help='random seed')
FLAGS = parser.parse_args()
FLAGS.cuda = FLAGS.cuda and torch.cuda.is_available()
assert FLAGS.model in legal_models
if FLAGS.seed is not None:
torch.manual_seed(FLAGS.seed)
if FLAGS.cuda:
torch.cuda.manual_seed(FLAGS.seed)
np.random.seed(FLAGS.seed)
cudnn.benchmark = True
try:
import accimage
torchvision.set_image_backend('accimage')
print('Image loader backend: accimage')
except:
print('Image loader backend: PIL')
if FLAGS.train_single:
utils.clear_folder(FLAGS.out_dir)
log_file = os.path.join(FLAGS.out_dir, 'log.txt')
print("Logging to {}\n".format(log_file))
sys.stdout = utils.StdOut(log_file)
print("PyTorch version: {}".format(torch.__version__))
print("CUDA version: {}\n".format(torch.version.cuda))
print(" " * 9 + "Args" + " " * 9 + "| " + "Type" + \
" | " + "Value")
print("-" * 50)
for arg in vars(FLAGS):
arg_str = str(arg)
var_str = str(getattr(FLAGS, arg))
type_str = str(type(getattr(FLAGS, arg)).__name__)
print(" " + arg_str + " " * (20-len(arg_str)) + "|" + \
" " + type_str + " " * (10-len(type_str)) + "|" + \
" " + var_str)
main()