Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Recent vLLMs ask for too much memory: ValueError: No available memory for the cache blocks. Try increasing gpu_memory_utilization when initializing the engine. #2248

Open
pseudotensor opened this issue Dec 24, 2023 · 43 comments
Labels
bug Something isn't working

Comments

@pseudotensor
Copy link

pseudotensor commented Dec 24, 2023

Since vLLM 0.2.5, we can't even run llama-2 70B 4bit AWQ on 4*A10G anymore, have to use old vLLM. Similar problems even trying to be two 7b models on 80B A100.

For small models, like 7b with 4k tokens, vLLM fails for "cache blocks" even though alot more memory is left.

E.g. building docker image with cuda 11.8 and vllm 0.2.5 or 0.2.6 and running like:

port=5001
tokens=8192
docker run -d \
    --runtime=nvidia \
    --gpus '"device=1"' \
    --shm-size=10.24gb \
    -p $port:$port \
    --entrypoint /h2ogpt_conda/vllm_env/bin/python3.10 \
    -e NCCL_IGNORE_DISABLED_P2P=1 \
    -v /etc/passwd:/etc/passwd:ro \
    -v /etc/group:/etc/group:ro \
    -u `id -u`:`id -g` \
    -v "${HOME}"/.cache:/workspace/.cache \
    --network host \
    gcr.io/vorvan/h2oai/h2ogpt-runtime:0.1.0 -m vllm.entrypoints.openai.api_server \
        --port=$port \
        --host=0.0.0.0 \
        --model=defog/sqlcoder2 \
        --seed 1234 \
        --trust-remote-code \
	--max-num-batched-tokens $tokens \
	--max-model-len=$tokens \
	--gpu-memory-utilization 0.4 \
        --download-dir=/workspace/.cache/huggingface/hub &>> logs.vllm_server.sqlcoder2.txt

port=5002
tokens=4096
docker run -d \
    --runtime=nvidia \
    --gpus '"device=1"' \
    --shm-size=10.24gb \
    -p $port:$port \
    --entrypoint /h2ogpt_conda/vllm_env/bin/python3.10 \
    -e NCCL_IGNORE_DISABLED_P2P=1 \
    -v /etc/passwd:/etc/passwd:ro \
    -v /etc/group:/etc/group:ro \
    -u `id -u`:`id -g` \
    -v "${HOME}"/.cache:/workspace/.cache \
    --network host \
    gcr.io/vorvan/h2oai/h2ogpt-runtime:0.1.0 -m vllm.entrypoints.openai.api_server \
        --port=$port \
        --host=0.0.0.0 \
        --model=NumbersStation/nsql-llama-2-7B \
        --seed 1234 \
        --trust-remote-code \
	--max-num-batched-tokens $tokens \
	--gpu-memory-utilization 0.6 \
	--max-model-len=$tokens \
        --download-dir=/workspace/.cache/huggingface/hub &>> logs.vllm_server.nsql7b.txt

works. However, if the 2nd model was to have 0.4, one gets:

Traceback (most recent call last):
  File "/h2ogpt_conda/lib/python3.10/runpy.py", line 196, in _run_module_as_main
    return _run_code(code, main_globals, None,
  File "/h2ogpt_conda/lib/python3.10/runpy.py", line 86, in _run_code
    exec(code, run_globals)
  File "/h2ogpt_conda/vllm_env/lib/python3.10/site-packages/vllm/entrypoints/openai/api_server.py", line 729, in <module>
    engine = AsyncLLMEngine.from_engine_args(engine_args)
  File "/h2ogpt_conda/vllm_env/lib/python3.10/site-packages/vllm/engine/async_llm_engine.py", line 496, in from_engine_args
    engine = cls(parallel_config.worker_use_ray,
  File "/h2ogpt_conda/vllm_env/lib/python3.10/site-packages/vllm/engine/async_llm_engine.py", line 269, in __init__
    self.engine = self._init_engine(*args, **kwargs)
  File "/h2ogpt_conda/vllm_env/lib/python3.10/site-packages/vllm/engine/async_llm_engine.py", line 314, in _init_engine
    return engine_class(*args, **kwargs)
  File "/h2ogpt_conda/vllm_env/lib/python3.10/site-packages/vllm/engine/llm_engine.py", line 113, in __init__
    self._init_cache()
  File "/h2ogpt_conda/vllm_env/lib/python3.10/site-packages/vllm/engine/llm_engine.py", line 227, in _init_cache
    raise ValueError("No available memory for the cache blocks. "
ValueError: No available memory for the cache blocks. Try increasing `gpu_memory_utilization` when initializing the engine.

However, with 0.6 util from before, here is what GPU looks like:


Sun Dec 24 02:45:53 2023
+---------------------------------------------------------------------------------------+
| NVIDIA-SMI 535.129.03             Driver Version: 535.129.03   CUDA Version: 12.2     |
|-----------------------------------------+----------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |         Memory-Usage | GPU-Util  Compute M. |
|                                         |                      |               MIG M. |
|=========================================+======================+======================|
|   0  NVIDIA A100 80GB PCIe          Off | 00000000:00:06.0 Off |                    0 |
| N/A   43C    P0              72W / 300W |  70917MiB / 81920MiB |      0%      Default |
|                                         |                      |             Disabled |
+-----------------------------------------+----------------------+----------------------+
|   1  NVIDIA A100 80GB PCIe          Off | 00000000:00:07.0 Off |                    0 |
| N/A   45C    P0              66W / 300W |  49136MiB / 81920MiB |      0%      Default |
|                                         |                      |             Disabled |
+-----------------------------------------+----------------------+----------------------+

+---------------------------------------------------------------------------------------+
| Processes:                                                                            |
|  GPU   GI   CI        PID   Type   Process name                            GPU Memory |
|        ID   ID                                                             Usage      |
|=======================================================================================|
|    0   N/A  N/A      6232      C   /h2ogpt_conda/vllm_env/bin/python3.10     70892MiB |
|    1   N/A  N/A      6966      C   /h2ogpt_conda/vllm_env/bin/python3.10     32430MiB |
|    1   N/A  N/A      7685      C   /h2ogpt_conda/vllm_env/bin/python3.10     16670MiB |

Ignore GPU=0.

So 0.6 util is 17GB, why would 0.4 util out of 80GB be a problem?

@hanzhi713
Copy link
Contributor

vLLM 0.2.6 added cuda graph support, which is enabled by default (probably not a good decision)

CUDA graph introduces a bit more memory overhead. Try to see if adding --enforce-eager flag helps. This flag disables CUDA graph execution.

@pseudotensor
Copy link
Author

pseudotensor commented Dec 24, 2023

Thanks for responding. However, we had problems starting with 0.2.5.

If you need a specific snapshot or something for 4*A10G using 70B AWQ on 0.2.4 vs. 0.2.5 let me know. Or what kind of repro do you need?

@hanzhi713
Copy link
Contributor

Oh I see. Sorry for not reading your issue carefully. vLLM 0.2.5 changed the way the memory is profiled with #2031. While the new profiling method is more accurate, it didn't seem to take account for multiple instances running together or GPU memory usage by other processes.

peak_memory = total_gpu_memory - free_gpu_memory

Here, vLLM basically thinks that any occupied GPU memory is attributed to the current running instance, and thus will calculate the number of available blocks based on that. This may explain the problem when running 2 7b models on one GPU. Not quite sure about the 4xA10G use case though. Is the GPU empty or shared by other processes for that case?

@hanzhi713
Copy link
Contributor

Just tried to write a fix. You can try it out: #2249

@pseudotensor
Copy link
Author

Our biggest issue is clean GPUs four A10G 70b AWQ. Nothing else on GPUs

@Snowdar
Copy link

Snowdar commented Dec 25, 2023

You could change the vllm/worker/worker.py like this (see #@ note):

    def load_model(self):
        free_gpu_memory, total_gpu_memory = torch.cuda.mem_get_info() #@ add
        self.gpu_mem_pre_occupied = total_gpu_memory - free_gpu_memory #@ add
        self.model_runner.load_model()

    @torch.inference_mode()
    def profile_num_available_blocks(
        self,
        block_size: int,
        gpu_memory_utilization: float,
        cpu_swap_space: int,
    ) -> Tuple[int, int]:
        # Profile the memory usage of the model and get the maximum number of
        # cache blocks that can be allocated with the remaining free memory.
        torch.cuda.empty_cache()

        # Execute a forward pass with dummy inputs to profile the memory usage
        # of the model.
        self.model_runner.profile_run()
        # Calculate the number of blocks that can be allocated with the
        # profiled peak memory.
        
        torch.cuda.synchronize()
        free_gpu_memory, total_gpu_memory = torch.cuda.mem_get_info()
        peak_memory = total_gpu_memory - free_gpu_memory

        cache_block_size = CacheEngine.get_cache_block_size(
            block_size, self.model_config, self.parallel_config)
        #@ add the self.gpu_mem_pre_occupied to fix the evaluation
        num_gpu_blocks = int(
            (total_gpu_memory * gpu_memory_utilization - peak_memory + self.gpu_mem_pre_occupied) //
            cache_block_size)
        ...

Then, the total_gpu_memory * gpu_memory_utilization will be real memory you want to allocate to this model according to your max memory of GPU, unaffected by any other models already loaded.

@baptistejamin
Copy link

We are having the exact same issue on our end, cache usage grows and consumes more than the allocated gpu_memory_utilization, even by using enforce-eager.

We had the same problem before with 0.2.1

@ronaldpanape
Copy link

having the same issue on cuda 11.8 and vllm 0.2.5 and 0.2.6

@DaBossCoda
Copy link

same here

@blake-howell-sc
Copy link

Same issue -- starting with vllm 0.2.5

@buptygz
Copy link

buptygz commented Jan 2, 2024

same issue when use vllm 0.2.6

@44670
Copy link
Contributor

44670 commented Jan 2, 2024

same here

3 similar comments
@gordicaleksa
Copy link

same here

@micronetboy
Copy link

same here

@micronetboy
Copy link

same here

@pseudotensor
Copy link
Author

@Snowdar @hanzhi713 et al. I want to be clear again. The primary issue is that even single sharded model across GPUs no longer works. Forget about multiple models per GPU for now.

That is, on AWS 4*A10G, vLLM 0.2.4 and lower work perfectly fine and leave plenty of room without any failure.

However, on 0.2.5+ no matter any settings of gpu utilitization etc., never will llama 70B AWQ model fit on the 4 A10G while before it was perfectly fine (even under heavy use for long periods).

@comaniac
Copy link
Collaborator

comaniac commented Jan 6, 2024

I'm working on v0.2.5 now and found this issue due to the same reason. My case is deploying a 70B BF16 model on 8xA100-40GB GPUs. I inserted logs to worker.py to get the sense about how this error came from:

        torch.cuda.empty_cache()
        # Here shows the free memory is ~22GB per GPU. This is expected given 40-(70GB*2)/8=22.5

        self.model_runner.profile_run()

        # Calculate the number of blocks that can be allocated with the
        # profiled peak memory.
        torch.cuda.synchronize()
        free_gpu_memory, total_gpu_memory = torch.cuda.mem_get_info()
        # Here shows the free memory is only 0.26 GB per GPU. Looks like "profile_run()" consumes all memory
        # even I don't know why for now.

@comaniac
Copy link
Collaborator

comaniac commented Jan 9, 2024

I dived in a bit and here are some findings:

  1. When serving large models (e.g. 70B), the model forward itself introduces memory fragmentation. I logged the free memory after each decoder layer and found that the free memory reduces after every layer. In the case of 70B model, after 80 layers the free memory is only ~2GBs our of 40 GBs per GPU.
  2. Profile run samples top_k=vocab-1. This results in a bit high memory usage when vocab size is large.
  3. GPU cache block estimation does not consider fragmentation. Combining the above 2, the free memory is less than 1GB, which results in a very small batch or even no available GPU blocks to be used for kv cache.

My temporary solution is as follows:

  1. Manually add torch.cuda.empty_cache() in worker.py before the line free_gpu_memory, total_gpu_memory = torch.cuda.mem_get_info(). This removes the impact of fragmentation.
  2. The above change makes OOM possible when actual serving the model, because empty_cache() also removes the impact of intermediate tensors when running forward pass. As a result, tuning the --gpu-memory-utilization becomes more important, as we have to use it to cover the forward intermediate tensors. Here are my testing results with different util values:
    • 0.8: 2828 = 45248 tokens
    • 0.9: 3644 = 58304 tokens
    • 1.0: OOM

@pseudotensor
Copy link
Author

pseudotensor commented Jan 16, 2024

Yet another version of this problem is that 01-ai/Yi-34B-Chat used to work perfectly fine on 4*H100 80GB when run like:

python -m vllm.entrypoints.openai.api_server --port=5000 --host=0.0.0.0 --model 01-ai/Yi-34B-Chat --seed 1234 --tensor-parallel-size=4 --trust-remote-code

But now it doesn't since 0.2.5+ including 0.2.7. Get instead:

INFO 01-16 14:40:02 api_server.py:750] args: Namespace(host='0.0.0.0', port=5000, allow_credentials=False, allowed_origins=['*'], allowed_methods=['*'], allowed_headers=['*'], served_model_name=None, ch>
2024-01-16 14:40:04,623 INFO worker.py:1673 -- Started a local Ray instance.
INFO 01-16 14:40:06 llm_engine.py:70] Initializing an LLM engine with config: model='01-ai/Yi-34B-Chat', tokenizer='01-ai/Yi-34B-Chat', tokenizer_mode=auto, revision=None, tokenizer_revision=None, trust>
INFO 01-16 14:41:00 llm_engine.py:294] # GPU blocks: 0, # CPU blocks: 4369
Traceback (most recent call last):
  File "/home/fsuser/miniconda3/envs/vllm/lib/python3.10/runpy.py", line 196, in _run_module_as_main
    return _run_code(code, main_globals, None,
  File "/home/fsuser/miniconda3/envs/vllm/lib/python3.10/runpy.py", line 86, in _run_code
    exec(code, run_globals)
  File "/home/fsuser/miniconda3/envs/vllm/lib/python3.10/site-packages/vllm/entrypoints/openai/api_server.py", line 760, in <module>
    engine = AsyncLLMEngine.from_engine_args(engine_args)
  File "/home/fsuser/miniconda3/envs/vllm/lib/python3.10/site-packages/vllm/engine/async_llm_engine.py", line 544, in from_engine_args
    engine = cls(parallel_config.worker_use_ray,
  File "/home/fsuser/miniconda3/envs/vllm/lib/python3.10/site-packages/vllm/engine/async_llm_engine.py", line 274, in __init__
    self.engine = self._init_engine(*args, **kwargs)
  File "/home/fsuser/miniconda3/envs/vllm/lib/python3.10/site-packages/vllm/engine/async_llm_engine.py", line 319, in _init_engine
    return engine_class(*args, **kwargs)
  File "/home/fsuser/miniconda3/envs/vllm/lib/python3.10/site-packages/vllm/engine/llm_engine.py", line 114, in __init__
    self._init_cache()
  File "/home/fsuser/miniconda3/envs/vllm/lib/python3.10/site-packages/vllm/engine/llm_engine.py", line 298, in _init_cache
    raise ValueError("No available memory for the cache blocks. "
ValueError: No available memory for the cache blocks. Try increasing `gpu_memory_utilization` when initializing the engine.

When can we expect a fix? It seems a pretty serious bug.

BTW, curiously, I ran the same exact command a second time (both times nothing on the GPUs) and second time didn't hit the error. So maybe there is a race in the memory size detection in vLLM.

@7flash
Copy link

7flash commented Jan 19, 2024

I am trying to run this command as given in docs

python3 -u -m vllm.entrypoints.openai.api_server        --host 0.0.0.0        --model mistralai/Mistral-7B-Instruct-v0.2

It gives me an error

  File "/usr/local/lib/python3.10/dist-packages/vllm/engine/llm_engine.py", line 279, in _init_cache
    raise ValueError("No available memory for the cache blocks. "
ValueError: No available memory for the cache blocks. Try increasing `gpu_memory_utilization` when initializing the engine.

What should I do? I am running a runpod with 1x RTX 4000 Ada

@7flash
Copy link

7flash commented Jan 22, 2024

I am trying to run this command as given in docs

python3 -u -m vllm.entrypoints.openai.api_server        --host 0.0.0.0        --model mistralai/Mistral-7B-Instruct-v0.2

It gives me an error

  File "/usr/local/lib/python3.10/dist-packages/vllm/engine/llm_engine.py", line 279, in _init_cache
    raise ValueError("No available memory for the cache blocks. "
ValueError: No available memory for the cache blocks. Try increasing `gpu_memory_utilization` when initializing the engine.

What should I do? I am running a runpod with 1x RTX 4000 Ada

I have upgraded to 1x A100 and now passing --gpu_memory_utilization 0.8 param, but still same error

@7flash
Copy link

7flash commented Jan 22, 2024

The issue was resolved by adding --tensor-parallel-size 1

The reason it helped, because I am running runpod instance which as I understand, gives me access only to requested GPUs attached to physical machine.

@ZiyueHuang
Copy link

I also encountered this problem (i.e., OOM, or too few KV cache blocks) on 70B LLM with v0.2.7 and dived in a bit. Here are my findings.

My dev environment: 8 A800 GPUs machine with CUDA 11.3.

Working Solution: Use peak_memory = torch.cuda.max_memory_allocated() in worker.py (basically revert #2031).

Another Working Solution: Update to torch==2.1.2.

Analysis: There are evidences of more memory fragmentation when tp > 1, see here and here. Seems that the record_stream (called for NCCL communication) makes the cached memory blocks of activations cannot be reused, so the memory consumption of one forward grows substantially. Setting TORCH_NCCL_AVOID_RECORD_STREAMS=1 can fix this problem, by stashing the references to the related memory blocks and do proper synchronization, without calling record_stream. This environment variable is already set in vLLM-v0.2.7, but the PyTorch version on my dev machine is 2.0.1, which has not been introduced TORCH_NCCL_AVOID_RECORD_STREAMS. Updating to torch==2.1.2 solves the problem.

@pseudotensor
Copy link
Author

@ZiyueHuang I have pytorch 2.1.2 and vllm 0.2.7 and this wasn't solved by that.

@ZiyueHuang
Copy link

@ZiyueHuang I have pytorch 2.1.2 and vllm 0.2.7 and this wasn't solved by that.

@pseudotensor How about trying reverting #2031?

@pseudotensor
Copy link
Author

@ZiyueHuang Yes, I'm trying that now.

@pseudotensor pseudotensor reopened this Jan 27, 2024
@pseudotensor
Copy link
Author

This issue was closed automatically by github, that was not correct.

@pseudotensor
Copy link
Author

Reverting avoided the title message, but it went GPU OOM unlike 0.2.4 with same long-context query. FYI @sh1ng

@XBeg9
Copy link

XBeg9 commented Jan 31, 2024

I've upgraded to latest ray/ray-llm and now having this issue. Is there a known hot fix? gpu_memory_utilization doesn't help at all.
Screenshot 2024-01-31 at 7 45 33 AM

Trying to run TheBloke/70b with this config

engine_config:
  model_id: TheBloke/Llama-2-70B-chat-AWQ
  hf_model_id: TheBloke/Llama-2-70B-chat-AWQ
  type: VLLMEngine
  engine_kwargs:
    quantization: awq
    max_num_batched_tokens: 32768
    max_num_seqs: 256
    gpu_memory_utilization: 0.90
  max_total_tokens: 4096

@sh1ng
Copy link
Contributor

sh1ng commented Jan 31, 2024

FYI @pseudotensor

I've tested the memory footprint of 0.2.4 and 0.2.7 and this is my finding:

  • I'm sure that Fix peak memory profiling #2031 is correct and should be there.
    |<-------------------------------------total GPU memory---------------------------------------->|
    |<---Allocated by torch allocator--->|<--Allocated by NCCL, cuBLAS, etc-->|<--free GPU memory-->|
    
    before Fix peak memory profiling #2031 non-torch-related allocations were completely ignored.
  • Fix peak memory profiling #2031 just computes it correctly. We still need to fix peak memory consumption in case of multiple memory-consuming processes.
  • Running 0.2.4 and 0.2.7 consume exactly the same amount of memory(measuring by old and new way) by a model.
  • Changing nccl version doesn't change memory consumption significantly (~10MB).
  • When using --enforce-eager the memory consumption is a little bit lower.
  • Using PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True helps and also makes execution w and w\o --enforce-eager identical. I'm not sure how stable it's as it's marked as experimental.
  • I believe that by carefully tuning gpu_memory_utilization we can get the original behavior as I don't see an increase in memory consumption.
  • It's better to fully dedicate a sub-set of GPUs to a single vllm-model and don't share a GPU across multiple models as NCCL's, cuBLAS's, torch's overhead will multiply.

@pseudotensor
Copy link
Author

pseudotensor commented Jan 31, 2024

@sh1ng Try 4*A10G's with 70B AWQ, it simply doesn't work, but on 0.2.4 works perfectly fine.

@pseudotensor
Copy link
Author

--enforce_eager works to solve this issue for 4*A10G 70b AWQ.

The issue with the 2 7Bs was how the memory was defined, total vs. free, that changed in vLLM. Free is a bit hard to manage if bringing up 2 7Bs at same time, not well defined. So we have to wait and bring up other with 0.9 not the fraction of total (e.g. 0.4-0.5).

@oandreeva-nv
Copy link

enforce_eager didn't do a trick for me. I have 2 vllm engines, each hosting a small model, both fit on the same GPU. Note, I deploy them on Triton server. One loads successfully, the other fails with this issue.

@pseudotensor
Copy link
Author

@oandreeva-nv I also explained what helped me for same GPU issue above. Did you try that? vLLM changed behavior from total to free memory, so it's confusing that first (say) 7b model should be 0.4 and second should be 0.9.

@oandreeva-nv
Copy link

@pseudotensor, yes I tried changing gpu_memory_utilization flag and set different values, i.e. first facebook/opt-125m model has 0.4 and another instance of facebook/opt-125m has 0.8, but still no success

@oandreeva-nv
Copy link

After some fine-tunning with gpu_memory_utilization I was able to make everything work. Thanks everyone for sharing their findings!

@WoosukKwon WoosukKwon reopened this Feb 14, 2024
@WoosukKwon WoosukKwon added the bug Something isn't working label Feb 14, 2024
@okwinds
Copy link

okwinds commented Mar 15, 2024

Has this issue been resolved? I've encountered the same problem as well.
I encountered this issue when loading a 7B model on a graphics card with 24GB of memory.
My version of VLLM is 0.3.3.

@hahmad2008
Copy link

@XBeg9 , I have a question please, for the engine arg you used,
max_num_batched_tokens: 32768
max_num_seqs: 256

does that mean you can handle within the same batch 32768 tokens? and the output number of sequence per input is 256?

@nicobieber99
Copy link

nicobieber99 commented Mar 26, 2024

Hi,
I have run into the same issue.
I have vllm==0.3.3

I don't know if there is anyone still active of this issue or has found any way to resolve it.
I have tried lowering gpu_memory_utilization all the way down to 0.4 but that seemingly still does not work and gave me:
ValueError: No available memory for the cache blocks. Try increasing gpu_memory_utilization when initializing the engine.

I am running on 1 NVIDIA T4 GPU, with 40 GB memory.
Trying to run the llama2-7b-huggingface
I have downloaded the model but that should hardly be enough to use enough memory to limit my use of vllm that much...

A side note is when it tries to run it says:
GPU blocks: 0, CPU blocks: 512

Help!

@bmcfeeters
Copy link

@nicobieber99 I managed to get Llama-2 7B working on an NVIDIA A2 GPU (16GB memory) today by setting these parameters. I'm using OpenShift AI with a custom vLLM serving runtime.

        - --gpu-memory-utilization
        - "0.98"
        - --enforce-eager
        - --max-model-len
        - "2048"
      image: docker.io/vllm/vllm-openai:v0.3.3
      env:
        - name: NUMBA_CACHE_DIR
          value: /tmp

I went the other way with the gpu memory utilization setting after seeing this post

I reduced the max model length (context) from the default for LLama-2 7b of 4096 down to 2048 after I got this error:
ValueError: The model's max seq len (4096) is larger than the maximum number of tokens that can be stored in KV cache (2256). Try increasing gpu_memory_utilization or decreasing max_model_len when initializing the engine.

Not ideal that I had to reduce the context but it is at least working now and may be ok for short Q&A stuff.

@hahmad2008
Copy link

hahmad2008 commented Apr 21, 2024

Any update on this issue? I am trying to serve two models (tinyllama 1b) on the same GPU cluster, so I use @serve.deployment(ray_actor_options={"num_gpus": 0.4},) and
ENGINE_ARGS = AsyncEngineArgs(
gpu_memory_utilization= 0.4,
model=model_path,
max_model_len=128,
enforce_eager=True,
)

I can only start model on a replica with 40% of GPU and model reserved 10G/22G (GPU RAM). However when I tried ot start the second model I got this error, although it created another replica and the usage of the cluster now 0.8/1 from GPU.

024-04-21 15:12:15,310 INFO worker.py:1540 -- Connecting to existing Ray cluster at address: 10.5.8.112:6379...
2024-04-21 15:12:15,317 INFO worker.py:1715 -- Connected to Ray cluster. View the dashboard at 127.0.0.1:8265 
(ServeController pid=11766) INFO 2024-04-21 15:12:15,440 controller 11766 deployment_state.py:1545 - Deploying new version of deployment MyModel in application 'model2'. Setting initial target number of replicas to 1.
(ServeController pid=11766) INFO 2024-04-21 15:12:15,541 controller 11766 deployment_state.py:1829 - Adding 1 replica to deployment MyModel in application 'model2'.
(ServeReplica:model2:MyModel pid=66303) INFO 04-21 15:12:18 llm_engine.py:72] Initializing an LLM engine with config: model='TinyLlama/TinyLlama-1.1B-Chat-v0.1', tokenizer='TinyLlama/TinyLlama-1.1B-Chat-v0.1', tokenizer_mode=auto, revision=None, tokenizer_revision=None, trust_remote_code=False, dtype=torch.float16, max_seq_len=128, download_dir=None, load_format=auto, tensor_parallel_size=1, disable_custom_all_reduce=False, quantization=None, enforce_eager=True, kv_cache_dtype=auto, seed=0)
(ServeReplica:model2:MyModel pid=66303) INFO 04-21 15:12:21 weight_utils.py:164] Using model weights format ['*.safetensors']
(ServeController pid=11766) ERROR 2024-04-21 15:12:24,184 controller 11766 deployment_state.py:658 - Exception in replica 'model2#MyModel#hbNcQm', the replica will be stopped.
(ServeController pid=11766) Traceback (most recent call last):
(ServeController pid=11766)   File "/myenv/lib/python3.9/site-packages/ray/serve/_private/deployment_state.py", line 656, in check_ready
(ServeController pid=11766)     _, self._version = ray.get(self._ready_obj_ref)
(ServeController pid=11766)   File "/myenv/lib/python3.9/site-packages/ray/_private/auto_init_hook.py", line 22, in auto_init_wrapper
(ServeController pid=11766)     return fn(*args, **kwargs)
(ServeController pid=11766)   File "/myenv/lib/python3.9/site-packages/ray/_private/client_mode_hook.py", line 103, in wrapper
(ServeController pid=11766)     return func(*args, **kwargs)
(ServeController pid=11766)   File "/myenv/lib/python3.9/site-packages/ray/_private/worker.py", line 2624, in get
(ServeController pid=11766)     raise value.as_instanceof_cause()
(ServeController pid=11766) ray.exceptions.RayTaskError(RuntimeError): ray::ServeReplica:model2:MyModel.initialize_and_get_metadata() (pid=66303, ip=10.5.8.112, actor_id=e6ee395511fd55b8b5457d7501000000, repr=<ray.serve._private.replica.ServeReplica:model2:MyModel object at 0x7fc28de7a1c0>)
(ServeController pid=11766)   File "/myenv/lib/python3.9/concurrent/futures/_base.py", line 439, in result
(ServeController pid=11766)     return self.__get_result()
(ServeController pid=11766)   File "/myenv/lib/python3.9/concurrent/futures/_base.py", line 391, in __get_result
(ServeController pid=11766)     raise self._exception
(ServeController pid=11766)   File "/myenv/lib/python3.9/site-packages/ray/serve/_private/replica.py", line 455, in initialize_and_get_metadata
(ServeController pid=11766)     raise RuntimeError(traceback.format_exc()) from None
(ServeController pid=11766) RuntimeError: Traceback (most recent call last):
(ServeController pid=11766)   File "/myenv/lib/python3.9/site-packages/ray/serve/_private/replica.py", line 443, in initialize_and_get_metadata
(ServeController pid=11766)     await self._initialize_replica()
(ServeController pid=11766)   File "/myenv/lib/python3.9/site-packages/ray/serve/_private/replica.py", line 182, in initialize_replica
(ServeController pid=11766)     await sync_to_async(_callable.__init__)(*init_args, **init_kwargs)
(ServeController pid=11766)   File "/myenv/lib/python3.9/site-packages/ray/serve/api.py", line 237, in __init__
(ServeController pid=11766)     cls.__init__(self, *args, **kwargs)
(ServeController pid=11766)   File "python_script_serving.py", line 28, in __init__
(ServeController pid=11766)     self.engine = AsyncLLMEngine.from_engine_args(ENGINE_ARGS)
(ServeController pid=11766)   File "/myenv/lib/python3.9/site-packages/vllm/engine/async_llm_engine.py", line 623, in from_engine_args
(ServeController pid=11766)     engine = cls(parallel_config.worker_use_ray,
(ServeController pid=11766)   File "/myenv/lib/python3.9/site-packages/vllm/engine/async_llm_engine.py", line 319, in __init__
(ServeController pid=11766)     self.engine = self._init_engine(*args, **kwargs)
(ServeController pid=11766)   File "/myenv/lib/python3.9/site-packages/vllm/engine/async_llm_engine.py", line 364, in _init_engine
(ServeController pid=11766)     return engine_class(*args, **kwargs)
(ServeController pid=11766)   File "/myenv/lib/python3.9/site-packages/vllm/engine/llm_engine.py", line 114, in __init__
(ServeController pid=11766)     self._init_cache()
(ServeController pid=11766)   File "/myenv/lib/python3.9/site-packages/vllm/engine/llm_engine.py", line 326, in _init_cache
(ServeController pid=11766)     raise ValueError("No available memory for the cache blocks. "
(ServeController pid=11766) ValueError: No available memory for the cache blocks. Try increasing `gpu_memory_utilization` when initializing the engine.
(ServeReplica:model2:MyModel pid=66303) INFO 04-21 15:12:24 llm_engine.py:322] # GPU blocks: 0, # CPU blocks: 11915
(ServeReplica:model2:MyModel pid=66303) sys:1: RuntimeWarning: coroutine 'ingress.<locals>.decorator.<locals>.ASGIIngressWrapper.__del__' was never awaited

@hahmad2008
Copy link

#4242

@pbasov
Copy link

pbasov commented Apr 22, 2024

I have a similar problem with the new nvidia-device-plugin on kubernetes with KServe. I've enabled nvidia MPS which "over-represents" the single GPU and allows scheduling two GPU containers with half the memory. nvidia-smi, however, still reports the full amount of memory.

I assumed --gpu-memory-utilization would save me from OOMing, but then even a 1.3B AWQ 8-bit model fills extra 6GB of VRAM with cache. --num-gpu-blocks-override doesn't seem to do anything.

I need to be able to strictly set memory limits for vLLM for my use-case to work.

Please advise.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests