Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Model] VRGCN example #305

Merged
merged 4 commits into from
Nov 10, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 24 additions & 0 deletions examples/VRGCN/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
# CogDL examples for ogbn-arxiv

CogDL implementation of VRGCN for [ogbn-arxiv](https://ogb.stanford.edu/docs/nodeprop/#ogbn-arxiv):

> Jianfei Chen, Jun Zhu, Le Song. Stochastic Training of Graph Convolutional Networks with Variance Reduction. [Paper in arXiv](https://arxiv.org/abs/1710.10568). In ICML'2018.

Requires CogDL 0.5-alpha0 or later versions.


## Training & Evaluation

```
# Run with model with default config
python main.py
```
For more hyper-parameters, please find them in the `main.py`.

## Results

Here are the results over 10 runs which are comparable with OGB official results reported in the leaderboard.

| Method | Test Accuracy | Validation Accuracy | #Parameters |
|:-------------------------------:|:---------------:|:-------------------:|:-----------:|
| VRGCN | 0.7224 ± 0.0042 | 0.7260 ± 0.0030 | 44,328 |
159 changes: 159 additions & 0 deletions examples/VRGCN/VRGCN.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,159 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import Tensor
from cogdl.utils import spmm
from cogdl.data import Graph


class History(torch.nn.Module):
r"""A historical embedding storage module."""
def __init__(self, num_embeddings: int, embedding_dim: int):
super().__init__()
self.num_embeddings = num_embeddings
self.embedding_dim = embedding_dim
self.emb = torch.empty(num_embeddings, embedding_dim, device='cpu',
pin_memory=True)
self._device = torch.device('cpu')
self.reset_parameters()

def reset_parameters(self):
self.emb.fill_(0)

def _apply(self, fn):
self._device = fn(torch.zeros(1)).device
return self

@torch.no_grad()
def pull(self, n_id=None):
out = self.emb
if n_id is not None:
assert n_id.device == self.emb.device
out = out.index_select(0, n_id)
return out.to(device=self._device)

@torch.no_grad()
def push(self, x, n_id=None):
assert n_id.device == self.emb.device
self.emb[n_id] = x.to(self.emb.device)

def forward(self, *args, **kwargs):
raise NotImplementedError


class VRGCN(torch.nn.Module):
def __init__(self, num_nodes: int, in_channels, hidden_channels: int,
out_channels: int, num_layers: int, dropout: float = 0.0,
residual: bool = False, device=None):
super().__init__()

self.in_channels = in_channels
self.out_channels = out_channels
self.dropout = dropout
self.residual = residual
self.num_layers = num_layers

self.lins = nn.ModuleList()
self.lins.append(nn.Linear(in_channels, hidden_channels))
for i in range(num_layers - 2):
self.lins.append(nn.Linear(hidden_channels, hidden_channels))
self.lins.append(nn.Linear(hidden_channels, out_channels))

self.norms = nn.ModuleList()
for i in range(num_layers):
norm = nn.LayerNorm(hidden_channels)
self.norms.append(norm)

self.histories = torch.nn.ModuleList([
History(num_nodes, hidden_channels) if l != 0 else History(num_nodes, in_channels)
for l in range(num_layers)
])

self._device = device

def reset_parameters(self):
for history in self.histories:
history.reset_parameters()
for lin in self.lins:
lin.reset_parameters()
for norm in self.norms:
norm.reset_parameters()

def forward(self, x, sample_ids_adjs, full_ids_adjs) -> Tensor:
sample_ids, sample_adjs = sample_ids_adjs
full_ids, full_adjs = full_ids_adjs

"""VR-GCN"""
x = x[sample_ids[0]].to(self._device)
x_list = []
for i in range(self.num_layers):
sample_adj, cur_id, target_id = sample_adjs[i], sample_ids[i], sample_ids[i + 1]
full_id, full_adj = full_ids[i], full_adjs[i]
full_adj = full_adj.to(x.device)
sample_adj = sample_adj.to(x.device)

x = x - self.histories[i].pull(cur_id).detach()
h = self.histories[i].pull(full_id)

x = self.slow_spmm(sample_adj, x)[:target_id.shape[0]] + \
self.slow_spmm(full_adj, h)[:target_id.shape[0]].detach()
x = self.lins[i](x)

if i != self.num_layers - 1:
x = self.norms[i](x)
x = x.relu_()
x_list.append(x)
x = F.dropout(x, p=self.dropout, training=self.training)

"""history embedding update"""
for i in range(1, self.num_layers):
self.histories[i].push(x_list[i - 1].detach(), sample_ids[i])
return x.log_softmax(dim=-1)

def slow_spmm(self, graph, x):
row, col = graph.edge_index
values = graph.edge_weight
output = x.index_select(0, col) * values.unsqueeze(-1)
output = torch.zeros_like(x).scatter_add_(0, row.unsqueeze(-1).expand_as(output), output)
return output

def initialize_history(self, x, test_loader):
_, xs = self.inference_batch(x, test_loader)
for i in range(self.num_layers):
self.histories[i].push(xs[i].detach(), torch.arange(0, self.histories[i].num_embeddings))

@torch.no_grad()
def inference(self, x, adj):
x = x.to(self._device)
origin_device = adj.device
adj = adj.to(self._device)
xs = [x]
for i in range(self.num_layers):
x = self.slow_spmm(adj, x)
x = self.lins[i](x)
if i != self.num_layers - 1:
x = self.norms[i](x)
x = x.relu_()
xs.append(x)
adj = adj.to(origin_device)
return x, xs

@torch.no_grad()
def inference_batch(self, x, test_loader):
device = self._device
xs = [x]
node_list = torch.arange(0, x.shape[0])
for i in range(self.num_layers):
tmp_x = []
for target_id, full_id, full_adj in test_loader:
full_adj = full_adj.to(device)
agg_x = self.slow_spmm(full_adj, x[full_id].to(device))[:target_id.shape[0]]
agg_x = self.lins[i](agg_x)

if i != self.num_layers - 1:
agg_x = self.norms[i](agg_x)
agg_x = agg_x.relu_()
tmp_x.append(agg_x.cpu())
x = torch.cat(tmp_x, dim=0)
xs.append(x)
return x, xs
84 changes: 84 additions & 0 deletions examples/VRGCN/dataloder.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
import torch
from cogdl.data import Graph
import copy


class PseudoRanger(torch.utils.data.Dataset):
def __init__(self, num):
self.indices = torch.arange(num)
self.num = num

def __getitem__(self, item):
return self.indices[item]

def __len__(self):
return self.num

def shuffle(self):
rand = torch.randperm(self.num)
self.indices = self.indices[rand]


class AdjSampler(torch.utils.data.DataLoader):
def __init__(self, graph, sizes=[2, 2], training=True, *args, **kwargs):

self.graph = copy.deepcopy(graph)
self.sizes = sizes
self.degree = graph.degrees()
self.diag = self._sparse_diagonal_value(graph)
self.training = training
if training:
idx = torch.where(graph['train_mask'])[0]
else:
idx = torch.arange(0, graph.x.shape[0])
self.dataset = PseudoRanger(idx.shape[0])

kwargs["collate_fn"] = self.collate_fn
super(AdjSampler, self).__init__(self.dataset, *args, **kwargs)

def shuffle(self):
self.dataset.shuffle()

def _sparse_diagonal_value(self, adj):
row, col = adj.edge_index
value = adj.edge_weight
return value[row == col]

def _construct_propagation_matrix(self, sample_adj, sample_id, num_neighbors):
row, col = sample_adj.edge_index
value = sample_adj.edge_weight
"""add self connection"""
num_row = row.max() + 1
row = torch.cat([torch.arange(0, num_row).long(), row], dim=0)
col = torch.cat([torch.arange(0, num_row).long(), col], dim=0)
value = torch.cat([self.diag[sample_id[:num_row]], value], dim=0)

value = value * self.degree[sample_id[row]] / num_neighbors
new_graph = Graph()
new_graph.edge_index = torch.stack([row, col])
new_graph.edge_weight = value
return new_graph

def collate_fn(self, idx):
if self.training:
sample_id = torch.tensor(idx)
sample_adjs, sample_ids = [], [sample_id]
full_adjs, full_ids = [], []

for size in self.sizes:
full_id, full_adj = self.graph.sample_adj(sample_id, -1)
sample_id, sample_adj = self.graph.sample_adj(sample_id, size, replace=False)

sample_adj = self._construct_propagation_matrix(sample_adj, sample_id, size)

sample_adjs = [sample_adj] + sample_adjs
sample_ids = [sample_id] + sample_ids
full_adjs = [full_adj] + full_adjs
full_ids = [full_id] + full_ids

return torch.tensor(idx), (sample_ids, sample_adjs), (full_ids, full_adjs)
else:
# only return full adj in Evalution phase
sample_id = torch.tensor(idx)
full_id, full_adj = self.graph.sample_adj(sample_id, -1)
return sample_id, full_id, full_adj
105 changes: 105 additions & 0 deletions examples/VRGCN/main.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
import argparse
from dataloder import AdjSampler
from cogdl.datasets.ogb import OGBArxivDataset
from VRGCN import VRGCN

def get_parser():
parser = argparse.ArgumentParser(description="OGBN-Arxiv (CogDL GNNs)")
parser.add_argument("--num-layers", type=int, default=2)
parser.add_argument("--num-neighbors", type=list, default=[2, 2])
parser.add_argument("--hidden-size", type=int, default=256)
parser.add_argument("--batch-size", type=int, default=2048)
parser.add_argument("--dropout", type=float, default=0.0)
parser.add_argument("--lr", type=float, default=0.001)
parser.add_argument("--weight-decay", type=float, default=1e-5)
parser.add_argument("--epochs", type=int, default=100)
parser.add_argument("--runs", type=int, default=10)
args = parser.parse_args()
return args

args = get_parser()

dataset = OGBArxivDataset(data_path='data/')
data = dataset.data
data.add_remaining_self_loops()
data.set_symmetric()
# data.sym_norm()

evaluator = dataset.get_evaluator()
train_loader = AdjSampler(data, sizes=args.num_neighbors, batch_size=args.batch_size, shuffle=True)
test_loader = AdjSampler(data, sizes=[-1], batch_size=args.batch_size, shuffle=False, training=False)

device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
model = VRGCN(num_nodes=data.x.shape[0], in_channels=dataset.num_features,
hidden_channels=args.hidden_size,
out_channels=dataset.num_classes,
dropout=args.dropout,
num_layers=args.num_layers, device=device).to(device)
model.reset_parameters()

x = data.x
y = data.y.squeeze().to(device)

def train(epoch):
model.train()
total_loss = total_correct = 0
for batch, sample_ids_adjs, full_ids_adjs in train_loader:
optimizer.zero_grad()
out = model(x, sample_ids_adjs, full_ids_adjs)
loss = F.nll_loss(out, y[batch])
loss.backward()
optimizer.step()
total_loss += float(loss)
total_correct += int(out.argmax(dim=-1).eq(y[batch]).sum())

loss = total_loss / len(train_loader)
approx_acc = total_correct / torch.where(data['train_mask'])[0].size(0)
return loss, approx_acc


@torch.no_grad()
def test():
model.eval()

# out, _ = model.inference(x, data)
out, _ = model.inference_batch(x, test_loader)

y_true = y.cpu()
y_pred = out.cpu()

train_acc = evaluator(y_pred[data['train_mask']], y_true[data['train_mask']])
val_acc = evaluator(y_pred[data['val_mask']], y_true[data['val_mask']])
test_acc = evaluator(y_pred[data['test_mask']], y_true[data['test_mask']])
return train_acc, val_acc, test_acc


test_accs = []
for run in range(args.runs):
model.reset_parameters()
model.eval()
model.initialize_history(x, test_loader)
optimizer = torch.optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay)

best_val_acc = final_test_acc = 0
for epoch in range(1, args.epochs):
loss, acc = train(epoch)
if epoch % 1 == 0:
train_acc, val_acc, test_acc = test()
print(f'Run: {run + 1:02d}, '
f'Epoch: {epoch:02d}, '
f'Loss: {loss:.4f}, '
f'Train: {100 * train_acc:.2f}%, '
f'Valid: {100 * val_acc:.2f}% '
f'Test: {100 * test_acc:.2f}%')

if val_acc > best_val_acc:
best_val_acc = val_acc
final_test_acc = test_acc
test_accs.append(final_test_acc)

test_acc = torch.tensor(test_accs)
print('============================')
print(f'Final Test: {test_acc.mean():.4f} ± {test_acc.std():.4f}')