diff --git a/libs/community/tests/unit_tests/tools/test_signatures.py b/libs/community/tests/unit_tests/tools/test_signatures.py index 2dc0292cbf924..105a2f4ac96a3 100644 --- a/libs/community/tests/unit_tests/tools/test_signatures.py +++ b/libs/community/tests/unit_tests/tools/test_signatures.py @@ -39,7 +39,7 @@ def get_non_abstract_subclasses(cls: Type[BaseTool]) -> List[Type[BaseTool]]: def test_all_subclasses_accept_run_manager(cls: Type[BaseTool]) -> None: """Test that tools defined in this repo accept a run manager argument.""" # This wouldn't be necessary if the BaseTool had a strict API. - if cls._run is not BaseTool._arun: + if cls._run is not BaseTool._run: run_func = cls._run params = inspect.signature(run_func).parameters assert "run_manager" in params diff --git a/libs/core/langchain_core/messages/tool.py b/libs/core/langchain_core/messages/tool.py index 1c85a6521f6a1..61a8d6252f1b5 100644 --- a/libs/core/langchain_core/messages/tool.py +++ b/libs/core/langchain_core/messages/tool.py @@ -1,7 +1,7 @@ import json from typing import Any, Dict, List, Literal, Optional, Tuple, Union -from typing_extensions import TypedDict +from typing_extensions import NotRequired, TypedDict from langchain_core.messages.base import BaseMessage, BaseMessageChunk, merge_content from langchain_core.utils._merge import merge_dicts, merge_obj @@ -146,6 +146,11 @@ class ToolCall(TypedDict): An identifier is needed to associate a tool call request with a tool call result in events when multiple concurrent tool calls are made. """ + type: NotRequired[Literal["tool_call"]] + + +def tool_call(*, name: str, args: Dict[str, Any], id: Optional[str]) -> ToolCall: + return ToolCall(name=name, args=args, id=id, type="tool_call") class ToolCallChunk(TypedDict): @@ -176,6 +181,19 @@ class ToolCallChunk(TypedDict): """An identifier associated with the tool call.""" index: Optional[int] """The index of the tool call in a sequence.""" + type: NotRequired[Literal["tool_call_chunk"]] + + +def tool_call_chunk( + *, + name: Optional[str] = None, + args: Optional[str] = None, + id: Optional[str] = None, + index: Optional[int] = None, +) -> ToolCallChunk: + return ToolCallChunk( + name=name, args=args, id=id, index=index, type="tool_call_chunk" + ) class InvalidToolCall(TypedDict): @@ -193,6 +211,19 @@ class InvalidToolCall(TypedDict): """An identifier associated with the tool call.""" error: Optional[str] """An error message associated with the tool call.""" + type: NotRequired[Literal["invalid_tool_call"]] + + +def invalid_tool_call( + *, + name: Optional[str] = None, + args: Optional[str] = None, + id: Optional[str] = None, + error: Optional[str] = None, +) -> InvalidToolCall: + return InvalidToolCall( + name=name, args=args, id=id, error=error, type="invalid_tool_call" + ) def default_tool_parser( diff --git a/libs/core/langchain_core/output_parsers/openai_tools.py b/libs/core/langchain_core/output_parsers/openai_tools.py index acc7fdb94febe..4122cbe63a693 100644 --- a/libs/core/langchain_core/output_parsers/openai_tools.py +++ b/libs/core/langchain_core/output_parsers/openai_tools.py @@ -5,6 +5,12 @@ from langchain_core.exceptions import OutputParserException from langchain_core.messages import AIMessage, InvalidToolCall +from langchain_core.messages.tool import ( + invalid_tool_call, +) +from langchain_core.messages.tool import ( + tool_call as create_tool_call, +) from langchain_core.output_parsers.transform import BaseCumulativeTransformOutputParser from langchain_core.outputs import ChatGeneration, Generation from langchain_core.pydantic_v1 import BaseModel, ValidationError @@ -59,6 +65,7 @@ def parse_tool_call( } if return_id: parsed["id"] = raw_tool_call.get("id") + parsed = create_tool_call(**parsed) # type: ignore return parsed @@ -75,7 +82,7 @@ def make_invalid_tool_call( Returns: An InvalidToolCall instance with the error message. """ - return InvalidToolCall( + return invalid_tool_call( name=raw_tool_call["function"]["name"], args=raw_tool_call["function"]["arguments"], id=raw_tool_call.get("id"), diff --git a/libs/core/langchain_core/tools.py b/libs/core/langchain_core/tools.py index 9c5aaa48e65d4..4a552f301a5a4 100644 --- a/libs/core/langchain_core/tools.py +++ b/libs/core/langchain_core/tools.py @@ -21,6 +21,7 @@ import asyncio import inspect +import json import textwrap import uuid import warnings @@ -34,6 +35,7 @@ Callable, Dict, List, + Literal, Optional, Sequence, Tuple, @@ -42,7 +44,7 @@ get_type_hints, ) -from typing_extensions import Annotated, get_args, get_origin +from typing_extensions import Annotated, cast, get_args, get_origin from langchain_core._api import deprecated from langchain_core.callbacks import ( @@ -56,6 +58,7 @@ Callbacks, ) from langchain_core.load.serializable import Serializable +from langchain_core.messages.tool import ToolCall, ToolMessage from langchain_core.prompts import ( BasePromptTemplate, PromptTemplate, @@ -306,7 +309,7 @@ class ToolException(Exception): pass -class BaseTool(RunnableSerializable[Union[str, Dict], Any]): +class BaseTool(RunnableSerializable[Union[str, Dict, ToolCall], Any]): """Interface LangChain tools must implement.""" def __init_subclass__(cls, **kwargs: Any) -> None: @@ -378,6 +381,14 @@ class ChildTool(BaseTool): ] = False """Handle the content of the ValidationError thrown.""" + response_format: Literal["content", "content_and_raw_output"] = "content" + """The tool response format. + + If "content" then the output of the tool is interpreted as the contents of a + ToolMessage. If "content_and_raw_output" then the output is expected to be a + two-tuple corresponding to the (content, raw_output) of a ToolMessage. + """ + class Config(Serializable.Config): """Configuration for this pydantic object.""" @@ -410,46 +421,25 @@ def get_input_schema( def invoke( self, - input: Union[str, Dict], + input: Union[str, Dict, ToolCall], config: Optional[RunnableConfig] = None, **kwargs: Any, ) -> Any: - config = ensure_config(config) - return self.run( - input, - callbacks=config.get("callbacks"), - tags=config.get("tags"), - metadata=config.get("metadata"), - run_name=config.get("run_name"), - run_id=config.pop("run_id", None), - config=config, - **kwargs, - ) + tool_input, kwargs = _prep_run_args(input, config, **kwargs) + return self.run(tool_input, **kwargs) async def ainvoke( self, - input: Union[str, Dict], + input: Union[str, Dict, ToolCall], config: Optional[RunnableConfig] = None, **kwargs: Any, ) -> Any: - config = ensure_config(config) - return await self.arun( - input, - callbacks=config.get("callbacks"), - tags=config.get("tags"), - metadata=config.get("metadata"), - run_name=config.get("run_name"), - run_id=config.pop("run_id", None), - config=config, - **kwargs, - ) + tool_input, kwargs = _prep_run_args(input, config, **kwargs) + return await self.arun(tool_input, **kwargs) # --- Tool --- - def _parse_input( - self, - tool_input: Union[str, Dict], - ) -> Union[str, Dict[str, Any]]: + def _parse_input(self, tool_input: Union[str, Dict]) -> Union[str, Dict[str, Any]]: """Convert tool input to pydantic model.""" input_args = self.args_schema if isinstance(tool_input, str): @@ -465,7 +455,7 @@ def _parse_input( for k, v in result.dict().items() if k in tool_input } - return tool_input + return tool_input @root_validator(pre=True) def raise_deprecation(cls, values: Dict) -> Dict: @@ -479,30 +469,27 @@ def raise_deprecation(cls, values: Dict) -> Dict: return values @abstractmethod - def _run( - self, - *args: Any, - **kwargs: Any, - ) -> Any: + def _run(self, *args: Any, **kwargs: Any) -> Any: """Use the tool. Add run_manager: Optional[CallbackManagerForToolRun] = None - to child implementations to enable tracing, + to child implementations to enable tracing. """ - async def _arun( - self, - *args: Any, - **kwargs: Any, - ) -> Any: + async def _arun(self, *args: Any, **kwargs: Any) -> Any: """Use the tool asynchronously. Add run_manager: Optional[AsyncCallbackManagerForToolRun] = None - to child implementations to enable tracing, + to child implementations to enable tracing. """ + if kwargs.get("run_manager") and signature(self._run).parameters.get( + "run_manager" + ): + kwargs["run_manager"] = kwargs["run_manager"].get_sync() return await run_in_executor(None, self._run, *args, **kwargs) def _to_args_and_kwargs(self, tool_input: Union[str, Dict]) -> Tuple[Tuple, Dict]: + tool_input = self._parse_input(tool_input) # For backwards compatibility, if run_input is a string, # pass as a positional argument. if isinstance(tool_input, str): @@ -523,24 +510,20 @@ def run( run_name: Optional[str] = None, run_id: Optional[uuid.UUID] = None, config: Optional[RunnableConfig] = None, + tool_call_id: Optional[str] = None, **kwargs: Any, ) -> Any: """Run the tool.""" - if not self.verbose and verbose is not None: - verbose_ = verbose - else: - verbose_ = self.verbose callback_manager = CallbackManager.configure( callbacks, self.callbacks, - verbose_, + self.verbose or bool(verbose), tags, self.tags, metadata, self.metadata, ) - # TODO: maybe also pass through run_manager is _run supports kwargs - new_arg_supported = signature(self._run).parameters.get("run_manager") + run_manager = callback_manager.on_tool_start( {"name": self.name, "description": self.description}, tool_input if isinstance(tool_input, str) else str(tool_input), @@ -550,67 +533,52 @@ def run( # Inputs by definition should always be dicts. # For now, it's unclear whether this assumption is ever violated, # but if it is we will send a `None` value to the callback instead - # And will need to address issue via a patch. - inputs=None if isinstance(tool_input, str) else tool_input, + # TODO: will need to address issue via a patch. + inputs=tool_input if isinstance(tool_input, dict) else None, **kwargs, ) + + content = None + raw_output = None + error_to_raise: Union[Exception, KeyboardInterrupt, None] = None try: - child_config = patch_config( - config, - callbacks=run_manager.get_child(), - ) + child_config = patch_config(config, callbacks=run_manager.get_child()) context = copy_context() context.run(_set_config_context, child_config) - parsed_input = self._parse_input(tool_input) - tool_args, tool_kwargs = self._to_args_and_kwargs(parsed_input) - observation = ( - context.run( - self._run, *tool_args, run_manager=run_manager, **tool_kwargs - ) - if new_arg_supported - else context.run(self._run, *tool_args, **tool_kwargs) - ) + tool_args, tool_kwargs = self._to_args_and_kwargs(tool_input) + if signature(self._run).parameters.get("run_manager"): + tool_kwargs["run_manager"] = run_manager + response = context.run(self._run, *tool_args, **tool_kwargs) + if self.response_format == "content_and_raw_output": + if not isinstance(response, tuple) or len(response) != 2: + raise ValueError( + "Since response_format='content_and_raw_output' " + "a two-tuple of the message content and raw tool output is " + f"expected. Instead generated response of type: " + f"{type(response)}." + ) + content, raw_output = response + else: + content = response except ValidationError as e: if not self.handle_validation_error: - raise e - elif isinstance(self.handle_validation_error, bool): - observation = "Tool input validation error" - elif isinstance(self.handle_validation_error, str): - observation = self.handle_validation_error - elif callable(self.handle_validation_error): - observation = self.handle_validation_error(e) + error_to_raise = e else: - raise ValueError( - f"Got unexpected type of `handle_validation_error`. Expected bool, " - f"str or callable. Received: {self.handle_validation_error}" - ) - return observation + content = _handle_validation_error(e, flag=self.handle_validation_error) except ToolException as e: if not self.handle_tool_error: - run_manager.on_tool_error(e) - raise e - elif isinstance(self.handle_tool_error, bool): - if e.args: - observation = e.args[0] - else: - observation = "Tool execution error" - elif isinstance(self.handle_tool_error, str): - observation = self.handle_tool_error - elif callable(self.handle_tool_error): - observation = self.handle_tool_error(e) + error_to_raise = e else: - raise ValueError( - f"Got unexpected type of `handle_tool_error`. Expected bool, str " - f"or callable. Received: {self.handle_tool_error}" - ) - run_manager.on_tool_end(observation, color="red", name=self.name, **kwargs) - return observation + content = _handle_tool_error(e, flag=self.handle_tool_error) except (Exception, KeyboardInterrupt) as e: - run_manager.on_tool_error(e) - raise e - else: - run_manager.on_tool_end(observation, color=color, name=self.name, **kwargs) - return observation + error_to_raise = e + + if error_to_raise: + run_manager.on_tool_error(error_to_raise) + raise error_to_raise + output = _format_output(content, raw_output, tool_call_id) + run_manager.on_tool_end(output, color=color, name=self.name, **kwargs) + return output async def arun( self, @@ -625,99 +593,80 @@ async def arun( run_name: Optional[str] = None, run_id: Optional[uuid.UUID] = None, config: Optional[RunnableConfig] = None, + tool_call_id: Optional[str] = None, **kwargs: Any, ) -> Any: """Run the tool asynchronously.""" - if not self.verbose and verbose is not None: - verbose_ = verbose - else: - verbose_ = self.verbose callback_manager = AsyncCallbackManager.configure( callbacks, self.callbacks, - verbose_, + self.verbose or bool(verbose), tags, self.tags, metadata, self.metadata, ) - new_arg_supported = signature(self._arun).parameters.get("run_manager") run_manager = await callback_manager.on_tool_start( {"name": self.name, "description": self.description}, tool_input if isinstance(tool_input, str) else str(tool_input), color=start_color, name=run_name, - inputs=tool_input, run_id=run_id, + # Inputs by definition should always be dicts. + # For now, it's unclear whether this assumption is ever violated, + # but if it is we will send a `None` value to the callback instead + # TODO: will need to address issue via a patch. + inputs=tool_input if isinstance(tool_input, dict) else None, **kwargs, ) + content = None + raw_output = None + error_to_raise: Optional[Union[Exception, KeyboardInterrupt]] = None try: - parsed_input = self._parse_input(tool_input) - # We then call the tool on the tool input to get an observation - tool_args, tool_kwargs = self._to_args_and_kwargs(parsed_input) - child_config = patch_config( - config, - callbacks=run_manager.get_child(), - ) + tool_args, tool_kwargs = self._to_args_and_kwargs(tool_input) + child_config = patch_config(config, callbacks=run_manager.get_child()) context = copy_context() context.run(_set_config_context, child_config) - coro = ( - context.run( - self._arun, *tool_args, run_manager=run_manager, **tool_kwargs - ) - if new_arg_supported - else context.run(self._arun, *tool_args, **tool_kwargs) - ) + if self.__class__._arun is BaseTool._arun or signature( + self._arun + ).parameters.get("run_manager"): + tool_kwargs["run_manager"] = run_manager + coro = context.run(self._arun, *tool_args, **tool_kwargs) if accepts_context(asyncio.create_task): - observation = await asyncio.create_task(coro, context=context) # type: ignore + response = await asyncio.create_task(coro, context=context) # type: ignore else: - observation = await coro - + response = await coro + if self.response_format == "content_and_raw_output": + if not isinstance(response, tuple) or len(response) != 2: + raise ValueError( + "Since response_format='content_and_raw_output' " + "a two-tuple of the message content and raw tool output is " + f"expected. Instead generated response of type: " + f"{type(response)}." + ) + content, raw_output = response + else: + content = response except ValidationError as e: if not self.handle_validation_error: - raise e - elif isinstance(self.handle_validation_error, bool): - observation = "Tool input validation error" - elif isinstance(self.handle_validation_error, str): - observation = self.handle_validation_error - elif callable(self.handle_validation_error): - observation = self.handle_validation_error(e) + error_to_raise = e else: - raise ValueError( - f"Got unexpected type of `handle_validation_error`. Expected bool, " - f"str or callable. Received: {self.handle_validation_error}" - ) - return observation + content = _handle_validation_error(e, flag=self.handle_validation_error) except ToolException as e: if not self.handle_tool_error: - await run_manager.on_tool_error(e) - raise e - elif isinstance(self.handle_tool_error, bool): - if e.args: - observation = e.args[0] - else: - observation = "Tool execution error" - elif isinstance(self.handle_tool_error, str): - observation = self.handle_tool_error - elif callable(self.handle_tool_error): - observation = self.handle_tool_error(e) + error_to_raise = e else: - raise ValueError( - f"Got unexpected type of `handle_tool_error`. Expected bool, str " - f"or callable. Received: {self.handle_tool_error}" - ) - await run_manager.on_tool_end( - observation, color="red", name=self.name, **kwargs - ) - return observation + content = _handle_tool_error(e, flag=self.handle_tool_error) except (Exception, KeyboardInterrupt) as e: - await run_manager.on_tool_error(e) - raise e - else: - await run_manager.on_tool_end( - observation, color=color, name=self.name, **kwargs - ) - return observation + error_to_raise = e + + if error_to_raise: + await run_manager.on_tool_error(error_to_raise) + raise error_to_raise + + output = _format_output(content, raw_output, tool_call_id) + await run_manager.on_tool_end(output, color=color, name=self.name, **kwargs) + return output @deprecated("0.1.47", alternative="invoke", removal="0.3.0") def __call__(self, tool_input: str, callbacks: Callbacks = None) -> str: @@ -738,7 +687,7 @@ class Tool(BaseTool): async def ainvoke( self, - input: Union[str, Dict], + input: Union[str, Dict, ToolCall], config: Optional[RunnableConfig] = None, **kwargs: Any, ) -> Any: @@ -780,17 +729,10 @@ def _run( ) -> Any: """Use the tool.""" if self.func: - new_argument_supported = signature(self.func).parameters.get("callbacks") - return ( - self.func( - *args, - callbacks=run_manager.get_child() if run_manager else None, - **kwargs, - ) - if new_argument_supported - else self.func(*args, **kwargs) - ) - raise NotImplementedError("Tool does not support sync") + if run_manager and signature(self.func).parameters.get("callbacks"): + kwargs["callbacks"] = run_manager.get_child() + return self.func(*args, **kwargs) + raise NotImplementedError("Tool does not support sync invocation.") async def _arun( self, @@ -800,26 +742,13 @@ async def _arun( ) -> Any: """Use the tool asynchronously.""" if self.coroutine: - new_argument_supported = signature(self.coroutine).parameters.get( - "callbacks" - ) - return ( - await self.coroutine( - *args, - callbacks=run_manager.get_child() if run_manager else None, - **kwargs, - ) - if new_argument_supported - else await self.coroutine(*args, **kwargs) - ) - else: - return await run_in_executor( - None, - self._run, - run_manager=run_manager.get_sync() if run_manager else None, - *args, - **kwargs, - ) + if run_manager and signature(self.coroutine).parameters.get("callbacks"): + kwargs["callbacks"] = run_manager.get_child() + return await self.coroutine(*args, **kwargs) + + # NOTE: this code is unreachable since _arun is only called if coroutine is not + # None. + return await super()._arun(*args, run_manager=run_manager, **kwargs) # TODO: this is for backwards compatibility, remove in future def __init__( @@ -870,9 +799,10 @@ class StructuredTool(BaseTool): # --- Runnable --- + # TODO: Is this needed? async def ainvoke( self, - input: Union[str, Dict], + input: Union[str, Dict, ToolCall], config: Optional[RunnableConfig] = None, **kwargs: Any, ) -> Any: @@ -897,45 +827,26 @@ def _run( ) -> Any: """Use the tool.""" if self.func: - new_argument_supported = signature(self.func).parameters.get("callbacks") - return ( - self.func( - *args, - callbacks=run_manager.get_child() if run_manager else None, - **kwargs, - ) - if new_argument_supported - else self.func(*args, **kwargs) - ) - raise NotImplementedError("Tool does not support sync") + if run_manager and signature(self.func).parameters.get("callbacks"): + kwargs["callbacks"] = run_manager.get_child() + return self.func(*args, **kwargs) + raise NotImplementedError("StructuredTool does not support sync invocation.") async def _arun( self, *args: Any, run_manager: Optional[AsyncCallbackManagerForToolRun] = None, **kwargs: Any, - ) -> str: + ) -> Any: """Use the tool asynchronously.""" if self.coroutine: - new_argument_supported = signature(self.coroutine).parameters.get( - "callbacks" - ) - return ( - await self.coroutine( - *args, - callbacks=run_manager.get_child() if run_manager else None, - **kwargs, - ) - if new_argument_supported - else await self.coroutine(*args, **kwargs) - ) - return await run_in_executor( - None, - self._run, - run_manager=run_manager.get_sync() if run_manager else None, - *args, - **kwargs, - ) + if run_manager and signature(self.coroutine).parameters.get("callbacks"): + kwargs["callbacks"] = run_manager.get_child() + return await self.coroutine(*args, **kwargs) + + # NOTE: this code is unreachable since _arun is only called if coroutine is not + # None. + return await super()._arun(*args, run_manager=run_manager, **kwargs) @classmethod def from_function( @@ -947,6 +858,8 @@ def from_function( return_direct: bool = False, args_schema: Optional[Type[BaseModel]] = None, infer_schema: bool = True, + *, + response_format: Literal["content", "content_and_raw_output"] = "content", parse_docstring: bool = False, error_on_invalid_docstring: bool = False, **kwargs: Any, @@ -963,6 +876,10 @@ def from_function( return_direct: Whether to return the result directly or as a callback args_schema: The schema of the tool's input arguments infer_schema: Whether to infer the schema from the function's signature + response_format: The tool response format. If "content" then the output of + the tool is interpreted as the contents of a ToolMessage. If + "content_and_raw_output" then the output is expected to be a two-tuple + corresponding to the (content, raw_output) of a ToolMessage. parse_docstring: if ``infer_schema`` and ``parse_docstring``, will attempt to parse parameter descriptions from Google Style function docstrings. error_on_invalid_docstring: if ``parse_docstring`` is provided, configures @@ -1020,6 +937,7 @@ def add(a: int, b: int) -> int: args_schema=_args_schema, # type: ignore[arg-type] description=description_, return_direct=return_direct, + response_format=response_format, **kwargs, ) @@ -1029,6 +947,7 @@ def tool( return_direct: bool = False, args_schema: Optional[Type[BaseModel]] = None, infer_schema: bool = True, + response_format: Literal["content", "content_and_raw_output"] = "content", parse_docstring: bool = False, error_on_invalid_docstring: bool = True, ) -> Callable: @@ -1042,6 +961,10 @@ def tool( infer_schema: Whether to infer the schema of the arguments from the function's signature. This also makes the resultant tool accept a dictionary input to its `run()` function. + response_format: The tool response format. If "content" then the output of + the tool is interpreted as the contents of a ToolMessage. If + "content_and_raw_output" then the output is expected to be a two-tuple + corresponding to the (content, raw_output) of a ToolMessage. parse_docstring: if ``infer_schema`` and ``parse_docstring``, will attempt to parse parameter descriptions from Google Style function docstrings. error_on_invalid_docstring: if ``parse_docstring`` is provided, configures @@ -1064,8 +987,12 @@ def search_api(query: str) -> str: # Searches the API for the query. return - .. versionadded:: 0.2.14 - Parse Google-style docstrings: + @tool(response_format="content_and_raw_output") + def search_api(query: str) -> Tuple[str, dict]: + return "partial json of results", {"full": "object of results"} + + .. versionadded:: 0.2.14 + Parse Google-style docstrings: .. code-block:: python @@ -1179,6 +1106,7 @@ def invoke_wrapper( return_direct=return_direct, args_schema=schema, infer_schema=infer_schema, + response_format=response_format, parse_docstring=parse_docstring, error_on_invalid_docstring=error_on_invalid_docstring, ) @@ -1195,6 +1123,7 @@ def invoke_wrapper( description=f"{tool_name} tool", return_direct=return_direct, coroutine=coroutine, + response_format=response_format, ) return _make_tool @@ -1350,6 +1279,103 @@ def get_tools(self) -> List[BaseTool]: """Get the tools in the toolkit.""" +def _is_tool_call(x: Any) -> bool: + return isinstance(x, dict) and x.get("type") == "tool_call" + + +def _handle_validation_error( + e: ValidationError, + *, + flag: Union[Literal[True], str, Callable[[ValidationError], str]], +) -> str: + if isinstance(flag, bool): + content = "Tool input validation error" + elif isinstance(flag, str): + content = flag + elif callable(flag): + content = flag(e) + else: + raise ValueError( + f"Got unexpected type of `handle_validation_error`. Expected bool, " + f"str or callable. Received: {flag}" + ) + return content + + +def _handle_tool_error( + e: ToolException, + *, + flag: Optional[Union[Literal[True], str, Callable[[ToolException], str]]], +) -> str: + if isinstance(flag, bool): + if e.args: + content = e.args[0] + else: + content = "Tool execution error" + elif isinstance(flag, str): + content = flag + elif callable(flag): + content = flag(e) + else: + raise ValueError( + f"Got unexpected type of `handle_tool_error`. Expected bool, str " + f"or callable. Received: {flag}" + ) + return content + + +def _prep_run_args( + input: Union[str, dict, ToolCall], + config: Optional[RunnableConfig], + **kwargs: Any, +) -> Tuple[Union[str, Dict], Dict]: + config = ensure_config(config) + if _is_tool_call(input): + tool_call_id: Optional[str] = cast(ToolCall, input)["id"] + tool_input: Union[str, dict] = cast(ToolCall, input)["args"] + else: + tool_call_id = None + tool_input = cast(Union[str, dict], input) + return ( + tool_input, + dict( + callbacks=config.get("callbacks"), + tags=config.get("tags"), + metadata=config.get("metadata"), + run_name=config.get("run_name"), + run_id=config.pop("run_id", None), + config=config, + tool_call_id=tool_call_id, + **kwargs, + ), + ) + + +def _format_output( + content: Any, raw_output: Any, tool_call_id: Optional[str] +) -> Union[ToolMessage, Any]: + if tool_call_id: + # NOTE: This will fail to stringify lists which aren't actually content blocks + # but whose first element happens to be a string or dict. Tools should avoid + # returning such contents. + if not isinstance(content, str) and not ( + isinstance(content, list) + and content + and isinstance(content[0], (str, dict)) + ): + content = _stringify(content) + return ToolMessage(content, raw_output=raw_output, tool_call_id=tool_call_id) + else: + return content + + +def _stringify(content: Any) -> str: + try: + return json.dumps(content) + except Exception: + return str(content) + + def _get_description_from_runnable(runnable: Runnable) -> str: """Generate a placeholder description of a runnable.""" input_schema = runnable.input_schema.schema() diff --git a/libs/core/tests/unit_tests/prompts/__snapshots__/test_chat.ambr b/libs/core/tests/unit_tests/prompts/__snapshots__/test_chat.ambr index 8b64b0d81de4f..199af98ece4ee 100644 --- a/libs/core/tests/unit_tests/prompts/__snapshots__/test_chat.ambr +++ b/libs/core/tests/unit_tests/prompts/__snapshots__/test_chat.ambr @@ -317,6 +317,13 @@ 'title': 'Name', 'type': 'string', }), + 'type': dict({ + 'enum': list([ + 'invalid_tool_call', + ]), + 'title': 'Type', + 'type': 'string', + }), }), 'required': list([ 'name', @@ -419,6 +426,13 @@ 'title': 'Name', 'type': 'string', }), + 'type': dict({ + 'enum': list([ + 'tool_call', + ]), + 'title': 'Type', + 'type': 'string', + }), }), 'required': list([ 'name', @@ -908,6 +922,13 @@ 'title': 'Name', 'type': 'string', }), + 'type': dict({ + 'enum': list([ + 'invalid_tool_call', + ]), + 'title': 'Type', + 'type': 'string', + }), }), 'required': list([ 'name', @@ -1010,6 +1031,13 @@ 'title': 'Name', 'type': 'string', }), + 'type': dict({ + 'enum': list([ + 'tool_call', + ]), + 'title': 'Type', + 'type': 'string', + }), }), 'required': list([ 'name', diff --git a/libs/core/tests/unit_tests/runnables/__snapshots__/test_graph.ambr b/libs/core/tests/unit_tests/runnables/__snapshots__/test_graph.ambr index af617e3c11cc6..1d097ae4fda9f 100644 --- a/libs/core/tests/unit_tests/runnables/__snapshots__/test_graph.ambr +++ b/libs/core/tests/unit_tests/runnables/__snapshots__/test_graph.ambr @@ -674,6 +674,13 @@ 'title': 'Name', 'type': 'string', }), + 'type': dict({ + 'enum': list([ + 'invalid_tool_call', + ]), + 'title': 'Type', + 'type': 'string', + }), }), 'required': list([ 'name', @@ -776,6 +783,13 @@ 'title': 'Name', 'type': 'string', }), + 'type': dict({ + 'enum': list([ + 'tool_call', + ]), + 'title': 'Type', + 'type': 'string', + }), }), 'required': list([ 'name', diff --git a/libs/core/tests/unit_tests/runnables/__snapshots__/test_runnable.ambr b/libs/core/tests/unit_tests/runnables/__snapshots__/test_runnable.ambr index ce3b6b109da2c..f29de06b133de 100644 --- a/libs/core/tests/unit_tests/runnables/__snapshots__/test_runnable.ambr +++ b/libs/core/tests/unit_tests/runnables/__snapshots__/test_runnable.ambr @@ -5577,6 +5577,13 @@ 'title': 'Name', 'type': 'string', }), + 'type': dict({ + 'enum': list([ + 'invalid_tool_call', + ]), + 'title': 'Type', + 'type': 'string', + }), }), 'required': list([ 'name', @@ -5701,6 +5708,13 @@ 'title': 'Name', 'type': 'string', }), + 'type': dict({ + 'enum': list([ + 'tool_call', + ]), + 'title': 'Type', + 'type': 'string', + }), }), 'required': list([ 'name', @@ -6237,6 +6251,13 @@ 'title': 'Name', 'type': 'string', }), + 'type': dict({ + 'enum': list([ + 'invalid_tool_call', + ]), + 'title': 'Type', + 'type': 'string', + }), }), 'required': list([ 'name', @@ -6361,6 +6382,13 @@ 'title': 'Name', 'type': 'string', }), + 'type': dict({ + 'enum': list([ + 'tool_call', + ]), + 'title': 'Type', + 'type': 'string', + }), }), 'required': list([ 'name', @@ -6834,6 +6862,13 @@ 'title': 'Name', 'type': 'string', }), + 'type': dict({ + 'enum': list([ + 'invalid_tool_call', + ]), + 'title': 'Type', + 'type': 'string', + }), }), 'required': list([ 'name', @@ -6936,6 +6971,13 @@ 'title': 'Name', 'type': 'string', }), + 'type': dict({ + 'enum': list([ + 'tool_call', + ]), + 'title': 'Type', + 'type': 'string', + }), }), 'required': list([ 'name', @@ -7444,6 +7486,13 @@ 'title': 'Name', 'type': 'string', }), + 'type': dict({ + 'enum': list([ + 'invalid_tool_call', + ]), + 'title': 'Type', + 'type': 'string', + }), }), 'required': list([ 'name', @@ -7568,6 +7617,13 @@ 'title': 'Name', 'type': 'string', }), + 'type': dict({ + 'enum': list([ + 'tool_call', + ]), + 'title': 'Type', + 'type': 'string', + }), }), 'required': list([ 'name', @@ -8068,6 +8124,13 @@ 'title': 'Name', 'type': 'string', }), + 'type': dict({ + 'enum': list([ + 'invalid_tool_call', + ]), + 'title': 'Type', + 'type': 'string', + }), }), 'required': list([ 'name', @@ -8203,6 +8266,13 @@ 'title': 'Name', 'type': 'string', }), + 'type': dict({ + 'enum': list([ + 'tool_call', + ]), + 'title': 'Type', + 'type': 'string', + }), }), 'required': list([ 'name', @@ -8683,6 +8753,13 @@ 'title': 'Name', 'type': 'string', }), + 'type': dict({ + 'enum': list([ + 'invalid_tool_call', + ]), + 'title': 'Type', + 'type': 'string', + }), }), 'required': list([ 'name', @@ -8785,6 +8862,13 @@ 'title': 'Name', 'type': 'string', }), + 'type': dict({ + 'enum': list([ + 'tool_call', + ]), + 'title': 'Type', + 'type': 'string', + }), }), 'required': list([ 'name', @@ -9238,6 +9322,13 @@ 'title': 'Name', 'type': 'string', }), + 'type': dict({ + 'enum': list([ + 'invalid_tool_call', + ]), + 'title': 'Type', + 'type': 'string', + }), }), 'required': list([ 'name', @@ -9340,6 +9431,13 @@ 'title': 'Name', 'type': 'string', }), + 'type': dict({ + 'enum': list([ + 'tool_call', + ]), + 'title': 'Type', + 'type': 'string', + }), }), 'required': list([ 'name', @@ -9880,6 +9978,13 @@ 'title': 'Name', 'type': 'string', }), + 'type': dict({ + 'enum': list([ + 'invalid_tool_call', + ]), + 'title': 'Type', + 'type': 'string', + }), }), 'required': list([ 'name', @@ -10004,6 +10109,13 @@ 'title': 'Name', 'type': 'string', }), + 'type': dict({ + 'enum': list([ + 'tool_call', + ]), + 'title': 'Type', + 'type': 'string', + }), }), 'required': list([ 'name', diff --git a/libs/core/tests/unit_tests/test_tools.py b/libs/core/tests/unit_tests/test_tools.py index 61266838ecca3..478707ea555c1 100644 --- a/libs/core/tests/unit_tests/test_tools.py +++ b/libs/core/tests/unit_tests/test_tools.py @@ -8,7 +8,7 @@ from datetime import datetime from enum import Enum from functools import partial -from typing import Any, Callable, Dict, List, Optional, Type, Union +from typing import Any, Callable, Dict, List, Literal, Optional, Tuple, Type, Union import pytest from typing_extensions import Annotated, TypedDict @@ -17,6 +17,7 @@ AsyncCallbackManagerForToolRun, CallbackManagerForToolRun, ) +from langchain_core.messages import ToolMessage from langchain_core.pydantic_v1 import BaseModel, ValidationError from langchain_core.runnables import Runnable, RunnableLambda, ensure_config from langchain_core.tools import ( @@ -1067,6 +1068,65 @@ def foo( } +def test_tool_call_input_tool_message_output() -> None: + tool_call = { + "name": "structured_api", + "args": {"arg1": 1, "arg2": True, "arg3": {"img": "base64string..."}}, + "id": "123", + "type": "tool_call", + } + tool = _MockStructuredTool() + expected = ToolMessage("1 True {'img': 'base64string...'}", tool_call_id="123") + actual = tool.invoke(tool_call) + assert actual == expected + + tool_call.pop("type") + with pytest.raises(ValidationError): + tool.invoke(tool_call) + + +class _MockStructuredToolWithRawOutput(BaseTool): + name: str = "structured_api" + args_schema: Type[BaseModel] = _MockSchema + description: str = "A Structured Tool" + response_format: Literal["content_and_raw_output"] = "content_and_raw_output" + + def _run( + self, arg1: int, arg2: bool, arg3: Optional[dict] = None + ) -> Tuple[str, dict]: + return f"{arg1} {arg2}", {"arg1": arg1, "arg2": arg2, "arg3": arg3} + + +@tool("structured_api", response_format="content_and_raw_output") +def _mock_structured_tool_with_raw_output( + arg1: int, arg2: bool, arg3: Optional[dict] = None +) -> Tuple[str, dict]: + """A Structured Tool""" + return f"{arg1} {arg2}", {"arg1": arg1, "arg2": arg2, "arg3": arg3} + + +@pytest.mark.parametrize( + "tool", [_MockStructuredToolWithRawOutput(), _mock_structured_tool_with_raw_output] +) +def test_tool_call_input_tool_message_with_raw_output(tool: BaseTool) -> None: + tool_call: Dict = { + "name": "structured_api", + "args": {"arg1": 1, "arg2": True, "arg3": {"img": "base64string..."}}, + "id": "123", + "type": "tool_call", + } + expected = ToolMessage("1 True", raw_output=tool_call["args"], tool_call_id="123") + actual = tool.invoke(tool_call) + assert actual == expected + + tool_call.pop("type") + with pytest.raises(ValidationError): + tool.invoke(tool_call) + + actual_content = tool.invoke(tool_call["args"]) + assert actual_content == expected.content + + def test_convert_from_runnable_dict() -> None: # Test with typed dict input class Args(TypedDict): diff --git a/libs/langchain/poetry.lock b/libs/langchain/poetry.lock index 000516ed8f80f..a406f2ed0123e 100644 --- a/libs/langchain/poetry.lock +++ b/libs/langchain/poetry.lock @@ -1,4 +1,4 @@ -# This file is automatically @generated by Poetry 1.8.3 and should not be changed by hand. +# This file is automatically @generated by Poetry 1.8.2 and should not be changed by hand. [[package]] name = "aiohttp" @@ -1760,7 +1760,7 @@ files = [ [[package]] name = "langchain-core" -version = "0.2.12" +version = "0.2.13" description = "Building applications with LLMs through composability" optional = false python-versions = ">=3.8.1,<4.0" @@ -1784,7 +1784,7 @@ url = "../core" [[package]] name = "langchain-openai" -version = "0.1.14" +version = "0.1.15" description = "An integration package connecting OpenAI and LangChain" optional = true python-versions = ">=3.8.1,<4.0" @@ -1792,7 +1792,7 @@ files = [] develop = true [package.dependencies] -langchain-core = ">=0.2.2,<0.3" +langchain-core = "^0.2.13" openai = "^1.32.0" tiktoken = ">=0.7,<1" @@ -1834,13 +1834,13 @@ types-requests = ">=2.31.0.2,<3.0.0.0" [[package]] name = "langsmith" -version = "0.1.84" +version = "0.1.85" description = "Client library to connect to the LangSmith LLM Tracing and Evaluation Platform." optional = false python-versions = "<4.0,>=3.8.1" files = [ - {file = "langsmith-0.1.84-py3-none-any.whl", hash = "sha256:01f3c6390dba26c583bac8dd0e551ce3d0509c7f55cad714db0b5c8d36e4c7ff"}, - {file = "langsmith-0.1.84.tar.gz", hash = "sha256:5220c0439838b9a5bd320fd3686be505c5083dcee22d2452006c23891153bea1"}, + {file = "langsmith-0.1.85-py3-none-any.whl", hash = "sha256:c1f94384f10cea96f7b4d33fd3db7ec180c03c7468877d50846f881d2017ff94"}, + {file = "langsmith-0.1.85.tar.gz", hash = "sha256:acff31f9e53efa48586cf8e32f65625a335c74d7c4fa306d1655ac18452296f6"}, ] [package.dependencies] @@ -2350,13 +2350,13 @@ files = [ [[package]] name = "openai" -version = "1.35.10" +version = "1.35.13" description = "The official Python library for the openai API" optional = true python-versions = ">=3.7.1" files = [ - {file = "openai-1.35.10-py3-none-any.whl", hash = "sha256:962cb5c23224b5cbd16078308dabab97a08b0a5ad736a4fdb3dc2ffc44ac974f"}, - {file = "openai-1.35.10.tar.gz", hash = "sha256:85966949f4f960f3e4b239a659f9fd64d3a97ecc43c44dc0a044b5c7f11cccc6"}, + {file = "openai-1.35.13-py3-none-any.whl", hash = "sha256:36ec3e93e0d1f243f69be85c89b9221a471c3e450dfd9df16c9829e3cdf63e60"}, + {file = "openai-1.35.13.tar.gz", hash = "sha256:c684f3945608baf7d2dcc0ef3ee6f3e27e4c66f21076df0b47be45d57e6ae6e4"}, ] [package.dependencies] @@ -4141,13 +4141,13 @@ urllib3 = ">=2" [[package]] name = "types-setuptools" -version = "70.2.0.20240704" +version = "70.3.0.20240710" description = "Typing stubs for setuptools" optional = false python-versions = ">=3.8" files = [ - {file = "types-setuptools-70.2.0.20240704.tar.gz", hash = "sha256:2f8d28d16ca1607080f9fdf19595bd49c942884b2bbd6529c9b8a9a8fc8db911"}, - {file = "types_setuptools-70.2.0.20240704-py3-none-any.whl", hash = "sha256:6b892d5441c2ed58dd255724516e3df1db54892fb20597599aea66d04c3e4d7f"}, + {file = "types-setuptools-70.3.0.20240710.tar.gz", hash = "sha256:842cbf399812d2b65042c9d6ff35113bbf282dee38794779aa1f94e597bafc35"}, + {file = "types_setuptools-70.3.0.20240710-py3-none-any.whl", hash = "sha256:bd0db2a4b9f2c49ac5564be4e0fb3125c4c46b1f73eafdcbceffa5b005cceca4"}, ] [[package]] diff --git a/libs/partners/anthropic/langchain_anthropic/chat_models.py b/libs/partners/anthropic/langchain_anthropic/chat_models.py index 61377694be523..6f943340ac441 100644 --- a/libs/partners/anthropic/langchain_anthropic/chat_models.py +++ b/libs/partners/anthropic/langchain_anthropic/chat_models.py @@ -43,6 +43,7 @@ ToolMessage, ) from langchain_core.messages.ai import UsageMetadata +from langchain_core.messages.tool import tool_call_chunk as create_tool_call_chunk from langchain_core.output_parsers import ( JsonOutputKeyToolsParser, PydanticToolsParser, @@ -1102,12 +1103,12 @@ def _make_message_chunk_from_anthropic_event( warnings.warn("Received unexpected tool content block.") content_block = event.content_block.model_dump() content_block["index"] = event.index - tool_call_chunk = { - "index": event.index, - "id": event.content_block.id, - "name": event.content_block.name, - "args": "", - } + tool_call_chunk = create_tool_call_chunk( + index=event.index, + id=event.content_block.id, + name=event.content_block.name, + args="", + ) message_chunk = AIMessageChunk( content=[content_block], tool_call_chunks=[tool_call_chunk], # type: ignore diff --git a/libs/partners/anthropic/langchain_anthropic/output_parsers.py b/libs/partners/anthropic/langchain_anthropic/output_parsers.py index c8d6aa3aeecc0..cd9f5308ddc51 100644 --- a/libs/partners/anthropic/langchain_anthropic/output_parsers.py +++ b/libs/partners/anthropic/langchain_anthropic/output_parsers.py @@ -1,6 +1,7 @@ from typing import Any, List, Optional, Type, Union, cast from langchain_core.messages import AIMessage, ToolCall +from langchain_core.messages.tool import tool_call from langchain_core.output_parsers import BaseGenerationOutputParser from langchain_core.outputs import ChatGeneration, Generation from langchain_core.pydantic_v1 import BaseModel @@ -79,7 +80,7 @@ def extract_tool_calls(content: Union[str, List[Union[str, dict]]]) -> List[Tool if block["type"] != "tool_use": continue tool_calls.append( - ToolCall(name=block["name"], args=block["input"], id=block["id"]) + tool_call(name=block["name"], args=block["input"], id=block["id"]) ) return tool_calls else: diff --git a/libs/partners/anthropic/tests/integration_tests/test_chat_models.py b/libs/partners/anthropic/tests/integration_tests/test_chat_models.py index a3a1d90f6a967..de6cdc0d139ba 100644 --- a/libs/partners/anthropic/tests/integration_tests/test_chat_models.py +++ b/libs/partners/anthropic/tests/integration_tests/test_chat_models.py @@ -365,10 +365,7 @@ async def test_astreaming() -> None: def test_tool_use() -> None: - llm = ChatAnthropic( # type: ignore[call-arg] - model=MODEL_NAME, - ) - + llm = ChatAnthropic(model=MODEL_NAME) # type: ignore[call-arg] llm_with_tools = llm.bind_tools( [ { @@ -478,6 +475,7 @@ def type_letter(letter: str) -> str: "name": "type_letter", "args": {"letter": "d"}, "id": "toolu_01V6d6W32QGGSmQm4BT98EKk", + "type": "tool_call", }, ], ), diff --git a/libs/partners/anthropic/tests/unit_tests/test_output_parsers.py b/libs/partners/anthropic/tests/unit_tests/test_output_parsers.py index a163702e544da..84e2e7506f8e4 100644 --- a/libs/partners/anthropic/tests/unit_tests/test_output_parsers.py +++ b/libs/partners/anthropic/tests/unit_tests/test_output_parsers.py @@ -33,8 +33,20 @@ class _Foo2(BaseModel): def test_tools_output_parser() -> None: output_parser = ToolsOutputParser() expected = [ - {"name": "_Foo1", "args": {"bar": 0}, "id": "1", "index": 1}, - {"name": "_Foo2", "args": {"baz": "a"}, "id": "2", "index": 3}, + { + "name": "_Foo1", + "args": {"bar": 0}, + "id": "1", + "index": 1, + "type": "tool_call", + }, + { + "name": "_Foo2", + "args": {"baz": "a"}, + "id": "2", + "index": 3, + "type": "tool_call", + }, ] actual = output_parser.parse_result(_RESULT) assert expected == actual @@ -56,7 +68,13 @@ def test_tools_output_parser_args_only() -> None: def test_tools_output_parser_first_tool_only() -> None: output_parser = ToolsOutputParser(first_tool_only=True) - expected: Any = {"name": "_Foo1", "args": {"bar": 0}, "id": "1", "index": 1} + expected: Any = { + "name": "_Foo1", + "args": {"bar": 0}, + "id": "1", + "index": 1, + "type": "tool_call", + } actual = output_parser.parse_result(_RESULT) assert expected == actual @@ -81,7 +99,14 @@ class ChartType(BaseModel): ) message = AIMessage( "", - tool_calls=[{"name": "ChartType", "args": {"chart_type": "pie"}, "id": "foo"}], + tool_calls=[ + { + "name": "ChartType", + "args": {"chart_type": "pie"}, + "id": "foo", + "type": "tool_call", + } + ], ) actual = output_parser.invoke(message) expected = ChartType(chart_type="pie") diff --git a/libs/partners/azure-dynamic-sessions/langchain_azure_dynamic_sessions/tools/sessions.py b/libs/partners/azure-dynamic-sessions/langchain_azure_dynamic_sessions/tools/sessions.py index b33a52da931a8..334ecfd017356 100644 --- a/libs/partners/azure-dynamic-sessions/langchain_azure_dynamic_sessions/tools/sessions.py +++ b/libs/partners/azure-dynamic-sessions/langchain_azure_dynamic_sessions/tools/sessions.py @@ -9,10 +9,11 @@ import os import re import urllib +from copy import deepcopy from dataclasses import dataclass from datetime import datetime, timedelta, timezone from io import BytesIO -from typing import Any, BinaryIO, Callable, List, Optional +from typing import Any, BinaryIO, Callable, List, Literal, Optional, Tuple from uuid import uuid4 import requests @@ -126,6 +127,8 @@ class SessionsPythonREPLTool(BaseTool): session_id: str = str(uuid4()) """The session ID to use for the code interpreter. Defaults to a random UUID.""" + response_format: Literal["content_and_raw_output"] = "content_and_raw_output" + def _build_url(self, path: str) -> str: pool_management_endpoint = self.pool_management_endpoint if not pool_management_endpoint: @@ -164,16 +167,16 @@ def execute(self, python_code: str) -> Any: properties = response_json.get("properties", {}) return properties - def _run(self, python_code: str) -> Any: + def _run(self, python_code: str, **kwargs: Any) -> Tuple[str, dict]: response = self.execute(python_code) # if the result is an image, remove the base64 data - result = response.get("result") + result = deepcopy(response.get("result")) if isinstance(result, dict): if result.get("type") == "image" and "base64_data" in result: result.pop("base64_data") - return json.dumps( + content = json.dumps( { "result": result, "stdout": response.get("stdout"), @@ -181,6 +184,7 @@ def _run(self, python_code: str) -> Any: }, indent=2, ) + return content, response def upload_file( self, diff --git a/libs/partners/fireworks/langchain_fireworks/chat_models.py b/libs/partners/fireworks/langchain_fireworks/chat_models.py index 64257e9f33dcf..4e3ce14816155 100644 --- a/libs/partners/fireworks/langchain_fireworks/chat_models.py +++ b/libs/partners/fireworks/langchain_fireworks/chat_models.py @@ -54,6 +54,12 @@ ToolMessage, ToolMessageChunk, ) +from langchain_core.messages.tool import ( + ToolCallChunk, +) +from langchain_core.messages.tool import ( + tool_call_chunk as create_tool_call_chunk, +) from langchain_core.output_parsers import JsonOutputParser, PydanticOutputParser from langchain_core.output_parsers.base import OutputParserLike from langchain_core.output_parsers.openai_tools import ( @@ -199,6 +205,7 @@ def _convert_chunk_to_message_chunk( role = cast(str, _dict.get("role")) content = cast(str, _dict.get("content") or "") additional_kwargs: Dict = {} + tool_call_chunks: List[ToolCallChunk] = [] if _dict.get("function_call"): function_call = dict(_dict["function_call"]) if "name" in function_call and function_call["name"] is None: @@ -206,21 +213,18 @@ def _convert_chunk_to_message_chunk( additional_kwargs["function_call"] = function_call if raw_tool_calls := _dict.get("tool_calls"): additional_kwargs["tool_calls"] = raw_tool_calls - try: - tool_call_chunks = [ - { - "name": rtc["function"].get("name"), - "args": rtc["function"].get("arguments"), - "id": rtc.get("id"), - "index": rtc["index"], - } - for rtc in raw_tool_calls - ] - except KeyError: - pass - else: - tool_call_chunks = [] - + for rtc in raw_tool_calls: + try: + tool_call_chunks.append( + create_tool_call_chunk( + name=rtc["function"].get("name"), + args=rtc["function"].get("arguments"), + id=rtc.get("id"), + index=rtc.get("index"), + ) + ) + except KeyError: + pass if role == "user" or default_class == HumanMessageChunk: return HumanMessageChunk(content=content) elif role == "assistant" or default_class == AIMessageChunk: @@ -237,7 +241,7 @@ def _convert_chunk_to_message_chunk( return AIMessageChunk( content=content, additional_kwargs=additional_kwargs, - tool_call_chunks=tool_call_chunks, # type: ignore[arg-type] + tool_call_chunks=tool_call_chunks, usage_metadata=usage_metadata, # type: ignore[arg-type] ) elif role == "system" or default_class == SystemMessageChunk: diff --git a/libs/partners/groq/langchain_groq/chat_models.py b/libs/partners/groq/langchain_groq/chat_models.py index efb5ce002231c..f7ce27fa9569d 100644 --- a/libs/partners/groq/langchain_groq/chat_models.py +++ b/libs/partners/groq/langchain_groq/chat_models.py @@ -53,6 +53,7 @@ ToolMessage, ToolMessageChunk, ) +from langchain_core.messages.tool import tool_call_chunk as create_tool_call_chunk from langchain_core.output_parsers import ( JsonOutputParser, PydanticOutputParser, @@ -511,19 +512,19 @@ def _stream( generation = chat_result.generations[0] message = cast(AIMessage, generation.message) tool_call_chunks = [ - { - "name": rtc["function"].get("name"), - "args": rtc["function"].get("arguments"), - "id": rtc.get("id"), - "index": rtc.get("index"), - } + create_tool_call_chunk( + name=rtc["function"].get("name"), + args=rtc["function"].get("arguments"), + id=rtc.get("id"), + index=rtc.get("index"), + ) for rtc in message.additional_kwargs.get("tool_calls", []) ] chunk_ = ChatGenerationChunk( message=AIMessageChunk( content=message.content, additional_kwargs=message.additional_kwargs, - tool_call_chunks=tool_call_chunks, # type: ignore[arg-type] + tool_call_chunks=tool_call_chunks, usage_metadata=message.usage_metadata, ), generation_info=generation.generation_info, diff --git a/libs/partners/groq/tests/unit_tests/test_chat_models.py b/libs/partners/groq/tests/unit_tests/test_chat_models.py index c061691612f47..698b50f7d1f49 100644 --- a/libs/partners/groq/tests/unit_tests/test_chat_models.py +++ b/libs/partners/groq/tests/unit_tests/test_chat_models.py @@ -77,6 +77,7 @@ def test__convert_dict_to_message_tool_call() -> None: name="GenerateUsername", args={"name": "Sally", "hair_color": "green"}, id="call_wm0JY6CdwOMZ4eTxHWUThDNz", + type="tool_call", ) ], ) @@ -112,6 +113,7 @@ def test__convert_dict_to_message_tool_call() -> None: args="oops", id="call_wm0JY6CdwOMZ4eTxHWUThDNz", error="Function GenerateUsername arguments:\n\noops\n\nare not valid JSON. Received JSONDecodeError Expecting value: line 1 column 1 (char 0)", # noqa: E501 + type="invalid_tool_call", ), ], tool_calls=[ @@ -119,6 +121,7 @@ def test__convert_dict_to_message_tool_call() -> None: name="GenerateUsername", args={"name": "Sally", "hair_color": "green"}, id="call_abc123", + type="tool_call", ), ], ) diff --git a/libs/partners/ibm/langchain_ibm/chat_models.py b/libs/partners/ibm/langchain_ibm/chat_models.py index e7fa66ea4ed0b..96a51b422662e 100644 --- a/libs/partners/ibm/langchain_ibm/chat_models.py +++ b/libs/partners/ibm/langchain_ibm/chat_models.py @@ -42,10 +42,12 @@ HumanMessageChunk, SystemMessage, SystemMessageChunk, + ToolCallChunk, ToolMessage, ToolMessageChunk, convert_to_messages, ) +from langchain_core.messages.tool import tool_call_chunk as create_tool_call_chunk from langchain_core.output_parsers import JsonOutputParser, PydanticOutputParser from langchain_core.output_parsers.base import OutputParserLike from langchain_core.output_parsers.openai_tools import ( @@ -174,6 +176,7 @@ def _convert_delta_to_message_chunk( role = cast(str, _dict.get("role")) content = cast(str, _dict.get("content") or "") additional_kwargs: Dict = {} + tool_call_chunks: List[ToolCallChunk] = [] if _dict.get("function_call"): function_call = dict(_dict["function_call"]) if "name" in function_call and function_call["name"] is None: @@ -181,21 +184,18 @@ def _convert_delta_to_message_chunk( additional_kwargs["function_call"] = function_call if raw_tool_calls := _dict.get("tool_calls"): additional_kwargs["tool_calls"] = raw_tool_calls - try: - tool_call_chunks = [ - { - "name": rtc["function"].get("name"), - "args": rtc["function"].get("arguments"), - "id": rtc.get("id"), - "index": rtc["index"], - } - for rtc in raw_tool_calls - ] - except KeyError: - pass - else: - tool_call_chunks = [] - + for rtc in raw_tool_calls: + try: + tool_call_chunks.append( + create_tool_call_chunk( + name=rtc["function"].get("name"), + args=rtc["function"].get("arguments"), + id=rtc.get("id"), + index=rtc.get("index"), + ) + ) + except KeyError: + pass if role == "user" or default_class == HumanMessageChunk: return HumanMessageChunk(content=content) elif role == "assistant" or default_class == AIMessageChunk: diff --git a/libs/partners/mistralai/langchain_mistralai/chat_models.py b/libs/partners/mistralai/langchain_mistralai/chat_models.py index 1d510e1c23f02..81256171ce64a 100644 --- a/libs/partners/mistralai/langchain_mistralai/chat_models.py +++ b/libs/partners/mistralai/langchain_mistralai/chat_models.py @@ -50,6 +50,7 @@ ToolCall, ToolMessage, ) +from langchain_core.messages.tool import tool_call_chunk from langchain_core.output_parsers import ( JsonOutputParser, PydanticOutputParser, @@ -103,19 +104,10 @@ def _convert_mistral_chat_message_to_message( dict, parse_tool_call(raw_tool_call, return_id=True) ) if not parsed["id"]: - tool_call_id = uuid.uuid4().hex[:] - tool_calls.append( - { - **parsed, - **{"id": tool_call_id}, - }, - ) - else: - tool_calls.append(parsed) + parsed["id"] = uuid.uuid4().hex[:] + tool_calls.append(parsed) except Exception as e: - invalid_tool_calls.append( - dict(make_invalid_tool_call(raw_tool_call, str(e))) - ) + invalid_tool_calls.append(make_invalid_tool_call(raw_tool_call, str(e))) return AIMessage( content=content, additional_kwargs=additional_kwargs, @@ -206,12 +198,12 @@ def _convert_chunk_to_message_chunk( else: tool_call_id = raw_tool_call.get("id") tool_call_chunks.append( - { - "name": raw_tool_call["function"].get("name"), - "args": raw_tool_call["function"].get("arguments"), - "id": tool_call_id, - "index": raw_tool_call.get("index"), - } + tool_call_chunk( + name=raw_tool_call["function"].get("name"), + args=raw_tool_call["function"].get("arguments"), + id=tool_call_id, + index=raw_tool_call.get("index"), + ) ) except KeyError: pass diff --git a/libs/partners/mistralai/tests/unit_tests/test_chat_models.py b/libs/partners/mistralai/tests/unit_tests/test_chat_models.py index 737b4dd1978ae..011bb0fa4ff88 100644 --- a/libs/partners/mistralai/tests/unit_tests/test_chat_models.py +++ b/libs/partners/mistralai/tests/unit_tests/test_chat_models.py @@ -144,6 +144,7 @@ def test__convert_dict_to_message_tool_call() -> None: name="GenerateUsername", args={"name": "Sally", "hair_color": "green"}, id="abc123", + type="tool_call", ) ], ) @@ -178,6 +179,7 @@ def test__convert_dict_to_message_tool_call() -> None: args="oops", error="Function GenerateUsername arguments:\n\noops\n\nare not valid JSON. Received JSONDecodeError Expecting value: line 1 column 1 (char 0)", # noqa: E501 id="abc123", + type="invalid_tool_call", ), ], tool_calls=[ @@ -185,6 +187,7 @@ def test__convert_dict_to_message_tool_call() -> None: name="GenerateUsername", args={"name": "Sally", "hair_color": "green"}, id="def456", + type="tool_call", ), ], ) diff --git a/libs/partners/openai/langchain_openai/chat_models/base.py b/libs/partners/openai/langchain_openai/chat_models/base.py index db5f2eec5a9e9..fc94175efa784 100644 --- a/libs/partners/openai/langchain_openai/chat_models/base.py +++ b/libs/partners/openai/langchain_openai/chat_models/base.py @@ -63,6 +63,7 @@ ToolMessageChunk, ) from langchain_core.messages.ai import UsageMetadata +from langchain_core.messages.tool import tool_call_chunk from langchain_core.output_parsers import JsonOutputParser, PydanticOutputParser from langchain_core.output_parsers.base import OutputParserLike from langchain_core.output_parsers.openai_tools import ( @@ -244,12 +245,12 @@ def _convert_delta_to_message_chunk( additional_kwargs["tool_calls"] = raw_tool_calls try: tool_call_chunks = [ - { - "name": rtc["function"].get("name"), - "args": rtc["function"].get("arguments"), - "id": rtc.get("id"), - "index": rtc["index"], - } + tool_call_chunk( + name=rtc["function"].get("name"), + args=rtc["function"].get("arguments"), + id=rtc.get("id"), + index=rtc["index"], + ) for rtc in raw_tool_calls ] except KeyError: diff --git a/libs/partners/openai/tests/unit_tests/chat_models/test_base.py b/libs/partners/openai/tests/unit_tests/chat_models/test_base.py index 55826046c8733..94bf2277c6fb8 100644 --- a/libs/partners/openai/tests/unit_tests/chat_models/test_base.py +++ b/libs/partners/openai/tests/unit_tests/chat_models/test_base.py @@ -117,6 +117,7 @@ def test__convert_dict_to_message_tool_call() -> None: name="GenerateUsername", args={"name": "Sally", "hair_color": "green"}, id="call_wm0JY6CdwOMZ4eTxHWUThDNz", + type="tool_call", ) ], ) @@ -151,6 +152,7 @@ def test__convert_dict_to_message_tool_call() -> None: args="oops", id="call_wm0JY6CdwOMZ4eTxHWUThDNz", error="Function GenerateUsername arguments:\n\noops\n\nare not valid JSON. Received JSONDecodeError Expecting value: line 1 column 1 (char 0)", # noqa: E501 + type="invalid_tool_call", ) ], tool_calls=[ @@ -158,6 +160,7 @@ def test__convert_dict_to_message_tool_call() -> None: name="GenerateUsername", args={"name": "Sally", "hair_color": "green"}, id="call_abc123", + type="tool_call", ) ], ) @@ -353,7 +356,10 @@ def test_get_num_tokens_from_messages() -> None: ), AIMessage("a nice bird"), AIMessage( - "", tool_calls=[ToolCall(id="foo", name="bar", args={"arg1": "arg1"})] + "", + tool_calls=[ + ToolCall(id="foo", name="bar", args={"arg1": "arg1"}, type="tool_call") + ], ), AIMessage( "", @@ -362,7 +368,10 @@ def test_get_num_tokens_from_messages() -> None: }, ), AIMessage( - "text", tool_calls=[ToolCall(id="foo", name="bar", args={"arg1": "arg1"})] + "text", + tool_calls=[ + ToolCall(id="foo", name="bar", args={"arg1": "arg1"}, type="tool_call") + ], ), ToolMessage("foobar", tool_call_id="foo"), ]