Skip to content

Commit

Permalink
setup training workflow
Browse files Browse the repository at this point in the history
  • Loading branch information
eddogola committed Jul 10, 2023
1 parent a40a9e7 commit ab5ba4f
Showing 1 changed file with 74 additions and 6 deletions.
80 changes: 74 additions & 6 deletions examples/mnist/new_pippy_mnist.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,21 @@
from tqdm import tqdm
import argparse
import os

import torch
from torch import nn
import torch.optim as optim
import torch.distributed as dist
from torchvision import datasets, transforms
from torch.nn.functional import cross_entropy
from torch.utils.data import DistributedSampler, DataLoader
from torch.utils.data import DataLoader

from pippy.microbatch import sum_reducer, TensorChunkSpec
from pippy.IR import LossWrapper, PipeSplitWrapper
from pippy.compile import compile_stage

USE_TQDM = bool(int(os.getenv("USE_TQDM", 1)))
LR_VERBOSE = bool(int(os.getenv("LR_VERBOSE", 1)))


def run_worker(args):
# define transforms
Expand Down Expand Up @@ -65,12 +69,74 @@ def forward(self, input, target):
device=args.device,
group=None,
example_inputs=[x, target],
# output_chunk_spec={
# "loss": sum_reducer,
# "logits": TensorChunkSpec(0),
# },
)

# setup optimizer
optimizer = optim.Adam(stage.submod.parameters(), lr=1e-3, betas=(0.9, 0.999), eps=1e-8)
# setup lr scheduler
lr_sched = optim.lr_scheduler.LinearLR(optimizer, verbose=LR_VERBOSE)

loaders = {
"train": train_dataloader,
"valid": valid_dataloader,
}

batches_events_contexts = []

for epoch in range(args.max_epochs):
print(f"Epoch: {epoch + 1} of {args.max_epochs}")

for k, dataloader in loaders.items():
epoch_correct = 0
epoch_all = 0
for i, (x_batch, y_batch) in enumerate(tqdm(dataloader) if USE_TQDM else dataloader):
x_batch = x_batch.to(args.device)
y_batch = y_batch.to(args.device)

if k == "train":
stage.train()
optimizer.zero_grad()

if args.rank == 0:
out = stage(x_batch)
elif args.rank == args.world_size - 1:
out = stage(y_batch)
else:
stage()

# outp, loss = stage(x_batch, y_batch)
preds = out.argmax(-1)
correct = (preds == y_batch).sum()
all = len(y_batch)
epoch_correct += correct.item()
epoch_all += all
optimizer.step()
else:
stage.eval()
with torch.no_grad():
if args.rank == 0:
out = stage(x_batch, y_batch)
elif args.rank == args.world_size - 1:
out = stage()
else:
stage()
# outp, _ = stage(x_batch, y_batch)
preds = out.argmax(-1)
correct = (preds == y_batch).sum()
all = len(y_batch)
epoch_correct += correct.item()
epoch_all += all

if args.visualize:
batches_events_contexts.append(stage.retrieve_events())
print(f"Loader: {k} Accuracy: {epoch_correct / epoch_all}")

if k == "train":
lr_sched.step()
# if LR_VERBOSE:
# print(f"Pipe ")


dist.barrier()
print(f"Rank {args.rank} completed!")

Expand All @@ -91,8 +157,10 @@ def main(args=None):
"--master_port", type=str, default=os.getenv("MASTER_PORT", "29500")
)
parser.add_argument("--cuda", type=int, default=int(torch.cuda.is_available()))
parser.add_argument("--max_epochs", type=int, default=10)
parser.add_argument("--batch_size", type=int, default=10)
parser.add_argument("--chunks", type=int, default=4)
parser.add_argument("--visualize", type=int, default=1, choices=[0, 1])
args = parser.parse_args(args)
if args.cuda:
dev_id = args.rank % torch.cuda.device_count()
Expand Down

0 comments on commit ab5ba4f

Please sign in to comment.