Skip to content

Commit

Permalink
A graph-based pipeline splitting (#1080)
Browse files Browse the repository at this point in the history
An automatic graph-based pipeline splitting algorithm. The goal of the
method is to split the computation graph into stages to minimize the
communication between the stages while trying to balance the
computation. The optimization is done via solving a mixed-integer linear
program (MILP) using `scipy`.

Measuring mean batch time in sec over 50 batches (after a warmup) for
various models using "manual split", "--autosplit", and the new
"--graphsplit":

| model | nproc-per-node | manual | autosplit | graphsplit |
|--------|--------|--------|--------|--------|
| pippy_bert | 2 | 0.1082 | 0.1279 | 0.1083 |
| pippy_bert | 4 | 0.0670 | 0.0798 | 0.0671 |
| pippy_gpt2 | 2 | 0.0388 | 0.0550 | 0.0391 |
| pippy_gpt2 | 4 | 0.0201 |  0.0271 | 0.0205 |
| pippy_fnet | 2 | 0.0324 | 0.0420 | 0.0323 |
| pippy_fnet | 4 | 0.0221 |  crash | 0.0218 |
| pippy_blenderbot | 2 | 0.4805 | 0.4991 | 0.4839 |
| pippy_blenderbot | 4 | 0.2421 |  0.2593 | 0.2436 |

That is, the results of graph-split are almost identical to manual
splitting, indicating that no manual model annotation is needed.
  • Loading branch information
spupyrev authored May 31, 2024
1 parent 395801c commit 5e1d719
Show file tree
Hide file tree
Showing 6 changed files with 636 additions and 2 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
__pycache__
build
pippy.egg-info
torchpippy.egg-info
pippy/version.py
dist
.idea/
Expand Down
9 changes: 8 additions & 1 deletion examples/huggingface/pippy_gpt2.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ def run(args):
config.n_embd = args.n_embd or config.n_embd
config.n_layer = args.n_layer or config.n_layer
config.n_head = args.n_head or config.n_head
print("Using device:", args.device)
print("[Rank {}] Using device: {}".format(args.rank, args.device))

# Create model
model_class = GPT2ForSequenceClassification
Expand All @@ -38,13 +38,19 @@ def run(args):
example_inputs = generate_inputs_for_model(
model_class, gpt2, model_name, args.batch_size, args.device)

assert not args.autosplit or not args.graphsplit

split_policy = None
split_spec = None

if args.autosplit:
# Automatic split
from pippy import split_into_equal_size
split_policy = split_into_equal_size(args.world_size)
elif args.graphsplit:
# Graph-based split
from pippy import split_by_graph
split_policy = split_by_graph(args.world_size)
else:
# Use manual split spec
decoders_per_rank = (gpt2.config.n_layer + args.world_size - 1) // args.world_size
Expand Down Expand Up @@ -106,6 +112,7 @@ def run(args):
parser.add_argument('--n_layer', type=int, default=None)
parser.add_argument('--n_head', type=int, default=None)
parser.add_argument('--autosplit', action="store_true")
parser.add_argument('--graphsplit', action="store_true")

args = parser.parse_args()

Expand Down
56 changes: 56 additions & 0 deletions pippy/ModelSplit.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
import torch
import torch.fx as fx

from pippy.graphsplit import split_by_graph_with_num_stages

from ._IR import aten_pipe_split_alias


Expand Down Expand Up @@ -202,3 +204,57 @@ def _split_into_nstages_equal_size(
return gm

return _split_into_nstages_equal_size


"""
Create a Callable that splits a model into a given number of stages, based on the computation graph, while
trying to minimize the communication between the stages and to balance the computation
Input:
nstages: the number of stages to split the module into
Output:
a Callable that transforms an input `fx.GraphModule` into an output `fx.GraphModule` that has `pipe_split` inserted
between `nstages` stages
"""


def split_by_graph(nstages: int) -> Callable[[fx.GraphModule], fx.GraphModule]:
def _split_by_graph(
gm: fx.GraphModule,
) -> fx.GraphModule:
node_param_sizes = _analyze_node_size(gm)
node2stage = split_by_graph_with_num_stages(
gm, nstages, node_param_sizes
)

# Remove existing split points
for node in gm.graph.nodes:
if "pipe_split" in node.name:
gm.graph.erase_node(node)

# Modify the graph by grouping nodes on the same stage and adding
# pipe_splits between the stages
node_order = [node for node in gm.graph.nodes if node in node2stage]
last_node = None
for stage_idx in range(nstages):
nodes_at_stage = [
node
for node in node_order
if node in node2stage and node2stage[node] == stage_idx
]
for idx, node in enumerate(nodes_at_stage):
if last_node is not None and last_node.next != node:
last_node.append(node)
last_node = node
# Insert pipe_split nodes after each stage, except the last one
if stage_idx + 1 != nstages and last_node is not None:
with gm.graph.inserting_after(last_node):
last_node = gm.graph.call_function(
aten_pipe_split_alias, (), {}
)

# Since we transformed the graph, recompile the module
gm.recompile()
gm.graph.lint()
return gm

return _split_by_graph
7 changes: 6 additions & 1 deletion pippy/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,11 @@
)
from ._PipelineStage import PipelineStage
from .ManualPipelineStage import ManualPipelineStage
from .ModelSplit import split_into_equal_size, split_on_size_threshold
from .ModelSplit import (
split_by_graph,
split_into_equal_size,
split_on_size_threshold,
)
from .PipelineSchedule import (
Schedule1F1B,
ScheduleGPipe,
Expand All @@ -27,6 +31,7 @@
"annotate_split_points",
"split_into_equal_size",
"split_on_size_threshold",
"split_by_graph",
"pipeline",
"Schedule1F1B",
"ScheduleGPipe",
Expand Down
Loading

0 comments on commit 5e1d719

Please sign in to comment.