Skip to content

Commit

Permalink
add client using OpenAI library
Browse files Browse the repository at this point in the history
  • Loading branch information
ks6088ts committed Jun 23, 2024
1 parent e5e3cc3 commit 1d2c585
Show file tree
Hide file tree
Showing 3 changed files with 113 additions and 40 deletions.
70 changes: 31 additions & 39 deletions backend/internals/azure_openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,7 @@
from collections.abc import AsyncIterable
from logging import getLogger

from langchain.prompts import ChatPromptTemplate, HumanMessagePromptTemplate
from langchain_core.messages import HumanMessage, SystemMessage
from langchain_openai import AzureChatOpenAI
from openai import AzureOpenAI

from backend.settings.azure_openai import Settings

Expand All @@ -15,40 +13,39 @@ class Client:
def __init__(self, settings: Settings) -> None:
self.settings = settings

def get_client(self) -> AzureChatOpenAI:
return AzureChatOpenAI(
def get_client(self) -> AzureOpenAI:
return AzureOpenAI(
api_key=self.settings.azure_openai_api_key,
api_version=self.settings.azure_openai_api_version,
azure_endpoint=self.settings.azure_openai_endpoint,
azure_deployment=self.settings.azure_openai_gpt_model,
)

def create_chat_completions(
self,
content: str,
) -> str:
response = self.get_client().invoke(
[
HumanMessage(
content=content,
),
]
response = self.get_client().chat.completions.create(
model=self.settings.azure_openai_gpt_model,
messages=[
{"role": "user", "content": content},
],
stream=False,
)
logger.info(response)
return response.content
return response.choices[0].message.content

async def create_chat_completions_stream(
self,
content: str,
) -> AsyncIterable[str]:
llm = self.get_client()
messages = [HumanMessagePromptTemplate.from_template(template="{message}")]
prompt = ChatPromptTemplate.from_messages(messages)
chain = prompt | llm
res = chain.astream({"message": content})
async for msg in res:
logger.info(msg)
yield msg.content
_ = self.get_client().chat.completions.create(
model=self.settings.azure_openai_gpt_model,
messages=[
{"role": "user", "content": content},
],
stream=True,
)
assert False, "Yet to be implemented."

def create_chat_completions_with_vision(
self,
Expand All @@ -58,24 +55,19 @@ def create_chat_completions_with_vision(
) -> str:
encoded_image = b64encode(image).decode("ascii")

response = self.get_client().invoke(
[
SystemMessage(
content=system_prompt,
),
HumanMessage(
content=[
{
"type": "text",
"text": user_prompt,
},
{
"type": "image_url",
"image_url": {"url": f"data:image/jpeg;base64,{encoded_image}"},
},
response = self.get_client().chat.completions.create(
model=self.settings.azure_openai_gpt_model,
messages=[
{"role": "system", "content": system_prompt},
{
"role": "user",
"content": [
{"type": "text", "text": user_prompt},
{"type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,{encoded_image}"}},
],
),
]
},
],
stream=False,
)
logger.info(response)
return response.content
return response.choices[0].message.content
81 changes: 81 additions & 0 deletions backend/internals/azure_openai_langchain.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
from base64 import b64encode
from collections.abc import AsyncIterable
from logging import getLogger

from langchain.prompts import ChatPromptTemplate, HumanMessagePromptTemplate
from langchain_core.messages import HumanMessage, SystemMessage
from langchain_openai import AzureChatOpenAI

from backend.settings.azure_openai import Settings

logger = getLogger(__name__)


class Client:
def __init__(self, settings: Settings) -> None:
self.settings = settings

def get_client(self) -> AzureChatOpenAI:
return AzureChatOpenAI(
api_key=self.settings.azure_openai_api_key,
api_version=self.settings.azure_openai_api_version,
azure_endpoint=self.settings.azure_openai_endpoint,
azure_deployment=self.settings.azure_openai_gpt_model,
)

def create_chat_completions(
self,
content: str,
) -> str:
response = self.get_client().invoke(
[
HumanMessage(
content=content,
),
]
)
logger.info(response)
return response.content

async def create_chat_completions_stream(
self,
content: str,
) -> AsyncIterable[str]:
llm = self.get_client()
messages = [HumanMessagePromptTemplate.from_template(template="{message}")]
prompt = ChatPromptTemplate.from_messages(messages)
chain = prompt | llm
res = chain.astream({"message": content})
async for msg in res:
logger.info(msg)
yield msg.content

def create_chat_completions_with_vision(
self,
system_prompt: str,
user_prompt: str,
image: bytes,
) -> str:
encoded_image = b64encode(image).decode("ascii")

response = self.get_client().invoke(
[
SystemMessage(
content=system_prompt,
),
HumanMessage(
content=[
{
"type": "text",
"text": user_prompt,
},
{
"type": "image_url",
"image_url": {"url": f"data:image/jpeg;base64,{encoded_image}"},
},
],
),
]
)
logger.info(response)
return response.content
2 changes: 1 addition & 1 deletion backend/routers/azure_openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from fastapi import APIRouter, UploadFile
from fastapi.responses import StreamingResponse

from backend.internals.azure_openai import Client
from backend.internals.azure_openai_langchain import Client
from backend.schemas import azure_openai as azure_openai_schemas
from backend.settings.azure_openai import Settings

Expand Down

0 comments on commit 1d2c585

Please sign in to comment.