-
Notifications
You must be signed in to change notification settings - Fork 1
/
test.py
140 lines (108 loc) · 4.81 KB
/
test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
import tensorflow as tf
import torch
import numpy as np
import unittest
from l4tensorflow.L4 import L4Adam as L4AdamTf
from l4tensorflow.L4 import L4Mom as L4MomTf
from l4pytorch import L4Adam as L4AdamTh
from l4pytorch import L4Mom as L4MomTh
class ModelTh(torch.nn.Module):
def __init__(self, w1, b1, w2, b2):
super(ModelTh, self).__init__()
self.linear1 = torch.nn.Linear(w1.shape[0], w1.shape[1], bias=True)
self.linear2 = torch.nn.Linear(w2.shape[0], w2.shape[1], bias=True)
self.relu = torch.nn.ReLU()
self.loss = torch.nn.CrossEntropyLoss()
self._init(w1, b1, w2, b2)
def _init(self, w1, b1, w2, b2):
self.linear1.weight.data.copy_(torch.from_numpy(w1).t())
self.linear1.bias.data.copy_(torch.from_numpy(b1))
self.linear2.weight.data.copy_(torch.from_numpy(w2).t())
self.linear2.bias.data.copy_(torch.from_numpy(b2))
def forward(self, x, y):
o = self.linear1(x)
o = self.relu(o)
o = self.linear2(o)
o = self.loss(o, y)
return o
def get_weights(self):
w1 = self.linear1.weight.data.t().numpy()
b1 = self.linear1.bias.data.numpy()
w2 = self.linear2.weight.data.t().numpy()
b2 = self.linear2.bias.data.numpy()
return w1, b1, w2, b2
class ModelTf():
def __init__(self, w1, b1, w2, b2):
with tf.variable_scope("weights"):
X = tf.placeholder("float64", [None, w1.shape[0]])
Y = tf.placeholder("float64", [None, w2.shape[1]])
O = tf.add(tf.matmul(X, tf.Variable(w1, name='W1')), tf.Variable(b1, name='B1'))
O = tf.nn.relu(O)
O = tf.add(tf.matmul(O, tf.Variable(w2, name='W2')), tf.Variable(b2, name='B2'))
O = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits_v2(logits=O, labels=Y))
self.X = X
self.Y = Y
self.out = O
def get_weights(self):
variables = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope="weights")
w1, b1, w2, b2 = variables[:4]
return w1.eval(), b1.eval(), w2.eval(), b2.eval()
class Test(unittest.TestCase):
def setUp(self):
np.random.seed(1234)
self.iterations = 50
self.n_samples = 10
self.h1 = 5
self.h2 = 8
self.n_classes = 3
self.x = np.random.normal(0, 1, size=(self.n_samples, self.h1))
self.y = np.random.randint(0, self.n_classes, size=(self.n_samples,))
self.w1 = np.random.normal(0, 1, size=(self.h1, self.h2))
self.b1 = np.random.normal(0, 1, size=(self.h2,))
self.w2 = np.random.normal(0, 1, size=(self.h2, self.n_classes))
self.b2 = np.random.normal(0, 1, size=(self.n_classes,))
self.setup_th()
self.setup_tf()
def setup_th(self):
torch.set_default_dtype(torch.double)
self.model_th = ModelTh(self.w1, self.b1, self.w2, self.b2)
self.feed_dict_th = {'x': torch.from_numpy(self.x), 'y': torch.from_numpy(self.y)}
def setup_tf(self):
tf.reset_default_graph()
self.model_tf = ModelTf(self.w1, self.b1, self.w2, self.b2)
y_one_hot = np.zeros((self.n_samples, self.n_classes), dtype=np.float32)
y_one_hot[np.arange(0, self.n_samples), self.y] = 1
self.feed_dict_tf = {self.model_tf.X: self.x, self.model_tf.Y: y_one_hot}
def check_model_weights(self, it, atol=1e-6, rtol=1e-10):
state_tf = self.model_tf.get_weights()
state_th = self.model_th.get_weights()
for w1, w2 in zip(state_tf, state_th):
assert np.allclose(w1, w2, atol=atol, rtol=rtol), "Failed check at iteration {}".format(it)
def test_l4adam(self):
optimizer_tf = L4AdamTf(dtype=tf.float64)
optimizer_th = L4AdamTh(self.model_th.parameters())
self.train_with(optimizer_tf, optimizer_th)
def test_l4mom(self):
optimizer_tf = L4MomTf(dtype=tf.float64)
optimizer_th = L4MomTh(self.model_th.parameters())
self.train_with(optimizer_tf, optimizer_th)
def train_with(self, optimizer_tf, optimizer_th):
train_op = optimizer_tf.minimize(self.model_tf.out)
with tf.Session() as sess:
tf.global_variables_initializer().run()
# check initial weights
self.check_model_weights(it=0)
# tf loop
for i in range(self.iterations):
# tf iteration
_, loss = sess.run((train_op, self.model_tf.out),
feed_dict=self.feed_dict_tf)
# th iteration
loss = self.model_th(**self.feed_dict_th)
optimizer_th.zero_grad()
loss.backward()
optimizer_th.step(lambda: loss)
# check weights
self.check_model_weights(it=i + 1)
if __name__ == '__main__':
unittest.main()