forked from pancetta/sdc-gym
-
Notifications
You must be signed in to change notification settings - Fork 1
/
dp_playground.py
109 lines (85 loc) · 2.88 KB
/
dp_playground.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
import datetime
from pathlib import Path
import jax
import jax.numpy as jnp
import dp
import utils
def _current_time():
return str(datetime.datetime.now()).replace(':', '-').replace(' ', 'T')
def _get_test_cp_path(logger, model_path):
if logger.steps > 0:
# Load best checkpoint for testing
load_cp_path = logger.best_cp_path
elif model_path is not None:
load_cp_path = model_path
else:
load_cp_path = None
return load_cp_path
def _maybe_load_test_cp(args, load_cp_path, model, params):
if load_cp_path is not None:
params = dp.model_utils.load_weights(load_cp_path)
print(f'Testing model at {load_cp_path}.')
dp.model_utils.check_output_size(args, model, params)
return params
def main():
script_start = _current_time()
args = dp.arguments.parse_args()
dp.setup.configure_jax(args)
utils.setup(True)
eval_seed = dp.setup.get_eval_seed(args.seed)
logger = dp.logging.TrainLogger(args, script_start)
rng_key = jax.random.PRNGKey(args.seed)
problem = dp.setup.get_problem(args)
dataloader, rng_key = dp.setup.get_dataloader(args, problem, rng_key)
params, model, model_arch, rng_key = dp.setup.get_model(
args, logger, rng_key)
opt_state, opt_update, opt_get_params = dp.training.build_opt(
args, params, logger.old_steps)
# We now always want to use `opt_get_params(opt_state)` to obtain
# our parameters.
del params
loss = dp.setup.get_loss_func(
args,
problem,
dataloader,
model,
opt_get_params,
opt_update,
)
update = dp.setup.get_update_func(loss, opt_get_params, opt_update)
steps = int(args.steps)
logger.start_timing()
for (step, (lams, input_data, loss_data)) in enumerate(dataloader):
loss_, opt_state, aux_data, rng_key = update(
jnp.array(step + logger.old_steps),
opt_state,
lams,
input_data,
loss_data,
rng_key,
)
# This also saves a checkpoint for the best model.
logger.log_step(step, opt_get_params(opt_state), model_arch, loss_)
if step >= steps:
break
duration = logger.end_timing()
print(f'Training took {duration} seconds.')
_, model = dp.model.from_model_arch(model_arch, train=False)
params = opt_get_params(opt_state)
if steps > 0:
logger.save_cp(params, model_arch)
load_cp_path = _get_test_cp_path(logger, args.model_path)
params = _maybe_load_test_cp(args, load_cp_path, model, params)
fig_path = Path(f'dp_results_{script_start}.pdf')
dp.testing.run_tests(
model,
params,
args,
seed=eval_seed,
fig_path=fig_path,
loss_func=loss,
# stats_path=args.model_path + '.stats.npz',
)
return load_cp_path
if __name__ == '__main__':
main()