From 66c3f077861853795eefea35702714f266ee07ab Mon Sep 17 00:00:00 2001 From: Ludwik Trammer Date: Thu, 9 Jan 2025 16:21:16 +0100 Subject: [PATCH 1/6] Feat(conversations): Add last message recontextualizer --- .../conversations/recontextualize_message.py | 52 ++++++++++++ packages/ragbits-conversations/README.md | 1 + packages/ragbits-conversations/pyproject.toml | 63 ++++++++++++++ .../src/ragbits/conversations/__init__.py | 0 .../ragbits/conversations/history/__init__.py | 0 .../history/compressors/__init__.py | 4 + .../conversations/history/compressors/base.py | 31 +++++++ .../conversations/history/compressors/llm.py | 82 +++++++++++++++++++ pyproject.toml | 4 + uv.lock | 28 +++++++ 10 files changed, 265 insertions(+) create mode 100644 examples/conversations/recontextualize_message.py create mode 100644 packages/ragbits-conversations/README.md create mode 100644 packages/ragbits-conversations/pyproject.toml create mode 100644 packages/ragbits-conversations/src/ragbits/conversations/__init__.py create mode 100644 packages/ragbits-conversations/src/ragbits/conversations/history/__init__.py create mode 100644 packages/ragbits-conversations/src/ragbits/conversations/history/compressors/__init__.py create mode 100644 packages/ragbits-conversations/src/ragbits/conversations/history/compressors/base.py create mode 100644 packages/ragbits-conversations/src/ragbits/conversations/history/compressors/llm.py diff --git a/examples/conversations/recontextualize_message.py b/examples/conversations/recontextualize_message.py new file mode 100644 index 000000000..78cca843f --- /dev/null +++ b/examples/conversations/recontextualize_message.py @@ -0,0 +1,52 @@ +""" +Ragbits Conversations Example: Recontextualize Last Message + +This example demonstrates how to use the `RecontextualizeLastMessage` compressor to recontextualize +the last message in a conversation history. +""" + +# /// script +# requires-python = ">=3.10" +# dependencies = [ +# "ragbits-conversations", +# ] +# /// + +import asyncio + +from ragbits.conversations.history.compressors.llm import RecontextualizeLastMessage +from ragbits.core.llms.litellm import LiteLLM +from ragbits.core.prompt import ChatFormat + +# Example conversation history +conversation: ChatFormat = [ + {"role": "user", "content": "Who's working on Friday?"}, + {"role": "assistant", "content": "Jim"}, + {"role": "user", "content": "Where is he based?"}, + {"role": "assistant", "content": "At our California Head Office"}, + {"role": "user", "content": "Is he a senior staff member?"}, + {"role": "assistant", "content": "Yes, he's a senior manager"}, + {"role": "user", "content": "What's his phone number (including the prefix for his state)?"}, +] + + +async def main() -> None: + """ + Main function to demonstrate the RecontextualizeLastMessage compressor. + """ + # Initialize the LiteLLM client + llm = LiteLLM("gpt-4o") + + # Initialize the RecontextualizeLastMessage compressor + compressor = RecontextualizeLastMessage(llm, history_len=10) + + # Compress the conversation history + recontextualized_message = await compressor.compress(conversation) + + # Print the recontextualized message + print("Recontextualized Message:") + print(recontextualized_message) + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/packages/ragbits-conversations/README.md b/packages/ragbits-conversations/README.md new file mode 100644 index 000000000..444078962 --- /dev/null +++ b/packages/ragbits-conversations/README.md @@ -0,0 +1 @@ +# Ragbits Conversation diff --git a/packages/ragbits-conversations/pyproject.toml b/packages/ragbits-conversations/pyproject.toml new file mode 100644 index 000000000..7b3ff43be --- /dev/null +++ b/packages/ragbits-conversations/pyproject.toml @@ -0,0 +1,63 @@ +[project] +name = "ragbits-conversations" +version = "0.6.0" +description = "Building blocks for rapid development of GenAI applications" +readme = "README.md" +requires-python = ">=3.10" +license = "MIT" +authors = [ + { name = "deepsense.ai", email = "ragbits@deepsense.ai"} +] +keywords = [ + "Retrieval Augmented Generation", + "RAG", + "Large Language Models", + "LLMs", + "Generative AI", + "GenAI", + "Prompt Management" +] +classifiers = [ + "Development Status :: 4 - Beta", + "Environment :: Console", + "Intended Audience :: Science/Research", + "License :: OSI Approved :: MIT License", + "Natural Language :: English", + "Operating System :: OS Independent", + "Programming Language :: Python :: 3.10", + "Programming Language :: Python :: 3.11", + "Programming Language :: Python :: 3.12", + "Programming Language :: Python :: 3.13", + "Topic :: Scientific/Engineering :: Artificial Intelligence", + "Topic :: Software Development :: Libraries :: Python Modules", +] +dependencies = [] + +[project.urls] +"Homepage" = "https://github.com/deepsense-ai/ragbits" +"Bug Reports" = "https://github.com/deepsense-ai/ragbits/issues" +"Documentation" = "https://ragbits.deepsense.ai/" +"Source" = "https://github.com/deepsense-ai/ragbits" + +[project.optional-dependencies] +[tool.uv] +dev-dependencies = [ + "pre-commit~=3.8.0", + "pytest~=8.3.3", + "pytest-cov~=5.0.0", + "pytest-asyncio~=0.24.0", + "pip-licenses>=4.0.0,<5.0.0" +] + +[build-system] +requires = ["hatchling"] +build-backend = "hatchling.build" + +[tool.hatch.metadata] +allow-direct-references = true + +[tool.hatch.build.targets.wheel] +packages = ["src/ragbits"] + +[tool.pytest.ini_options] +asyncio_mode = "auto" diff --git a/packages/ragbits-conversations/src/ragbits/conversations/__init__.py b/packages/ragbits-conversations/src/ragbits/conversations/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/packages/ragbits-conversations/src/ragbits/conversations/history/__init__.py b/packages/ragbits-conversations/src/ragbits/conversations/history/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/packages/ragbits-conversations/src/ragbits/conversations/history/compressors/__init__.py b/packages/ragbits-conversations/src/ragbits/conversations/history/compressors/__init__.py new file mode 100644 index 000000000..fcd1df752 --- /dev/null +++ b/packages/ragbits-conversations/src/ragbits/conversations/history/compressors/__init__.py @@ -0,0 +1,4 @@ +from .base import ConversationHistoryCompressor +from .llm import RecontextualizeLastMessage + +__all__ = ["ConversationHistoryCompressor", "RecontextualizeLastMessage"] diff --git a/packages/ragbits-conversations/src/ragbits/conversations/history/compressors/base.py b/packages/ragbits-conversations/src/ragbits/conversations/history/compressors/base.py new file mode 100644 index 000000000..9b5816e12 --- /dev/null +++ b/packages/ragbits-conversations/src/ragbits/conversations/history/compressors/base.py @@ -0,0 +1,31 @@ +from abc import ABC, abstractmethod +from typing import ClassVar + +from ragbits.conversations.history import compressors +from ragbits.core.prompt.base import ChatFormat +from ragbits.core.utils.config_handling import WithConstructionConfig + + +class ConversationHistoryCompressor(WithConstructionConfig, ABC): + """ + An abstract class for conversation history compressors, + i.e. class that takes the entire conversation history + and returns a single string representation of it. + + The exact logic of what the string should include and represent + depends on the specific implementation. + + Usually used to provide LLM additional context from the conversation history. + """ + + default_module: ClassVar = compressors + configuration_key: ClassVar = "history_compressor" + + @abstractmethod + async def compress(self, conversation: ChatFormat) -> str: + """ + Compresses the conversation history to a single string. + + Args: + conversation: List of dicts with "role" and "content" keys, representing the chat history so far. + """ diff --git a/packages/ragbits-conversations/src/ragbits/conversations/history/compressors/llm.py b/packages/ragbits-conversations/src/ragbits/conversations/history/compressors/llm.py new file mode 100644 index 000000000..c94949eda --- /dev/null +++ b/packages/ragbits-conversations/src/ragbits/conversations/history/compressors/llm.py @@ -0,0 +1,82 @@ +from pydantic import BaseModel + +from ragbits.conversations.history.compressors import ConversationHistoryCompressor +from ragbits.core.llms.base import LLM +from ragbits.core.prompt import ChatFormat, Prompt + + +class LastMessageAndHistory(BaseModel): + """ + A class representing the last message and the history of messages. + """ + + last_message: str + history: list[str] + + +class RecontextualizeLastMessagePrompt(Prompt[LastMessageAndHistory, str]): + """ + A prompt for recontextualizing the last message in the history. + """ + + system_prompt = """ + Given a new message and a history of the conversation, create a standalone version of the message. + If the message references any context from history, it should be added to the message itself. + Return only the recontextualized message. + Do NOT return the history, do NOT answer the question, and do NOT add context irrelevant to the message. + """ + + user_prompt = """ + Message: + {{ last_message }} + + History: + {% for message in history %} + * {{ message }} + {% endfor %} + """ + + +class RecontextualizeLastMessage(ConversationHistoryCompressor): + """ + A compressor that uses LLM to recontextualize the last message in the history, + i.e. create a standalone version of the message that includes necessary context. + """ + + def __init__(self, llm: LLM, history_len: int = 5, prompt: type[Prompt[LastMessageAndHistory, str]] | None = None): + """ + Initialize the RecontextualizeLastMessage compressor with a LLM. + + Args: + llm: A LLM instance to handle recontextualizing the last message. + history_len: The number of previous messages to include in the history. + prompt: The prompt to use for recontextualizing the last message. + """ + self._llm = llm + self._history_len = history_len + self._prompt = prompt or RecontextualizeLastMessagePrompt + + async def compress(self, conversation: ChatFormat) -> str: + """ + Contextualize the last message in the conversation history. + + Args: + conversation: List of dicts with "role" and "content" keys, representing the chat history so far. + The most recent message should be from the user. + """ + if len(conversation) == 0: + raise ValueError("Conversation history is empty.") + + last_message = conversation[-1] + if last_message["role"] != "user": + raise ValueError("RecontextualizeLastMessage expects the last message to be from the user.") + + # Only indlucde "user" and "assistant" messages in the history + other_messages = [message for message in conversation[:-1] if message["role"] in ["user", "assistant"]] + + history = [f"{message['role']}: {message['content']}" for message in other_messages[-self._history_len :]] + + input_data = LastMessageAndHistory(last_message=last_message["content"], history=history) + prompt = self._prompt(input_data) + response = await self._llm.generate(prompt) + return response diff --git a/pyproject.toml b/pyproject.toml index 2502b26b0..ea1b22632 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -10,6 +10,7 @@ dependencies = [ "ragbits-document-search[gcs,huggingface,distributed]", "ragbits-evaluate[relari]", "ragbits-guardrails[openai]", + "ragbits-conversations", ] [tool.uv] @@ -38,6 +39,7 @@ ragbits-core = { workspace = true } ragbits-document-search = { workspace = true } ragbits-evaluate = {workspace = true} ragbits-guardrails = {workspace = true} +ragbits-conversations = {workspace = true} [tool.uv.workspace] members = [ @@ -46,6 +48,7 @@ members = [ "packages/ragbits-document-search", "packages/ragbits-evaluate", "packages/ragbits-guardrails", + "packages/ragbits-conversations", ] [tool.pytest] @@ -93,6 +96,7 @@ mypy_path = [ "packages/ragbits-document-search/src", "packages/ragbits-evaluate/src", "packages/ragbits-guardrails/src", + "packages/ragbits-conversations/src", ] exclude = ["scripts"] diff --git a/uv.lock b/uv.lock index 299460ca8..2424ad2e0 100644 --- a/uv.lock +++ b/uv.lock @@ -11,6 +11,7 @@ resolution-markers = [ [manifest] members = [ "ragbits-cli", + "ragbits-conversations", "ragbits-core", "ragbits-document-search", "ragbits-evaluate", @@ -3845,6 +3846,31 @@ requires-dist = [ { name = "typer", specifier = ">=0.12.5" }, ] +[[package]] +name = "ragbits-conversations" +version = "0.6.0" +source = { editable = "packages/ragbits-conversations" } + +[package.dev-dependencies] +dev = [ + { name = "pip-licenses" }, + { name = "pre-commit" }, + { name = "pytest" }, + { name = "pytest-asyncio" }, + { name = "pytest-cov" }, +] + +[package.metadata] + +[package.metadata.requires-dev] +dev = [ + { name = "pip-licenses", specifier = ">=4.0.0,<5.0.0" }, + { name = "pre-commit", specifier = "~=3.8.0" }, + { name = "pytest", specifier = "~=8.3.3" }, + { name = "pytest-asyncio", specifier = "~=0.24.0" }, + { name = "pytest-cov", specifier = "~=5.0.0" }, +] + [[package]] name = "ragbits-core" version = "0.6.0" @@ -4055,6 +4081,7 @@ version = "0.1.0" source = { virtual = "." } dependencies = [ { name = "ragbits-cli" }, + { name = "ragbits-conversations" }, { name = "ragbits-core", extra = ["chroma", "lab", "local", "otel", "qdrant"] }, { name = "ragbits-document-search", extra = ["distributed", "gcs", "huggingface"] }, { name = "ragbits-evaluate", extra = ["relari"] }, @@ -4084,6 +4111,7 @@ dev = [ [package.metadata] requires-dist = [ { name = "ragbits-cli", editable = "packages/ragbits-cli" }, + { name = "ragbits-conversations", editable = "packages/ragbits-conversations" }, { name = "ragbits-core", extras = ["chroma", "lab", "local", "otel", "qdrant"], editable = "packages/ragbits-core" }, { name = "ragbits-document-search", extras = ["gcs", "huggingface", "distributed"], editable = "packages/ragbits-document-search" }, { name = "ragbits-evaluate", extras = ["relari"], editable = "packages/ragbits-evaluate" }, From 607bab39eb8d9d1da6b5b673b5110d94cf2a3e3a Mon Sep 17 00:00:00 2001 From: Ludwik Trammer Date: Thu, 9 Jan 2025 16:28:57 +0100 Subject: [PATCH 2/6] Fix linter errors --- packages/ragbits-core/src/ragbits/core/prompt/promptfoo.py | 2 +- .../src/ragbits/core/utils/dict_transformations.py | 2 +- .../src/ragbits/evaluate/dataset_generator/prompts/qa.py | 6 ++---- 3 files changed, 4 insertions(+), 6 deletions(-) diff --git a/packages/ragbits-core/src/ragbits/core/prompt/promptfoo.py b/packages/ragbits-core/src/ragbits/core/prompt/promptfoo.py index 73d5f679c..ea9a41932 100644 --- a/packages/ragbits-core/src/ragbits/core/prompt/promptfoo.py +++ b/packages/ragbits-core/src/ragbits/core/prompt/promptfoo.py @@ -46,5 +46,5 @@ def generate_configs( target_path.mkdir() for prompt in prompts: with open(target_path / f"{prompt.__qualname__}.yaml", "w", encoding="utf-8") as f: - prompt_path = f'file://{prompt.__module__.replace(".", os.sep)}.py:{prompt.__qualname__}.to_promptfoo' + prompt_path = f"file://{prompt.__module__.replace('.', os.sep)}.py:{prompt.__qualname__}.to_promptfoo" yaml.dump({"prompts": [prompt_path]}, f) diff --git a/packages/ragbits-core/src/ragbits/core/utils/dict_transformations.py b/packages/ragbits-core/src/ragbits/core/utils/dict_transformations.py index ae58fc62d..617cce76f 100644 --- a/packages/ragbits-core/src/ragbits/core/utils/dict_transformations.py +++ b/packages/ragbits-core/src/ragbits/core/utils/dict_transformations.py @@ -98,7 +98,7 @@ def _decompose_key(key: str) -> tuple[str | int | None, str | int | None]: _current_subkey = int(_key[start_subscript_index:end_subscript_index]) if len(_key[end_subscript_index:]) > 1: - _current_subkey = f"{_current_subkey}.{_key[end_subscript_index + 2:]}" + _current_subkey = f"{_current_subkey}.{_key[end_subscript_index + 2 :]}" break elif char == ".": split_work = _key.split(".", 1) diff --git a/packages/ragbits-evaluate/src/ragbits/evaluate/dataset_generator/prompts/qa.py b/packages/ragbits-evaluate/src/ragbits/evaluate/dataset_generator/prompts/qa.py index 40e29078d..d9295ae3c 100644 --- a/packages/ragbits-evaluate/src/ragbits/evaluate/dataset_generator/prompts/qa.py +++ b/packages/ragbits-evaluate/src/ragbits/evaluate/dataset_generator/prompts/qa.py @@ -23,7 +23,7 @@ class BasicAnswerGenPrompt(Prompt[BasicAnswerGenInput, str]): "If you don't know the answer just say: I don't know." ) - user_prompt: str = "Text:\n<|text_start|>\n {{ chunk }} \n<|text_end|>\n\nQuestion:\n " "{{ question }} \n\nAnswer:" + user_prompt: str = "Text:\n<|text_start|>\n {{ chunk }} \n<|text_end|>\n\nQuestion:\n {{ question }} \n\nAnswer:" class PassagesGenInput(BaseModel): @@ -49,9 +49,7 @@ class PassagesGenPrompt(Prompt[PassagesGenInput, str]): "FULL SENTENCES" ) - user_prompt: str = ( - "Question:\n {{ question }} \nAnswer:\n {{ basic_answer }} \nChunk:\n " "{{ chunk }}\n\nPassages:" - ) + user_prompt: str = "Question:\n {{ question }} \nAnswer:\n {{ basic_answer }} \nChunk:\n {{ chunk }}\n\nPassages:" class QueryGenInput(BaseModel): From 5e552b721af2b08ad37cbce70bf7a10a3b7e0b8e Mon Sep 17 00:00:00 2001 From: Ludwik Trammer Date: Mon, 13 Jan 2025 14:13:57 +0100 Subject: [PATCH 3/6] Add unit tests --- .../conversations/history/compressors/llm.py | 2 +- .../tests/unit/history/test_llm_compressor.py | 111 ++++++++++++++++++ .../src/ragbits/core/llms/mock.py | 75 ++++++++++++ 3 files changed, 187 insertions(+), 1 deletion(-) create mode 100644 packages/ragbits-conversations/tests/unit/history/test_llm_compressor.py create mode 100644 packages/ragbits-core/src/ragbits/core/llms/mock.py diff --git a/packages/ragbits-conversations/src/ragbits/conversations/history/compressors/llm.py b/packages/ragbits-conversations/src/ragbits/conversations/history/compressors/llm.py index c94949eda..9bf9bfd39 100644 --- a/packages/ragbits-conversations/src/ragbits/conversations/history/compressors/llm.py +++ b/packages/ragbits-conversations/src/ragbits/conversations/history/compressors/llm.py @@ -71,7 +71,7 @@ async def compress(self, conversation: ChatFormat) -> str: if last_message["role"] != "user": raise ValueError("RecontextualizeLastMessage expects the last message to be from the user.") - # Only indlucde "user" and "assistant" messages in the history + # Only include "user" and "assistant" messages in the history other_messages = [message for message in conversation[:-1] if message["role"] in ["user", "assistant"]] history = [f"{message['role']}: {message['content']}" for message in other_messages[-self._history_len :]] diff --git a/packages/ragbits-conversations/tests/unit/history/test_llm_compressor.py b/packages/ragbits-conversations/tests/unit/history/test_llm_compressor.py new file mode 100644 index 000000000..2fd1fa522 --- /dev/null +++ b/packages/ragbits-conversations/tests/unit/history/test_llm_compressor.py @@ -0,0 +1,111 @@ +import pytest + +from ragbits.conversations.history.compressors.llm import LastMessageAndHistory, RecontextualizeLastMessage +from ragbits.core.llms.mock import MockLLM, MockLLMOptions +from ragbits.core.prompt import ChatFormat +from ragbits.core.prompt.prompt import Prompt + + +class MockPrompt(Prompt[LastMessageAndHistory, str]): + user_prompt = "mock prompt" + + +async def test_messages_included(): + conversation: ChatFormat = [ + {"role": "user", "content": "foo1"}, + {"role": "assistant", "content": "foo2"}, + {"role": "user", "content": "foo3"}, + ] + llm = MockLLM(default_options=MockLLMOptions(response="some answer")) + compressor = RecontextualizeLastMessage(llm) + answer = await compressor.compress(conversation) + assert answer == "some answer" + user_prompt = llm.calls[0][1] + assert user_prompt["role"] == "user" + content = user_prompt["content"] + assert "foo1" in content + assert "foo2" in content + assert "foo3" in content + + +async def test_no_messages(): + conversation: ChatFormat = [] + compressor = RecontextualizeLastMessage(MockLLM()) + + with pytest.raises(ValueError): + await compressor.compress(conversation) + + +async def test_last_message_not_user(): + conversation: ChatFormat = [ + {"role": "assistant", "content": "foo2"}, + ] + compressor = RecontextualizeLastMessage(MockLLM()) + + with pytest.raises(ValueError): + await compressor.compress(conversation) + + +async def test_history_len(): + conversation: ChatFormat = [ + {"role": "user", "content": "foo1"}, + {"role": "assistant", "content": "foo2"}, + {"role": "user", "content": "foo3"}, + {"role": "user", "content": "foo4"}, + {"role": "user", "content": "foo5"}, + ] + llm = MockLLM() + compressor = RecontextualizeLastMessage(llm, history_len=3) + await compressor.compress(conversation) + user_prompt = llm.calls[0][1] + assert user_prompt["role"] == "user" + content = user_prompt["content"] + + # The rephrased message should be included + assert "foo5" in content + + # Three previous messages should be included + assert "foo2" in content + assert "foo3" in content + assert "foo4" in content + + # Earlier messages should not be included + assert "foo1" not in content + + +async def test_only_user_and_assistant_messages_in_history(): + conversation: ChatFormat = [ + {"role": "user", "content": "foo4"}, + {"role": "system", "content": "foo1"}, + {"role": "unknown", "content": "foo2"}, + {"role": "assistant", "content": "foo3"}, + {"role": "user", "content": "foo4"}, + {"role": "assistant", "content": "foo5"}, + {"role": "user", "content": "foo6"}, + ] + llm = MockLLM() + compressor = RecontextualizeLastMessage(llm, history_len=4) + await compressor.compress(conversation) + user_prompt = llm.calls[0][1] + assert user_prompt["role"] == "user" + content = user_prompt["content"] + assert "foo4" in content + assert "foo5" in content + assert "foo6" in content + assert "foo3" in content + assert "foo1" not in content + assert "foo2" not in content + + +async def test_changing_prompt(): + conversation: ChatFormat = [ + {"role": "user", "content": "foo1"}, + {"role": "assistant", "content": "foo2"}, + {"role": "user", "content": "foo3"}, + ] + llm = MockLLM() + compressor = RecontextualizeLastMessage(llm, prompt=MockPrompt) + await compressor.compress(conversation) + user_prompt = llm.calls[0][0] + assert user_prompt["role"] == "user" + assert user_prompt["content"] == "mock prompt" diff --git a/packages/ragbits-core/src/ragbits/core/llms/mock.py b/packages/ragbits-core/src/ragbits/core/llms/mock.py new file mode 100644 index 000000000..9a82489fc --- /dev/null +++ b/packages/ragbits-core/src/ragbits/core/llms/mock.py @@ -0,0 +1,75 @@ +from collections.abc import AsyncGenerator + +from pydantic import BaseModel + +from ragbits.core.options import Options +from ragbits.core.prompt import ChatFormat +from ragbits.core.types import NOT_GIVEN, NotGiven + +from .base import LLM + + +class MockLLMOptions(Options): + """ + Options for the MockLLM class. + """ + + response: str | NotGiven = NOT_GIVEN + response_stream: list[str] | NotGiven = NOT_GIVEN + + +class MockLLM(LLM[MockLLMOptions]): + """ + Class for mocking interactions with LLMs - useful for testing. + """ + + options_cls = MockLLMOptions + + def __init__(self, model_name: str = "mock", default_options: MockLLMOptions | None = None) -> None: + """ + Constructs a new MockLLM instance. + + Args: + model_name: Name of the model to be used. + default_options: Default options to be used. + """ + super().__init__(model_name, default_options=default_options) + self.calls: list[ChatFormat] = [] + + async def _call( # noqa: PLR6301 + self, + conversation: ChatFormat, + options: MockLLMOptions, + json_mode: bool = False, + output_schema: type[BaseModel] | dict | None = None, + ) -> str: + """ + Mocks the call to the LLM, using the response from the options if provided. + """ + self.calls.append(conversation) + if not isinstance(options.response, NotGiven): + return options.response + return "mocked response" + + async def _call_streaming( # noqa: PLR6301 + self, + conversation: ChatFormat, + options: MockLLMOptions, + json_mode: bool = False, + output_schema: type[BaseModel] | dict | None = None, + ) -> AsyncGenerator[str, None]: + """ + Mocks the call to the LLM, using the response from the options if provided. + """ + self.calls.append(conversation) + + async def generator() -> AsyncGenerator[str, None]: + if not isinstance(options.response_stream, NotGiven): + for response in options.response_stream: + yield response + elif not isinstance(options.response, NotGiven): + yield options.response + else: + yield "mocked response" + + return generator() From ffca297c408d2cd7536f4706572ce9566800dfb2 Mon Sep 17 00:00:00 2001 From: Ludwik Trammer Date: Mon, 13 Jan 2025 17:35:20 +0100 Subject: [PATCH 4/6] Improvments from review comments --- examples/conversations/recontextualize_message.py | 10 +++++----- .../conversations/history/compressors/__init__.py | 4 ++-- .../conversations/history/compressors/llm.py | 13 ++++++++----- .../tests/unit/history/test_llm_compressor.py | 14 +++++++------- packages/ragbits/pyproject.toml | 9 ++++++++- 5 files changed, 30 insertions(+), 20 deletions(-) diff --git a/examples/conversations/recontextualize_message.py b/examples/conversations/recontextualize_message.py index 78cca843f..23248f81b 100644 --- a/examples/conversations/recontextualize_message.py +++ b/examples/conversations/recontextualize_message.py @@ -1,7 +1,7 @@ """ Ragbits Conversations Example: Recontextualize Last Message -This example demonstrates how to use the `RecontextualizeLastMessage` compressor to recontextualize +This example demonstrates how to use the `StandaloneMessageCompressor` compressor to recontextualize the last message in a conversation history. """ @@ -14,7 +14,7 @@ import asyncio -from ragbits.conversations.history.compressors.llm import RecontextualizeLastMessage +from ragbits.conversations.history.compressors.llm import StandaloneMessageCompressor from ragbits.core.llms.litellm import LiteLLM from ragbits.core.prompt import ChatFormat @@ -32,13 +32,13 @@ async def main() -> None: """ - Main function to demonstrate the RecontextualizeLastMessage compressor. + Main function to demonstrate the StandaloneMessageCompressor compressor. """ # Initialize the LiteLLM client llm = LiteLLM("gpt-4o") - # Initialize the RecontextualizeLastMessage compressor - compressor = RecontextualizeLastMessage(llm, history_len=10) + # Initialize the StandaloneMessageCompressor compressor + compressor = StandaloneMessageCompressor(llm, history_len=10) # Compress the conversation history recontextualized_message = await compressor.compress(conversation) diff --git a/packages/ragbits-conversations/src/ragbits/conversations/history/compressors/__init__.py b/packages/ragbits-conversations/src/ragbits/conversations/history/compressors/__init__.py index fcd1df752..66ac169ad 100644 --- a/packages/ragbits-conversations/src/ragbits/conversations/history/compressors/__init__.py +++ b/packages/ragbits-conversations/src/ragbits/conversations/history/compressors/__init__.py @@ -1,4 +1,4 @@ from .base import ConversationHistoryCompressor -from .llm import RecontextualizeLastMessage +from .llm import StandaloneMessageCompressor -__all__ = ["ConversationHistoryCompressor", "RecontextualizeLastMessage"] +__all__ = ["ConversationHistoryCompressor", "StandaloneMessageCompressor"] diff --git a/packages/ragbits-conversations/src/ragbits/conversations/history/compressors/llm.py b/packages/ragbits-conversations/src/ragbits/conversations/history/compressors/llm.py index 9bf9bfd39..fd741395f 100644 --- a/packages/ragbits-conversations/src/ragbits/conversations/history/compressors/llm.py +++ b/packages/ragbits-conversations/src/ragbits/conversations/history/compressors/llm.py @@ -14,7 +14,7 @@ class LastMessageAndHistory(BaseModel): history: list[str] -class RecontextualizeLastMessagePrompt(Prompt[LastMessageAndHistory, str]): +class StandaloneMessageCompressorPrompt(Prompt[LastMessageAndHistory, str]): """ A prompt for recontextualizing the last message in the history. """ @@ -37,7 +37,7 @@ class RecontextualizeLastMessagePrompt(Prompt[LastMessageAndHistory, str]): """ -class RecontextualizeLastMessage(ConversationHistoryCompressor): +class StandaloneMessageCompressor(ConversationHistoryCompressor): """ A compressor that uses LLM to recontextualize the last message in the history, i.e. create a standalone version of the message that includes necessary context. @@ -45,7 +45,7 @@ class RecontextualizeLastMessage(ConversationHistoryCompressor): def __init__(self, llm: LLM, history_len: int = 5, prompt: type[Prompt[LastMessageAndHistory, str]] | None = None): """ - Initialize the RecontextualizeLastMessage compressor with a LLM. + Initialize the StandaloneMessageCompressor compressor with a LLM. Args: llm: A LLM instance to handle recontextualizing the last message. @@ -54,7 +54,7 @@ def __init__(self, llm: LLM, history_len: int = 5, prompt: type[Prompt[LastMessa """ self._llm = llm self._history_len = history_len - self._prompt = prompt or RecontextualizeLastMessagePrompt + self._prompt = prompt or StandaloneMessageCompressorPrompt async def compress(self, conversation: ChatFormat) -> str: """ @@ -69,7 +69,10 @@ async def compress(self, conversation: ChatFormat) -> str: last_message = conversation[-1] if last_message["role"] != "user": - raise ValueError("RecontextualizeLastMessage expects the last message to be from the user.") + raise ValueError("StandaloneMessageCompressor expects the last message to be from the user.") + + if len(conversation) == 1: + return last_message["content"] # Only include "user" and "assistant" messages in the history other_messages = [message for message in conversation[:-1] if message["role"] in ["user", "assistant"]] diff --git a/packages/ragbits-conversations/tests/unit/history/test_llm_compressor.py b/packages/ragbits-conversations/tests/unit/history/test_llm_compressor.py index 2fd1fa522..715d8828f 100644 --- a/packages/ragbits-conversations/tests/unit/history/test_llm_compressor.py +++ b/packages/ragbits-conversations/tests/unit/history/test_llm_compressor.py @@ -1,6 +1,6 @@ import pytest -from ragbits.conversations.history.compressors.llm import LastMessageAndHistory, RecontextualizeLastMessage +from ragbits.conversations.history.compressors.llm import LastMessageAndHistory, StandaloneMessageCompressor from ragbits.core.llms.mock import MockLLM, MockLLMOptions from ragbits.core.prompt import ChatFormat from ragbits.core.prompt.prompt import Prompt @@ -17,7 +17,7 @@ async def test_messages_included(): {"role": "user", "content": "foo3"}, ] llm = MockLLM(default_options=MockLLMOptions(response="some answer")) - compressor = RecontextualizeLastMessage(llm) + compressor = StandaloneMessageCompressor(llm) answer = await compressor.compress(conversation) assert answer == "some answer" user_prompt = llm.calls[0][1] @@ -30,7 +30,7 @@ async def test_messages_included(): async def test_no_messages(): conversation: ChatFormat = [] - compressor = RecontextualizeLastMessage(MockLLM()) + compressor = StandaloneMessageCompressor(MockLLM()) with pytest.raises(ValueError): await compressor.compress(conversation) @@ -40,7 +40,7 @@ async def test_last_message_not_user(): conversation: ChatFormat = [ {"role": "assistant", "content": "foo2"}, ] - compressor = RecontextualizeLastMessage(MockLLM()) + compressor = StandaloneMessageCompressor(MockLLM()) with pytest.raises(ValueError): await compressor.compress(conversation) @@ -55,7 +55,7 @@ async def test_history_len(): {"role": "user", "content": "foo5"}, ] llm = MockLLM() - compressor = RecontextualizeLastMessage(llm, history_len=3) + compressor = StandaloneMessageCompressor(llm, history_len=3) await compressor.compress(conversation) user_prompt = llm.calls[0][1] assert user_prompt["role"] == "user" @@ -84,7 +84,7 @@ async def test_only_user_and_assistant_messages_in_history(): {"role": "user", "content": "foo6"}, ] llm = MockLLM() - compressor = RecontextualizeLastMessage(llm, history_len=4) + compressor = StandaloneMessageCompressor(llm, history_len=4) await compressor.compress(conversation) user_prompt = llm.calls[0][1] assert user_prompt["role"] == "user" @@ -104,7 +104,7 @@ async def test_changing_prompt(): {"role": "user", "content": "foo3"}, ] llm = MockLLM() - compressor = RecontextualizeLastMessage(llm, prompt=MockPrompt) + compressor = StandaloneMessageCompressor(llm, prompt=MockPrompt) await compressor.compress(conversation) user_prompt = llm.calls[0][0] assert user_prompt["role"] == "user" diff --git a/packages/ragbits/pyproject.toml b/packages/ragbits/pyproject.toml index cd32f8139..29e436a09 100644 --- a/packages/ragbits/pyproject.toml +++ b/packages/ragbits/pyproject.toml @@ -31,7 +31,14 @@ classifiers = [ "Topic :: Scientific/Engineering :: Artificial Intelligence", "Topic :: Software Development :: Libraries :: Python Modules", ] -dependencies = ["ragbits-document-search==0.6.0", "ragbits-cli==0.6.0", "ragbits-evaluate==0.6.0", "ragbits-guardrails==0.6.0", "ragbits-core==0.6.0"] +dependencies = [ + "ragbits-document-search==0.6.0", + "ragbits-cli==0.6.0", + "ragbits-evaluate==0.6.0", + "ragbits-guardrails==0.6.0", + "ragbits-core==0.6.0", + "ragbits-conversations==0.6.0", +] [project.urls] "Homepage" = "https://github.com/deepsense-ai/ragbits" From e6972ab21e79d4bc2b95b4688c6d6a934d8d3e86 Mon Sep 17 00:00:00 2001 From: Ludwik Trammer Date: Mon, 13 Jan 2025 17:43:11 +0100 Subject: [PATCH 5/6] Make the condition make more sense --- .../src/ragbits/conversations/history/compressors/llm.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/packages/ragbits-conversations/src/ragbits/conversations/history/compressors/llm.py b/packages/ragbits-conversations/src/ragbits/conversations/history/compressors/llm.py index fd741395f..25fbd96b6 100644 --- a/packages/ragbits-conversations/src/ragbits/conversations/history/compressors/llm.py +++ b/packages/ragbits-conversations/src/ragbits/conversations/history/compressors/llm.py @@ -71,12 +71,13 @@ async def compress(self, conversation: ChatFormat) -> str: if last_message["role"] != "user": raise ValueError("StandaloneMessageCompressor expects the last message to be from the user.") - if len(conversation) == 1: - return last_message["content"] - # Only include "user" and "assistant" messages in the history other_messages = [message for message in conversation[:-1] if message["role"] in ["user", "assistant"]] + if not other_messages: + # No history to use fro recontextualization, simply return the user message + return last_message["content"] + history = [f"{message['role']}: {message['content']}" for message in other_messages[-self._history_len :]] input_data = LastMessageAndHistory(last_message=last_message["content"], history=history) From f2312614b0943e68eafd073f4c00b5ce2c450b6a Mon Sep 17 00:00:00 2001 From: Ludwik Trammer Date: Mon, 13 Jan 2025 17:44:41 +0100 Subject: [PATCH 6/6] Fix a typo --- .../src/ragbits/conversations/history/compressors/llm.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/packages/ragbits-conversations/src/ragbits/conversations/history/compressors/llm.py b/packages/ragbits-conversations/src/ragbits/conversations/history/compressors/llm.py index 25fbd96b6..70797609a 100644 --- a/packages/ragbits-conversations/src/ragbits/conversations/history/compressors/llm.py +++ b/packages/ragbits-conversations/src/ragbits/conversations/history/compressors/llm.py @@ -75,7 +75,7 @@ async def compress(self, conversation: ChatFormat) -> str: other_messages = [message for message in conversation[:-1] if message["role"] in ["user", "assistant"]] if not other_messages: - # No history to use fro recontextualization, simply return the user message + # No history to use for recontextualization, simply return the user message return last_message["content"] history = [f"{message['role']}: {message['content']}" for message in other_messages[-self._history_len :]]