Skip to content

Commit

Permalink
enable strict mode for openai tool calls
Browse files Browse the repository at this point in the history
  • Loading branch information
tim-smart committed Sep 19, 2024
1 parent d26e484 commit 4719746
Show file tree
Hide file tree
Showing 2 changed files with 49 additions and 52 deletions.
15 changes: 15 additions & 0 deletions packages/ai/ai/src/Completions.ts
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,21 @@ export declare namespace Completions {

const constEmptyMap = new Map<never, never>()

/**
* @since 1.0.0
* @category models
*/
export interface CompletionOptions {
readonly system: Option.Option<string>
readonly input: Chunk.NonEmptyChunk<Message>
readonly tools: Array<{
readonly name: string
readonly description: string
readonly parameters: JSONSchema.JsonSchema7
}>
readonly required: boolean | string
}

/**
* @since 1.0.0
* @category constructors
Expand Down
86 changes: 34 additions & 52 deletions packages/ai/openai/src/OpenAiCompletions.ts
Original file line number Diff line number Diff line change
Expand Up @@ -25,33 +25,38 @@ const make = (options: {
const client = yield* OpenAiClient
const config = yield* OpenAiConfig.getOrUndefined

const makeRequest = ({ input, required, system, tools }: Completions.CompletionOptions) =>
Effect.map(
Effect.context<never>(),
(context): typeof Generated.CreateChatCompletionRequest.Encoded => ({
model: options.model,
...config,
...context.unsafeMap.get(OpenAiConfig.key),
messages: makeMessages(input, system),
tools: tools.length > 0 ?
tools.map((tool) => ({
type: "function",
function: {
name: tool.name,
description: tool.description,
parameters: tool.parameters as any,
strict: true
}
})) :
undefined,
tool_choice: tools.length > 0 ?
typeof required === "boolean" ? (required ? "required" : "auto") : {
type: "function",
function: { name: required }
} :
undefined
})
)

return Completions.make({
create({ input, required, system, tools }) {
return OpenAiConfig.getOrUndefined.pipe(
Effect.flatMap((localConfig) =>
client.client.createChatCompletion({
model: options.model,
...config,
...localConfig,
messages: makeMessages(input, system),
tools: tools.length > 0 ?
tools.map((tool) => ({
type: "function",
function: {
name: tool.name,
description: tool.description,
parameters: tool.parameters as any
}
})) :
undefined,
tool_choice: tools.length > 0 ?
typeof required === "boolean" ? (required ? "required" : "auto") : {
type: "function",
function: { name: required }
} :
undefined
})
),
create(options) {
return makeRequest(options).pipe(
Effect.flatMap(client.client.createChatCompletion),
Effect.catchAll((cause) =>
Effect.fail(
new AiError({
Expand All @@ -65,32 +70,9 @@ const make = (options: {
Effect.flatMap((response) => makeResponse(response, "create"))
)
},
stream({ input, required, system, tools }) {
return OpenAiConfig.getOrUndefined.pipe(
Effect.map((localConfig) =>
client.stream({
model: options.model,
...config,
...localConfig,
messages: makeMessages(input, system),
tools: tools.length > 0 ?
tools.map((tool) => ({
type: "function",
function: {
name: tool.name,
description: tool.description,
parameters: tool.parameters as any
}
})) :
undefined,
tool_choice: tools.length > 0 ?
typeof required === "boolean" ? (required ? "required" : "auto") : {
type: "function",
function: { name: required }
} :
undefined
})
),
stream(options) {
return makeRequest(options).pipe(
Effect.map(client.stream),
Stream.unwrap,
Stream.catchAll((cause) =>
Effect.fail(
Expand Down

0 comments on commit 4719746

Please sign in to comment.