Skip to content

Commit

Permalink
[engin/schedule] use p2p_v2 to recontruct pipeline_schedule
Browse files Browse the repository at this point in the history
  • Loading branch information
LSTM-Kirigaya committed Aug 8, 2022
1 parent 6c5147a commit 4cb7559
Show file tree
Hide file tree
Showing 2 changed files with 1 addition and 12 deletions.
5 changes: 0 additions & 5 deletions colossalai/engine/schedule/_pipeline_schedule_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,9 +106,7 @@ def forward_backward_step(self,

# Run warmup forward passes.
for i in range(num_warmup_microbatches):
# print(Back.BLUE, "rank {}".format(local_rank), Style.RESET_ALL, "ready to recv_forward")
input_obj = comm.recv_forward()
# print(Back.BLUE, "rank {}".format(local_rank), Style.RESET_ALL, "finish recv_forward")

output_obj = self._forward_step(engine,
input_obj,
Expand All @@ -117,14 +115,11 @@ def forward_backward_step(self,
accum_loss=accum_loss)

comm.send_forward(output_obj)
# print(Back.BLUE, "rank {}".format(local_rank), Style.RESET_ALL, "finish send_forward")

if not forward_only:
input_objs.append(input_obj)
output_objs.append(output_obj)

# print(Back.GREEN, "rank {}".format(local_rank), Style.RESET_ALL, "warmup finish")

# Before running 1F1B, need to receive first forward tensor.
# If all microbatches are run in warmup / cooldown phase, then no need to
# receive this tensor here.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -82,13 +82,7 @@ def run_trainer(rank, world_size, port):
train_dataloader=train_dataloader)

engine._schedule = PipelineScheduleV2(num_microbatches=gpc.config.NUM_MICRO_BATCHES)
# print("enter" * 20)
# # test v2 schedule
# try:
# engine._schedule = PipelineSchedule(gpc.config.NUM_MICRO_BATCHES)
# except Exception as e:
# print(e)
# return

logger = get_dist_logger()

trainer = Trainer(engine=engine, logger=logger)
Expand Down

0 comments on commit 4cb7559

Please sign in to comment.