Skip to content

Commit

Permalink
Always Flush UIDs after Exceptions (#491)
Browse files Browse the repository at this point in the history
Co-authored-by: Logan Adams <114770087+loadams@users.noreply.github.com>
  • Loading branch information
weiqisun and loadams committed Jul 1, 2024
1 parent 7efbb32 commit e421fc2
Show file tree
Hide file tree
Showing 2 changed files with 52 additions and 40 deletions.
33 changes: 21 additions & 12 deletions mii/batching/ragged_batching.py
Original file line number Diff line number Diff line change
Expand Up @@ -656,6 +656,8 @@ def __call__(self) -> None:

def _get_uid(self) -> int:
with self.lock:
if len(self.uids) >= self.UID_RANGE_UB - self.UID_RANGE_LB:
raise RuntimeError("No available choices for a new UID.")
uid = random.randrange(self.UID_RANGE_LB, self.UID_RANGE_UB)
while uid in self.uids:
uid = random.randrange(self.UID_RANGE_LB, self.UID_RANGE_UB)
Expand All @@ -674,21 +676,28 @@ def put_request(self, prompt: str, kwargs: Dict) -> int:

uid = self._get_uid()

# Temporary hack to avoid non-rank 0 processes not shutting down. See
# related TODO above.
if not self.is_rank_0:
return uid
try:
# Temporary hack to avoid non-rank 0 processes not shutting down. See
# related TODO above.
if not self.is_rank_0:
return uid

tid = threading.get_ident()
with self.lock:
if tid not in self.result_queues:
self.result_queues[tid] = queue.Queue()
tid = threading.get_ident()
with self.lock:
if tid not in self.result_queues:
self.result_queues[tid] = queue.Queue()

input_tokens = self.tokenizer.encode(prompt)
request = self.make_request(tid, uid, input_tokens, kwargs)
self.request_queue.put(request)
input_tokens = self.tokenizer.encode(prompt)
request = self.make_request(tid, uid, input_tokens, kwargs)
self.request_queue.put(request)

return uid
return uid
except:
# It is OK to have `self.request_queue.put(request)` in the try block since
# it will never raise exceptions with unlimited queue size. If any exception
# occurred in the above block, the `request` obj was not enqueued.
self.flush_uid(uid)
raise

def get_response(self) -> Tuple[int, Response]:
# TODO: We should avoid any request/response work with non-rank 0, but
Expand Down
59 changes: 31 additions & 28 deletions mii/grpc_related/modelresponse_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,34 +67,37 @@ def GeneratorReply(self, request, context):
prompts, kwargs = task_methods.unpack_request_from_proto(request)
uids_put_order, uids_running, uids_complete_order, responses = [], [], [], []

# Put requests for all prompts into the pipeline
for p in prompts:
request_kwargs = kwargs.copy()
uid = self.inference_pipeline.put_request(p, request_kwargs)
uids_put_order.append(uid)
uids_running.append(uid)

# Get responses from the pipeline as they are ready, flush finished uids
# so new requests can be processed
while uids_running:
uid, response = self.inference_pipeline.get_response()
# TODO: Ugly hack for multi-threading. Will be fixed when we refactor these methods
if uid == -1:
uid = uids_running[0]
responses.append(response)
self.inference_pipeline.flush_uid(uid)
uids_complete_order.append(uids_put_order.index(uid))
uids_running.remove(uid)

# Sort responses in the order of prompts
responses = [
r for idx,
r in sorted(zip(uids_complete_order,
responses),
key=lambda pair: pair[0])
]

return task_methods.pack_response_to_proto(responses)
try:
# Put requests for all prompts into the pipeline
for p in prompts:
request_kwargs = kwargs.copy()
uid = self.inference_pipeline.put_request(p, request_kwargs)
uids_put_order.append(uid)
uids_running.append(uid)

# Get responses from the pipeline as they are ready, flush finished uids
# so new requests can be processed
while uids_running:
uid, response = self.inference_pipeline.get_response()
# TODO: Ugly hack for multi-threading. Will be fixed when we refactor these methods
if uid == -1:
uid = uids_running[0]
responses.append(response)
self.inference_pipeline.flush_uid(uid)
uids_complete_order.append(uids_put_order.index(uid))
uids_running.remove(uid)

# Sort responses in the order of prompts
responses = [
r for idx,
r in sorted(zip(uids_complete_order,
responses),
key=lambda pair: pair[0])
]

return task_methods.pack_response_to_proto(responses)
finally:
[self.inference_pipeline.flush_uid(uid) for uid in uids_running]

def GeneratorReplyStream(self, request, context):
task_methods = self._get_task_methods("GeneratorReply")
Expand Down

0 comments on commit e421fc2

Please sign in to comment.