forked from snath-xoc/cGAN_tutorial
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathsetupmodel.py
More file actions
executable file
·116 lines (110 loc) · 3.18 KB
/
setupmodel.py
File metadata and controls
executable file
·116 lines (110 loc) · 3.18 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
import gc
import tensorflow as tf
from tensorflow import keras
from keras.optimizers import Adam
from model import Deterministic, WGANGP, VAE, generator, discriminator
def setup_model(
*,
mode=None,
arch=None,
downscaling_steps=None,
input_channels=None,
constant_fields=None,
filters_gen=None,
filters_disc=None,
noise_channels=None,
latent_variables=None,
padding=None,
kl_weight=None,
ensemble_size=None,
CLtype=None,
content_loss_weight=None,
lr_disc=None,
lr_gen=None
):
if mode in ("GAN", "VAEGAN"):
gen_to_use = {
"normal": generator,
"forceconv": generator,
"forceconv-long": generator,
}[arch]
disc_to_use = {
"normal": discriminator,
"forceconv": discriminator,
"forceconv-long": discriminator,
}[arch]
elif mode == "det":
gen_to_use = {"normal": generator, "forceconv": generator}[arch]
if mode == "GAN":
gen = gen_to_use(
mode=mode,
arch=arch,
downscaling_steps=downscaling_steps,
input_channels=input_channels,
constant_fields=constant_fields,
filters_gen=filters_gen,
noise_channels=noise_channels,
padding=padding,
)
disc = disc_to_use(
arch=arch,
downscaling_steps=downscaling_steps,
input_channels=input_channels,
constant_fields=constant_fields,
filters_disc=filters_disc,
padding=padding,
)
model = WGANGP(
gen,
disc,
mode,
lr_disc=lr_disc,
lr_gen=lr_gen,
ensemble_size=ensemble_size,
CLtype=CLtype,
content_loss_weight=content_loss_weight,
)
elif mode == "VAEGAN":
encoder, decoder = gen_to_use(
mode=mode,
arch=arch,
downscaling_steps=downscaling_steps,
input_channels=input_channels,
constant_fields=constant_fields,
filters_gen=filters_gen,
latent_variables=latent_variables,
padding=padding,
)
disc = disc_to_use(
arch=arch,
downscaling_steps=downscaling_steps,
input_channels=input_channels,
constant_fields=constant_fields,
filters_disc=filters_disc,
padding=padding,
)
gen = VAE(encoder, decoder)
model = WGANGP(
gen,
disc,
mode,
lr_disc=lr_disc,
lr_gen=lr_gen,
kl_weight=kl_weight,
ensemble_size=ensemble_size,
CLtype=CLtype,
content_loss_weight=content_loss_weight,
)
elif mode == "det":
gen = gen_to_use(
mode=mode,
arch=arch,
downscaling_steps=downscaling_steps,
input_channels=input_channels,
constant_fields=constant_fields,
filters_gen=filters_gen,
padding=padding,
)
model = Deterministic(gen, lr=lr_gen, loss="mse", optimizer=Adam)
gc.collect()
return model