Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add explanations of MII code into comments #493

Merged
merged 4 commits into from
Jul 1, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 26 additions & 2 deletions mii/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,12 +146,36 @@ def serve(
**kwargs,
)

# Eventually we will move away from generating score files, leaving this
# here as a placeholder for now.
# TODO: Creating a score file behavior should be changed as AML deployment
# support no longer works. Given the integration of MII/FastGen within AML
# deployment containers, we can remove that deployment type and the need to
# create a score file. Instead, we should create a config file (perhaps a
# pickled MIIConfig?) and save that where we would save the score file. This
# config can then be loaded and used similar to how the score file was used.
# Additionally, some code for standing up the deployment server will need to
# move from the score file template file to the `MIIServer` class:
# MIIServer(mii_config)
create_score_file(mii_config)

if mii_config.deployment_type == DeploymentType.LOCAL:
# Imports the created score file and executes the init() function, then
# returns a MIIClient object. With the changes suggested in the comment
# above, importing the score file would not be necessary.

# How grpc server is created:
# 1. The score.py file init() function makes a call to mii.backend.server.MIIServer()
# 2. MIIServer.__init__() starts load balancer, REST API, and inference
# model processes via the mii.launch.multi_gpu_server script.
# Load balancer -> mii.grpc_related.modelresponse_server.serve_load_balancing
# REST API -> mii.grpc_related.restful_gateway.RestfulGatewayThread
# Inference model -> mii.api.async_pipeline & mii.grpc_related.modelresponse_server.serve_inference
# 3. Load balancer and inference model create grpc.server() processes
# (via mii.grpc_related.modelresponse_server._do_serve)
# 4. An MIIClient() is created that uses a "stub" (via
# mii.grpc_related.proto.modelresponse_pb2_grpc.ModelResponseStub) that
# can send/receive messages to/from the load balancer process. The load
# balancer process then acts as a middle layer between the client(s) and
# the model inference server(s)
import_score_file(mii_config.deployment_name, DeploymentType.LOCAL).init()
return MIIClient(mii_config=mii_config)
if mii_config.deployment_type == DeploymentType.AML:
Expand Down
2 changes: 2 additions & 0 deletions mii/backend/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,8 @@ def __init__(self, mii_config: MIIConfig, host: str = "localhost") -> None:
self.port = mii_config.port_number
self.asyncio_loop = asyncio.get_event_loop()
channel = create_channel(host, self.port)
# This stub allows interaction the client to send/receive messages with
# the load balancer process
self.stub = modelresponse_pb2_grpc.ModelResponseStub(channel)

def __call__(self, *args, **kwargs) -> List[Response]:
Expand Down
3 changes: 3 additions & 0 deletions mii/backend/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,9 @@ def __init__(self, mii_config: MIIConfig) -> None:

mii_config.generate_replica_configs()

# Spin up all the processes necessary for the server (i.e., load
# balancer process, each DeepSpeed model replica, and optionally the
# REST API process)
processes = self._initialize_service(mii_config)
self._wait_until_server_is_live(processes,
mii_config.model_config.replica_configs)
Expand Down
6 changes: 6 additions & 0 deletions mii/batching/postprocess.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,12 @@ def run_batch_processing(input_tensor: torch.Tensor,
requests: "RaggedRequestBatch",
processor_fns: Dict[str,
Any]) -> torch.Tensor:
"""
Runs the post-processing steps for batched requests. If we apply the
post-processing one-by-one for each request performance takes a big hit.
Instead, we identify all the requests that need to be processed by a given
post-processor, sampler, etc. and perform the action on a batch of requests.
"""
idx_list: List[int] = []
output_list: List[torch.Tensor] = []

Expand Down
26 changes: 24 additions & 2 deletions mii/batching/ragged_batching.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,13 +44,17 @@ def __init__(self, inference_engine, tokenizer, model_config):
self.vocab_size = tokenizer.vocab_size
self.model_config = model_config
self.zmq_port = model_config.zmq_port_number

# Set max sequence length from either user-passed model_config or from
# HF model_config
if model_config.max_length is not None:
self.max_length = model_config.max_length
else:
self.max_length = inference_engine._policy._checkpoint_engine.model_config.max_seq_length
self.sync_debug = model_config.sync_debug
self.profile_model_time = model_config.profile_model_time

# Create queues and other values for scheduling of requests and results
self.request_queue: queue.Queue = queue.Queue()
self.result_queues: Dict[int, queue.Queue] = {}
self.scheduled_requests: RequestBatch = RequestBatch()
Expand All @@ -59,23 +63,29 @@ def __init__(self, inference_engine, tokenizer, model_config):
self.scheduled_seq_num = 0
self.scheduled_req_blocks = 0

# TODO: we will need to prune self._post_processors for long running deployments
# TODO: Each request we process can have a unique post_processor (e.g.,
# different temperature value). We will need to prune
# self._post_processors for long running deployments
self._post_processors = {}
self.logit_processor = run_batch_logit_processing
self.sampler = run_batch_sampler
self.stop_criterion = run_batch_stop_criterion

# If profiling is enabled, these are used to capture/generate data
self._timers: SynchronizedWallClockTimer = SynchronizedWallClockTimer()
self._profiled_times: DefaultDict[str, List[int]] = defaultdict(list)
self._iters: int = 0
self._num_generated_tokens: int = 0

# Use ZMQ because it is light-weight and fast for passing simple
# messages (i.e., token sequences) between each TP process of the
# inference engine
self._zmq_context = zmq.Context()
torch.cuda.synchronize()
if self.is_rank_0:
self.socket = self._zmq_context.socket(zmq.PUB)
self.socket.bind(f"tcp://*:{self.zmq_port}")
time.sleep(1) # Give the subscriber a change to connect
time.sleep(1) # Give the subscriber a chance to connect
else:
self.socket = self._zmq_context.socket(zmq.SUB)
self.socket.connect(f"tcp://localhost:{self.zmq_port}")
Expand All @@ -92,6 +102,10 @@ def is_rank_0(self) -> bool:

@profiler
def generate(self) -> None:
"""
This is the main loop of FastGen: puts requests and gets generated results.
"""

# 1. Get a batch of requests, broadcast to all ranks
scheduled_requests = self._bcast_requests()

Expand Down Expand Up @@ -154,6 +168,9 @@ def _print_profiled_times(self) -> None:

@sync_debug
def _bcast_requests(self, force=False) -> RequestBatch:
# Rank 0 is the main process that does scheduling of requests on the
# inference engine. When new requests are to be placed on the engine,
# the prompt tokens must be broadcast to all TP processes.
if self.is_rank_0:
if not self.scheduled_requests and not force:
return self.scheduled_requests
Expand Down Expand Up @@ -183,6 +200,8 @@ def _process_logits(
next_token_logits: torch.Tensor,
running_requests: RequestBatch) -> Tuple[torch.Tensor,
torch.Tensor]:
# Process generated logits, run post processing, gets next token, and
# checks for stop criteria at each round of generation for all requests.
next_token_logits = next_token_logits[:, :self.vocab_size]
next_token_logits = self.logit_processor(next_token_logits,
running_requests,
Expand All @@ -199,6 +218,9 @@ def _process_logits(

@sync_debug
def _generate_output(self, r: Request) -> bool:
# Gather generated tokens and put them in the result queue. For
# streaming, this happens at every generated token. For non-streaming,
# this happens only when a stop criteria is met.
outputs = []
if r.stream:
outputs.append((
Expand Down
6 changes: 6 additions & 0 deletions mii/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -387,10 +387,16 @@ def _allocate_devices(hostfile_path: str,


def get_mii_config(model_or_deployment_name: str) -> MIIConfig:
"""
Looks for score file of given model or deployment name, loads the file and
returns the MIIConfig object.
"""
try:
deployment_name = model_or_deployment_name
mii_config = import_score_file(deployment_name, DeploymentType.LOCAL).mii_config
except:
# If a deployment name is not given, then one was generated
# automatically from the model name, so try that
try:
deployment_name = generate_deployment_name(
model_name_or_path=model_or_deployment_name)
Expand Down