Skip to content

Commit

Permalink
feat(api): Runs endpoints (#583)
Browse files Browse the repository at this point in the history
* Splits up router into Runs, Messages, Threads, and Run-Steps
* Implemented Runs endpoints
* Adds schema for Runs
* Adds Tests for Runs, Messages, and Threads
* Fixes a few issues from other parts of the code to ensure Runs works correctly
* Remove old RAG deployment from UDS Bundles

---------

Co-authored-by: Gregory Horvath <gphorvath@defenseunicorns.com>
  • Loading branch information
CollectiveUnicorn and gphorvath committed Jun 18, 2024
1 parent c3d4883 commit fecf0f8
Show file tree
Hide file tree
Showing 44 changed files with 2,353 additions and 585 deletions.
2 changes: 1 addition & 1 deletion packages/api/chart/values.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ image:
fsGroup: 65532

supabase:
url: "https://supabase-kong.###ZARF_VAR_HOSTED_DOMAIN###"
url: "http://supabase-kong.leapfrogai.svc.cluster.local:80"

api:
replicas: 1
Expand Down
52 changes: 52 additions & 0 deletions packages/api/supabase/migrations/20240611111500_runs.sql
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
-- Create a table to store the OpenAI Run Objects
create table
run_objects (
id uuid primary key default uuid_generate_v4(),
user_id uuid references auth.users not null,
object text check (object in ('thread.run')),
created_at bigint default extract(epoch FROM NOW()) not null,
thread_id uuid references thread_objects (id) on delete cascade not null,
assistant_id uuid references assistant_objects (id) on delete cascade not null,
status text,
required_action jsonb,
last_error jsonb,
expires_at bigint,
started_at bigint,
cancelled_at bigint,
failed_at bigint,
completed_at bigint,
model text,
instructions text,
tools jsonb,
metadata jsonb,
parallel_tool_calls boolean,
stream boolean,
file_ids uuid[],
incomplete_details jsonb,
usage jsonb,
temperature float,
top_p float,
max_prompt_tokens int,
max_completion_tokens int,
truncation_strategy jsonb,
tool_choice jsonb,
response_format jsonb
);

-- RLS policies
alter table run_objects enable row level security;

-- Policies for run_objects
create policy "Individuals can view their own run_objects." on run_objects for
select using (auth.uid() = user_id);
create policy "Individuals can create run_objects." on run_objects for
insert with check (auth.uid() = user_id);
create policy "Individuals can update their own run_objects." on run_objects for
update using (auth.uid() = user_id);
create policy "Individuals can delete their own run_objects." on run_objects for
delete using (auth.uid() = user_id);

-- Indexes for common filtering and sorting for run_objects
CREATE INDEX run_objects_id ON run_objects (id);
CREATE INDEX run_objects_user_id ON run_objects (user_id);
CREATE INDEX run_objects_created_at ON run_objects (created_at);
2 changes: 1 addition & 1 deletion src/leapfrogai_api/Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -36,4 +36,4 @@ env:
$(call get_jwt_token,"${SUPABASE_URL}/auth/v1/token?grant_type=password")

test-integration:
cd ../../ && python -m pytest tests/integration/api
cd ../../ && python -m pytest tests/integration/api/ -vv -s
107 changes: 107 additions & 0 deletions src/leapfrogai_api/backend/converters.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
"""Converters for the LeapfrogAI API"""

from typing import Iterable
from openai.types.beta import AssistantStreamEvent
from openai.types.beta.assistant_stream_event import ThreadMessageDelta
from openai.types.beta.threads.file_citation_annotation import FileCitation
from openai.types.beta.threads import (
MessageContentPartParam,
MessageContent,
TextContentBlock,
Text,
Message,
MessageDeltaEvent,
MessageDelta,
TextDeltaBlock,
TextDelta,
FileCitationAnnotation,
)


def from_assistant_stream_event_to_str(stream_event: AssistantStreamEvent):
return f"event: {stream_event.event}\ndata: {stream_event.data.model_dump_json()}"


def from_content_param_to_content(
thread_message_content: str | Iterable[MessageContentPartParam],
) -> MessageContent:
"""Converts messages from MessageContentPartParam to MessageContent"""
if isinstance(thread_message_content, str):
return TextContentBlock(
text=Text(annotations=[], value=thread_message_content),
type="text",
)
else:
result: str = ""

for message_content_part in thread_message_content:
if isinstance(text := message_content_part.get("text"), str):
result += text

return TextContentBlock(
text=Text(annotations=[], value=result),
type="text",
)


def from_text_to_message(text: str, file_ids: list[str]) -> Message:
all_file_ids: str = ""

for file_id in file_ids:
all_file_ids += f" [{file_id}]"

message_content: TextContentBlock = TextContentBlock(
text=Text(
annotations=[
FileCitationAnnotation(
text=f"[{file_id}]",
file_citation=FileCitation(file_id=file_id, quote=""),
start_index=0,
end_index=0,
type="file_citation",
)
for file_id in file_ids
],
value=text + all_file_ids,
),
type="text",
)

new_message = Message(
id="",
created_at=0,
object="thread.message",
status="in_progress",
thread_id="",
content=[message_content],
role="assistant",
metadata=None,
)

return new_message


async def from_chat_completion_choice_to_thread_message_delta(
index, random_uuid, streaming_response
) -> ThreadMessageDelta:
thread_message_event: ThreadMessageDelta = ThreadMessageDelta(
data=MessageDeltaEvent(
id=str(random_uuid),
delta=MessageDelta(
content=[
TextDeltaBlock(
index=index,
type="text",
text=TextDelta(
annotations=[],
value=streaming_response.choices[0].chat_item.content,
),
)
],
role="assistant",
),
object="thread.message.delta",
),
event="thread.message.delta",
)
return thread_message_event
21 changes: 20 additions & 1 deletion src/leapfrogai_api/backend/grpc_client.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
"""gRPC client for OpenAI models."""

from typing import Iterator
from typing import Iterator, AsyncGenerator, Any
import grpc
from fastapi.responses import StreamingResponse
import leapfrogai_sdk as lfai
Expand All @@ -16,6 +16,9 @@
EmbeddingResponseData,
Usage,
)
from leapfrogai_sdk.chat.chat_pb2 import (
ChatCompletionResponse as ProtobufChatCompletionResponse,
)
from leapfrogai_api.utils.config import Model


Expand Down Expand Up @@ -66,6 +69,22 @@ async def stream_chat_completion(model: Model, request: lfai.ChatCompletionReque
return StreamingResponse(recv_chat(stream), media_type="text/event-stream")


async def stream_chat_completion_raw(
model: Model, request: lfai.ChatCompletionRequest
) -> AsyncGenerator[ProtobufChatCompletionResponse, Any]:
"""Stream chat completion using the specified model."""
async with grpc.aio.insecure_channel(model.backend) as channel:
stub = lfai.ChatCompletionStreamServiceStub(channel)
stream: grpc.aio.UnaryStreamCall[
lfai.ChatCompletionRequest, lfai.ChatCompletionResponse
] = stub.ChatCompleteStream(request)

await stream.wait_for_connection()

async for response in stream:
yield response


# TODO: Clean up completion() and stream_completion() to reduce code duplication
async def chat_completion(model: Model, request: lfai.ChatCompletionRequest):
"""Complete chat using the specified model."""
Expand Down
4 changes: 2 additions & 2 deletions src/leapfrogai_api/backend/helpers.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
"""Helper functions for the OpenAI backend."""

from typing import BinaryIO, Iterator
from typing import BinaryIO, Iterator, AsyncGenerator, Any
import grpc
import leapfrogai_sdk as lfai
from leapfrogai_api.backend.types import (
Expand Down Expand Up @@ -48,7 +48,7 @@ async def recv_chat(
stream: grpc.aio.UnaryStreamCall[
lfai.ChatCompletionRequest, lfai.ChatCompletionResponse
],
):
) -> AsyncGenerator[str, Any]:
"""Generator that yields chat completion responses as Server-Sent Events."""
async for c in stream:
yield (
Expand Down
5 changes: 4 additions & 1 deletion src/leapfrogai_api/backend/rag/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from supabase_py_async import AsyncClient
from leapfrogai_api.backend.rag.index import IndexingService
from postgrest.base_request_builder import SingleAPIResponse


class QueryService:
Expand All @@ -11,7 +12,9 @@ def __init__(self, db: AsyncClient) -> None:
"""Initializes the QueryService."""
self.db = db

async def query_rag(self, query: str, vector_store_id: str, k: int = 5):
async def query_rag(
self, query: str, vector_store_id: str, k: int = 5
) -> SingleAPIResponse:
"""
Query the Vector Store.
Expand Down
95 changes: 60 additions & 35 deletions src/leapfrogai_api/backend/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,24 @@
import datetime
from enum import Enum
from typing import Literal
from pydantic import BaseModel, Field

from fastapi import UploadFile, Form, File
from openai.types.beta.vector_store import ExpiresAfter
from openai.types import FileObject
from openai.types.beta import VectorStore
from openai.types.beta import Assistant, AssistantTool
from openai.types.beta.threads import Message, MessageContent, TextContentBlock, Text
from openai.types.beta.threads.message import Attachment
from openai.types.beta.assistant import ToolResources
from openai.types.beta import VectorStore
from openai.types.beta.assistant import (
ToolResources as BetaAssistantToolResources,
ToolResourcesFileSearch,
)
from openai.types.beta.assistant_tool import FileSearchTool
from openai.types.beta.thread import ToolResources as BetaThreadToolResources
from openai.types.beta.thread_create_params import (
ToolResourcesFileSearchVectorStoreChunkingStrategy,
ToolResourcesFileSearchVectorStoreChunkingStrategyAuto,
)
from openai.types.beta.threads.text_content_block_param import TextContentBlockParam
from openai.types.beta.vector_store import ExpiresAfter
from pydantic import BaseModel, Field


##########
Expand Down Expand Up @@ -101,8 +110,8 @@ class ChatFunction(BaseModel):
class ChatMessage(BaseModel):
"""Message object for chat completion."""

role: str
content: str
role: Literal["user", "assistant", "system", "function"]
content: str | list[TextContentBlockParam]


class ChatDelta(BaseModel):
Expand Down Expand Up @@ -259,16 +268,29 @@ class ListFilesResponse(BaseModel):
class CreateAssistantRequest(BaseModel):
"""Request object for creating an assistant."""

model: str = "mistral"
name: str | None = "Froggy Assistant"
description: str | None = "A helpful assistant."
instructions: str | None = "You are a helpful assistant."
tools: list[AssistantTool] | None = [] # This is all we support right now
tool_resources: ToolResources | None = ToolResources()
metadata: dict | None = Field(default=None, examples=[{}])
temperature: float | None = 1.0
top_p: float | None = 1.0
response_format: Literal["auto"] | None = "auto" # This is all we support right now
model: str = Field(default="llama-cpp-python", examples=["llama-cpp-python"])
name: str | None = Field(default=None, examples=["Froggy Assistant"])
description: str | None = Field(default=None, examples=["A helpful assistant."])
instructions: str | None = Field(
default=None, examples=["You are a helpful assistant."]
)
tools: list[AssistantTool] | None = Field(
default=None, examples=[[FileSearchTool(type="file_search")]]
)
tool_resources: BetaAssistantToolResources | None = Field(
default=None,
examples=[
BetaAssistantToolResources(
file_search=ToolResourcesFileSearch(vector_store_ids=[])
)
],
)
metadata: dict | None = Field(default={}, examples=[{}])
temperature: float | None = Field(default=None, examples=[1.0])
top_p: float | None = Field(default=None, examples=[1.0])
response_format: Literal["auto"] | None = Field(
default=None, examples=["auto"]
) # This is all we support right now


class ModifyAssistantRequest(CreateAssistantRequest):
Expand Down Expand Up @@ -304,6 +326,21 @@ class VectorStoreStatus(Enum):
COMPLETED = "completed"


class CreateVectorStoreFileRequest(BaseModel):
"""Request object for creating a vector store file."""

chunking_strategy: ToolResourcesFileSearchVectorStoreChunkingStrategy | None = (
Field(
default=None,
examples=[
ToolResourcesFileSearchVectorStoreChunkingStrategyAuto(type="auto")
],
)
)

file_id: str = Field(default="", examples=[""])


class CreateVectorStoreRequest(BaseModel):
"""Request object for creating a vector store."""

Expand Down Expand Up @@ -371,30 +408,18 @@ class ListVectorStoresResponse(BaseModel):
################


class CreateThreadRequest(BaseModel):
"""Request object for creating a thread."""
class ModifyRunRequest(BaseModel):
"""Request object for modifying a run."""

messages: list[Message] | None = Field(default=None, examples=[None])
tool_resources: ToolResources | None = Field(default=None, examples=[None])
metadata: dict | None = Field(default=None, examples=[{}])
metadata: dict[str, str] | None = Field(default=None, examples=[{}])


class ModifyThreadRequest(BaseModel):
"""Request object for modifying a thread."""

tool_resources: ToolResources | None = Field(default=None, examples=[None])
metadata: dict | None = Field(default=None, examples=[{}])


class CreateMessageRequest(BaseModel):
"""Request object for creating a message."""

role: Literal["user", "assistant"] = Field(default="user")
content: list[MessageContent] = Field(
default=[TextContentBlock(text=Text(value="", annotations=[]), type="text")],
examples=[[TextContentBlock(text=Text(value="", annotations=[]), type="text")]],
tool_resources: BetaThreadToolResources | None = Field(
default=None, examples=[None]
)
attachments: list[Attachment] | None = Field(default=None, examples=[None])
metadata: dict | None = Field(default=None, examples=[{}])


Expand Down
Loading

0 comments on commit fecf0f8

Please sign in to comment.