-
Notifications
You must be signed in to change notification settings - Fork 19
/
geco.py
51 lines (45 loc) · 1.92 KB
/
geco.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
# =========================== A2I Copyright Header ===========================
#
# Copyright (c) 2003-2021 University of Oxford. All rights reserved.
# Authors: Applied AI Lab, Oxford Robotics Institute, University of Oxford
# https://ori.ox.ac.uk/labs/a2i/
#
# This file is the property of the University of Oxford.
# Redistribution and use in source and binary forms, with or without
# modification, is not permitted without an explicit licensing agreement
# (research or commercial). No warranty, explicit or implicit, provided.
#
# =========================== A2I Copyright Header ===========================
import torch
class GECO():
def __init__(self, goal, step_size, alpha=0.99, beta_init=1.0,
beta_min=1e-10, speedup=None):
self.err_ema = None
self.goal = goal
self.step_size = step_size
self.alpha = alpha
self.beta = torch.tensor(beta_init)
self.beta_min = torch.tensor(beta_min)
self.beta_max = torch.tensor(1e10)
self.speedup = speedup
def to_cuda(self):
self.beta = self.beta.cuda()
if self.err_ema is not None:
self.err_ema = self.err_ema.cuda()
def loss(self, err, kld):
# Compute loss with current beta
loss = err + self.beta * kld
# Update beta without computing / backpropping gradients
with torch.no_grad():
if self.err_ema is None:
self.err_ema = err
else:
self.err_ema = (1.0-self.alpha)*err + self.alpha*self.err_ema
constraint = (self.goal - self.err_ema)
if self.speedup is not None and constraint.item() > 0:
factor = torch.exp(self.speedup * self.step_size * constraint)
else:
factor = torch.exp(self.step_size * constraint)
self.beta = (factor * self.beta).clamp(self.beta_min, self.beta_max)
# Return loss
return loss