From 7123c427136969b52799fff5eeb22fa2c7a7164c Mon Sep 17 00:00:00 2001 From: Jon Perry Date: Wed, 7 Aug 2024 13:45:44 -0400 Subject: [PATCH] fix: batch requests to the CreateEmbedding stub If we don't batch the requests we run the risk of erroring when the CreateEmbedding stub returns because the response is too large for our gRPC service to handle. Signed-off-by: Jon Perry --- src/leapfrogai_api/backend/grpc_client.py | 27 ++++++++++++++++------- 1 file changed, 19 insertions(+), 8 deletions(-) diff --git a/src/leapfrogai_api/backend/grpc_client.py b/src/leapfrogai_api/backend/grpc_client.py index 1a3fe07f8..9dbe782de 100644 --- a/src/leapfrogai_api/backend/grpc_client.py +++ b/src/leapfrogai_api/backend/grpc_client.py @@ -1,6 +1,6 @@ """gRPC client for OpenAI models.""" -from typing import Iterator, AsyncGenerator, Any +from typing import Iterator, AsyncGenerator, Any, List import grpc from fastapi.responses import StreamingResponse import leapfrogai_sdk as lfai @@ -120,14 +120,25 @@ async def create_embeddings(model: Model, request: lfai.EmbeddingRequest): """Create embeddings using the specified model.""" async with grpc.aio.insecure_channel(model.backend) as channel: stub = lfai.EmbeddingsServiceStub(channel) - e: lfai.EmbeddingResponse = await stub.CreateEmbedding(request) + embeddings: List[EmbeddingResponseData] = [] + + # Loop through inputs - 500 at a time + for i in range(0, len(request.inputs), 500): + request_embeddings = request.inputs[i : i + 500] + + range_request = lfai.EmbeddingRequest(inputs=request_embeddings) + e: lfai.EmbeddingResponse = await stub.CreateEmbedding(range_request) + if e and e.embeddings is not None: + data = [ + EmbeddingResponseData( + embedding=list(e.embeddings[i].embedding), index=i + ) + for i in range(len(e.embeddings)) + ] + embeddings.extend(data) + return CreateEmbeddingResponse( - data=[ - EmbeddingResponseData( - embedding=list(e.embeddings[i].embedding), index=i - ) - for i in range(len(e.embeddings)) - ], + data=embeddings, model=model.name, usage=Usage(prompt_tokens=0, total_tokens=0), )