Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Example] implement Simple-HGN in pure cogdl #265

Merged
merged 1 commit into from
Jul 31, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
210 changes: 87 additions & 123 deletions examples/simple_hgn/conv.py
Original file line number Diff line number Diff line change
@@ -1,141 +1,105 @@
"""Torch modules for graph attention networks(GAT)."""
# pylint: disable= no-member, arguments-differ, invalid-name
import torch as th
from torch import nn
import math

from dgl import function as fn
from dgl.nn.pytorch import edge_softmax
from dgl._ffi.base import DGLError
from dgl.nn.pytorch.utils import Identity
from dgl.utils import expand_as_pair
import torch
import torch.nn as nn
import torch.nn.functional as F

from cogdl.utils import check_mh_spmm, mh_spmm, mul_edge_softmax, spmm, get_activation, get_norm_layer


# pylint: enable=W0235
class myGATConv(nn.Module):
"""
Adapted from
https://docs.dgl.ai/_modules/dgl/nn/pytorch/conv/gatconv.html#GATConv
cogdl implementation of Simple-HGN layer
"""

def __init__(
self,
edge_feats,
num_etypes,
in_feats,
out_feats,
num_heads,
feat_drop=0.0,
attn_drop=0.0,
negative_slope=0.2,
residual=False,
activation=None,
allow_zero_in_degree=False,
bias=False,
alpha=0.0,
self, edge_feats, num_etypes, in_features, out_features, nhead, feat_drop=0.0, attn_drop=0.5, negative_slope=0.2, residual=False, activation=None, alpha=0.0
):
super(myGATConv, self).__init__()
self._edge_feats = edge_feats
self._num_heads = num_heads
self._in_src_feats, self._in_dst_feats = expand_as_pair(in_feats)
self._out_feats = out_feats
self._allow_zero_in_degree = allow_zero_in_degree
self.edge_emb = nn.Embedding(num_etypes, edge_feats)
if isinstance(in_feats, tuple):
self.fc_src = nn.Linear(self._in_src_feats, out_feats * num_heads, bias=False)
self.fc_dst = nn.Linear(self._in_dst_feats, out_feats * num_heads, bias=False)
else:
self.fc = nn.Linear(self._in_src_feats, out_feats * num_heads, bias=False)
self.fc_e = nn.Linear(edge_feats, edge_feats * num_heads, bias=False)
self.attn_l = nn.Parameter(th.FloatTensor(size=(1, num_heads, out_feats)))
self.attn_r = nn.Parameter(th.FloatTensor(size=(1, num_heads, out_feats)))
self.attn_e = nn.Parameter(th.FloatTensor(size=(1, num_heads, edge_feats)))
self.edge_feats = edge_feats
self.in_features = in_features
self.out_features = out_features
self.nhead = nhead
self.edge_emb = nn.Parameter(torch.zeros(size=(num_etypes, edge_feats))) # nn.Embedding(num_etypes, edge_feats)

self.W = nn.Parameter(torch.FloatTensor(in_features, out_features * nhead))
self.W_e = nn.Parameter(torch.FloatTensor(edge_feats, edge_feats * nhead))

self.a_l = nn.Parameter(torch.zeros(size=(1, nhead, out_features)))
self.a_r = nn.Parameter(torch.zeros(size=(1, nhead, out_features)))
self.a_e = nn.Parameter(torch.zeros(size=(1, nhead, edge_feats)))

self.feat_drop = nn.Dropout(feat_drop)
self.attn_drop = nn.Dropout(attn_drop)
self.leaky_relu = nn.LeakyReLU(negative_slope)
self.dropout = nn.Dropout(attn_drop)
self.leakyrelu = nn.LeakyReLU(negative_slope)
self.act = None if activation is None else get_activation(activation)

if residual:
if self._in_dst_feats != out_feats:
self.res_fc = nn.Linear(self._in_dst_feats, num_heads * out_feats, bias=False)
else:
self.res_fc = Identity()
self.residual = nn.Linear(in_features, out_features * nhead)
else:
self.register_buffer("res_fc", None)
self.register_buffer("residual", None)
self.reset_parameters()
self.activation = activation
self.bias = bias
if bias:
self.bias_param = nn.Parameter(th.zeros((1, num_heads, out_feats)))
self.alpha = alpha

def reset_parameters(self):
gain = nn.init.calculate_gain("relu")
if hasattr(self, "fc"):
nn.init.xavier_normal_(self.fc.weight, gain=gain)
else:
nn.init.xavier_normal_(self.fc_src.weight, gain=gain)
nn.init.xavier_normal_(self.fc_dst.weight, gain=gain)
nn.init.xavier_normal_(self.attn_l, gain=gain)
nn.init.xavier_normal_(self.attn_r, gain=gain)
nn.init.xavier_normal_(self.attn_e, gain=gain)
if isinstance(self.res_fc, nn.Linear):
nn.init.xavier_normal_(self.res_fc.weight, gain=gain)
nn.init.xavier_normal_(self.fc_e.weight, gain=gain)

def set_allow_zero_in_degree(self, set_value):
self._allow_zero_in_degree = set_value

def forward(self, graph, feat, e_feat, res_attn=None):
with graph.local_scope():
if not self._allow_zero_in_degree:
if (graph.in_degrees() == 0).any():
raise DGLError(
"There are 0-in-degree nodes in the graph, "
"output for those nodes will be invalid. "
"This is harmful for some applications, "
"causing silent performance regression. "
"Adding self-loop on the input graph by "
"calling `g = dgl.add_self_loop(g)` will resolve "
"the issue. Setting ``allow_zero_in_degree`` "
"to be `True` when constructing this module will "
"suppress the check and let the code run."
)

if isinstance(feat, tuple):
h_src = self.feat_drop(feat[0])
h_dst = self.feat_drop(feat[1])
if not hasattr(self, "fc_src"):
self.fc_src, self.fc_dst = self.fc, self.fc
feat_src = self.fc_src(h_src).view(-1, self._num_heads, self._out_feats)
feat_dst = self.fc_dst(h_dst).view(-1, self._num_heads, self._out_feats)
def reset(tensor):
stdv = math.sqrt(6.0 / (tensor.size(-2) + tensor.size(-1)))
tensor.data.uniform_(-stdv, stdv)

reset(self.a_l)
reset(self.a_r)
reset(self.a_e)
reset(self.W)
reset(self.W_e)
reset(self.edge_emb)

def forward(self, graph, x, res_attn=None):
x = self.feat_drop(x)
h = torch.matmul(x, self.W).view(-1, self.nhead, self.out_features)
h[torch.isnan(h)] = 0.0
e = torch.matmul(self.edge_emb, self.W_e).view(-1, self.nhead, self.edge_feats)

row, col = graph.edge_index
tp = graph.edge_type
# Self-attention on the nodes - Shared attention mechanism
h_l = (self.a_l * h).sum(dim=-1)[row]
h_r = (self.a_r * h).sum(dim=-1)[col]
h_e = (self.a_e * e).sum(dim=-1)[tp]
edge_attention = self.leakyrelu(h_l + h_r + h_e)
# edge_attention: E * H
edge_attention = mul_edge_softmax(graph, edge_attention)
edge_attention = self.dropout(edge_attention)
if res_attn is not None:
edge_attention = edge_attention * (1 - self.alpha) + res_attn * self.alpha

if check_mh_spmm() and next(self.parameters()).device.type != "cpu":
if self.nhead > 1:
h_prime = mh_spmm(graph, edge_attention, h)
out = h_prime.view(h_prime.shape[0], -1)
else:
h_src = h_dst = self.feat_drop(feat)
feat_src = feat_dst = self.fc(h_src).view(-1, self._num_heads, self._out_feats)
if graph.is_block:
feat_dst = feat_src[: graph.number_of_dst_nodes()]
e_feat = self.edge_emb(e_feat)
e_feat = self.fc_e(e_feat).view(-1, self._num_heads, self._edge_feats)
ee = (e_feat * self.attn_e).sum(dim=-1).unsqueeze(-1)
el = (feat_src * self.attn_l).sum(dim=-1).unsqueeze(-1)
er = (feat_dst * self.attn_r).sum(dim=-1).unsqueeze(-1)
graph.srcdata.update({"ft": feat_src, "el": el})
graph.dstdata.update({"er": er})
graph.edata.update({"ee": ee})
graph.apply_edges(fn.u_add_v("el", "er", "e"))
e = self.leaky_relu(graph.edata.pop("e") + graph.edata.pop("ee"))
# compute softmax
graph.edata["a"] = self.attn_drop(edge_softmax(graph, e))
if res_attn is not None:
graph.edata["a"] = graph.edata["a"] * (1 - self.alpha) + res_attn * self.alpha
# message passing
graph.update_all(fn.u_mul_e("ft", "a", "m"), fn.sum("m", "ft"))
rst = graph.dstdata["ft"]
# residual
if self.res_fc is not None:
resval = self.res_fc(h_dst).view(h_dst.shape[0], -1, self._out_feats)
rst = rst + resval
# bias
if self.bias:
rst = rst + self.bias_param
# activation
if self.activation:
rst = self.activation(rst)
return rst, graph.edata.pop("a").detach()
edge_weight = edge_attention.view(-1)
with graph.local_graph():
graph.edge_weight = edge_weight
out = spmm(graph, h.squeeze(1))
else:
with graph.local_graph():
h_prime = []
h = h.permute(1, 0, 2).contiguous()
for i in range(self.nhead):
edge_weight = edge_attention[:, i]
graph.edge_weight = edge_weight
hidden = h[i]
assert not torch.isnan(hidden).any()
h_prime.append(spmm(graph, hidden))
out = torch.cat(h_prime, dim=1)

if self.residual:
res = self.residual(x)
out += res
if self.act is not None:
out = self.act(out)
return out, edge_attention.detach()

def __repr__(self):
return self.__class__.__name__ + " (" + str(self.in_features) + " -> " + str(self.out_features) + ")"

31 changes: 11 additions & 20 deletions examples/simple_hgn/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
import scipy.sparse as sp
import torch
import torch.nn as nn
import dgl

from conv import myGATConv

Expand All @@ -11,6 +10,7 @@
from cogdl import experiment, options
from cogdl.models import BaseModel, register_model
from cogdl.utils import accuracy
from cogdl.data import Graph


@register_model("simple_hgn")
Expand Down Expand Up @@ -137,13 +137,6 @@ def __init__(
)
self.epsilon = torch.FloatTensor([1e-12]).to(self.device)

def list_to_sp_mat(self, edges, weights):
data = [x for x in weights]
i = [x for x in edges[0]]
j = [x for x in edges[1]]
total = max(max(i), max(j)) + 1
return sp.coo_matrix((data, (i, j)), shape=(total, total)).tocsr()

def build_g_feat(self, A):
edge2type = {}
edges = []
Expand All @@ -155,21 +148,21 @@ def build_g_feat(self, A):
edge2type[(u, v)] = k
edges = np.concatenate(edges, axis=1)
weights = np.concatenate(weights)
adjM = self.list_to_sp_mat(edges, weights)
g = dgl.DGLGraph(adjM)
g = dgl.remove_self_loop(g)
g = dgl.add_self_loop(g)
edges = torch.tensor(edges).to(self.device)
weights = torch.tensor(weights).to(self.device)

g = Graph(edge_index=edges, edge_weight=weights)
g = g.to(self.device)
e_feat = []
for u, v in zip(*g.edges()):
for u, v in zip(*g.edge_index):
u = u.cpu().item()
v = v.cpu().item()
e_feat.append(edge2type[(u, v)])
e_feat = torch.tensor(e_feat, dtype=torch.long).to(self.device)
g.edge_type = e_feat
self.g = g
self.e_feat = e_feat

def forward(self, A, X, target_x, target): # features_list, e_feat):
def forward(self, A, X, target_x, target):
# h = []
# for fc, feature in zip(self.fc_list, [X]):
# h.append(fc(feature))
Expand All @@ -178,11 +171,11 @@ def forward(self, A, X, target_x, target): # features_list, e_feat):
self.build_g_feat(A)
res_attn = None
for l in range(self.num_layers): # noqa E741
h, res_attn = self.gat_layers[l](self.g, h, self.e_feat, res_attn=res_attn)
h, res_attn = self.gat_layers[l](self.g, h, res_attn=res_attn)
h = h.flatten(1)
# output projection
logits, _ = self.gat_layers[-1](self.g, h, self.e_feat, res_attn=None)
logits = logits.mean(1)
logits, _ = self.gat_layers[-1](self.g, h, res_attn=None)
# logits = logits.mean(1)
# This is an equivalent replacement for tf.l2_normalize, see https://www.tensorflow.org/versions/r1.15/api_docs/python/tf/math/l2_normalize for more information.
logits = logits / (torch.max(torch.norm(logits, dim=1, keepdim=True), self.epsilon))
y = logits[target_x]
Expand All @@ -200,9 +193,7 @@ def evaluate(self, data, nodes, targets):


if __name__ == "__main__":
# CUDA_VISIBLE_DEVICES=0 python custom_gcn.py --seed 0 1 2 3 4 -t heterogeneous_node_classification -dt gtn-acm -m simple_hgn --lr 0.001
parser = options.get_training_parser()
args, _ = parser.parse_known_args()
args = options.parse_args_and_arch(parser, args)
experiment(task="heterogeneous_node_classification", dataset="gtn-acm", model="simple_hgn", args=args)
# experiment(task="node_classification", dataset="cora", model="mygcn")