-
Notifications
You must be signed in to change notification settings - Fork 8
/
layers.py
113 lines (97 loc) · 4.15 KB
/
layers.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
File name: layers.py
Author: locke
Date created: 2018/10/5 下午2:41
"""
from __future__ import absolute_import
from __future__ import unicode_literals
from __future__ import division
from __future__ import print_function
import math
import torch
import torch.nn as nn
from torch.nn.parameter import Parameter
from torch.nn.modules.module import Module
import torch.nn.functional as F
class SpecialSpmmFunction(torch.autograd.Function):
"""Special function for only sparse region backpropataion layer."""
@staticmethod
def forward(ctx, indices, values, shape, b):
assert indices.requires_grad == False
a = torch.sparse_coo_tensor(indices, values, shape)
ctx.save_for_backward(a, b)
ctx.N = shape[0]
return torch.matmul(a, b)
@staticmethod
def backward(ctx, grad_output):
a, b = ctx.saved_tensors
grad_values = grad_b = None
if ctx.needs_input_grad[1]:
grad_a_dense = grad_output.matmul(b.t())
edge_idx = a._indices()[0, :] * ctx.N + a._indices()[1, :]
grad_values = grad_a_dense.view(-1)[edge_idx]
if ctx.needs_input_grad[3]:
grad_b = a.t().matmul(grad_output)
return None, grad_values, None, grad_b
class SpecialSpmm(nn.Module):
def forward(self, indices, values, shape, b):
return SpecialSpmmFunction.apply(indices, values, shape, b)
class MultiHeadGraphAttention(nn.Module):
"""
Sparse version GAT layer, similar to https://arxiv.org/abs/1710.10903
https://github.com/Diego999/pyGAT/blob/master/layers.py
"""
def __init__(self, n_head, f_in, f_out, attn_dropout, diag=True, init=None, bias=False):
super(MultiHeadGraphAttention, self).__init__()
self.n_head = n_head
self.f_in = f_in
self.f_out = f_out
self.diag = diag
if self.diag:
self.w = Parameter(torch.Tensor(n_head, 1, f_out))
else:
self.w = Parameter(torch.Tensor(n_head, f_in, f_out))
self.a_src_dst = Parameter(torch.Tensor(n_head, f_out * 2, 1))
self.attn_dropout = attn_dropout
self.leaky_relu = nn.LeakyReLU(negative_slope=0.2)
self.special_spmm = SpecialSpmm()
if bias:
self.bias = Parameter(torch.Tensor(f_out))
nn.init.constant_(self.bias, 0)
else:
self.register_parameter('bias', None)
if init is not None and diag:
init(self.w)
stdv = 1. / math.sqrt(self.a_src_dst.size(1))
nn.init.uniform_(self.a_src_dst, -stdv, stdv)
else:
nn.init.xavier_uniform_(self.w)
nn.init.xavier_uniform_(self.a_src_dst)
def forward(self, input, adj):
output = []
for i in range(self.n_head):
N = input.size()[0]
edge = adj._indices()
if self.diag:
h = torch.mul(input, self.w[i])
else:
h = torch.mm(input, self.w[i])
edge_h = torch.cat((h[edge[0, :], :], h[edge[1, :], :]), dim=1) # edge: 2*D x E
edge_e = torch.exp(-self.leaky_relu(edge_h.mm(self.a_src_dst[i]).squeeze())) # edge_e: 1 x E
e_rowsum = self.special_spmm(edge, edge_e, torch.Size([N, N]), torch.ones(size=(N, 1)).cuda() if next(self.parameters()).is_cuda else torch.ones(size=(N, 1))) # e_rowsum: N x 1
edge_e = F.dropout(edge_e, self.attn_dropout, training=self.training) # edge_e: 1 x E
h_prime = self.special_spmm(edge, edge_e, torch.Size([N, N]), h)
h_prime = h_prime.div(e_rowsum)
output.append(h_prime.unsqueeze(0))
output = torch.cat(output, dim=0)
if self.bias is not None:
return output + self.bias
else:
return output
def __repr__(self):
if self.diag:
return self.__class__.__name__ + ' (' + str(self.f_out) + ' -> ' + str(self.f_out) + ') * ' + str(self.n_head) + ' heads'
else:
return self.__class__.__name__ + ' (' + str(self.f_in) + ' -> ' + str(self.f_out) + ') * ' + str(self.n_head) + ' heads'