Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Features] Add operators for manipulating node/edge features #266

Merged
merged 5 commits into from
Aug 6, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions cogdl/configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -380,6 +380,18 @@
"weight_decay": 0,
"max_epoch": 3000,
},
"revgcn": {
"general": {},
"cora": {
"hidden_size": 128,
"lr": 0.001,
"dropout": 0.4706458854,
"weight_decay": 0.0008907,
"norm": "layernorm",
"num_layers": 10,
# 81.40
},
},
},
"graph_classification": {
"gin": {
Expand Down
62 changes: 51 additions & 11 deletions cogdl/data/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
get_degrees,
)
from cogdl.utils import check_fast_spmm as check_indicator
from cogdl.utils import RandomWalker
from cogdl.operators.sample import sample_adj_c, subgraph_c


Expand Down Expand Up @@ -150,6 +151,7 @@ def set_weight(self, weight):
self.__symmetric__ = False

def get_weight(self, indicator=None):
"""If `indicator` is not None, the normalization will not be implemented"""
if self.weight is None or self.weight.shape[0] != self.col.shape[0]:
self.weight = torch.ones(self.num_edges, device=self.device)
weight = self.weight
Expand All @@ -168,19 +170,31 @@ def get_weight(self, indicator=None):
return weight

def add_remaining_self_loops(self):
edge_index, self.weight = add_remaining_self_loops((self.row, self.col), num_nodes=self.num_nodes)
self.row, self.col = edge_index
if self.attr is not None:
edge_index, weight_attr = add_remaining_self_loops(
(self.row, self.col), edge_weight=self.attr, fill_value=0, num_nodes=self.num_nodes
)
self.row, self.col = edge_index
self.attr = weight_attr
self.weight = torch.ones_like(self.row).float()
else:
edge_index, self.weight = add_remaining_self_loops(
(self.row, self.col), fill_value=1, num_nodes=self.num_nodes
)
self.row, self.col = edge_index
self.attr = None
self.row_ptr, reindex = coo2csr_index(self.row, self.col, num_nodes=self.num_nodes)
self.row = self.row[reindex]
self.col = self.col[reindex]
self.attr = None
# if self.attr is not None:

def remove_self_loops(self):
mask = self.row == self.col
inv_mask = ~mask
self.row = self.row[inv_mask]
self.col = self.col[inv_mask]
for item in self.__attr_keys__():
if self[item] is not None:
self[item] = self[item][inv_mask]

self.convert_csr()

Expand Down Expand Up @@ -370,6 +384,28 @@ def __attr_keys__(self):
def clone(self):
return Adjacency.from_dict({k: v.clone() for k, v in self})

def to_scipy_csr(self):
data = self.get_weight().cpu().numpy()
num_nodes = int(self.num_nodes)
if self.row_ptr is None:
row = self.row.cpu().numpy()
col = self.col.cpu().numpy()
mx = sp.csr_matrix((data, (row, col)), shape=(num_nodes, num_nodes))
else:
row_ptr = self.row_ptr.cpu().numpy()
col_ind = self.col.cpu().numpy()
mx = sp.csr_matrix((data, col_ind, row_ptr), shape=(num_nodes, num_nodes))
return mx

def random_walk(self, start, length=1, restart_p=0.0):
if not hasattr(self, "__walker__"):
scipy_adj = self.to_scipy_csr()
self.__walker__ = RandomWalker(scipy_adj)
return self.__walker__.walk(start, length, restart_p=restart_p)

def random_walk_with_restart(self, start, length, restart_p):
return self.random_walk(start, length, restart_p)

@staticmethod
def from_dict(dictionary):
r"""Creates a data object from a python dictionary."""
Expand Down Expand Up @@ -402,6 +438,9 @@ def is_read_adj_key(key):
class Graph(BaseGraph):
def __init__(self, x=None, y=None, **kwargs):
super(Graph, self).__init__()
if x is not None:
if not torch.is_tensor(x):
raise ValueError("Node features must be Tensor")
self.x = x
self.y = y

Expand Down Expand Up @@ -506,18 +545,16 @@ def test_nid(self):
return self.mask2nid("test")

@contextmanager
def local_graph(self, key=None):
def local_graph(self):
self.__temp_adj_stack__.append(self._adj)
if key is None:
adj = copy.copy(self._adj)
else:
adj = copy.copy(self._adj)
key = KEY_MAP.get(key, key)
adj[key] = self._adj[key].clone()
adj = copy.copy(self._adj)
others = [(key, val) for key, val in self.__dict__.items() if not key.startswith("__") and "adj" not in key]
self._adj = adj
yield
del adj
self._adj = self.__temp_adj_stack__.pop()
for key, val in others:
self[key] = val

@property
def edge_index(self):
Expand Down Expand Up @@ -801,6 +838,9 @@ def edge_subgraph(self, edge_idx, require_idx=True):
else:
return g

def to_scipy_csr(self):
return self._adj.to_scipy_csr()

@staticmethod
def from_dict(dictionary):
r"""Creates a data object from a python dictionary."""
Expand Down
9 changes: 8 additions & 1 deletion cogdl/data/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,7 +222,14 @@ def get(self, idx):
if self.slices is not None:
return self._get(idx)
return self.data[idx]
elif len(idx) > 1:
if isinstance(idx, slice):
start = idx.start
end = idx.stop
step = idx.step
idx = list(range(start, end, step))

if len(idx) > 1:
# unsupport `slice`
if self.slices is not None:
return [self._get(int(i)) for i in idx]
return [self.data[i] for i in idx]
Expand Down
2 changes: 1 addition & 1 deletion cogdl/datasets/customized_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ def __init__(self, path="cus_graph_data.pt", metric="accuracy"):
super(GraphDataset, self).__init__(root=path)
# try:
data = torch.load(path)
if hasattr(data[0], "y"):
if hasattr(data[0], "y") and data[0].y is None:
self.y = torch.cat([idata.y for idata in data])
self.data = data

Expand Down
7 changes: 5 additions & 2 deletions cogdl/datasets/ogb.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,8 +155,11 @@ def process(self):
x = x / deg.view(-1, 1)
data.x = x

data.node_species = torch.as_tensor(graph["node_species"])

node_species = torch.as_tensor(graph["node_species"])
n_species, new_index = torch.unique(node_species, return_inverse=True)
one_hot_x = torch.nn.functional.one_hot(new_index, num_classes=torch.max(new_index).int().item())
data.species = node_species
data.x = torch.cat([data.x, one_hot_x], dim=1)
torch.save(data, self.processed_paths[0])
return data

Expand Down
17 changes: 7 additions & 10 deletions cogdl/layers/deepergcn_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from torch.utils.checkpoint import checkpoint

from .mlp_layer import MLP
from cogdl.utils import get_activation, mul_edge_softmax, get_norm_layer
from cogdl.utils import get_activation, mul_edge_softmax, get_norm_layer, batch_max_pooling


class GENConv(nn.Module):
Expand Down Expand Up @@ -59,8 +59,6 @@ def __init__(
self.eps = 1e-7

self.s = torch.nn.Parameter(torch.Tensor([1.0]), requires_grad=learn_msg_scale and use_msg_norm)
self.act = None if activation is None else get_activation(activation)
self.norm = None if norm is None else get_norm_layer(norm, in_feats)
self.residual = residual

if edge_attr_size is not None and edge_attr_size[0] > 0:
Expand All @@ -78,10 +76,6 @@ def message_norm(self, x, msg):
return x + self.s * msg_norm

def forward(self, graph, x):
if self.norm is not None:
x = self.norm(x)
if self.act is not None:
x = self.act(x)
edge_index = graph.edge_index
dim = x.shape[1]
edge_msg = x[edge_index[1]]
Expand All @@ -105,9 +99,12 @@ def forward(self, graph, x):
deg_rev[torch.isinf(deg_rev)] = 0
h = edge_msg * deg_rev[edge_index[0]].unsqueeze(-1)
else:
raise NotImplementedError
h = edge_msg

h = torch.zeros_like(x).scatter_add_(dim=0, index=edge_index[0].unsqueeze(-1).repeat(1, dim), src=h)
if self.aggr == "max":
h = batch_max_pooling(h, edge_index[0])
else:
h = torch.zeros_like(x).scatter_add_(dim=0, index=edge_index[0].unsqueeze(-1).repeat(1, dim), src=h)
if self.aggr == "powermean":
h = h.pow(1.0 / self.p)
if self.use_msg_norm:
Expand Down Expand Up @@ -158,7 +155,7 @@ def __init__(
self.out_norm = get_norm_layer(norm, out_channels)
else:
self.out_norm = None
self.checkpoint_grad = False
self.checkpoint_grad = checkpoint_grad

def forward(self, graph, x, dropout=None, *args, **kwargs):
h = self.norm(x)
Expand Down
2 changes: 1 addition & 1 deletion cogdl/layers/reversible_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -492,7 +492,7 @@ def inverse(self, y, graph, *args):

# RevGNN BaseBlock
class RevGNNLayer(nn.Module):
def __init__(self, conv, group, norm=None):
def __init__(self, conv, group):
super(RevGNNLayer, self).__init__()
self.groups = nn.ModuleList()
for i in range(group):
Expand Down
7 changes: 6 additions & 1 deletion cogdl/layers/sage_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,10 @@ def __init__(self, in_feats, out_feats, normalize=False, aggr="mean", dropout=0.
self.out_feats = out_feats
self.fc = nn.Linear(2 * in_feats, out_feats)
self.normalize = normalize
self.dropout = dropout
if dropout > 0:
self.dropout = nn.Dropout(dropout)
else:
self.dropout = None
if aggr == "mean":
self.aggr = MeanAggregator()
elif aggr == "sum":
Expand All @@ -38,4 +41,6 @@ def forward(self, graph, x):
out = self.fc(out)
if self.normalize:
out = F.normalize(out, p=2.0, dim=-1)
if self.dropout:
out = self.dropout(out)
return out
7 changes: 1 addition & 6 deletions cogdl/models/nn/deepergcn.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,6 @@ def __init__(
self.feat_encoder = nn.Linear(in_feat, hidden_size)

self.layers = nn.ModuleList()
self.layers.append(GENConv(hidden_size, hidden_size))
for i in range(num_layers - 1):
self.layers.append(
ResGNNLayer(
Expand All @@ -87,7 +86,7 @@ def __init__(
in_channels=hidden_size,
activation=activation,
dropout=dropout,
checkpoint_grad=(num_layers > 3) and ((i + 1) == num_layers // 2),
checkpoint_grad=False,
)
)
self.norm = nn.BatchNorm1d(hidden_size, affine=True)
Expand All @@ -106,7 +105,3 @@ def forward(self, graph):

def predict(self, graph):
return self.forward(graph)

@staticmethod
def get_trainer(args):
return RandomClusterTrainer
16 changes: 13 additions & 3 deletions cogdl/models/nn/drgcn.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@ def add_args(parser):
parser.add_argument("--hidden-size", type=int, default=16)
parser.add_argument("--num-layers", type=int, default=2)
parser.add_argument("--dropout", type=float, default=0.5)
parser.add_argument("--norm", type=str, default=None)
parser.add_argument("--activation", type=str, default="relu")
# fmt: on

@classmethod
Expand All @@ -28,9 +30,11 @@ def build_model_from_args(cls, args):
args.hidden_size,
args.num_layers,
args.dropout,
args.norm,
args.activation,
)

def __init__(self, num_features, num_classes, hidden_size, num_layers, dropout):
def __init__(self, num_features, num_classes, hidden_size, num_layers, dropout, norm=None, activation="relu"):
super(DrGCN, self).__init__()

self.num_features = num_features
Expand All @@ -39,7 +43,13 @@ def __init__(self, num_features, num_classes, hidden_size, num_layers, dropout):
self.num_layers = num_layers
self.dropout = dropout
shapes = [num_features] + [hidden_size] * (num_layers - 1) + [num_classes]
self.convs = nn.ModuleList([GCNLayer(shapes[layer], shapes[layer + 1]) for layer in range(num_layers)])
self.convs = nn.ModuleList(
[
GCNLayer(shapes[layer], shapes[layer + 1], activation=activation, norm=norm)
for layer in range(num_layers - 1)
]
)
self.convs.append(GCNLayer(shapes[-2], shapes[-1]))
self.ses = nn.ModuleList(
[SELayer(shapes[layer], se_channels=int(np.sqrt(shapes[layer]))) for layer in range(num_layers)]
)
Expand All @@ -49,7 +59,7 @@ def forward(self, graph):
x = graph.x
x = self.ses[0](x)
for se, conv in zip(self.ses[1:], self.convs[:-1]):
x = F.relu(conv(graph, x))
x = conv(graph, x)
x = se(x)
x = F.dropout(x, p=self.dropout, training=self.training)
x = self.convs[-1](graph, x)
Expand Down
13 changes: 5 additions & 8 deletions cogdl/models/nn/gat.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ def add_args(parser):
parser.add_argument("--alpha", type=float, default=0.2)
parser.add_argument("--nhead", type=int, default=8)
parser.add_argument("--last-nhead", type=int, default=1)
parser.add_argument("--norm", type=str, default=None)
# fmt: on

@classmethod
Expand All @@ -48,6 +49,7 @@ def build_model_from_args(cls, args):
args.nhead,
args.residual,
args.last_nhead,
args.norm,
)

def __init__(
Expand All @@ -62,20 +64,14 @@ def __init__(
nhead,
residual,
last_nhead,
norm=None,
):
"""Sparse version of GAT."""
super(GAT, self).__init__()
self.dropout = dropout
self.attentions = nn.ModuleList()
self.attentions.append(
GATLayer(
in_feats,
hidden_size,
nhead=nhead,
attn_drop=attn_drop,
alpha=alpha,
residual=residual,
)
GATLayer(in_feats, hidden_size, nhead=nhead, attn_drop=attn_drop, alpha=alpha, residual=residual, norm=norm)
)
for i in range(num_layers - 2):
self.attentions.append(
Expand All @@ -86,6 +82,7 @@ def __init__(
attn_drop=attn_drop,
alpha=alpha,
residual=residual,
norm=norm,
)
)
self.attentions.append(
Expand Down
2 changes: 1 addition & 1 deletion cogdl/models/nn/ppnp.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ def get_ready_format(input, edge_index, edge_attr=None):
final_preds = F.dropout(self.vals) @ local_preds
else: # appnp
preds = local_preds
with graph.local_graph("edge_weight"):
with graph.local_graph():
graph.edge_weight = F.dropout(graph.edge_weight, p=self.dropout, training=self.training)
graph.set_symmetric()
for _ in range(self.niter):
Expand Down
Loading