Source code for langchain.chains.retrieval_qa.base

"""Chain for question-answering against a vector database."""
from __future__ import annotations

import warnings
from abc import abstractmethod
from typing import Any, Dict, List, Optional

from pydantic import Extra, Field, root_validator

from langchain.chains.base import Chain
from langchain.chains.combine_documents.base import BaseCombineDocumentsChain
from langchain.chains.combine_documents.stuff import StuffDocumentsChain
from langchain.chains.llm import LLMChain
from langchain.chains.question_answering import load_qa_chain
from langchain.chains.question_answering.stuff_prompt import PROMPT_SELECTOR
from langchain.prompts import PromptTemplate
from langchain.schema import BaseLanguageModel, BaseRetriever, Document
from langchain.vectorstores.base import VectorStore


class BaseRetrievalQA(Chain):
    combine_documents_chain: BaseCombineDocumentsChain
    """Chain to use to combine the documents."""
    input_key: str = "query"  #: :meta private:
    output_key: str = "result"  #: :meta private:
    return_source_documents: bool = False
    """Return the source documents."""

    class Config:
        """Configuration for this pydantic object."""

        extra = Extra.forbid
        arbitrary_types_allowed = True
        allow_population_by_field_name = True

    @property
    def input_keys(self) -> List[str]:
        """Return the input keys.

        :meta private:
        """
        return [self.input_key]

    @property
    def output_keys(self) -> List[str]:
        """Return the output keys.

        :meta private:
        """
        _output_keys = [self.output_key]
        if self.return_source_documents:
            _output_keys = _output_keys + ["source_documents"]
        return _output_keys

    @classmethod
    def from_llm(
        cls,
        llm: BaseLanguageModel,
        prompt: Optional[PromptTemplate] = None,
        **kwargs: Any,
    ) -> BaseRetrievalQA:
        """Initialize from LLM."""
        _prompt = prompt or PROMPT_SELECTOR.get_prompt(llm)
        llm_chain = LLMChain(llm=llm, prompt=_prompt)
        document_prompt = PromptTemplate(
            input_variables=["page_content"], template="Context:\n{page_content}"
        )
        combine_documents_chain = StuffDocumentsChain(
            llm_chain=llm_chain,
            document_variable_name="context",
            document_prompt=document_prompt,
        )

        return cls(combine_documents_chain=combine_documents_chain, **kwargs)

    @classmethod
    def from_chain_type(
        cls,
        llm: BaseLanguageModel,
        chain_type: str = "stuff",
        chain_type_kwargs: Optional[dict] = None,
        **kwargs: Any,
    ) -> BaseRetrievalQA:
        """Load chain from chain type."""
        _chain_type_kwargs = chain_type_kwargs or {}
        combine_documents_chain = load_qa_chain(
            llm, chain_type=chain_type, **_chain_type_kwargs
        )
        return cls(combine_documents_chain=combine_documents_chain, **kwargs)

    @abstractmethod
    def _get_docs(self, question: str) -> List[Document]:
        """Get documents to do question answering over."""

    def _call(self, inputs: Dict[str, str]) -> Dict[str, Any]:
        """Run get_relevant_text and llm on input query.

        If chain has 'return_source_documents' as 'True', returns
        the retrieved documents as well under the key 'source_documents'.

        Example:
        .. code-block:: python

        res = indexqa({'query': 'This is my query'})
        answer, docs = res['result'], res['source_documents']
        """
        question = inputs[self.input_key]

        docs = self._get_docs(question)
        answer = self.combine_documents_chain.run(
            input_documents=docs, question=question
        )

        if self.return_source_documents:
            return {self.output_key: answer, "source_documents": docs}
        else:
            return {self.output_key: answer}

    @abstractmethod
    async def _aget_docs(self, question: str) -> List[Document]:
        """Get documents to do question answering over."""

    async def _acall(self, inputs: Dict[str, str]) -> Dict[str, Any]:
        """Run get_relevant_text and llm on input query.

        If chain has 'return_source_documents' as 'True', returns
        the retrieved documents as well under the key 'source_documents'.

        Example:
        .. code-block:: python

        res = indexqa({'query': 'This is my query'})
        answer, docs = res['result'], res['source_documents']
        """
        question = inputs[self.input_key]

        docs = await self._aget_docs(question)
        answer = await self.combine_documents_chain.arun(
            input_documents=docs, question=question
        )

        if self.return_source_documents:
            return {self.output_key: answer, "source_documents": docs}
        else:
            return {self.output_key: answer}


[docs]class RetrievalQA(BaseRetrievalQA): """Chain for question-answering against an index. Example: .. code-block:: python from langchain.llms import OpenAI from langchain.chains import RetrievalQA from langchain.faiss import FAISS from langchain.vectorstores.base import VectorStoreRetriever retriever = VectorStoreRetriever(vectorstore=FAISS(...)) retrievalQA = RetrievalQA.from_llm(llm=OpenAI(), retriever=retriever) """ retriever: BaseRetriever = Field(exclude=True) def _get_docs(self, question: str) -> List[Document]: return self.retriever.get_relevant_documents(question) async def _aget_docs(self, question: str) -> List[Document]: return await self.retriever.aget_relevant_documents(question)
[docs]class VectorDBQA(BaseRetrievalQA): """Chain for question-answering against a vector database.""" vectorstore: VectorStore = Field(exclude=True, alias="vectorstore") """Vector Database to connect to.""" k: int = 4 """Number of documents to query for.""" search_type: str = "similarity" """Search type to use over vectorstore. `similarity` or `mmr`.""" search_kwargs: Dict[str, Any] = Field(default_factory=dict) """Extra search args.""" @root_validator() def raise_deprecation(cls, values: Dict) -> Dict: warnings.warn( "`VectorDBQA` is deprecated - " "please use `from langchain.chains import RetrievalQA`" ) return values @root_validator() def validate_search_type(cls, values: Dict) -> Dict: """Validate search type.""" if "search_type" in values: search_type = values["search_type"] if search_type not in ("similarity", "mmr"): raise ValueError(f"search_type of {search_type} not allowed.") return values def _get_docs(self, question: str) -> List[Document]: if self.search_type == "similarity": docs = self.vectorstore.similarity_search( question, k=self.k, **self.search_kwargs ) elif self.search_type == "mmr": docs = self.vectorstore.max_marginal_relevance_search( question, k=self.k, **self.search_kwargs ) else: raise ValueError(f"search_type of {self.search_type} not allowed.") return docs async def _aget_docs(self, question: str) -> List[Document]: raise NotImplementedError("VectorDBQA does not support async") @property def _chain_type(self) -> str: """Return the chain type.""" return "vector_db_qa"