Skip to content

Commit

Permalink
Add aggregation
Browse files Browse the repository at this point in the history
  • Loading branch information
jacoblee93 committed Jul 30, 2024
1 parent c66494e commit c60b0ce
Show file tree
Hide file tree
Showing 2 changed files with 59 additions and 1 deletion.
29 changes: 28 additions & 1 deletion js/src/tests/wrapped_ai_sdk.int.test.ts
Original file line number Diff line number Diff line change
@@ -1,5 +1,11 @@
import { openai } from "@ai-sdk/openai";
import { generateObject, generateText, streamObject, streamText } from "ai";
import {
generateObject,
generateText,
streamObject,
streamText,
tool,
} from "ai";
import { z } from "zod";
import { wrapAISDKModel } from "../wrappers/vercel.js";

Expand All @@ -12,6 +18,27 @@ test("AI SDK generateText", async () => {
console.log(text);
});

test("AI SDK generateText with a tool", async () => {
const modelWithTracing = wrapAISDKModel(openai("gpt-4o-mini"));
const { text } = await generateText({
model: modelWithTracing,
prompt:
"Write a vegetarian lasagna recipe for 4 people. Get ingredients first.",
tools: {
getIngredients: tool({
description: "get a list of ingredients",
parameters: z.object({
ingredients: z.array(z.string()),
}),
execute: async () =>
JSON.stringify(["pasta", "tomato", "cheese", "onions"]),
}),
},
maxToolRoundtrips: 2,
});
console.log(text);
});

test("AI SDK generateObject", async () => {
const modelWithTracing = wrapAISDKModel(openai("gpt-4o-mini"));
const { object } = await generateObject({
Expand Down
31 changes: 31 additions & 0 deletions js/src/wrappers/vercel.ts
Original file line number Diff line number Diff line change
Expand Up @@ -50,14 +50,45 @@ export const wrapAISDKModel = <T extends object>(
const originalValue = target[propKey as keyof T];
if (typeof originalValue === "function") {
let __finalTracedIteratorKey;
let aggregator;
if (propKey === "doStream") {
__finalTracedIteratorKey = "stream";
aggregator = (chunks: any[]) => {
return chunks.reduce(
(aggregated, chunk) => {
console.log(chunk);
if (chunk.type === "text-delta") {
return {
...aggregated,
text: aggregated.text + chunk.textDelta,
};
} else if (chunk.type === "tool-call") {
return {
...aggregated,
...chunk,
};
} else if (chunk.type === "finish") {
return {
...aggregated,
usage: chunk.usage,
finishReason: chunk.finishReason,
};
} else {
return aggregated;
}
},
{
text: "",
}
);
};
}
return traceable(originalValue.bind(target), {
run_type: "llm",
name: runName,
...options,
__finalTracedIteratorKey,
aggregator,
});
} else if (
originalValue != null &&
Expand Down

0 comments on commit c60b0ce

Please sign in to comment.