Skip to content

Commit

Permalink
fix: batch requests to the CreateEmbedding stub (#887)
Browse files Browse the repository at this point in the history
If we don't batch the requests we run the risk of erroring when the CreateEmbedding stub returns because the response is too large for our gRPC service to handle.

Signed-off-by: Jon Perry <yrrepnoj@gmail.com>
  • Loading branch information
YrrepNoj committed Aug 9, 2024
1 parent 8cd3011 commit 8ed4328
Showing 1 changed file with 19 additions and 8 deletions.
27 changes: 19 additions & 8 deletions src/leapfrogai_api/backend/grpc_client.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
"""gRPC client for OpenAI models."""

from typing import Iterator, AsyncGenerator, Any
from typing import Iterator, AsyncGenerator, Any, List
import grpc
from fastapi.responses import StreamingResponse
import leapfrogai_sdk as lfai
Expand Down Expand Up @@ -120,14 +120,25 @@ async def create_embeddings(model: Model, request: lfai.EmbeddingRequest):
"""Create embeddings using the specified model."""
async with grpc.aio.insecure_channel(model.backend) as channel:
stub = lfai.EmbeddingsServiceStub(channel)
e: lfai.EmbeddingResponse = await stub.CreateEmbedding(request)
embeddings: List[EmbeddingResponseData] = []

# Loop through inputs - 500 at a time
for i in range(0, len(request.inputs), 500):
request_embeddings = request.inputs[i : i + 500]

range_request = lfai.EmbeddingRequest(inputs=request_embeddings)
e: lfai.EmbeddingResponse = await stub.CreateEmbedding(range_request)
if e and e.embeddings is not None:
data = [
EmbeddingResponseData(
embedding=list(e.embeddings[i].embedding), index=i
)
for i in range(len(e.embeddings))
]
embeddings.extend(data)

return CreateEmbeddingResponse(
data=[
EmbeddingResponseData(
embedding=list(e.embeddings[i].embedding), index=i
)
for i in range(len(e.embeddings))
],
data=embeddings,
model=model.name,
usage=Usage(prompt_tokens=0, total_tokens=0),
)
Expand Down

0 comments on commit 8ed4328

Please sign in to comment.