-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathget_config.py
More file actions
135 lines (113 loc) · 6.72 KB
/
get_config.py
File metadata and controls
135 lines (113 loc) · 6.72 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
# ---------------------------------------------------------------
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
#
# This work is licensed under the NVIDIA Source Code License
# for DiffPure. To view a copy of this license, see the LICENSE file.
# ---------------------------------------------------------------
import argparse
import logging
import yaml
import os
import time
import random
import numpy as np
import torch
def dict2namespace(config):
namespace = argparse.Namespace()
for key, value in config.items():
if isinstance(value, dict):
new_value = dict2namespace(value)
else:
new_value = value
setattr(namespace, key, new_value)
return namespace
def str2bool(v):
if isinstance(v, bool):
return v
if v.lower() in ('yes', 'true', 't', 'y', '1'):
return True
elif v.lower() in ('no', 'false', 'f', 'n', '0'):
return False
else:
raise argparse.ArgumentTypeError('Boolean value expected.')
def parse_args_and_config(mode='test'):
parser = argparse.ArgumentParser(description=globals()['__doc__'])
# training setting
if mode == 'train':
parser.add_argument('--wm_type', type=str, default='combine', choices=['dwtDct', 'dwtDctSvd', 'rivaGan', 'Digimarc', 'HiDDeN', 'StegaStamp', 'combine'], help='Type of watermark')
parser.add_argument('--length', type=int, default=4, help='Watermark length')
parser.add_argument('--color_space', type=str, default='rgb', choices=['rgb', 'yuv'], help='Color space: rgb or yuv')
parser.add_argument('--size', type=int, default=256, help='Image size')
parser.add_argument('--total_step', type=int, default=10000, help='Total step')
parser.add_argument('--bs', type=int, default=3, help='Batch size')
parser.add_argument('--d_bs_mult', type=int, default=3, help='Discriminator batch size multiplier')
parser.add_argument('--lr', type=float, default=1e-4, help='Learning rate')
parser.add_argument('--d_lr_mult', type=int, default=1, help='Discriminator Learning rate multiplier')
parser.add_argument('--scheduler', type=str, default='None', choices=['None', 'LambdaLR', 'CosineAnnealingWarmRestarts', 'ExponentialLR'], help='Scheduler for learning rate')
parser.add_argument('--d_step', type=int, default=1, help='Discriminator step')
parser.add_argument('--d_adv_step', type=int, default=1, help='Discriminator adversarial step')
parser.add_argument('--factor', type=float, default=0.25, help='Factor of loss')
parser.add_argument('--lambda_mse', type=float, default=1.0, help='Lambda for mse')
parser.add_argument('--lambda_lpips', type=float, default=0.03, help='Lambda for lpips')
parser.add_argument('--lambda_adv', type=float, default=0.0002, help='Lambda for adversarial loss')
parser.add_argument('--not_val', action='store_true', help='No validation')
# diffusion models
parser.add_argument('--config', type=str, default='celeba.yml', help='Path to the config file')
parser.add_argument('--data_seed', type=int, default=0, help='Random seed')
parser.add_argument('--seed', type=int, default=1234, help='Random seed')
parser.add_argument('--exp', type=str, default='./exp_results', help='Path for saving running related data.')
parser.add_argument('--verbose', type=str, default='info', help='Verbose level: info | debug | warning | critical')
parser.add_argument('-i', '--image_folder', type=str, default='images', help="The folder name of samples")
parser.add_argument('--ni', action='store_true', help="No interaction. Suitable for Slurm Job launcher")
parser.add_argument('--sample_step', type=int, default=1, help='Total sampling steps')
parser.add_argument('--t', type=int, default=30, help='Sampling noise scale')
parser.add_argument('--t_delta', type=int, default=15, help='Perturbation range of sampling noise scale')
parser.add_argument('--rand_t', type=str2bool, default=False, help='Decide if randomize sampling noise scale')
parser.add_argument('--diffusion_type', type=str, default='sde', help='[ddpm, sde]')
parser.add_argument('--score_type', type=str, default='guided_diffusion', help='[guided_diffusion, score_sde]')
parser.add_argument('--eot_iter', type=int, default=20, help='only for rand version of autoattack')
parser.add_argument('--use_bm', action='store_true', help='whether to use brownian motion')
parser.add_argument('--log_dir', type=str, default='output')
# LDSDE
parser.add_argument('--sigma2', type=float, default=1e-3, help='LDSDE sigma2')
parser.add_argument('--lambda_ld', type=float, default=1e-2, help='lambda_ld')
parser.add_argument('--eta', type=float, default=5., help='LDSDE eta')
parser.add_argument('--step_size', type=float, default=1e-3, help='step size for ODE Euler method')
# adv
parser.add_argument('--domain', type=str, default='imagenet', help='which domain: celebahq, cat, car, imagenet')
parser.add_argument('--classifier_name', type=str, default='Eyeglasses', help='which classifier to use')
parser.add_argument('--partition', type=str, default='val')
parser.add_argument('--adv_batch_size', type=int, default=1)
parser.add_argument('--attack_type', type=str, default='square')
parser.add_argument('--lp_norm', type=str, default='Linf', choices=['Linf', 'L2'])
parser.add_argument('--attack_version', type=str, default='rand')
parser.add_argument('--num_sub', type=int, default=16, help='imagenet subset')
parser.add_argument('--adv_eps', type=float, default=0.07)
args = parser.parse_args()
# parse config file
with open(os.path.join('configs', args.config), 'r') as f:
config = yaml.safe_load(f)
new_config = dict2namespace(config)
level = getattr(logging, args.verbose.upper(), None)
if not isinstance(level, int):
raise ValueError('level {} not supported'.format(args.verbose))
handler1 = logging.StreamHandler()
formatter = logging.Formatter('%(levelname)s - %(filename)s - %(asctime)s - %(message)s')
handler1.setFormatter(formatter)
logger = logging.getLogger()
logger.addHandler(handler1)
logger.setLevel(level)
# args.image_folder = os.path.join(args.exp, args.image_folder)
# os.makedirs(args.image_folder, exist_ok=True)
# add device
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
logging.info("Using device: {}".format(device))
new_config.device = device
# set random seed
torch.manual_seed(args.seed)
random.seed(args.seed)
np.random.seed(args.seed)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(args.seed)
torch.backends.cudnn.benchmark = True
return args, new_config