Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Always Flush UIDs after Exceptions #491

Merged
merged 4 commits into from
Jul 1, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 21 additions & 12 deletions mii/batching/ragged_batching.py
Original file line number Diff line number Diff line change
Expand Up @@ -656,6 +656,8 @@ def __call__(self) -> None:

def _get_uid(self) -> int:
with self.lock:
if len(self.uids) >= self.UID_RANGE_UB - self.UID_RANGE_LB:
raise RuntimeError("No available choices for a new UID.")
uid = random.randrange(self.UID_RANGE_LB, self.UID_RANGE_UB)
while uid in self.uids:
uid = random.randrange(self.UID_RANGE_LB, self.UID_RANGE_UB)
Expand All @@ -674,21 +676,28 @@ def put_request(self, prompt: str, kwargs: Dict) -> int:

uid = self._get_uid()

# Temporary hack to avoid non-rank 0 processes not shutting down. See
# related TODO above.
if not self.is_rank_0:
return uid
try:
# Temporary hack to avoid non-rank 0 processes not shutting down. See
# related TODO above.
if not self.is_rank_0:
return uid

tid = threading.get_ident()
with self.lock:
if tid not in self.result_queues:
self.result_queues[tid] = queue.Queue()
tid = threading.get_ident()
with self.lock:
if tid not in self.result_queues:
self.result_queues[tid] = queue.Queue()

input_tokens = self.tokenizer.encode(prompt)
request = self.make_request(tid, uid, input_tokens, kwargs)
self.request_queue.put(request)
input_tokens = self.tokenizer.encode(prompt)
request = self.make_request(tid, uid, input_tokens, kwargs)
self.request_queue.put(request)

return uid
return uid
except:
# It is OK to have `self.request_queue.put(request)` in the try block since
# it will never raise exceptions with unlimited queue size. If any exception
# occurred in the above block, the `request` obj was not enqueued.
self.flush_uid(uid)
raise

def get_response(self) -> Tuple[int, Response]:
# TODO: We should avoid any request/response work with non-rank 0, but
Expand Down
59 changes: 31 additions & 28 deletions mii/grpc_related/modelresponse_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,34 +67,37 @@ def GeneratorReply(self, request, context):
prompts, kwargs = task_methods.unpack_request_from_proto(request)
uids_put_order, uids_running, uids_complete_order, responses = [], [], [], []

# Put requests for all prompts into the pipeline
for p in prompts:
request_kwargs = kwargs.copy()
uid = self.inference_pipeline.put_request(p, request_kwargs)
uids_put_order.append(uid)
uids_running.append(uid)

# Get responses from the pipeline as they are ready, flush finished uids
# so new requests can be processed
while uids_running:
uid, response = self.inference_pipeline.get_response()
# TODO: Ugly hack for multi-threading. Will be fixed when we refactor these methods
if uid == -1:
uid = uids_running[0]
responses.append(response)
self.inference_pipeline.flush_uid(uid)
uids_complete_order.append(uids_put_order.index(uid))
uids_running.remove(uid)

# Sort responses in the order of prompts
responses = [
r for idx,
r in sorted(zip(uids_complete_order,
responses),
key=lambda pair: pair[0])
]

return task_methods.pack_response_to_proto(responses)
try:
# Put requests for all prompts into the pipeline
for p in prompts:
request_kwargs = kwargs.copy()
uid = self.inference_pipeline.put_request(p, request_kwargs)
uids_put_order.append(uid)
uids_running.append(uid)

# Get responses from the pipeline as they are ready, flush finished uids
# so new requests can be processed
while uids_running:
uid, response = self.inference_pipeline.get_response()
# TODO: Ugly hack for multi-threading. Will be fixed when we refactor these methods
if uid == -1:
uid = uids_running[0]
responses.append(response)
self.inference_pipeline.flush_uid(uid)
uids_complete_order.append(uids_put_order.index(uid))
uids_running.remove(uid)

# Sort responses in the order of prompts
responses = [
r for idx,
r in sorted(zip(uids_complete_order,
responses),
key=lambda pair: pair[0])
]

return task_methods.pack_response_to_proto(responses)
finally:
[self.inference_pipeline.flush_uid(uid) for uid in uids_running]

def GeneratorReplyStream(self, request, context):
task_methods = self._get_task_methods("GeneratorReply")
Expand Down