Skip to content

Commit

Permalink
fix: streaming fix in ai sdk provider
Browse files Browse the repository at this point in the history
  • Loading branch information
dosco committed Jul 26, 2024
1 parent da8aff7 commit 192adac
Show file tree
Hide file tree
Showing 2 changed files with 59 additions and 37 deletions.
94 changes: 57 additions & 37 deletions src/ai-sdk-provider/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,10 @@ export class AxAgentFramework implements LanguageModelV1 {
options: Readonly<Parameters<LanguageModelV1['doStream']>[0]>
): Promise<Awaited<ReturnType<LanguageModelV1['doStream']>>> {
const { req, warnings } = createChatRequest(options);
const res = (await this.ai.chat(req)) as ReadableStream<AxChatResponse>;

const res = (await this.ai.chat(req, {
stream: true
})) as ReadableStream<AxChatResponse>;

return {
stream: res.pipeThrough(new AxToSDKTransformer()),
Expand Down Expand Up @@ -205,11 +208,15 @@ function convertToAxChatPrompt(
}
}

messages.push({
role: 'assistant',
content: text,
functionCalls: toolCalls
});
const functionCalls = toolCalls.length === 0 ? undefined : toolCalls;

if (functionCalls || text.length > 0) {
messages.push({
role: 'assistant',
content: text,
functionCalls
});
}

break;
}
Expand Down Expand Up @@ -322,42 +329,55 @@ class AxToSDKTransformer extends TransformStream<
LanguageModelV1StreamPart,
{ type: 'finish' }
>['finishReason'] = 'other';
constructor() {
const transformer = {
transform: (
chunk: Readonly<AxChatResponse>,
controller: TransformStreamDefaultController<LanguageModelV1StreamPart>
) => {
const choice = chunk.results.at(0);
if (!choice) {
return;
}

transform(
chunk: Readonly<AxChatResponse>,
controller: TransformStreamDefaultController<LanguageModelV1StreamPart>
) {
const choice = chunk.results.at(0);
if (!choice) {
throw new Error('No choice returned');
}
if (choice.functionCalls) {
for (const tc of choice.functionCalls) {
if (choice.functionCalls) {
for (const tc of choice.functionCalls) {
controller.enqueue({
type: 'tool-call',
toolCallType: 'function',
toolCallId: tc.id,
toolName: tc.function.name,
args: JSON.stringify(tc.function.params)
});
this.finishReason = 'tool-calls';
}
}

if (choice.content && choice.content.length > 0) {
controller.enqueue({
type: 'text-delta',
textDelta: choice.content ?? ''
});
}
this.finishReason = mapAxFinishReason(choice.finishReason);
},
flush: (
controller: TransformStreamDefaultController<LanguageModelV1StreamPart>
) => {
controller.enqueue({
type: 'tool-call',
toolCallType: 'function',
toolCallId: tc.id,
toolName: tc.function.name,
args: JSON.stringify(tc.function.params!)
type: 'finish',
finishReason: this.finishReason,
usage: this.usage
});
this.finishReason = 'tool-calls';
}
}
};

controller.enqueue({
type: 'text-delta',
textDelta: choice.content ?? ''
});
this.finishReason = mapAxFinishReason(choice.finishReason);
}
super(transformer);

flush(
controller: TransformStreamDefaultController<LanguageModelV1StreamPart>
) {
controller.enqueue({
type: 'finish',
finishReason: this.finishReason,
usage: this.usage
});
this.usage = {
promptTokens: 0,
completionTokens: 0
};
this.finishReason = 'other';
}
}
2 changes: 2 additions & 0 deletions src/ax/ai/base.ts
Original file line number Diff line number Diff line change
Expand Up @@ -463,6 +463,8 @@ export class AxBaseAI<
const logChatRequest = (req: Readonly<AxChatRequest>) => {
const items = req.chatPrompt?.map((msg) => {
switch (msg.role) {
case 'system':
return `System: ${colorLog.whiteBright(msg.content)}`;
case 'function':
return `Function Result: ${colorLog.whiteBright(msg.result)}`;
case 'user': {
Expand Down

0 comments on commit 192adac

Please sign in to comment.