From 405788cf3b62fa6ea044fa9873c0b6f3f09fcec1 Mon Sep 17 00:00:00 2001 From: Yukuo Cen Date: Tue, 23 Mar 2021 19:45:08 +0800 Subject: [PATCH 1/5] Add distributed trainer --- cogdl/data/sampler.py | 57 ++++++++- cogdl/trainers/__init__.py | 1 + cogdl/trainers/distributed_trainer.py | 167 ++++++++++++++++++++++++++ cogdl/trainers/sampled_trainer.py | 2 +- 4 files changed, 225 insertions(+), 2 deletions(-) create mode 100644 cogdl/trainers/distributed_trainer.py diff --git a/cogdl/data/sampler.py b/cogdl/data/sampler.py index 1e9d04b7..7e346989 100644 --- a/cogdl/data/sampler.py +++ b/cogdl/data/sampler.py @@ -381,6 +381,61 @@ def sample(self, batch): return batch, node_id, adj_list[::-1] +class ClusteredDataset(torch.utils.data.Dataset): + partition_tool = None + + def __init__(self, data: Data, n_cluster: int, batch_size: int, log=False): + super(ClusteredDataset).__init__() + try: + import metis + + ClusteredDataset.partition_tool = metis + except Exception as e: + print(e) + exit(1) + + self.data = data + self.batch_size = batch_size + self.log = log + self.clusters = self.preprocess(n_cluster) + self.batch_idx = np.array(range(n_cluster)) + + def shuffle(self): + random.shuffle(self.batch_idx) + + def __len__(self): + return (len(self.clusters) - 1) // self.batch_size + 1 + + def __getitem__(self, idx): + batch = self.batch_idx[idx * self.batch_size : (idx + 1) * self.batch_size] + nodes = np.concatenate([self.clusters[i] for i in batch]) + subgraph = self.data.subgraph(nodes) + + return subgraph + + def preprocess(self, n_cluster): + if self.log: + print("Preprocessing...") + edges = self.data.edge_index + edges, _ = remove_self_loops(edges) + if str(edges.device) != "cpu": + edges = edges.cpu() + edges = edges.numpy() + num_nodes = np.max(edges) + 1 + adj = sp.csr_matrix((np.ones(edges.shape[1]), (edges[0], edges[1])), shape=(num_nodes, num_nodes)) + indptr = adj.indptr + indptr = np.split(adj.indices, indptr[1:])[:-1] + _, parts = ClusteredDataset.partition_tool.part_graph(indptr, n_cluster, seed=1) + division = [[] for _ in range(n_cluster)] + for i, v in enumerate(parts): + division[v].append(i) + for k in range(len(division)): + division[k] = np.array(division[k], dtype=np.int) + if self.log: + print("Graph clustering done") + return division + + class ClusteredLoader(torch.utils.data.DataLoader): partition_tool = None @@ -414,7 +469,7 @@ def preprocess(self, n_cluster): division[v].append(i) for k in range(len(division)): division[k] = np.array(division[k], dtype=np.int) - print("Graph clustering over") + print("Graph clustering done") return division def batcher(self, batch): diff --git a/cogdl/trainers/__init__.py b/cogdl/trainers/__init__.py index 528c161c..6e6d7b89 100644 --- a/cogdl/trainers/__init__.py +++ b/cogdl/trainers/__init__.py @@ -55,4 +55,5 @@ def build_trainer(args): "clustergcn": "cogdl.trainers.sampled_trainer", "random_partition": "cogdl.trainers.sampled_trainer", "self_auxiliary_task": "cogdl.trainers.self_auxiliary_task_trainer", + "distributed_clustergcn": "cogdl.trainers.distributed_trainer", } diff --git a/cogdl/trainers/distributed_trainer.py b/cogdl/trainers/distributed_trainer.py new file mode 100644 index 00000000..2343adfe --- /dev/null +++ b/cogdl/trainers/distributed_trainer.py @@ -0,0 +1,167 @@ +import argparse +import copy +import os + +import numpy as np +import torch +import torch.distributed as dist +import torch.multiprocessing as mp +from torch.multiprocessing import Process +from torch.nn.parallel import DistributedDataParallel as DDP +from tqdm import tqdm + +from cogdl.data.sampler import ClusteredDataset +from cogdl.trainers.base_trainer import BaseTrainer +from . import register_trainer + + +def train_step(model, data_loader, optimizer, device): + model.train() + for batch in data_loader: + batch = batch.to(device) + optimizer.zero_grad() + model.module.node_classification_loss(batch).backward() + optimizer.step() + + +def test_step(model, data, evaluator, loss_fn): + model.eval() + model = model.cpu() + masks = {"train": data.train_mask, "val": data.val_mask, "test": data.test_mask} + with torch.no_grad(): + logits = model.predict(data) + loss = {key: loss_fn(logits[val], data.y[val]) for key, val in masks.items()} + metric = {key: evaluator(logits[val], data.y[val]) for key, val in masks.items()} + return metric, loss + + +def batcher(data): + return data[0] + + +def train(model, data, args, rank, evaluator, loss_fn): + print(f"Running on rank {rank}.") + + os.environ["MASTER_ADDR"] = "localhost" + os.environ["MASTER_PORT"] = str(args.master_port) + + # initialize the process group + dist.init_process_group("nccl", rank=rank, world_size=args.world_size) + + model = copy.deepcopy(model).to(rank) + model = DDP(model, device_ids=[rank]) + + train_dataset = ClusteredDataset(data, args.n_cluster, args.batch_size, log=(rank == 0)) + + train_sampler = torch.utils.data.distributed.DistributedSampler( + train_dataset, num_replicas=args.world_size, rank=rank + ) + + train_loader = torch.utils.data.DataLoader( + dataset=train_dataset, + batch_size=1, + shuffle=False, + num_workers=4, + pin_memory=False, + sampler=train_sampler, + persistent_workers=True, + collate_fn=batcher, + ) + + optimizer = torch.optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay) + + epoch_iter = tqdm(range(args.max_epoch)) if rank == 0 else range(args.max_epoch) + patience = 0 + max_score = 0 + min_loss = np.inf + best_model = None + for epoch in epoch_iter: + train_dataset.shuffle() + train_step(model, train_loader, optimizer, rank) + if (epoch + 1) % args.eval_step == 0: + if rank == 0: + acc, loss = test_step(model.module, data, evaluator, loss_fn) + train_acc = acc["train"] + val_acc = acc["val"] + val_loss = loss["val"] + epoch_iter.set_description(f"Epoch: {epoch:03d}, Train: {train_acc:.4f}, Val: {val_acc:.4f}") + model = model.to(rank) + object_list = [val_loss, val_acc] + else: + object_list = [None, None] + dist.broadcast_object_list(object_list, src=0) + val_loss, val_acc = object_list + + if val_loss <= min_loss or val_acc >= max_score: + if val_loss <= min_loss: + best_model = copy.deepcopy(model) + min_loss = np.min((min_loss, val_loss)) + max_score = np.max((max_score, val_acc)) + patience = 0 + else: + patience += 1 + if patience == args.patience: + break + dist.barrier() + + if rank == 0: + os.makedirs("./checkpoints", exist_ok=True) + checkpoint_path = os.path.join("./checkpoints", f"{args.model}_{args.dataset}.pt") + if best_model is not None: + print(f"Saving model to {checkpoint_path}") + torch.save(best_model.module.state_dict(), checkpoint_path) + + dist.barrier() + + dist.destroy_process_group() + + +@register_trainer("distributed_clustergcn") +class DistributedClusterGCNTrainer(BaseTrainer): + @staticmethod + def add_args(parser: argparse.ArgumentParser): + """Add trainer-specific arguments to the parser.""" + # fmt: off + parser.add_argument("--n-cluster", type=int, default=1000) + parser.add_argument("--batch-size", type=int, default=20) + parser.add_argument("--eval-step", type=int, default=10) + parser.add_argument("--world-size", type=int, default=2) + parser.add_argument("--master-port", type=int, default=13579) + # fmt: on + + @classmethod + def build_trainer_from_args(cls, args): + return cls(args) + + def __init__(self, args): + self.args = args + + def fit(self, model, dataset): + mp.set_start_method("spawn", force=True) + + data = dataset[0] + model = model.cpu() + + evaluator = dataset.get_evaluator() + loss_fn = dataset.get_loss_fn() + + device_count = torch.cuda.device_count() + if device_count < self.args.world_size: + size = device_count + print(f"Available device count ({device_count}) is less than world size ({self.args.world_size})") + else: + size = self.args.world_size + + processes = [] + for rank in range(size): + p = Process(target=train, args=(model, data, self.args, rank, evaluator, loss_fn)) + p.start() + processes.append(p) + + for p in processes: + p.join() + + model.load_state_dict(torch.load(os.path.join("./checkpoints", f"{self.args.model}_{self.args.dataset}.pt"))) + metric, loss = test_step(model, data, evaluator, loss_fn) + + return dict(Acc=metric["test"], ValAcc=metric["val"]) diff --git a/cogdl/trainers/sampled_trainer.py b/cogdl/trainers/sampled_trainer.py index 6981432e..e67f68a9 100644 --- a/cogdl/trainers/sampled_trainer.py +++ b/cogdl/trainers/sampled_trainer.py @@ -35,7 +35,7 @@ def __init__(self, args): self.weight_decay = args.weight_decay self.loss_fn, self.evaluator = None, None self.data, self.train_loader, self.optimizer = None, None, None - self.eval_step = 1 + self.eval_step = args.eval_step @classmethod def build_trainer_from_args(cls, args): From 91f109b08fa8f345b1ba1c3ceb804a689b60b108 Mon Sep 17 00:00:00 2001 From: Yukuo Cen Date: Wed, 24 Mar 2021 17:07:46 +0800 Subject: [PATCH 2/5] Split operators --- cogdl/data/data.py | 2 +- cogdl/operators/sample.py | 18 ++++++++++++++++++ cogdl/operators/{operators.py => spmm.py} | 18 +----------------- cogdl/utils/utils.py | 4 ++-- 4 files changed, 22 insertions(+), 20 deletions(-) create mode 100644 cogdl/operators/sample.py rename cogdl/operators/{operators.py => spmm.py} (80%) diff --git a/cogdl/data/data.py b/cogdl/data/data.py index 5b9d86b4..6453d1da 100644 --- a/cogdl/data/data.py +++ b/cogdl/data/data.py @@ -15,7 +15,7 @@ fast_spmm, get_degrees, ) -from cogdl.operators.operators import sample_adj_c, subgraph_c +from cogdl.operators.sample import sample_adj_c, subgraph_c indicator = fast_spmm is None diff --git a/cogdl/operators/sample.py b/cogdl/operators/sample.py new file mode 100644 index 00000000..b2c84946 --- /dev/null +++ b/cogdl/operators/sample.py @@ -0,0 +1,18 @@ +import os +from torch.utils.cpp_extension import load + +path = os.path.join(os.path.dirname(__file__)) + +# subgraph and sample_adj +try: + sample = load(name="sampler", sources=[os.path.join(path, "sample/sample.cpp")], verbose=False) + subgraph_c = sample.subgraph + sample_adj_c = sample.sample_adj + coo2csr_cpu = sample.coo2csr_cpu + coo2csr_cpu_index = sample.coo2csr_cpu_index +except Exception as e: + print(e) + subgraph_c = None + sample_adj_c = None + coo2csr_cpu_index = None + coo2csr_cpu = None diff --git a/cogdl/operators/operators.py b/cogdl/operators/spmm.py similarity index 80% rename from cogdl/operators/operators.py rename to cogdl/operators/spmm.py index efb41ca4..80c4212b 100644 --- a/cogdl/operators/operators.py +++ b/cogdl/operators/spmm.py @@ -2,24 +2,8 @@ import torch from torch.utils.cpp_extension import load -path = os.path.join(os.path.dirname(os.path.dirname(__file__)), "operators") +path = os.path.join(os.path.dirname(__file__)) -# subgraph and sample_adj -try: - sample = load(name="sampler", sources=[os.path.join(path, "sample/sample.cpp")], verbose=False) - subgraph_c = sample.subgraph - sample_adj_c = sample.sample_adj - coo2csr_cpu = sample.coo2csr_cpu - coo2csr_cpu_index = sample.coo2csr_cpu_index -except Exception as e: - print(e) - subgraph_c = None - sample_adj_c = None - coo2csr_cpu_index = None - coo2csr_cpu = None - - -# SPMM if not torch.cuda.is_available(): spmm = None else: diff --git a/cogdl/utils/utils.py b/cogdl/utils/utils.py index ff285940..8a68a20b 100644 --- a/cogdl/utils/utils.py +++ b/cogdl/utils/utils.py @@ -14,7 +14,7 @@ import torch.nn.functional as F from tabulate import tabulate -from cogdl.operators.operators import coo2csr_cpu, coo2csr_cpu_index +from cogdl.operators.sample import coo2csr_cpu, coo2csr_cpu_index class ArgClass(object): @@ -222,7 +222,7 @@ def spmm_adj(indices, values, x, num_nodes=None): def initialize_spmm(args): if hasattr(args, "fast_spmm") and args.fast_spmm is True: try: - from cogdl.operators.operators import csrspmm + from cogdl.operators.spmm import csrspmm global fast_spmm fast_spmm = csrspmm From 199d4edd6cc182fa9b5ea72f610116d8d4b5b34a Mon Sep 17 00:00:00 2001 From: Yukuo Cen Date: Wed, 24 Mar 2021 19:20:07 +0800 Subject: [PATCH 3/5] Update ddp dataloader --- cogdl/data/data.py | 6 +++--- cogdl/trainers/distributed_trainer.py | 5 ++--- 2 files changed, 5 insertions(+), 6 deletions(-) diff --git a/cogdl/data/data.py b/cogdl/data/data.py index 6453d1da..ec95e4eb 100644 --- a/cogdl/data/data.py +++ b/cogdl/data/data.py @@ -12,12 +12,12 @@ remove_self_loops, symmetric_normalization, row_normalization, - fast_spmm, get_degrees, ) -from cogdl.operators.sample import sample_adj_c, subgraph_c +from cogdl.operators.sample import sample_adj_c -indicator = fast_spmm is None +indicator = False +subgraph_c = None class BaseGraph(object): diff --git a/cogdl/trainers/distributed_trainer.py b/cogdl/trainers/distributed_trainer.py index 62424e92..2dd6608e 100644 --- a/cogdl/trainers/distributed_trainer.py +++ b/cogdl/trainers/distributed_trainer.py @@ -63,10 +63,9 @@ def train(model, dataset, args, rank, evaluator, loss_fn): dataset=train_dataset, batch_size=1, shuffle=False, - num_workers=4, - pin_memory=False, + num_workers=0, + pin_memory=True, sampler=train_sampler, - persistent_workers=True, collate_fn=batcher, ) From 3123af60c1d7d46918701cd1b32c451e20d2d811 Mon Sep 17 00:00:00 2001 From: Yukuo Cen Date: Tue, 30 Mar 2021 16:07:14 +0800 Subject: [PATCH 4/5] Add distributed saint trainer --- cogdl/data/sampler.py | 52 ++++++++++++++-- cogdl/trainers/__init__.py | 2 +- cogdl/trainers/distributed_trainer.py | 85 +++++++++++++++++++++------ cogdl/trainers/sampled_trainer.py | 1 + 4 files changed, 117 insertions(+), 23 deletions(-) diff --git a/cogdl/data/sampler.py b/cogdl/data/sampler.py index 92c71004..24a78312 100644 --- a/cogdl/data/sampler.py +++ b/cogdl/data/sampler.py @@ -168,10 +168,7 @@ def one_batch(self, phase, require_norm=True): return data def exists_train_nodes(self, node_idx): - for idx in node_idx: - if self.train_mask[idx]: - return True - return False + return self.train_mask[node_idx].any().item() def node_induction(self, node_idx): node_idx = np.unique(node_idx) @@ -201,6 +198,53 @@ def sample(self): pass +class SAINTDataset(torch.utils.data.Dataset): + partition_tool = None + + def __init__(self, dataset, args_sampler, require_norm=True, log=False): + super(SAINTDataset).__init__() + + self.data = dataset.data + self.dataset_name = dataset.__class__.__name__ + self.args_sampler = args_sampler + self.require_norm = require_norm + self.log = log + + if self.args_sampler["sampler"] == "node": + self.sampler = NodeSampler(self.data, self.args_sampler) + elif self.args_sampler["sampler"] == "edge": + self.sampler = EdgeSampler(self.data, self.args_sampler) + elif self.args_sampler["sampler"] == "rw": + self.sampler = RWSampler(self.data, self.args_sampler) + elif self.args_sampler["sampler"] == "mrw": + self.sampler = MRWSampler(self.data, self.args_sampler) + else: + raise NotImplementedError + + self.batch_idx = np.array(range(len(self.sampler.subgraph_data))) + + def shuffle(self): + random.shuffle(self.batch_idx) + + def __len__(self): + return len(self.sampler.subgraph_data) + + def __getitem__(self, idx): + new_idx = self.batch_idx[idx] + data = self.sampler.subgraph_data[new_idx] + node_idx = self.sampler.subgraph_node_idx[new_idx] + edge_idx = self.sampler.subgraph_edge_idx[new_idx] + + if self.require_norm: + data.norm_aggr = torch.FloatTensor(self.sampler.norm_aggr_train[edge_idx][:]) + data.norm_loss = self.sampler.norm_loss_train[node_idx] + + edge_weight = row_normalization(data.x.shape[0], data.edge_index) + data.edge_weight = edge_weight + + return data + + class NodeSampler(SAINTSampler): r""" randomly select nodes, then adding edges connecting these nodes diff --git a/cogdl/trainers/__init__.py b/cogdl/trainers/__init__.py index 6e6d7b89..91ab816a 100644 --- a/cogdl/trainers/__init__.py +++ b/cogdl/trainers/__init__.py @@ -55,5 +55,5 @@ def build_trainer(args): "clustergcn": "cogdl.trainers.sampled_trainer", "random_partition": "cogdl.trainers.sampled_trainer", "self_auxiliary_task": "cogdl.trainers.self_auxiliary_task_trainer", - "distributed_clustergcn": "cogdl.trainers.distributed_trainer", + "distributed_trainer": "cogdl.trainers.distributed_trainer", } diff --git a/cogdl/trainers/distributed_trainer.py b/cogdl/trainers/distributed_trainer.py index 2dd6608e..a332dec1 100644 --- a/cogdl/trainers/distributed_trainer.py +++ b/cogdl/trainers/distributed_trainer.py @@ -2,6 +2,7 @@ import copy import os +import time import numpy as np import torch import torch.distributed as dist @@ -10,7 +11,7 @@ from torch.nn.parallel import DistributedDataParallel as DDP from tqdm import tqdm -from cogdl.data.sampler import ClusteredDataset +from cogdl.data.sampler import ClusteredDataset, SAINTDataset from cogdl.trainers.base_trainer import BaseTrainer from . import register_trainer @@ -35,10 +36,66 @@ def test_step(model, data, evaluator, loss_fn): return metric, loss -def batcher(data): +def batcher_clustergcn(data): return data[0] +def batcher_saint(data): + return data[0] + + +def sampler_from_args(args): + args_sampler = { + "sampler": args.sampler, + "sample_coverage": args.sample_coverage, + "size_subgraph": args.size_subgraph, + "num_walks": args.num_walks, + "walk_length": args.walk_length, + "size_frontier": args.size_frontier, + } + return args_sampler + + +def get_train_loader(dataset, args, rank): + if args.sampler == "clustergcn": + train_dataset = ClusteredDataset(dataset, args.n_cluster, args.batch_size, log=(rank == 0)) + + train_sampler = torch.utils.data.distributed.DistributedSampler( + train_dataset, num_replicas=args.world_size, rank=rank + ) + + train_loader = torch.utils.data.DataLoader( + dataset=train_dataset, + batch_size=1, + shuffle=False, + num_workers=4, + # pin_memory=True, + sampler=train_sampler, + persistent_workers=True, + collate_fn=batcher_clustergcn, + ) + elif args.sampler in ["node", "edge", "rw", "mrw"]: + train_dataset = SAINTDataset(dataset, sampler_from_args(args), log=(rank == 0)) + + train_sampler = torch.utils.data.distributed.DistributedSampler( + train_dataset, num_replicas=args.world_size, rank=rank + ) + + train_loader = torch.utils.data.DataLoader( + dataset=train_dataset, + batch_size=1, + shuffle=False, + num_workers=0, + pin_memory=True, + sampler=train_sampler, + collate_fn=batcher_saint, + ) + else: + raise NotImplementedError(f"{args.trainer} is not implemented.") + + return train_dataset, train_loader + + def train(model, dataset, args, rank, evaluator, loss_fn): print(f"Running on rank {rank}.") @@ -53,21 +110,7 @@ def train(model, dataset, args, rank, evaluator, loss_fn): data = dataset[0] - train_dataset = ClusteredDataset(dataset, args.n_cluster, args.batch_size, log=(rank == 0)) - - train_sampler = torch.utils.data.distributed.DistributedSampler( - train_dataset, num_replicas=args.world_size, rank=rank - ) - - train_loader = torch.utils.data.DataLoader( - dataset=train_dataset, - batch_size=1, - shuffle=False, - num_workers=0, - pin_memory=True, - sampler=train_sampler, - collate_fn=batcher, - ) + train_dataset, train_loader = get_train_loader(dataset, args, rank) optimizer = torch.optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay) @@ -117,7 +160,7 @@ def train(model, dataset, args, rank, evaluator, loss_fn): dist.destroy_process_group() -@register_trainer("distributed_clustergcn") +@register_trainer("distributed_trainer") class DistributedClusterGCNTrainer(BaseTrainer): @staticmethod def add_args(parser: argparse.ArgumentParser): @@ -127,6 +170,12 @@ def add_args(parser: argparse.ArgumentParser): parser.add_argument("--batch-size", type=int, default=20) parser.add_argument("--eval-step", type=int, default=10) parser.add_argument("--world-size", type=int, default=2) + parser.add_argument("--sampler", type=str, default="clustergcn") + parser.add_argument('--sample-coverage', default=20, type=float, help='sample coverage ratio') + parser.add_argument('--size-subgraph', default=1200, type=int, help='subgraph size') + parser.add_argument('--num-walks', default=50, type=int, help='number of random walks') + parser.add_argument('--walk-length', default=20, type=int, help='random walk length') + parser.add_argument('--size-frontier', default=20, type=int, help='frontier size in multidimensional random walks') parser.add_argument("--master-port", type=int, default=13579) # fmt: on diff --git a/cogdl/trainers/sampled_trainer.py b/cogdl/trainers/sampled_trainer.py index 4dd36e9f..96406700 100644 --- a/cogdl/trainers/sampled_trainer.py +++ b/cogdl/trainers/sampled_trainer.py @@ -83,6 +83,7 @@ def add_args(parser: argparse.ArgumentParser): parser.add_argument('--walk-length', default=20, type=int, help='random walk length') parser.add_argument('--size-frontier', default=20, type=int, help='frontier size in multidimensional random walks') parser.add_argument('--valid-cpu', action='store_true', help='run validation on cpu') + parser.add_argument("--eval-step", type=int, default=1) # fmt: on @classmethod From 95f4e5b15d6ba981887eebd7757b2c9ca802a703 Mon Sep 17 00:00:00 2001 From: Yukuo Cen Date: Wed, 7 Apr 2021 10:56:03 +0800 Subject: [PATCH 5/5] Trigger travis