-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathpostprocessing.py
More file actions
executable file
·145 lines (112 loc) · 5.18 KB
/
postprocessing.py
File metadata and controls
executable file
·145 lines (112 loc) · 5.18 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
import torch
import copy
import os
import argparse
from argparse import Namespace
from util import parse_bool, parse_int_list, set_seed
from train_fns import get_val_loss, get_model_and_optimizer_context
from dataset_context import get_loaders_synthetic_with_context
def get_row_embeds(config, model, loader, device):
with torch.no_grad():
row_embeds = []
for batch in loader:
for k,v in batch.items():
batch[k] = v.to(device)
row_embeds.append(batch['Z'])
return torch.concatenate(row_embeds)
def get_row_embeds_fname(out_dir, prefix='best_loss_'):
return out_dir + f'/{prefix}row_embeds.pt'
def get_predictions_fname(out_dir, prefix='best_loss_'):
return out_dir + f'/{prefix}predictions.pt'
def get_posterior_samples_fname(out_dir, prefix='best_loss_'):
return out_dir + f'/{prefix}posterior_samples.pt'
def save_row_embeds(config, model, loader_dict, out_dir, device, loader_names, prefix='best_loss_', recalc=False):
save_fname = get_row_embeds_fname(out_dir, prefix)
if not recalc and os.path.exists(save_fname):
res = torch.load(save_fname, map_location='cpu')
else:
res = {}
with torch.no_grad():
for loader_name in loader_names:
if loader_name in res.keys(): continue
print(f'Saving row embeds for {loader_name}')
res[loader_name] = get_row_embeds(config, model, loader_dict[loader_name+'_loader'], device)
torch.save(res, save_fname)
return res
def save_model_predictions(model, loader_dict, out_dir, device, loader_names, prefix='best_loss_',
recalc=False, embed_data=False,
use_X_model=False):
save_fname = get_predictions_fname(out_dir, prefix)
if not recalc and os.path.exists(save_fname):
res = torch.load(save_fname, map_location='cpu')
else:
res = {}
for loader_name in loader_names:
if loader_name in res.keys(): continue
print(f'Saving model loss for {loader_name}')
loader = loader_dict[loader_name+'_loader']
res[loader_name] = get_val_loss(model, loader, device, embed_data=embed_data,
use_X_model=use_X_model)
torch.save(res, save_fname)
return res
def get_device(gpu):
if gpu is not None and int(gpu) >= 0:
return torch.device(f'cuda:{gpu}')
else:
return torch.device('cuda' if torch.cuda.is_available() else 'cpu')
def load_old_model(config, sd, check=None):
model, optimizer = get_model_and_optimizer_context(config)
model.to('cpu')
model.load_state_dict(sd)
return model
def do_postprocessing(args):
print(f'ARGS: {args}')
device = args.device
check = torch.load(args.run_dir + '/best_loss.pt', map_location='cpu')
config = copy.deepcopy(check['config'])
if not hasattr(config, 'embed_data_dir'):
setattr(config, 'embed_data_dir', False)
setattr(config, 'batch_size', args.batch_size)
setattr(config, 'device', args.device)
model = load_old_model(config, check['state_dict'], check)
set_seed(config.seed)
loaders = get_loaders_synthetic_with_context(config, train_deterministic_row_order=True, extras=True)
model.to(device)
model.eval()
recalc = args.postproc_force_recalc
all_loader_names = [x.split('_loader')[0] for x in loaders.keys() if x.endswith('_loader')]
if not hasattr(config, 'embed_data_dir'):
config.embed_data_dir=False
# this will probably do some unnecessary computation, e.g. on val set,
# where those outputs are already in e.g. best_loss.pt (but I think we don't care)
predictions = save_model_predictions(model, loaders,
args.run_dir, device, all_loader_names, recalc=recalc,
embed_data = config.embed_data_dir,
use_X_model=config.use_X_model)
print(predictions['train']['theta_hats'].shape)
row_embeds = save_row_embeds(config, model, loaders,
args.run_dir, device, all_loader_names, recalc=recalc)
def add_default_postproc_params(parser):
parser.add_argument('--postproc_force_recalc', type=parse_bool, default=True) # true for actual usage
parser.add_argument('--post_sample_all_num_prev_obs', type=parse_int_list,
help='an integer or a list of integers separated by commas',
default=[0,1,2,5,10,25])
parser.add_argument('--post_sample_num_repetitions', type=int, default=250)
parser.add_argument('--post_sample_num_imagined', type=int, default=500)
def main():
parser = argparse.ArgumentParser()
parser.add_argument('--run_dir', type=str, help='directory with model outputs')
parser.add_argument('--gpu', type=int, default=None)
parser.add_argument('--batch_size', type=int, default=100)
parser.add_argument('--wandb_entity', default='ps-autoregressive')
add_default_postproc_params(parser)
args = parser.parse_args()
import wandb
wandb.login()
wandb.init(project='postprocessing', entity=args.wandb_entity)
device = get_device(args.gpu)
args.device = device
print(f'Arg gpu: {args.gpu}, Device: {device}')
do_postprocessing(args)
if __name__ == "__main__":
main()