Skip to content

Commit

Permalink
Add ability to specify topics for memory extraction and retrieval (#93)
Browse files Browse the repository at this point in the history
This PR introduces the ability for the client to specify topics for
extraction and retrieval.

Specifically, inside the memory_module config we add the ability to
specify a Topic:
eg:
```
Topic(name="Device Type", description="The type of device the user has"),
Topic(name="Operating System", description="The user's operating system"),
Topic(name="Device year", description="The year of the user's device"),
```

Here, we are telling the system to specifically look for these topics
during extraction. After extraction, when we store the memories, we also
store the associated topic.
By default, we provide some general topics that the system uses - these
retain the current behavior of the system (See [default
topics](https://github.com/microsoft/teams-memory-agents-py/blob/61a1538c7888ddcda4e5f4a4e953d7df91729479/packages/memory_module/config.py#L22-L29)).

During retrieval, we now provide the ability to query for memories with
topic as well. So if the user specifies a topic, but no query, then all
the memories associated for that particular user for that topic will be
returned ordered from latest to oldest memory. If a query is specified,
then both query and topic will be taken into consideration.
The new signature for retrieval is as follows:
```python
@AbstractMethod
async def retrieve_memories(
    self,
    user_id: Optional[str],
    config: RetrievalConfig,
) -> List[Memory]:
    """Retrieve relevant memories based on a query."""
    pass
```
Here `RetrievalConfig` is:
```python
class RetrievalConfig(BaseModel):
    query: Optional[str] = None
    topic: Optional[Topic] = None
    limit: Optional[int] = None
```

TODO:
[ ] Add a test for topic retrieval

Addresses:
#40
  • Loading branch information
heyitsaamir authored Jan 10, 2025
1 parent 6471376 commit 069677b
Show file tree
Hide file tree
Showing 16 changed files with 718 additions and 252 deletions.
8 changes: 4 additions & 4 deletions packages/evals/benchmark_memory_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,9 @@

from memory_module.config import LLMConfig, MemoryModuleConfig
from memory_module.core.memory_module import MemoryModule
from memory_module.interfaces.types import AssistantMessage, UserMessage
from memory_module.interfaces.types import AssistantMessage, RetrievalConfig, UserMessage

from evals.helpers import Dataset, DatasetItem, load_dataset, setup_mlflow
from evals.helpers import Dataset, DatasetItem, SessionMessage, load_dataset, setup_mlflow
from evals.metrics import string_check_metric

setup_mlflow(experiment_name="memory_module")
Expand Down Expand Up @@ -48,7 +48,7 @@ def __exit__(self, exc_type, exc_value, traceback):
os.remove(self._db_path)


async def add_messages(memory_module: MemoryModule, messages: List[dict]):
async def add_messages(memory_module: MemoryModule, messages: List[SessionMessage]):
def create_message(**kwargs):
params = {
"id": str(uuid.uuid4()),
Expand Down Expand Up @@ -94,7 +94,7 @@ async def benchmark_memory_module(input: DatasetItem):
# buffer size has to be the same as the session length to trigger sm processing
with MemoryModuleManager(buffer_size=len(session)) as memory_module:
await add_messages(memory_module, messages=session)
memories = await memory_module.retrieve_memories(query, user_id=None, limit=None)
memories = await memory_module.retrieve_memories(None, RetrievalConfig(query=query, limit=None))

return {
"input": {
Expand Down
2 changes: 2 additions & 0 deletions packages/memory_module/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
Memory,
Message,
MessageInput,
RetrievalConfig,
ShortTermMemoryRetrievalConfig,
UserMessage,
UserMessageInput,
Expand All @@ -29,6 +30,7 @@
"MessageInput",
"AssistantMessage",
"AssistantMessageInput",
"RetrievalConfig",
"ShortTermMemoryRetrievalConfig",
"MemoryMiddleware",
]
17 changes: 17 additions & 0 deletions packages/memory_module/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@

from pydantic import BaseModel, ConfigDict, Field

from memory_module.interfaces.types import Topic


class LLMConfig(BaseModel):
"""Configuration for LLM service."""
Expand All @@ -16,6 +18,18 @@ class LLMConfig(BaseModel):
embedding_model: Optional[str] = None


DEFAULT_TOPICS = [
Topic(
name="General Interests and Preferences",
description="When a user mentions specific events or actions, focus on the underlying interests, hobbies, or preferences they reveal (e.g., if the user mentions attending a conference, focus on the topic of the conference, not the date or location).", # noqa: E501
),
Topic(
name="General Facts about the user",
description="Facts that describe relevant information about the user, such as details about where they live or things they own.", # noqa: E501
),
]


class MemoryModuleConfig(BaseModel):
"""Configuration for memory module components.
Expand All @@ -35,4 +49,7 @@ class MemoryModuleConfig(BaseModel):
description="Seconds to wait before processing a conversation",
)
llm: LLMConfig = Field(description="LLM service configuration")
topics: list[Topic] = Field(
default=DEFAULT_TOPICS, description="List of topics that the memory module should listen to", min_length=1
)
enable_logging: bool = Field(default=False, description="Enable verbose logging for memory module")
59 changes: 45 additions & 14 deletions packages/memory_module/core/memory_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,14 @@
from memory_module.interfaces.base_memory_storage import BaseMemoryStorage
from memory_module.interfaces.types import (
BaseMemoryInput,
EmbedText,
Memory,
MemoryType,
Message,
MessageInput,
RetrievalConfig,
ShortTermMemoryRetrievalConfig,
TextEmbedding,
Topic,
)
from memory_module.services.llm_service import LLMService
from memory_module.storage.in_memory_storage import InMemoryStorage
Expand Down Expand Up @@ -51,6 +53,11 @@ class SemanticFact(BaseModel):
default_factory=set,
description="The indices of the messages that the fact was extracted from.",
)
# TODO: Add a validator to ensure that topics are valid
topics: Optional[List[str]] = Field(
default=None,
description="The name of the topic that the fact is most relevant to.", # noqa: E501
)


class SemanticMemoryExtraction(BaseModel):
Expand Down Expand Up @@ -106,6 +113,7 @@ def __init__(
self.storage: BaseMemoryStorage = storage or (
SQLiteMemoryStorage(db_path=config.db_path) if config.db_path is not None else InMemoryStorage()
)
self.topics = config.topics

async def process_semantic_messages(
self,
Expand Down Expand Up @@ -133,7 +141,7 @@ async def process_semantic_messages(

if extraction.action == "add" and extraction.facts:
for fact in extraction.facts:
decision = await self._get_add_memory_processing_decision(fact.text, author_id)
decision = await self._get_add_memory_processing_decision(fact, author_id)
if decision.decision == "ignore":
logger.info(f"Decision to ignore fact {fact.text}")
continue
Expand All @@ -145,6 +153,7 @@ async def process_semantic_messages(
user_id=author_id,
message_attributions=list(message_ids),
memory_type=MemoryType.SEMANTIC,
topics=fact.topics,
)
embed_vectors = await self._get_semantic_fact_embeddings(fact.text, metadata)
await self.storage.store_memory(memory, embedding_vectors=embed_vectors)
Expand All @@ -154,16 +163,37 @@ async def process_episodic_messages(self, messages: List[Message]) -> None:
# TODO: Implement episodic memory processing
await self._extract_episodic_memory_from_messages(messages)

async def retrieve_memories(self, query: str, user_id: Optional[str], limit: Optional[int]) -> List[Memory]:
async def retrieve_memories(
self,
user_id: Optional[str],
config: RetrievalConfig,
) -> List[Memory]:
return await self._retrieve_memories(
user_id, config.query, [config.topic] if config.topic else None, config.limit
)

async def _retrieve_memories(
self,
user_id: Optional[str],
query: Optional[str],
topics: Optional[List[Topic]],
limit: Optional[int],
) -> List[Memory]:
"""Retrieve memories based on a query.
Steps:
1. Convert query to embedding
2. Find relevant memories
3. Possibly rerank or filter results
"""
embedText = EmbedText(text=query, embedding_vector=await self._get_query_embedding(query))
return await self.storage.retrieve_memories(embedText, user_id, limit)
if query:
text_embedding = TextEmbedding(text=query, embedding_vector=await self._get_query_embedding(query))
else:
text_embedding = None

return await self.storage.retrieve_memories(
user_id=user_id, text_embedding=text_embedding, topics=topics, limit=limit
)

async def update_memory(self, memory_id: str, updated_memory: str) -> None:
metadata = await self._extract_metadata_from_fact(updated_memory)
Expand Down Expand Up @@ -195,10 +225,13 @@ async def remove_messages(self, message_ids: List[str]) -> None:
logger.info("messages {} are removed".format(",".join(message_ids)))

async def _get_add_memory_processing_decision(
self, new_memory_fact: str, user_id: Optional[str]
self, new_memory_fact: SemanticFact, user_id: Optional[str]
) -> ProcessSemanticMemoryDecision:
similar_memories = await self.retrieve_memories(new_memory_fact, user_id, None)
decision = await self._extract_memory_processing_decision(new_memory_fact, similar_memories, user_id)
# topics = (
# [topic for topic in self.topics if topic.name in new_memory_fact.topics] if new_memory_fact.topics else None # noqa: E501
# )
similar_memories = await self._retrieve_memories(user_id, new_memory_fact.text, None, None)
decision = await self._extract_memory_processing_decision(new_memory_fact.text, similar_memories, user_id)
return decision

async def _extract_memory_processing_decision(
Expand Down Expand Up @@ -306,6 +339,9 @@ async def _extract_semantic_fact_from_messages(
else:
# we explicitly ignore internal messages
continue
topics = "\n".join(
[f"<MEMORY_TOPIC NAME={topic.name}>{topic.description}</MEMORY_TOPIC>" for topic in self.topics]
)

existing_memories_str = ""
if existing_memories:
Expand All @@ -318,11 +354,7 @@ async def _extract_semantic_fact_from_messages(
that will remain relevant over time, even if the user is mentioning short-term plans or events.
Prioritize:
- General Interests and Preferences: When a user mentions specific events or actions, focus on the underlying
interests, hobbies, or preferences they reveal (e.g., if the user mentions attending a conference, focus on the topic of the conference,
not the date or location).
- Facts or Details about user: Extract facts that describe relevant information about the user, such as details about things they own.
- Facts about the user that the assistant might find useful.
{topics}
Avoid:
- Extraction memories that already exist in the system. If a fact is already stored, ignore it.
Expand All @@ -335,7 +367,6 @@ async def _extract_semantic_fact_from_messages(
{messages_str}
</TRANSCRIPT>
""" # noqa: E501

llm_messages = [
{"role": "system", "content": system_message},
{
Expand Down
18 changes: 14 additions & 4 deletions packages/memory_module/core/memory_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,13 @@
from memory_module.interfaces.base_memory_core import BaseMemoryCore
from memory_module.interfaces.base_memory_module import BaseMemoryModule
from memory_module.interfaces.base_message_queue import BaseMessageQueue
from memory_module.interfaces.types import Memory, Message, MessageInput, ShortTermMemoryRetrievalConfig
from memory_module.interfaces.types import (
Memory,
Message,
MessageInput,
RetrievalConfig,
ShortTermMemoryRetrievalConfig,
)
from memory_module.services.llm_service import LLMService
from memory_module.utils.logging import configure_logging

Expand Down Expand Up @@ -50,10 +56,14 @@ async def add_message(self, message: MessageInput) -> Message:
await self.message_queue.enqueue(message_res)
return message_res

async def retrieve_memories(self, query: str, user_id: Optional[str], limit: Optional[int]) -> List[Memory]:
async def retrieve_memories(
self,
user_id: Optional[str],
config: RetrievalConfig,
) -> List[Memory]:
"""Retrieve relevant memories based on a query."""
logger.debug(f"retrieve memories from (query: {query}, user_id: {user_id}, limit: {limit})")
memories = await self.memory_core.retrieve_memories(query, user_id, limit)
logger.debug(f"retrieve memories from (query: {config.query}, user_id: {user_id}, limit: {config.limit})")
memories = await self.memory_core.retrieve_memories(user_id=user_id, config=config)
logger.debug(f"retrieved memories: {memories}")
return memories

Expand Down
14 changes: 12 additions & 2 deletions packages/memory_module/interfaces/base_memory_core.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,13 @@
from abc import ABC, abstractmethod
from typing import Dict, List, Optional

from memory_module.interfaces.types import Memory, Message, MessageInput, ShortTermMemoryRetrievalConfig
from memory_module.interfaces.types import (
Memory,
Message,
MessageInput,
RetrievalConfig,
ShortTermMemoryRetrievalConfig,
)


class BaseMemoryCore(ABC):
Expand All @@ -22,7 +28,11 @@ async def process_episodic_messages(self, messages: List[Message]) -> None:
pass

@abstractmethod
async def retrieve_memories(self, query: str, user_id: Optional[str], limit: Optional[int]) -> List[Memory]:
async def retrieve_memories(
self,
user_id: Optional[str],
config: RetrievalConfig,
) -> List[Memory]:
"""Retrieve memories based on a query."""
pass

Expand Down
14 changes: 12 additions & 2 deletions packages/memory_module/interfaces/base_memory_module.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,13 @@
from abc import ABC, abstractmethod
from typing import Dict, List, Optional

from memory_module.interfaces.types import Memory, Message, MessageInput, ShortTermMemoryRetrievalConfig
from memory_module.interfaces.types import (
Memory,
Message,
MessageInput,
RetrievalConfig,
ShortTermMemoryRetrievalConfig,
)


class BaseMemoryModule(ABC):
Expand All @@ -13,7 +19,11 @@ async def add_message(self, message: MessageInput) -> Message:
pass

@abstractmethod
async def retrieve_memories(self, query: str, user_id: Optional[str], limit: Optional[int]) -> List[Memory]:
async def retrieve_memories(
self,
user_id: Optional[str],
config: RetrievalConfig,
) -> List[Memory]:
"""Retrieve relevant memories based on a query."""
pass

Expand Down
10 changes: 8 additions & 2 deletions packages/memory_module/interfaces/base_memory_storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,12 @@

from memory_module.interfaces.types import (
BaseMemoryInput,
EmbedText,
Memory,
Message,
MessageInput,
ShortTermMemoryRetrievalConfig,
TextEmbedding,
Topic,
)


Expand Down Expand Up @@ -47,7 +48,12 @@ async def store_short_term_memory(self, message: MessageInput) -> Message:

@abstractmethod
async def retrieve_memories(
self, embedText: EmbedText, user_id: Optional[str], limit: Optional[int] = None
self,
*,
user_id: Optional[str],
text_embedding: Optional[TextEmbedding] = None,
topics: Optional[List[Topic]] = None,
limit: Optional[int] = None,
) -> List[Memory]:
"""Retrieve memories based on a query.
Expand Down
38 changes: 36 additions & 2 deletions packages/memory_module/interfaces/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,12 @@ class BaseMemoryInput(BaseModel):
memory_type: MemoryType
user_id: Optional[str] = None
message_attributions: Optional[List[str]] = Field(default=[])
topics: Optional[List[str]] = None


class Topic(BaseModel):
name: str = Field(description="A unique name of the topic that the memory module should listen to")
description: str = Field(description="Description of the topic")


class Memory(BaseMemoryInput):
Expand All @@ -112,12 +118,40 @@ class Memory(BaseMemoryInput):
id: str


class EmbedText(BaseModel):
class TextEmbedding(BaseModel):
text: str
embedding_vector: List[float]


class ShortTermMemoryRetrievalConfig(BaseModel):
class RetrievalConfig(BaseModel):
"""Configuration for memory retrieval operations.
This class defines the parameters used to retrieve memories from storage. Memories can be
retrieved either by a semantic search query or by filtering for a specific topic or both.
In case of both, the memories are retrieved by the intersection of the two sets.
"""

query: Optional[str] = Field(
default=None, description="A natural language query to search for semantically similar memories"
)
topic: Optional[Topic] = Field(
default=None,
description="Topic to filter memories by. Only memories tagged with this topic will be retrieved",
)
limit: Optional[int] = Field(
default=None,
description="Maximum number of memories to retrieve. If not specified, all matching memories are returned",
)

@model_validator(mode="after")
def check_parameters(self) -> "RetrievalConfig":
if self.query is None and self.topic is None:
raise ValueError("Either query or topic must be provided")
return self


class ShortTermMemoryRetrievalConfig(RetrievalConfig):
n_messages: Optional[int] = None # Number of messages to retrieve
last_minutes: Optional[float] = None # Time frame in minutes
before: Optional[datetime] = None # Retrieve messages up until a specific timestamp
Expand Down
Loading

0 comments on commit 069677b

Please sign in to comment.