Skip to content

Commit

Permalink
core[minor], integrations...[patch]: Support ToolCall as Tool input a…
Browse files Browse the repository at this point in the history
…nd ToolMessage as Tool output (#24038)

Changes:
- ToolCall, InvalidToolCall and ToolCallChunk can all accept a "type"
parameter now
- LLM integration packages add "type" to all the above
- Tool supports ToolCall inputs that have "type" specified
- Tool outputs ToolMessage when a ToolCall is passed as input
- Tools can separately specify ToolMessage.content and
ToolMessage.raw_output
- Tools emit events for validation errors (using on_tool_error and
on_tool_end)

Example:
```python
@tool("structured_api", response_format="content_and_raw_output")
def _mock_structured_tool_with_raw_output(
    arg1: int, arg2: bool, arg3: Optional[dict] = None
) -> Tuple[str, dict]:
    """A Structured Tool"""
    return f"{arg1} {arg2}", {"arg1": arg1, "arg2": arg2, "arg3": arg3}


def test_tool_call_input_tool_message_with_raw_output() -> None:
    tool_call: Dict = {
        "name": "structured_api",
        "args": {"arg1": 1, "arg2": True, "arg3": {"img": "base64string..."}},
        "id": "123",
        "type": "tool_call",
    }
    expected = ToolMessage("1 True", raw_output=tool_call["args"], tool_call_id="123")
    tool = _mock_structured_tool_with_raw_output
    actual = tool.invoke(tool_call)
    assert actual == expected

    tool_call.pop("type")
    with pytest.raises(ValidationError):
        tool.invoke(tool_call)

    actual_content = tool.invoke(tool_call["args"])
    assert actual_content == expected.content
```

---------

Co-authored-by: Erick Friis <erick@langchain.dev>
  • Loading branch information
baskaryan and efriis authored Jul 11, 2024
1 parent eeb9960 commit 5fd1e67
Show file tree
Hide file tree
Showing 22 changed files with 647 additions and 327 deletions.
2 changes: 1 addition & 1 deletion libs/community/tests/unit_tests/tools/test_signatures.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ def get_non_abstract_subclasses(cls: Type[BaseTool]) -> List[Type[BaseTool]]:
def test_all_subclasses_accept_run_manager(cls: Type[BaseTool]) -> None:
"""Test that tools defined in this repo accept a run manager argument."""
# This wouldn't be necessary if the BaseTool had a strict API.
if cls._run is not BaseTool._arun:
if cls._run is not BaseTool._run:
run_func = cls._run
params = inspect.signature(run_func).parameters
assert "run_manager" in params
Expand Down
33 changes: 32 additions & 1 deletion libs/core/langchain_core/messages/tool.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import json
from typing import Any, Dict, List, Literal, Optional, Tuple, Union

from typing_extensions import TypedDict
from typing_extensions import NotRequired, TypedDict

from langchain_core.messages.base import BaseMessage, BaseMessageChunk, merge_content
from langchain_core.utils._merge import merge_dicts, merge_obj
Expand Down Expand Up @@ -146,6 +146,11 @@ class ToolCall(TypedDict):
An identifier is needed to associate a tool call request with a tool
call result in events when multiple concurrent tool calls are made.
"""
type: NotRequired[Literal["tool_call"]]


def tool_call(*, name: str, args: Dict[str, Any], id: Optional[str]) -> ToolCall:
return ToolCall(name=name, args=args, id=id, type="tool_call")


class ToolCallChunk(TypedDict):
Expand Down Expand Up @@ -176,6 +181,19 @@ class ToolCallChunk(TypedDict):
"""An identifier associated with the tool call."""
index: Optional[int]
"""The index of the tool call in a sequence."""
type: NotRequired[Literal["tool_call_chunk"]]


def tool_call_chunk(
*,
name: Optional[str] = None,
args: Optional[str] = None,
id: Optional[str] = None,
index: Optional[int] = None,
) -> ToolCallChunk:
return ToolCallChunk(
name=name, args=args, id=id, index=index, type="tool_call_chunk"
)


class InvalidToolCall(TypedDict):
Expand All @@ -193,6 +211,19 @@ class InvalidToolCall(TypedDict):
"""An identifier associated with the tool call."""
error: Optional[str]
"""An error message associated with the tool call."""
type: NotRequired[Literal["invalid_tool_call"]]


def invalid_tool_call(
*,
name: Optional[str] = None,
args: Optional[str] = None,
id: Optional[str] = None,
error: Optional[str] = None,
) -> InvalidToolCall:
return InvalidToolCall(
name=name, args=args, id=id, error=error, type="invalid_tool_call"
)


def default_tool_parser(
Expand Down
9 changes: 8 additions & 1 deletion libs/core/langchain_core/output_parsers/openai_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,12 @@

from langchain_core.exceptions import OutputParserException
from langchain_core.messages import AIMessage, InvalidToolCall
from langchain_core.messages.tool import (
invalid_tool_call,
)
from langchain_core.messages.tool import (
tool_call as create_tool_call,
)
from langchain_core.output_parsers.transform import BaseCumulativeTransformOutputParser
from langchain_core.outputs import ChatGeneration, Generation
from langchain_core.pydantic_v1 import BaseModel, ValidationError
Expand Down Expand Up @@ -59,6 +65,7 @@ def parse_tool_call(
}
if return_id:
parsed["id"] = raw_tool_call.get("id")
parsed = create_tool_call(**parsed) # type: ignore
return parsed


Expand All @@ -75,7 +82,7 @@ def make_invalid_tool_call(
Returns:
An InvalidToolCall instance with the error message.
"""
return InvalidToolCall(
return invalid_tool_call(
name=raw_tool_call["function"]["name"],
args=raw_tool_call["function"]["arguments"],
id=raw_tool_call.get("id"),
Expand Down
Loading

0 comments on commit 5fd1e67

Please sign in to comment.