-
Notifications
You must be signed in to change notification settings - Fork 28
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat(api): Threads and Message Endpoints (#554)
* Adds OpenAI compliant endpoints for Threads and Messages * Adds supabase migrations for threads and messages * Adds crud operations for threads and messages * Adds integration tests for threads and messages --------- Co-authored-by: Jon Perry <yrrepnoj@gmail.com> Co-authored-by: Gregory Horvath <gphorvath@defenseunicorns.com>
- Loading branch information
1 parent
e910f06
commit 4b69d3c
Showing
12 changed files
with
684 additions
and
54 deletions.
There are no files selected for viewing
64 changes: 64 additions & 0 deletions
64
packages/api/supabase/migrations/20240522141100_threads_runs_messages_steps.sql
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,64 @@ | ||
-- Create a table to store the OpenAI Thread Objects | ||
create table | ||
thread_objects ( | ||
id uuid primary key DEFAULT uuid_generate_v4(), | ||
user_id uuid references auth.users not null, | ||
object text check (object in ('thread')), | ||
created_at bigint default extract(epoch FROM NOW()) NOT NULL, | ||
tool_resources jsonb, | ||
metadata jsonb | ||
); | ||
|
||
-- Create a table to store the OpenAI Message Objects | ||
create table | ||
message_objects ( | ||
id uuid primary key DEFAULT uuid_generate_v4(), | ||
user_id uuid references auth.users not null, | ||
object text check (object in ('thread.message')), | ||
created_at bigint default extract(epoch FROM NOW()) not null, | ||
thread_id uuid references thread_objects (id) on delete cascade not null, | ||
status text, | ||
incomplete_details jsonb, | ||
completed_at bigint, | ||
incomplete_at bigint, | ||
role text, | ||
content jsonb, | ||
assistant_id uuid, -- No foreign key constraint, can be null and doesn't have to refer to an assistant that exists | ||
run_id uuid, -- No foreign key constraint, can be null and doesn't have to refer to a thread that exists | ||
attachments jsonb, | ||
metadata jsonb | ||
); | ||
|
||
-- RLS policies | ||
alter table thread_objects enable row level security; | ||
alter table message_objects enable row level security; | ||
|
||
-- Policies for thread_objects | ||
create policy "Individuals can view their own thread_objects." on thread_objects for | ||
select using (auth.uid() = user_id); | ||
create policy "Individuals can create thread_objects." on thread_objects for | ||
insert with check (auth.uid() = user_id); | ||
create policy "Individuals can update their own thread_objects." on thread_objects for | ||
update using (auth.uid() = user_id); | ||
create policy "Individuals can delete their own thread_objects." on thread_objects for | ||
delete using (auth.uid() = user_id); | ||
|
||
-- Policies for message_objects | ||
create policy "Individuals can view their own message_objects." on message_objects for | ||
select using (auth.uid() = user_id); | ||
create policy "Individuals can create message_objects." on message_objects for | ||
insert with check (auth.uid() = user_id); | ||
create policy "Individuals can update their own message_objects." on message_objects for | ||
update using (auth.uid() = user_id); | ||
create policy "Individuals can delete their own message_objects." on message_objects for | ||
delete using (auth.uid() = user_id); | ||
|
||
-- Indexes for foreign keys for message_objects | ||
CREATE INDEX message_objects_user_id ON message_objects (user_id); | ||
CREATE INDEX message_objects_thread_id ON message_objects (thread_id); | ||
CREATE INDEX message_objects_created_at ON thread_objects (created_at); | ||
|
||
-- Indexes for common filtering and sorting for thread_objects | ||
CREATE INDEX thread_objects_id ON thread_objects (id); | ||
CREATE INDEX thread_objects_user_id ON thread_objects (user_id); | ||
CREATE INDEX thread_objects_created_at ON thread_objects (created_at); |
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
File renamed without changes.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,53 @@ | ||
"""CRUD Operations for Message.""" | ||
|
||
from pydantic import Field | ||
from openai.types.beta.threads import Message | ||
from supabase_py_async import AsyncClient | ||
from leapfrogai_api.data.crud_base import CRUDBase | ||
|
||
|
||
class AuthMessage(Message): | ||
"""A wrapper for the message that includes a user_id for auth""" | ||
|
||
user_id: str = Field(default="") | ||
|
||
|
||
class CRUDMessage(CRUDBase[AuthMessage]): | ||
"""CRUD Operations for message""" | ||
|
||
def __init__(self, db: AsyncClient): | ||
super().__init__(db=db, model=AuthMessage, table_name="message_objects") | ||
|
||
async def create(self, object_: Message) -> Message | None: | ||
"""Create new message.""" | ||
user_id: str = (await self.db.auth.get_user()).user.id | ||
return await super().create( | ||
object_=AuthMessage(user_id=user_id, **object_.model_dump()) | ||
) | ||
|
||
async def get(self, filters: dict | None = None) -> Message | None: | ||
"""Get a message by its ID.""" | ||
return await super().get(filters=filters) | ||
|
||
async def list(self, filters: dict | None = None) -> list[Message] | None: | ||
"""List all messages by thread ID.""" | ||
return await super().list(filters=filters) | ||
|
||
async def update(self, id_: str, object_: Message) -> Message | None: | ||
"""Update a message by its ID.""" | ||
|
||
dict_ = object_.model_dump() | ||
|
||
data, _count = ( | ||
await self.db.table(self.table_name).update(dict_).eq("id", id_).execute() | ||
) | ||
|
||
_, response = data | ||
|
||
if response: | ||
return self.model(**response[0]) | ||
return None | ||
|
||
async def delete(self, filters: dict | None = None) -> bool: | ||
"""Delete a message by its ID and thread ID.""" | ||
return await super().delete(filters) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,44 @@ | ||
"""CRUD Operations for Thread.""" | ||
|
||
from pydantic import Field | ||
from openai.types.beta import Thread | ||
from supabase_py_async import AsyncClient | ||
from leapfrogai_api.data.crud_base import CRUDBase | ||
|
||
|
||
class AuthThread(Thread): | ||
"""A wrapper for the thread that includes a user_id for auth""" | ||
|
||
user_id: str = Field(default="") | ||
|
||
|
||
class CRUDThread(CRUDBase[AuthThread]): | ||
"""CRUD Operations for thread""" | ||
|
||
def __init__(self, db: AsyncClient): | ||
super().__init__(db=db, model=AuthThread, table_name="thread_objects") | ||
|
||
async def create(self, object_: Thread) -> Thread | None: | ||
"""Create new thread.""" | ||
user_id: str = (await self.db.auth.get_user()).user.id | ||
return await super().create( | ||
object_=AuthThread(user_id=user_id, **object_.model_dump()) | ||
) | ||
|
||
async def get(self, filters: dict | None = None) -> Thread | None: | ||
"""Get a vector store by its ID.""" | ||
|
||
return await super().get(filters=filters) | ||
|
||
async def list(self, filters: dict | None = None) -> list[Thread] | None: | ||
"""List all threads.""" | ||
|
||
return await super().list(filters=filters) | ||
|
||
async def update(self, id_: str, object_: Thread) -> Thread | None: | ||
"""Update a thread by its ID.""" | ||
return await super().update(id_=id_, object_=object_) | ||
|
||
async def delete(self, filters: dict | None = None) -> bool: | ||
"""Delete a thread by its ID.""" | ||
return await super().delete(filters=filters) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.