Source code for langchain.memory.entity

import logging
from abc import ABC, abstractmethod
from itertools import islice
from typing import Any, Dict, Iterable, List, Optional

from pydantic import Field

from langchain.chains.llm import LLMChain
from langchain.memory.chat_memory import BaseChatMemory
from langchain.memory.prompt import (
    ENTITY_EXTRACTION_PROMPT,
    ENTITY_SUMMARIZATION_PROMPT,
)
from langchain.memory.utils import get_prompt_input_key
from langchain.prompts.base import BasePromptTemplate
from langchain.schema import BaseLanguageModel, BaseMessage, get_buffer_string

logger = logging.getLogger(__name__)


class BaseEntityStore(ABC):
    @abstractmethod
    def get(self, key: str, default: Optional[str] = None) -> Optional[str]:
        """Get entity value from store."""
        pass

    @abstractmethod
    def set(self, key: str, value: Optional[str]) -> None:
        """Set entity value in store."""
        pass

    @abstractmethod
    def delete(self, key: str) -> None:
        """Delete entity value from store."""
        pass

    @abstractmethod
    def exists(self, key: str) -> bool:
        """Check if entity exists in store."""
        pass

    @abstractmethod
    def clear(self) -> None:
        """Delete all entities from store."""
        pass


[docs]class InMemoryEntityStore(BaseEntityStore): """Basic in-memory entity store.""" store: Dict[str, Optional[str]] = {}
[docs] def get(self, key: str, default: Optional[str] = None) -> Optional[str]: return self.store.get(key, default)
[docs] def set(self, key: str, value: Optional[str]) -> None: self.store[key] = value
[docs] def delete(self, key: str) -> None: del self.store[key]
[docs] def exists(self, key: str) -> bool: return key in self.store
[docs] def clear(self) -> None: return self.store.clear()
[docs]class RedisEntityStore(BaseEntityStore): """Redis-backed Entity store. Entities get a TTL of 1 day by default, and that TTL is extended by 3 days every time the entity is read back. """ redis_client: Any session_id: str = "default" key_prefix: str = "memory_store" ttl: Optional[int] = 60 * 60 * 24 recall_ttl: Optional[int] = 60 * 60 * 24 * 3 def __init__( self, session_id: str = "default", url: str = "redis://localhost:6379/0", key_prefix: str = "memory_store", ttl: Optional[int] = 60 * 60 * 24, recall_ttl: Optional[int] = 60 * 60 * 24 * 3, *args: Any, **kwargs: Any, ): try: import redis except ImportError: raise ValueError( "Could not import redis python package. " "Please install it with `pip install redis`." ) super().__init__(*args, **kwargs) try: self.redis_client = redis.Redis.from_url(url=url, decode_responses=True) except redis.exceptions.ConnectionError as error: logger.error(error) self.session_id = session_id self.key_prefix = key_prefix self.ttl = ttl self.recall_ttl = recall_ttl or ttl @property def full_key_prefix(self) -> str: return f"{self.key_prefix}:{self.session_id}"
[docs] def get(self, key: str, default: Optional[str] = None) -> Optional[str]: res = ( self.redis_client.getex(f"{self.full_key_prefix}:{key}", ex=self.recall_ttl) or default or "" ) logger.debug(f"REDIS MEM get '{self.full_key_prefix}:{key}': '{res}'") return res
[docs] def set(self, key: str, value: Optional[str]) -> None: if not value: return self.delete(key) self.redis_client.set(f"{self.full_key_prefix}:{key}", value, ex=self.ttl) logger.debug( f"REDIS MEM set '{self.full_key_prefix}:{key}': '{value}' EX {self.ttl}" )
[docs] def delete(self, key: str) -> None: self.redis_client.delete(f"{self.full_key_prefix}:{key}")
[docs] def exists(self, key: str) -> bool: return self.redis_client.exists(f"{self.full_key_prefix}:{key}") == 1
[docs] def clear(self) -> None: # iterate a list in batches of size batch_size def batched(iterable: Iterable[Any], batch_size: int) -> Iterable[Any]: iterator = iter(iterable) while batch := list(islice(iterator, batch_size)): yield batch for keybatch in batched( self.redis_client.scan_iter(f"{self.full_key_prefix}:*"), 500 ): self.redis_client.delete(*keybatch)
[docs]class ConversationEntityMemory(BaseChatMemory): """Entity extractor & summarizer to memory.""" human_prefix: str = "Human" ai_prefix: str = "AI" llm: BaseLanguageModel entity_extraction_prompt: BasePromptTemplate = ENTITY_EXTRACTION_PROMPT entity_summarization_prompt: BasePromptTemplate = ENTITY_SUMMARIZATION_PROMPT entity_cache: List[str] = [] k: int = 3 chat_history_key: str = "history" entity_store: BaseEntityStore = Field(default_factory=InMemoryEntityStore) @property def buffer(self) -> List[BaseMessage]: return self.chat_memory.messages @property def memory_variables(self) -> List[str]: """Will always return list of memory variables. :meta private: """ return ["entities", self.chat_history_key]
[docs] def load_memory_variables(self, inputs: Dict[str, Any]) -> Dict[str, Any]: """Return history buffer.""" chain = LLMChain(llm=self.llm, prompt=self.entity_extraction_prompt) if self.input_key is None: prompt_input_key = get_prompt_input_key(inputs, self.memory_variables) else: prompt_input_key = self.input_key buffer_string = get_buffer_string( self.buffer[-self.k * 2 :], human_prefix=self.human_prefix, ai_prefix=self.ai_prefix, ) output = chain.predict( history=buffer_string, input=inputs[prompt_input_key], ) if output.strip() == "NONE": entities = [] else: entities = [w.strip() for w in output.split(",")] entity_summaries = {} for entity in entities: entity_summaries[entity] = self.entity_store.get(entity, "") self.entity_cache = entities if self.return_messages: buffer: Any = self.buffer[-self.k * 2 :] else: buffer = buffer_string return { self.chat_history_key: buffer, "entities": entity_summaries, }
[docs] def save_context(self, inputs: Dict[str, Any], outputs: Dict[str, str]) -> None: """Save context from this conversation to buffer.""" super().save_context(inputs, outputs) if self.input_key is None: prompt_input_key = get_prompt_input_key(inputs, self.memory_variables) else: prompt_input_key = self.input_key buffer_string = get_buffer_string( self.buffer[-self.k * 2 :], human_prefix=self.human_prefix, ai_prefix=self.ai_prefix, ) input_data = inputs[prompt_input_key] chain = LLMChain(llm=self.llm, prompt=self.entity_summarization_prompt) for entity in self.entity_cache: existing_summary = self.entity_store.get(entity, "") output = chain.predict( summary=existing_summary, entity=entity, history=buffer_string, input=input_data, ) self.entity_store.set(entity, output.strip())
[docs] def clear(self) -> None: """Clear memory contents.""" self.chat_memory.clear() self.entity_store.clear()