diff --git a/superduperdb/ext/llm/vllm.py b/superduperdb/ext/llm/vllm.py index 6aa2529985..c54b61e51e 100644 --- a/superduperdb/ext/llm/vllm.py +++ b/superduperdb/ext/llm/vllm.py @@ -8,6 +8,8 @@ from superduperdb import logging from superduperdb.ext.llm.base import BaseLLMAPI, BaseLLMModel, BaseOpenAI +__all__ = ["VllmAPI", "VllmModel", "VllmOpenAI"] + VLLM_INFERENCE_PARAMETERS_LIST = [ "n", "best_of", @@ -119,13 +121,40 @@ class VllmModel(BaseLLMModel): def __post_init__(self): self.on_ray = self.on_ray or bool(self.ray_address) + if 'tensor_parallel_size' not in self.vllm_kwargs: + self.vllm_kwargs['tensor_parallel_size'] = self.tensor_parallel_size + + if 'trust_remote_code' not in self.vllm_kwargs: + self.vllm_kwargs['trust_remote_code'] = self.trust_remote_code + + if 'model' not in self.vllm_kwargs: + self.vllm_kwargs['model'] = self.model_name + super().__post_init__() def init(self): - try: - from vllm import LLM - except ImportError: - raise Exception("You must install vllm with command 'pip install vllm'") + class _VLLMCore: + """ + Wrapper for vllm model to support ray. + Implementing the client in this way will no longer require vllm dependencies + """ + + def __init__(self, **kwargs) -> None: + try: + from vllm import LLM + except ImportError: + raise Exception( + "You must install vllm with command 'pip install vllm'" + ) + self.model = LLM(**kwargs) + + def generate(self, prompts: List[str], **kwargs) -> List[str]: + from vllm import SamplingParams + + sampling_params = SamplingParams(**kwargs) + results = self.model.generate(prompts, sampling_params, use_tqdm=False) + results = [result.outputs[0].text for result in results] + return results if self.on_ray: try: @@ -133,37 +162,37 @@ def init(self): except ImportError: raise Exception("You must install vllm with command 'pip install ray'") - runtime_env = {"pip": ["vllm"]} + runtime_env = { + "pip": [ + "vllm", + ] + } if not ray.is_initialized(): ray.init(address=self.ray_address, runtime_env=runtime_env) - LLM = ray.remote(LLM).remote + if "num_gpus" not in self.ray_config: + self.ray_config["num_gpus"] = self.tensor_parallel_size + LLM = ray.remote(**self.ray_config)(_VLLMCore).remote + else: + LLM = _VLLMCore - self.llm = LLM( - model=self.model_name, - tensor_parallel_size=self.tensor_parallel_size, - trust_remote_code=self.trust_remote_code, - **self.vllm_kwargs, - ) + self.llm = LLM(**self.vllm_kwargs) def _batch_generate(self, prompts: List[str], **kwargs: Any) -> List[str]: - from vllm import SamplingParams - - # support more parameters - sampling_params = SamplingParams( - **self.get_kwargs(SamplingParams, kwargs, self.inference_kwargs) - ) + total_kwargs = {} + for key, value in {**self.inference_kwargs, **kwargs}.items(): + if key in VLLM_INFERENCE_PARAMETERS_LIST: + total_kwargs[key] = value if self.on_ray: import ray - results = ray.get( - self.llm.generate.remote(prompts, sampling_params, use_tqdm=False) - ) + results = ray.get(self.llm.generate.remote(prompts, **total_kwargs)) else: - results = self.llm.generate(prompts, sampling_params, use_tqdm=False) + results = self.llm.generate(prompts, **total_kwargs) + results = [result.outputs[0].text for result in results] - return [result.outputs[0].text for result in results] + return results def _generate(self, prompt: str, **kwargs: Any) -> str: return self._batch_generate([prompt], **kwargs)[0]