Source code for langchain.agents.self_ask_with_search.base

"""Chain that does self ask with search."""
from typing import Any, Sequence, Union

from pydantic import Field

from langchain.agents.agent import Agent, AgentExecutor, AgentOutputParser
from langchain.agents.agent_types import AgentType
from langchain.agents.self_ask_with_search.output_parser import SelfAskOutputParser
from langchain.agents.self_ask_with_search.prompt import PROMPT
from langchain.agents.tools import Tool
from langchain.llms.base import BaseLLM
from langchain.prompts.base import BasePromptTemplate
from langchain.tools.base import BaseTool
from langchain.utilities.google_serper import GoogleSerperAPIWrapper
from langchain.utilities.serpapi import SerpAPIWrapper


class SelfAskWithSearchAgent(Agent):
    """Agent for the self-ask-with-search paper."""

    output_parser: AgentOutputParser = Field(default_factory=SelfAskOutputParser)

    @classmethod
    def _get_default_output_parser(cls, **kwargs: Any) -> AgentOutputParser:
        return SelfAskOutputParser()

    @property
    def _agent_type(self) -> str:
        """Return Identifier of agent type."""
        return AgentType.SELF_ASK_WITH_SEARCH

    @classmethod
    def create_prompt(cls, tools: Sequence[BaseTool]) -> BasePromptTemplate:
        """Prompt does not depend on tools."""
        return PROMPT

    @classmethod
    def _validate_tools(cls, tools: Sequence[BaseTool]) -> None:
        if len(tools) != 1:
            raise ValueError(f"Exactly one tool must be specified, but got {tools}")
        tool_names = {tool.name for tool in tools}
        if tool_names != {"Intermediate Answer"}:
            raise ValueError(
                f"Tool name should be Intermediate Answer, got {tool_names}"
            )

    @property
    def observation_prefix(self) -> str:
        """Prefix to append the observation with."""
        return "Intermediate answer: "

    @property
    def llm_prefix(self) -> str:
        """Prefix to append the LLM call with."""
        return ""


[docs]class SelfAskWithSearchChain(AgentExecutor): """Chain that does self ask with search. Example: .. code-block:: python from langchain import SelfAskWithSearchChain, OpenAI, GoogleSerperAPIWrapper search_chain = GoogleSerperAPIWrapper() self_ask = SelfAskWithSearchChain(llm=OpenAI(), search_chain=search_chain) """ def __init__( self, llm: BaseLLM, search_chain: Union[GoogleSerperAPIWrapper, SerpAPIWrapper], **kwargs: Any, ): """Initialize with just an LLM and a search chain.""" search_tool = Tool( name="Intermediate Answer", func=search_chain.run, description="Search" ) agent = SelfAskWithSearchAgent.from_llm_and_tools(llm, [search_tool]) super().__init__(agent=agent, tools=[search_tool], **kwargs)