Skip to content

Commit

Permalink
Remove vllm dependency when using ray to run vllm
Browse files Browse the repository at this point in the history
  • Loading branch information
jieguangzhou committed Jan 3, 2024
1 parent 2521343 commit c32f2c0
Showing 1 changed file with 52 additions and 23 deletions.
75 changes: 52 additions & 23 deletions superduperdb/ext/llm/vllm.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
from superduperdb import logging
from superduperdb.ext.llm.base import BaseLLMAPI, BaseLLMModel, BaseOpenAI

__all__ = ["VllmAPI", "VllmModel", "VllmOpenAI"]

VLLM_INFERENCE_PARAMETERS_LIST = [
"n",
"best_of",
Expand Down Expand Up @@ -119,51 +121,78 @@ class VllmModel(BaseLLMModel):

def __post_init__(self):
self.on_ray = self.on_ray or bool(self.ray_address)
if 'tensor_parallel_size' not in self.vllm_kwargs:
self.vllm_kwargs['tensor_parallel_size'] = self.tensor_parallel_size

if 'trust_remote_code' not in self.vllm_kwargs:
self.vllm_kwargs['trust_remote_code'] = self.trust_remote_code

if 'model' not in self.vllm_kwargs:
self.vllm_kwargs['model'] = self.model_name

super().__post_init__()

def init(self):
try:
from vllm import LLM
except ImportError:
raise Exception("You must install vllm with command 'pip install vllm'")
class _VLLMCore:
"""
Wrapper for vllm model to support ray.
Implementing the client in this way will no longer require vllm dependencies
"""

def __init__(self, **kwargs) -> None:
try:
from vllm import LLM
except ImportError:
raise Exception(
"You must install vllm with command 'pip install vllm'"
)
self.model = LLM(**kwargs)

def generate(self, prompts: List[str], **kwargs) -> List[str]:
from vllm import SamplingParams

sampling_params = SamplingParams(**kwargs)
results = self.model.generate(prompts, sampling_params, use_tqdm=False)
results = [result.outputs[0].text for result in results]
return results

if self.on_ray:
try:
import ray
except ImportError:
raise Exception("You must install vllm with command 'pip install ray'")

runtime_env = {"pip": ["vllm"]}
runtime_env = {
"pip": [
"vllm",
]
}
if not ray.is_initialized():
ray.init(address=self.ray_address, runtime_env=runtime_env)

LLM = ray.remote(LLM).remote
if "num_gpus" not in self.ray_config:
self.ray_config["num_gpus"] = self.tensor_parallel_size
LLM = ray.remote(**self.ray_config)(_VLLMCore).remote
else:
LLM = _VLLMCore

self.llm = LLM(
model=self.model_name,
tensor_parallel_size=self.tensor_parallel_size,
trust_remote_code=self.trust_remote_code,
**self.vllm_kwargs,
)
self.llm = LLM(**self.vllm_kwargs)

def _batch_generate(self, prompts: List[str], **kwargs: Any) -> List[str]:
from vllm import SamplingParams

# support more parameters
sampling_params = SamplingParams(
**self.get_kwargs(SamplingParams, kwargs, self.inference_kwargs)
)
total_kwargs = {}
for key, value in {**self.inference_kwargs, **kwargs}.items():
if key in VLLM_INFERENCE_PARAMETERS_LIST:
total_kwargs[key] = value

if self.on_ray:
import ray

results = ray.get(
self.llm.generate.remote(prompts, sampling_params, use_tqdm=False)
)
results = ray.get(self.llm.generate.remote(prompts, **total_kwargs))
else:
results = self.llm.generate(prompts, sampling_params, use_tqdm=False)
results = self.llm.generate(prompts, **total_kwargs)
results = [result.outputs[0].text for result in results]

return [result.outputs[0].text for result in results]
return results

def _generate(self, prompt: str, **kwargs: Any) -> str:
return self._batch_generate([prompt], **kwargs)[0]

0 comments on commit c32f2c0

Please sign in to comment.