From 8128e24a42af400bb1b490356e3e32809714ecfe Mon Sep 17 00:00:00 2001 From: Michael Wyatt Date: Wed, 12 Jun 2024 12:07:35 -0700 Subject: [PATCH 1/2] add code comments --- mii/api.py | 28 ++++++++++++++++++++++++++-- mii/backend/client.py | 2 ++ mii/backend/server.py | 3 +++ mii/batching/ragged_batching.py | 26 ++++++++++++++++++++++++-- mii/config.py | 6 ++++++ 5 files changed, 61 insertions(+), 4 deletions(-) diff --git a/mii/api.py b/mii/api.py index d909c837..77ed6e19 100644 --- a/mii/api.py +++ b/mii/api.py @@ -146,12 +146,36 @@ def serve( **kwargs, ) - # Eventually we will move away from generating score files, leaving this - # here as a placeholder for now. + # TODO: Creating a score file behavior should be changed as AML deployment + # support no longer works. Given the integration of MII/FastGen within AML + # deployment containers, we can remove that deployment type and the need to + # create a score file. Instead, we should create a config file (perhaps a + # pickled MIIConfig?) and save that where we would save the score file. This + # config can then be loaded and used similar to how the score file was used. + # Additionally, some code for standing up the deployment server will need to + # move from the score file template file to the `MIIServer` class: # MIIServer(mii_config) create_score_file(mii_config) if mii_config.deployment_type == DeploymentType.LOCAL: + # Imports the created score file and executes the init() function, then + # returns a MIIClient object. With the changes suggested in the comment + # above, importing the score file would not be necessary. + + # How grpc server is created: + # 1. The score.py file init() function makes a call to mii.backend.server.MIIServer() + # 2. MIIServer.__init__() starts load balancer, REST API, and inference + # model processes via the mii.launch.multi_gpu_server script. + # Load balancer -> mii.grpc_related.modelresponse_server.serve_load_balancing + # REST API -> mii.grpc_related.restful_gateway.RestfulGatewayThread + # Inference model -> mii.api.async_pipeline & mii.grpc_related.modelresponse_server.serve_inference + # 3. Load balancer and inference model create grpc.server() processes + # (via mii.grpc_related.modelresponse_server._do_serve) + # 4. An MIIClient() is created that uses a "stub" (via + # mii.grpc_related.proto.modelresponse_pb2_grpc.ModelResponseStub) that + # can send/receive messages to/from the load balancer process. The load + # balancer process then acts as a middle layer between the client(s) and + # the model inference server(s) import_score_file(mii_config.deployment_name, DeploymentType.LOCAL).init() return MIIClient(mii_config=mii_config) if mii_config.deployment_type == DeploymentType.AML: diff --git a/mii/backend/client.py b/mii/backend/client.py index 32847fad..cb4acc17 100644 --- a/mii/backend/client.py +++ b/mii/backend/client.py @@ -41,6 +41,8 @@ def __init__(self, mii_config: MIIConfig, host: str = "localhost") -> None: self.port = mii_config.port_number self.asyncio_loop = asyncio.get_event_loop() channel = create_channel(host, self.port) + # This stub allows interaction the client to send/receive messages with + # the load balancer process self.stub = modelresponse_pb2_grpc.ModelResponseStub(channel) def __call__(self, *args, **kwargs) -> List[Response]: diff --git a/mii/backend/server.py b/mii/backend/server.py index ffa78f77..02e055d5 100644 --- a/mii/backend/server.py +++ b/mii/backend/server.py @@ -43,6 +43,9 @@ def __init__(self, mii_config: MIIConfig) -> None: mii_config.generate_replica_configs() + # Spin up all the processes necessary for the server (i.e., load + # balancer process, each DeepSpeed model replica, and optionally the + # REST API process) processes = self._initialize_service(mii_config) self._wait_until_server_is_live(processes, mii_config.model_config.replica_configs) diff --git a/mii/batching/ragged_batching.py b/mii/batching/ragged_batching.py index a4b49ff3..b7d7c6e4 100644 --- a/mii/batching/ragged_batching.py +++ b/mii/batching/ragged_batching.py @@ -44,6 +44,9 @@ def __init__(self, inference_engine, tokenizer, model_config): self.vocab_size = tokenizer.vocab_size self.model_config = model_config self.zmq_port = model_config.zmq_port_number + + # Set max sequence length from either user-passed model_config or from + # HF model_config if model_config.max_length is not None: self.max_length = model_config.max_length else: @@ -51,6 +54,7 @@ def __init__(self, inference_engine, tokenizer, model_config): self.sync_debug = model_config.sync_debug self.profile_model_time = model_config.profile_model_time + # Create queues and other values for scheduling of requests and results self.request_queue: queue.Queue = queue.Queue() self.result_queues: Dict[int, queue.Queue] = {} self.scheduled_requests: RequestBatch = RequestBatch() @@ -59,23 +63,29 @@ def __init__(self, inference_engine, tokenizer, model_config): self.scheduled_seq_num = 0 self.scheduled_req_blocks = 0 - # TODO: we will need to prune self._post_processors for long running deployments + # TODO: Each request we process can have a unique post_processor (e.g., + # different temperature value). We will need to prune + # self._post_processors for long running deployments self._post_processors = {} self.logit_processor = run_batch_logit_processing self.sampler = run_batch_sampler self.stop_criterion = run_batch_stop_criterion + # If profiling is enabled, these are used to capture/generate data self._timers: SynchronizedWallClockTimer = SynchronizedWallClockTimer() self._profiled_times: DefaultDict[str, List[int]] = defaultdict(list) self._iters: int = 0 self._num_generated_tokens: int = 0 + # Use ZMQ because it is light-weight and fast for passing simple + # messages (i.e., token sequences) between each TP process of the + # inference engine self._zmq_context = zmq.Context() torch.cuda.synchronize() if self.is_rank_0: self.socket = self._zmq_context.socket(zmq.PUB) self.socket.bind(f"tcp://*:{self.zmq_port}") - time.sleep(1) # Give the subscriber a change to connect + time.sleep(1) # Give the subscriber a chance to connect else: self.socket = self._zmq_context.socket(zmq.SUB) self.socket.connect(f"tcp://localhost:{self.zmq_port}") @@ -92,6 +102,10 @@ def is_rank_0(self) -> bool: @profiler def generate(self) -> None: + """ + This is the main loop of FastGen: puts requests and gets generated results. + """ + # 1. Get a batch of requests, broadcast to all ranks scheduled_requests = self._bcast_requests() @@ -154,6 +168,9 @@ def _print_profiled_times(self) -> None: @sync_debug def _bcast_requests(self, force=False) -> RequestBatch: + # Rank 0 is the main process that does scheduling of requests on the + # inference engine. When new requests are to be placed on the engine, + # the prompt tokens must be broadcast to all TP processes. if self.is_rank_0: if not self.scheduled_requests and not force: return self.scheduled_requests @@ -183,6 +200,8 @@ def _process_logits( next_token_logits: torch.Tensor, running_requests: RequestBatch) -> Tuple[torch.Tensor, torch.Tensor]: + # Process generated logits, run post processing, gets next token, and + # checks for stop criteria at each round of generation for all requests. next_token_logits = next_token_logits[:, :self.vocab_size] next_token_logits = self.logit_processor(next_token_logits, running_requests, @@ -199,6 +218,9 @@ def _process_logits( @sync_debug def _generate_output(self, r: Request) -> bool: + # Gather generated tokens and put them in the result queue. For + # streaming, this happens at every generated token. For non-streaming, + # this happens only when a stop criteria is met. outputs = [] if r.stream: outputs.append(( diff --git a/mii/config.py b/mii/config.py index 8e9c5cd7..565cdbbc 100644 --- a/mii/config.py +++ b/mii/config.py @@ -387,10 +387,16 @@ def _allocate_devices(hostfile_path: str, def get_mii_config(model_or_deployment_name: str) -> MIIConfig: + """ + Looks for score file of given model or deployment name, loads the file and + returns the MIIConfig object. + """ try: deployment_name = model_or_deployment_name mii_config = import_score_file(deployment_name, DeploymentType.LOCAL).mii_config except: + # If a deployment name is not given, then one was generated + # automatically from the model name, so try that try: deployment_name = generate_deployment_name( model_name_or_path=model_or_deployment_name) From ba36252d6a35233abb9ca7c69175f183ef0adc85 Mon Sep 17 00:00:00 2001 From: Michael Wyatt Date: Wed, 12 Jun 2024 12:10:04 -0700 Subject: [PATCH 2/2] explain batched post-processing --- mii/batching/postprocess.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/mii/batching/postprocess.py b/mii/batching/postprocess.py index 1b8ff6e7..46fc0f80 100644 --- a/mii/batching/postprocess.py +++ b/mii/batching/postprocess.py @@ -14,6 +14,12 @@ def run_batch_processing(input_tensor: torch.Tensor, requests: "RaggedRequestBatch", processor_fns: Dict[str, Any]) -> torch.Tensor: + """ + Runs the post-processing steps for batched requests. If we apply the + post-processing one-by-one for each request performance takes a big hit. + Instead, we identify all the requests that need to be processed by a given + post-processor, sampler, etc. and perform the action on a batch of requests. + """ idx_list: List[int] = [] output_list: List[torch.Tensor] = []