Skip to content

Commit

Permalink
feat(api): Threads and Message Endpoints (#554)
Browse files Browse the repository at this point in the history
* Adds OpenAI compliant endpoints for Threads and Messages
* Adds supabase migrations for threads and messages
* Adds crud operations for threads and messages
* Adds integration tests for threads and messages

---------

Co-authored-by: Jon Perry <yrrepnoj@gmail.com>
Co-authored-by: Gregory Horvath <gphorvath@defenseunicorns.com>
  • Loading branch information
3 people committed Jun 12, 2024
1 parent e910f06 commit 4b69d3c
Show file tree
Hide file tree
Showing 12 changed files with 684 additions and 54 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
-- Create a table to store the OpenAI Thread Objects
create table
thread_objects (
id uuid primary key DEFAULT uuid_generate_v4(),
user_id uuid references auth.users not null,
object text check (object in ('thread')),
created_at bigint default extract(epoch FROM NOW()) NOT NULL,
tool_resources jsonb,
metadata jsonb
);

-- Create a table to store the OpenAI Message Objects
create table
message_objects (
id uuid primary key DEFAULT uuid_generate_v4(),
user_id uuid references auth.users not null,
object text check (object in ('thread.message')),
created_at bigint default extract(epoch FROM NOW()) not null,
thread_id uuid references thread_objects (id) on delete cascade not null,
status text,
incomplete_details jsonb,
completed_at bigint,
incomplete_at bigint,
role text,
content jsonb,
assistant_id uuid, -- No foreign key constraint, can be null and doesn't have to refer to an assistant that exists
run_id uuid, -- No foreign key constraint, can be null and doesn't have to refer to a thread that exists
attachments jsonb,
metadata jsonb
);

-- RLS policies
alter table thread_objects enable row level security;
alter table message_objects enable row level security;

-- Policies for thread_objects
create policy "Individuals can view their own thread_objects." on thread_objects for
select using (auth.uid() = user_id);
create policy "Individuals can create thread_objects." on thread_objects for
insert with check (auth.uid() = user_id);
create policy "Individuals can update their own thread_objects." on thread_objects for
update using (auth.uid() = user_id);
create policy "Individuals can delete their own thread_objects." on thread_objects for
delete using (auth.uid() = user_id);

-- Policies for message_objects
create policy "Individuals can view their own message_objects." on message_objects for
select using (auth.uid() = user_id);
create policy "Individuals can create message_objects." on message_objects for
insert with check (auth.uid() = user_id);
create policy "Individuals can update their own message_objects." on message_objects for
update using (auth.uid() = user_id);
create policy "Individuals can delete their own message_objects." on message_objects for
delete using (auth.uid() = user_id);

-- Indexes for foreign keys for message_objects
CREATE INDEX message_objects_user_id ON message_objects (user_id);
CREATE INDEX message_objects_thread_id ON message_objects (thread_id);
CREATE INDEX message_objects_created_at ON thread_objects (created_at);

-- Indexes for common filtering and sorting for thread_objects
CREATE INDEX thread_objects_id ON thread_objects (id);
CREATE INDEX thread_objects_user_id ON thread_objects (user_id);
CREATE INDEX thread_objects_created_at ON thread_objects (created_at);
1 change: 0 additions & 1 deletion src/leapfrogai_api/.gitignore

This file was deleted.

17 changes: 9 additions & 8 deletions src/leapfrogai_api/Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,6 @@ install:
dev:
python -m uvicorn main:app --port 3000 --reload --log-level info

test-integration:
cd ../../ && python -m pytest tests/integration/api

define get_jwt_token
echo "Getting JWT token from ${SUPABASE_URL}..."; \
TOKEN_RESPONSE=$$(curl -s -X POST $(1) \
Expand All @@ -23,16 +20,20 @@ define get_jwt_token
echo "Extracting token from $(TOKEN_RESPONSE)"; \
JWT=$$(echo $$TOKEN_RESPONSE | grep -oP '(?<="access_token":")[^"]*'); \
echo -n "$$JWT" | xclip -selection clipboard; \
echo "export SUPABASE_USER_JWT=$$JWT" > .jwt; \
echo "SUPABASE_USER_JWT=$$JWT" > .env; \
echo "SUPABASE_URL=$$SUPABASE_URL" >> .env; \
echo "SUPABASE_ANON_KEY=$$SUPABASE_ANON_KEY" >> .env; \
echo "DONE - JWT token copied to clipboard"
endef

supabase-user:
@read -s -p "Enter your Supabase password: " SUPABASE_PASS; echo; \
user:
@read -s -p "Enter a new DEV API password: " SUPABASE_PASS; echo; \
echo "Creating new supabase user..."; \
$(call get_jwt_token,"${SUPABASE_URL}/auth/v1/signup")

supabase-jwt-token:
@read -s -p "Enter your Supabase password: " SUPABASE_PASS; echo; \
env:
@read -s -p "Enter your DEV API password: " SUPABASE_PASS; echo; \
$(call get_jwt_token,"${SUPABASE_URL}/auth/v1/token?grant_type=password")

test-integration:
cd ../../ && python -m pytest tests/integration/api
17 changes: 8 additions & 9 deletions src/leapfrogai_api/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -22,16 +22,17 @@ A mostly OpenAI compliant API surface.
supabase status # to check status and see your keys
```

### Session Authentication

3. Create a local api user
```bash
make supabase-user
make user
```

### Session Authentication

4. Create a JWT token
```bash
make supabase-jwt-token
make jwt
source .env
```
This will copy the JWT token to your clipboard.

Expand All @@ -53,13 +54,11 @@ The integration tests serve to identify any mismatches between components:
Integration tests require a Supabase instance and environment variables configured (see [Local Development](#local-development)).

### Authentication
Tests require a JWT token environment variable `SUPABASE_USER_JWT`:

``` bash
make supabase-jwt-token
source .jwt
```
Tests require a JWT token environment variable `SUPABASE_USER_JWT`. See [Session Authentication](#session-authentication) to set this up.

### Running the tests
After obtaining the JWT token, run the following:
```
make test-integration
```
Expand Down
55 changes: 46 additions & 9 deletions src/leapfrogai_api/backend/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@
from openai.types import FileObject
from openai.types.beta import VectorStore
from openai.types.beta import Assistant, AssistantTool
from openai.types.beta.threads import Message, MessageContent, TextContentBlock, Text
from openai.types.beta.threads.message import Attachment
from openai.types.beta.assistant import ToolResources


Expand Down Expand Up @@ -263,7 +265,7 @@ class CreateAssistantRequest(BaseModel):
instructions: str | None = "You are a helpful assistant."
tools: list[AssistantTool] | None = [] # This is all we support right now
tool_resources: ToolResources | None = ToolResources()
metadata: object | None = {}
metadata: dict | None = Field(default=None, examples=[{}])
temperature: float | None = 1.0
top_p: float | None = 1.0
response_format: Literal["auto"] | None = "auto" # This is all we support right now
Expand Down Expand Up @@ -310,7 +312,7 @@ class CreateVectorStoreRequest(BaseModel):
expires_after: ExpiresAfter | None = Field(
default=None, examples=[ExpiresAfter(anchor="last_active_at", days=1)]
)
metadata: dict | None = {}
metadata: dict | None = Field(default=None, examples=[{}])

def add_days_to_timestamp(self, timestamp: int, days: int) -> int:
"""
Expand All @@ -335,9 +337,6 @@ def add_days_to_timestamp(self, timestamp: int, days: int) -> int:

return int(new_timestamp)

def can_expire(self) -> bool:
return self.expires_after is not None

def get_expiry(self, last_active_at: int) -> tuple[ExpiresAfter | None, int | None]:
"""
Return expiration details based on the provided last_active_at unix timestamp
Expand All @@ -348,12 +347,12 @@ def get_expiry(self, last_active_at: int) -> tuple[ExpiresAfter | None, int | No
Returns:
A tuple of when the vector store should expire and the timestamp of the expiry date.
"""
if self.can_expire():
if isinstance(self.expires_after, ExpiresAfter):
return self.expires_after, self.add_days_to_timestamp(
last_active_at, self.expires_after.days if self.expires_after else None
last_active_at, self.expires_after.days
)
else:
return None, None

return None, None # Will not expire


class ModifyVectorStoreRequest(CreateVectorStoreRequest):
Expand All @@ -367,6 +366,44 @@ class ListVectorStoresResponse(BaseModel):
data: list[VectorStore] = []


################
# THREADS, RUNS, MESSAGES
################


class CreateThreadRequest(BaseModel):
"""Request object for creating a thread."""

messages: list[Message] | None = Field(default=None, examples=[None])
tool_resources: ToolResources | None = Field(default=None, examples=[None])
metadata: dict | None = Field(default=None, examples=[{}])


class ModifyThreadRequest(BaseModel):
"""Request object for modifying a thread."""

tool_resources: ToolResources | None = Field(default=None, examples=[None])
metadata: dict | None = Field(default=None, examples=[{}])


class CreateMessageRequest(BaseModel):
"""Request object for creating a message."""

role: Literal["user", "assistant"] = Field(default="user")
content: list[MessageContent] = Field(
default=[TextContentBlock(text=Text(value="", annotations=[]), type="text")],
examples=[[TextContentBlock(text=Text(value="", annotations=[]), type="text")]],
)
attachments: list[Attachment] | None = Field(default=None, examples=[None])
metadata: dict | None = Field(default=None, examples=[{}])


class ModifyMessageRequest(BaseModel):
"""Request object for modifying a message."""

metadata: dict | None = Field(default=None, examples=[{}])


################
# LEAPFROGAI RAG
################
Expand Down
File renamed without changes.
53 changes: 53 additions & 0 deletions src/leapfrogai_api/data/crud_message.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
"""CRUD Operations for Message."""

from pydantic import Field
from openai.types.beta.threads import Message
from supabase_py_async import AsyncClient
from leapfrogai_api.data.crud_base import CRUDBase


class AuthMessage(Message):
"""A wrapper for the message that includes a user_id for auth"""

user_id: str = Field(default="")


class CRUDMessage(CRUDBase[AuthMessage]):
"""CRUD Operations for message"""

def __init__(self, db: AsyncClient):
super().__init__(db=db, model=AuthMessage, table_name="message_objects")

async def create(self, object_: Message) -> Message | None:
"""Create new message."""
user_id: str = (await self.db.auth.get_user()).user.id
return await super().create(
object_=AuthMessage(user_id=user_id, **object_.model_dump())
)

async def get(self, filters: dict | None = None) -> Message | None:
"""Get a message by its ID."""
return await super().get(filters=filters)

async def list(self, filters: dict | None = None) -> list[Message] | None:
"""List all messages by thread ID."""
return await super().list(filters=filters)

async def update(self, id_: str, object_: Message) -> Message | None:
"""Update a message by its ID."""

dict_ = object_.model_dump()

data, _count = (
await self.db.table(self.table_name).update(dict_).eq("id", id_).execute()
)

_, response = data

if response:
return self.model(**response[0])
return None

async def delete(self, filters: dict | None = None) -> bool:
"""Delete a message by its ID and thread ID."""
return await super().delete(filters)
44 changes: 44 additions & 0 deletions src/leapfrogai_api/data/crud_thread.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
"""CRUD Operations for Thread."""

from pydantic import Field
from openai.types.beta import Thread
from supabase_py_async import AsyncClient
from leapfrogai_api.data.crud_base import CRUDBase


class AuthThread(Thread):
"""A wrapper for the thread that includes a user_id for auth"""

user_id: str = Field(default="")


class CRUDThread(CRUDBase[AuthThread]):
"""CRUD Operations for thread"""

def __init__(self, db: AsyncClient):
super().__init__(db=db, model=AuthThread, table_name="thread_objects")

async def create(self, object_: Thread) -> Thread | None:
"""Create new thread."""
user_id: str = (await self.db.auth.get_user()).user.id
return await super().create(
object_=AuthThread(user_id=user_id, **object_.model_dump())
)

async def get(self, filters: dict | None = None) -> Thread | None:
"""Get a vector store by its ID."""

return await super().get(filters=filters)

async def list(self, filters: dict | None = None) -> list[Thread] | None:
"""List all threads."""

return await super().list(filters=filters)

async def update(self, id_: str, object_: Thread) -> Thread | None:
"""Update a thread by its ID."""
return await super().update(id_=id_, object_=object_)

async def delete(self, filters: dict | None = None) -> bool:
"""Delete a thread by its ID."""
return await super().delete(filters=filters)
2 changes: 1 addition & 1 deletion src/leapfrogai_api/routers/openai/assistants.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
ListAssistantsResponse,
ModifyAssistantRequest,
)
from leapfrogai_api.data.crud_assistant_object import CRUDAssistant, FilterAssistant
from leapfrogai_api.data.crud_assistant import CRUDAssistant, FilterAssistant
from leapfrogai_api.routers.supabase_session import Session

router = APIRouter(prefix="/openai/v1/assistants", tags=["openai/assistants"])
Expand Down
Loading

0 comments on commit 4b69d3c

Please sign in to comment.