Skip to content

Commit

Permalink
Actual fix
Browse files Browse the repository at this point in the history
  • Loading branch information
kwen2501 committed Jul 10, 2023
1 parent 494cd55 commit d223f16
Showing 1 changed file with 8 additions and 8 deletions.
16 changes: 8 additions & 8 deletions pippy/compile.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,10 +81,10 @@ def _compile(

# Figure out which output is loss from output_chunk_spec
output_loss_value_spec: Any = None
if isinstance(output_chunk_spec, dict):
output_loss_value_spec = {
k: isinstance(v, LossReducer) for k, v in output_chunk_spec.items()
}
if output_chunk_spec is not None:
output_loss_value_spec = fx.node.map_aggregate(
output_chunk_spec, lambda v: isinstance(v, LossReducer)
)

logging.info("[PiPPy] Tracing model ...")
pipe_model = Pipe.from_tracing(
Expand Down Expand Up @@ -239,10 +239,10 @@ def compile_stage(

# Figure out which output is loss from output_chunk_spec
output_loss_value_spec: Any = None
if isinstance(output_chunk_spec, dict):
output_loss_value_spec = {
k: isinstance(v, LossReducer) for k, v in output_chunk_spec.items()
}
if output_chunk_spec is not None:
output_loss_value_spec = fx.node.map_aggregate(
output_chunk_spec, lambda v: isinstance(v, LossReducer)
)

logging.info("[PiPPy] Tracing model ...")
pipe = Pipe.from_tracing(
Expand Down

0 comments on commit d223f16

Please sign in to comment.