Skip to content

Commit

Permalink
fix: updates to the ai sdk provider
Browse files Browse the repository at this point in the history
  • Loading branch information
dosco committed Jul 29, 2024
1 parent 4bf9ebf commit ff77e1f
Show file tree
Hide file tree
Showing 5 changed files with 54 additions and 67 deletions.
108 changes: 49 additions & 59 deletions src/ai-sdk-provider/index.ts
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
// cspell:ignore Streamable

import {
type ReadableStream,
TransformStream,
Expand All @@ -13,7 +15,7 @@ import {
type LanguageModelV1StreamPart
} from '@ai-sdk/provider';
import type {
AxAgent,
AxAgentic,
AxAIService,
AxChatRequest,
AxChatResponse,
Expand All @@ -23,7 +25,9 @@ import type {
AxGenIn,
AxGenOut
} from '@ax-llm/ax/index.js';
import type { CoreMessage } from 'ai';
import { customAlphabet } from 'nanoid';
import type { ReactNode } from 'react';
import { z } from 'zod';

type Writeable<T> = { -readonly [P in keyof T]: T[P] };
Expand All @@ -33,53 +37,47 @@ type AxConfig = {
fetch?: typeof fetch;
};

type generateFunction<T> = ((input: T) => Promise<unknown>) | undefined;
type Streamable = ReactNode | Promise<ReactNode>;
type Renderer<T> = (
args: T
) =>
| Streamable
| Generator<Streamable, Streamable, void>
| AsyncGenerator<Streamable, Streamable, void>;

export const nanoid = customAlphabet(
'0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz',
7
);

interface AIStateValue {
messages: unknown[];
}

export interface AxAISdkAIState {
get: () => AIStateValue;
update: (newState: Readonly<AIStateValue>) => void;
done: ((newState: Readonly<AIStateValue>) => void) | (() => void);
}

export class AxAgentProvider<IN extends AxGenIn, OUT extends AxGenOut> {
private readonly config?: AxConfig;
private readonly agent: AxAgent<IN, OUT>;
private readonly funcInfo: AxFunction;
private generateFunction?: generateFunction<OUT>;
private state: AxAISdkAIState;
private generateFunction: Renderer<OUT>;
private updateState: (msgs: readonly CoreMessage[]) => void;

constructor({
agent,
state,
updateState,
generate,
config
}: Readonly<{
agent: AxAgent<IN, OUT>;
state: Readonly<AxAISdkAIState>;
generate?: generateFunction<OUT>;
agent: AxAgentic;
updateState: (msgs: readonly CoreMessage[]) => void;
generate: Renderer<OUT>;
config?: Readonly<AxConfig>;
}>) {
this.agent = agent;
this.funcInfo = agent.getFunction();
this.generateFunction = generate;
this.state = state;
this.updateState = updateState;
this.config = config;
}

get description() {
return this.funcInfo.description;
}

get parameters(): unknown {
get parameters(): z.ZodTypeAny {
const schema = this.funcInfo.parameters ?? {
type: 'object',
properties: {}
Expand All @@ -88,47 +86,39 @@ export class AxAgentProvider<IN extends AxGenIn, OUT extends AxGenOut> {
return convertToZodSchema(schema);
}

get generate(): generateFunction<IN> {
return async (input: IN): Promise<unknown> => {
const res = await this.agent.forward(input);
get generate(): Renderer<IN> {
const fn = async (input: IN) => {
const res = (await this.funcInfo.func(input)) as OUT;
const toolCallId = nanoid();

this.state.done({
...this.state.get(),
messages: [
...this.state.get().messages,
{
id: nanoid(),
role: 'assistant',
content: [
{
type: 'tool-call',
toolName: this.funcInfo.name,
toolCallId,
args: input
}
]
},
{
id: nanoid(),
role: 'tool',
content: [
{
type: 'tool-result',
toolName: this.funcInfo.name,
toolCallId,
result: res
}
]
}
]
});
this.updateState([
{
role: 'assistant',
content: [
{
type: 'tool-call',
toolName: this.funcInfo.name,
toolCallId,
args: input
}
]
},
{
role: 'tool',
content: [
{
type: 'tool-result',
toolName: this.funcInfo.name,
toolCallId,
result: res
}
]
}
]);

if (this.generateFunction) {
return await this.generateFunction(res);
}
return res;
return this.generateFunction(res);
};
return fn as Renderer<IN>;
}
}

Expand Down
1 change: 1 addition & 0 deletions src/ai-sdk-provider/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
"zod": "^3.23.8"
},
"devDependencies": {
"@types/react": "^18.3.3",
"npm-run-all": "^4.1.5",
"tsx": "^4.7.1"
},
Expand Down
2 changes: 1 addition & 1 deletion src/ax/ai/types.ts
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ export type AxFunction = {
name: string;
description: string;
parameters?: AxFunctionJSONSchema;
func?: AxFunctionHandler;
func: AxFunctionHandler;
};

export type AxChatResponseResult = {
Expand Down
4 changes: 0 additions & 4 deletions src/ax/dsp/functions.ts
Original file line number Diff line number Diff line change
Expand Up @@ -30,10 +30,6 @@ export class AxFunctionProcessor {
func: Readonly<AxChatResponseFunctionCall>,
options?: Readonly<AxAIServiceActionOptions>
): Promise<AxFunctionExec> => {
if (!fnSpec.func) {
throw new Error(`Function handler for ${fnSpec.name} not implemented`);
}

let args;

if (typeof func.args === 'string' && func.args.length > 0) {
Expand Down
6 changes: 3 additions & 3 deletions src/docs/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,11 @@
"astro": "astro"
},
"dependencies": {
"@astrojs/check": "^0.7.0",
"@astrojs/starlight": "^0.24.5",
"@astrojs/check": "^0.8.3",
"@astrojs/starlight": "^0.25.3",
"@astrojs/tailwind": "^5.1.0",
"@fontsource/roboto": "^5.0.13",
"astro": "^4.10.2",
"astro": "^4.12.2",
"astro-imagetools": "^0.9.0",
"sharp": "^0.32.5",
"tailwindcss": "^3.4.4"
Expand Down

0 comments on commit ff77e1f

Please sign in to comment.