Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

core[minor], integrations...[patch]: Support ToolCall as Tool input and ToolMessage as Tool output #24038

Merged
merged 33 commits into from
Jul 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
33 commits
Select commit Hold shift + click to select a range
8eb1776
rfc: sep methd for tool message res
baskaryan Jul 9, 2024
2474dcd
fmt
baskaryan Jul 9, 2024
a2d8652
fmt
baskaryan Jul 9, 2024
0d42f29
fmt
baskaryan Jul 10, 2024
f688f87
fmt
baskaryan Jul 10, 2024
0cd8bb9
fmt
baskaryan Jul 10, 2024
4525856
fmt
baskaryan Jul 10, 2024
f1b3b27
fmt
baskaryan Jul 10, 2024
3fc90ec
fmt
baskaryan Jul 10, 2024
fb11a47
fmt
baskaryan Jul 10, 2024
7ebf05d
Merge branch 'master' into bagatur/rfc_sep_method_for_tool_res
baskaryan Jul 10, 2024
b43bf00
fmt
baskaryan Jul 10, 2024
5c88785
fmt
baskaryan Jul 10, 2024
0f42560
fmt
baskaryan Jul 10, 2024
4abc681
fmt
baskaryan Jul 10, 2024
a2e9fb8
Merge branch 'master' into bagatur/rfc_sep_method_for_tool_res
baskaryan Jul 10, 2024
ccbc750
fmt
baskaryan Jul 10, 2024
d7d5a80
fmt
baskaryan Jul 10, 2024
7aefb5f
fmt
baskaryan Jul 10, 2024
f6f5894
fmt
baskaryan Jul 10, 2024
20c05ae
Merge branch 'master' into bagatur/rfc_sep_method_for_tool_res
efriis Jul 10, 2024
8b45880
fmt
baskaryan Jul 10, 2024
72aa42c
Merge branch 'bagatur/rfc_sep_method_for_tool_res' of github.com:lang…
baskaryan Jul 10, 2024
50cd03a
fmt
baskaryan Jul 10, 2024
2e81755
fmt
baskaryan Jul 10, 2024
c24cf5d
fmt
baskaryan Jul 11, 2024
9228c58
fmt
baskaryan Jul 11, 2024
b525d5e
fmt
baskaryan Jul 11, 2024
844586a
fmt
baskaryan Jul 11, 2024
03bcf88
fmt
baskaryan Jul 11, 2024
82ec9f9
fmt
baskaryan Jul 11, 2024
59ce038
Merge branch 'master' into bagatur/rfc_sep_method_for_tool_res
baskaryan Jul 11, 2024
32e4073
fmt
baskaryan Jul 11, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion libs/community/tests/unit_tests/tools/test_signatures.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ def get_non_abstract_subclasses(cls: Type[BaseTool]) -> List[Type[BaseTool]]:
def test_all_subclasses_accept_run_manager(cls: Type[BaseTool]) -> None:
"""Test that tools defined in this repo accept a run manager argument."""
# This wouldn't be necessary if the BaseTool had a strict API.
if cls._run is not BaseTool._arun:
if cls._run is not BaseTool._run:
run_func = cls._run
params = inspect.signature(run_func).parameters
assert "run_manager" in params
Expand Down
33 changes: 32 additions & 1 deletion libs/core/langchain_core/messages/tool.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import json
from typing import Any, Dict, List, Literal, Optional, Tuple, Union

from typing_extensions import TypedDict
from typing_extensions import NotRequired, TypedDict

from langchain_core.messages.base import BaseMessage, BaseMessageChunk, merge_content
from langchain_core.utils._merge import merge_dicts, merge_obj
Expand Down Expand Up @@ -146,6 +146,11 @@ class ToolCall(TypedDict):
An identifier is needed to associate a tool call request with a tool
call result in events when multiple concurrent tool calls are made.
"""
type: NotRequired[Literal["tool_call"]]


def tool_call(*, name: str, args: Dict[str, Any], id: Optional[str]) -> ToolCall:
return ToolCall(name=name, args=args, id=id, type="tool_call")


class ToolCallChunk(TypedDict):
Expand Down Expand Up @@ -176,6 +181,19 @@ class ToolCallChunk(TypedDict):
"""An identifier associated with the tool call."""
index: Optional[int]
"""The index of the tool call in a sequence."""
type: NotRequired[Literal["tool_call_chunk"]]


def tool_call_chunk(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why is this needed?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

convenience / to make sure ppl don't mess up the type

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How is that easier than using ToolCallChunk? Is it validating that there are no extra arguments or typos in the arguments?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

type is hard coded

*,
name: Optional[str] = None,
args: Optional[str] = None,
id: Optional[str] = None,
index: Optional[int] = None,
) -> ToolCallChunk:
return ToolCallChunk(
name=name, args=args, id=id, index=index, type="tool_call_chunk"
)


class InvalidToolCall(TypedDict):
Expand All @@ -193,6 +211,19 @@ class InvalidToolCall(TypedDict):
"""An identifier associated with the tool call."""
error: Optional[str]
"""An error message associated with the tool call."""
type: NotRequired[Literal["invalid_tool_call"]]


def invalid_tool_call(
*,
name: Optional[str] = None,
args: Optional[str] = None,
id: Optional[str] = None,
error: Optional[str] = None,
) -> InvalidToolCall:
return InvalidToolCall(
name=name, args=args, id=id, error=error, type="invalid_tool_call"
)


def default_tool_parser(
Expand Down
9 changes: 8 additions & 1 deletion libs/core/langchain_core/output_parsers/openai_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,12 @@

from langchain_core.exceptions import OutputParserException
from langchain_core.messages import AIMessage, InvalidToolCall
from langchain_core.messages.tool import (
invalid_tool_call,
)
from langchain_core.messages.tool import (
tool_call as create_tool_call,
)
from langchain_core.output_parsers.transform import BaseCumulativeTransformOutputParser
from langchain_core.outputs import ChatGeneration, Generation
from langchain_core.pydantic_v1 import BaseModel, ValidationError
Expand Down Expand Up @@ -59,6 +65,7 @@ def parse_tool_call(
}
if return_id:
parsed["id"] = raw_tool_call.get("id")
parsed = create_tool_call(**parsed) # type: ignore
return parsed


Expand All @@ -75,7 +82,7 @@ def make_invalid_tool_call(
Returns:
An InvalidToolCall instance with the error message.
"""
return InvalidToolCall(
return invalid_tool_call(
name=raw_tool_call["function"]["name"],
args=raw_tool_call["function"]["arguments"],
id=raw_tool_call.get("id"),
Expand Down
Loading
Loading