Source code for langchain.chains.qa_with_sources.base

"""Question answering with sources over documents."""

from __future__ import annotations

import re
from abc import ABC, abstractmethod
from typing import Any, Dict, List, Optional

from pydantic import Extra, root_validator

from langchain.chains.base import Chain
from langchain.chains.combine_documents.base import BaseCombineDocumentsChain
from langchain.chains.combine_documents.map_reduce import MapReduceDocumentsChain
from langchain.chains.combine_documents.stuff import StuffDocumentsChain
from langchain.chains.llm import LLMChain
from langchain.chains.qa_with_sources.loading import load_qa_with_sources_chain
from langchain.chains.qa_with_sources.map_reduce_prompt import (
    COMBINE_PROMPT,
    EXAMPLE_PROMPT,
    QUESTION_PROMPT,
)
from langchain.docstore.document import Document
from langchain.prompts.base import BasePromptTemplate
from langchain.schema import BaseLanguageModel


class BaseQAWithSourcesChain(Chain, ABC):
    """Question answering with sources over documents."""

    combine_documents_chain: BaseCombineDocumentsChain
    """Chain to use to combine documents."""
    question_key: str = "question"  #: :meta private:
    input_docs_key: str = "docs"  #: :meta private:
    answer_key: str = "answer"  #: :meta private:
    sources_answer_key: str = "sources"  #: :meta private:
    return_source_documents: bool = False
    """Return the source documents."""

    @classmethod
    def from_llm(
        cls,
        llm: BaseLanguageModel,
        document_prompt: BasePromptTemplate = EXAMPLE_PROMPT,
        question_prompt: BasePromptTemplate = QUESTION_PROMPT,
        combine_prompt: BasePromptTemplate = COMBINE_PROMPT,
        **kwargs: Any,
    ) -> BaseQAWithSourcesChain:
        """Construct the chain from an LLM."""
        llm_question_chain = LLMChain(llm=llm, prompt=question_prompt)
        llm_combine_chain = LLMChain(llm=llm, prompt=combine_prompt)
        combine_results_chain = StuffDocumentsChain(
            llm_chain=llm_combine_chain,
            document_prompt=document_prompt,
            document_variable_name="summaries",
        )
        combine_document_chain = MapReduceDocumentsChain(
            llm_chain=llm_question_chain,
            combine_document_chain=combine_results_chain,
            document_variable_name="context",
        )
        return cls(
            combine_documents_chain=combine_document_chain,
            **kwargs,
        )

    @classmethod
    def from_chain_type(
        cls,
        llm: BaseLanguageModel,
        chain_type: str = "stuff",
        chain_type_kwargs: Optional[dict] = None,
        **kwargs: Any,
    ) -> BaseQAWithSourcesChain:
        """Load chain from chain type."""
        _chain_kwargs = chain_type_kwargs or {}
        combine_document_chain = load_qa_with_sources_chain(
            llm, chain_type=chain_type, **_chain_kwargs
        )
        return cls(combine_documents_chain=combine_document_chain, **kwargs)

    class Config:
        """Configuration for this pydantic object."""

        extra = Extra.forbid
        arbitrary_types_allowed = True

    @property
    def input_keys(self) -> List[str]:
        """Expect input key.

        :meta private:
        """
        return [self.question_key]

    @property
    def output_keys(self) -> List[str]:
        """Return output key.

        :meta private:
        """
        _output_keys = [self.answer_key, self.sources_answer_key]
        if self.return_source_documents:
            _output_keys = _output_keys + ["source_documents"]
        return _output_keys

    @root_validator(pre=True)
    def validate_naming(cls, values: Dict) -> Dict:
        """Fix backwards compatability in naming."""
        if "combine_document_chain" in values:
            values["combine_documents_chain"] = values.pop("combine_document_chain")
        return values

    @abstractmethod
    def _get_docs(self, inputs: Dict[str, Any]) -> List[Document]:
        """Get docs to run questioning over."""

    def _call(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
        docs = self._get_docs(inputs)
        answer = self.combine_documents_chain.run(input_documents=docs, **inputs)
        if re.search(r"SOURCES:\s", answer):
            answer, sources = re.split(r"SOURCES:\s", answer)
        else:
            sources = ""
        result: Dict[str, Any] = {
            self.answer_key: answer,
            self.sources_answer_key: sources,
        }
        if self.return_source_documents:
            result["source_documents"] = docs
        return result

    @abstractmethod
    async def _aget_docs(self, inputs: Dict[str, Any]) -> List[Document]:
        """Get docs to run questioning over."""

    async def _acall(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
        docs = await self._aget_docs(inputs)
        answer = await self.combine_documents_chain.arun(input_documents=docs, **inputs)
        if re.search(r"SOURCES:\s", answer):
            answer, sources = re.split(r"SOURCES:\s", answer)
        else:
            sources = ""
        result: Dict[str, Any] = {
            self.answer_key: answer,
            self.sources_answer_key: sources,
        }
        if self.return_source_documents:
            result["source_documents"] = docs
        return result


[docs]class QAWithSourcesChain(BaseQAWithSourcesChain): """Question answering with sources over documents.""" input_docs_key: str = "docs" #: :meta private: @property def input_keys(self) -> List[str]: """Expect input key. :meta private: """ return [self.input_docs_key, self.question_key] def _get_docs(self, inputs: Dict[str, Any]) -> List[Document]: return inputs.pop(self.input_docs_key) async def _aget_docs(self, inputs: Dict[str, Any]) -> List[Document]: return inputs.pop(self.input_docs_key) @property def _chain_type(self) -> str: return "qa_with_sources_chain"