From cbfa6bd752e5c97e38a1c3519894394d17650ed5 Mon Sep 17 00:00:00 2001 From: kdziedzic68 Date: Wed, 15 Jan 2025 10:34:42 +0100 Subject: [PATCH 1/8] fix linters --- packages/ragbits-core/src/ragbits/core/llms/base.py | 13 +++++++------ .../ragbits-core/src/ragbits/core/llms/litellm.py | 11 +++++++---- .../ragbits-core/src/ragbits/core/llms/local.py | 5 +++-- .../ragbits-core/src/ragbits/core/prompt/lab/app.py | 4 ++-- 4 files changed, 19 insertions(+), 14 deletions(-) diff --git a/packages/ragbits-core/src/ragbits/core/llms/base.py b/packages/ragbits-core/src/ragbits/core/llms/base.py index b5fc2a802..d2c4fdbfb 100644 --- a/packages/ragbits-core/src/ragbits/core/llms/base.py +++ b/packages/ragbits-core/src/ragbits/core/llms/base.py @@ -2,7 +2,7 @@ import warnings as wrngs from abc import ABC, abstractmethod from collections.abc import AsyncGenerator -from typing import ClassVar, TypeVar, cast, overload +from typing import Any, ClassVar, TypeVar, cast, overload from pydantic import BaseModel @@ -68,7 +68,7 @@ async def generate_raw( prompt: BasePrompt, *, options: LLMClientOptionsT | None = None, - ) -> str: + ) -> dict[str, Any]: """ Prepares and sends a prompt to the LLM and returns the raw response (without parsing). @@ -123,11 +123,12 @@ async def generate( Text response from LLM. """ response = await self.generate_raw(prompt, options=options) + response_str = cast(str, response["response"]) if isinstance(prompt, BasePromptWithParser): - return prompt.parse_response(response) + return prompt.parse_response(response_str) - return cast(OutputT, response) + return cast(OutputT, response_str) async def generate_streaming( self, @@ -167,7 +168,7 @@ async def _call( options: LLMClientOptionsT, json_mode: bool = False, output_schema: type[BaseModel] | dict | None = None, - ) -> str: + ) -> dict[str, Any]: """ Calls LLM inference API. @@ -178,7 +179,7 @@ async def _call( output_schema: Schema for structured response (either Pydantic model or a JSON schema). Returns: - Response string from LLM. + Response dict from LLM. """ @abstractmethod diff --git a/packages/ragbits-core/src/ragbits/core/llms/litellm.py b/packages/ragbits-core/src/ragbits/core/llms/litellm.py index 467807cf1..dfe6100b7 100644 --- a/packages/ragbits-core/src/ragbits/core/llms/litellm.py +++ b/packages/ragbits-core/src/ragbits/core/llms/litellm.py @@ -1,6 +1,7 @@ import base64 import warnings from collections.abc import AsyncGenerator +from typing import Any import litellm from litellm.utils import CustomStreamWrapper, ModelResponse @@ -35,6 +36,8 @@ class LiteLLMOptions(Options): temperature: float | None | NotGiven = NOT_GIVEN top_p: float | None | NotGiven = NOT_GIVEN mock_response: str | None | NotGiven = NOT_GIVEN + logprobs: bool | None | NotGiven = NOT_GIVEN + top_logprobs: int | None | NotGiven = NOT_GIVEN class LiteLLM(LLM[LiteLLMOptions]): @@ -125,7 +128,7 @@ async def _call( options: LiteLLMOptions, json_mode: bool = False, output_schema: type[BaseModel] | dict | None = None, - ) -> str: + ) -> dict[str, Any]: """ Calls the appropriate LLM endpoint with the given prompt and options. @@ -145,7 +148,6 @@ async def _call( LLMResponseError: If the LLM API response is invalid. """ response_format = self._get_response_format(output_schema=output_schema, json_mode=json_mode) - with trace( messages=conversation, model=self.model_name, @@ -166,8 +168,9 @@ async def _call( outputs.completion_tokens = response.usage.completion_tokens # type: ignore outputs.prompt_tokens = response.usage.prompt_tokens # type: ignore outputs.total_tokens = response.usage.total_tokens # type: ignore - - return outputs.response # type: ignore + if options.logprobs: + outputs.logprobs = response.choices[0].logprobs # type: ignore + return vars(outputs) # type: ignore async def _call_streaming( self, diff --git a/packages/ragbits-core/src/ragbits/core/llms/local.py b/packages/ragbits-core/src/ragbits/core/llms/local.py index f704f68aa..86f3b021b 100644 --- a/packages/ragbits-core/src/ragbits/core/llms/local.py +++ b/packages/ragbits-core/src/ragbits/core/llms/local.py @@ -1,6 +1,7 @@ import asyncio import threading from collections.abc import AsyncGenerator +from typing import Any from pydantic import BaseModel @@ -93,7 +94,7 @@ async def _call( options: LocalLLMOptions, json_mode: bool = False, output_schema: type[BaseModel] | dict | None = None, - ) -> str: + ) -> dict[str, Any]: """ Makes a call to the local LLM with the provided prompt and options. @@ -117,7 +118,7 @@ async def _call( ) response = outputs[0][input_ids.shape[-1] :] decoded_response = self.tokenizer.decode(response, skip_special_tokens=True) - return decoded_response + return {"response": decoded_response} async def _call_streaming( self, diff --git a/packages/ragbits-core/src/ragbits/core/prompt/lab/app.py b/packages/ragbits-core/src/ragbits/core/prompt/lab/app.py index 1fd7846c7..81858ba3a 100644 --- a/packages/ragbits-core/src/ragbits/core/prompt/lab/app.py +++ b/packages/ragbits-core/src/ragbits/core/prompt/lab/app.py @@ -1,6 +1,6 @@ import asyncio from dataclasses import dataclass, field, replace -from typing import Any +from typing import Any, cast try: import gradio as gr @@ -107,7 +107,7 @@ def send_prompt_to_llm(state: PromptState) -> str: raise ValueError("LLM model is not configured.") try: - response = asyncio.run(state.llm.generate_raw(prompt=state.rendered_prompt)) + response = cast(str, asyncio.run(state.llm.generate_raw(prompt=state.rendered_prompt))["response"]) except Exception as e: # pylint: disable=broad-except response = str(e) From f78c0cd82088f61946fb0c8defad21e0447ae1e2 Mon Sep 17 00:00:00 2001 From: kdziedzic68 Date: Wed, 15 Jan 2025 10:45:57 +0100 Subject: [PATCH 2/8] fix tests and formatter --- packages/ragbits-core/src/ragbits/core/llms/base.py | 2 +- packages/ragbits-core/src/ragbits/core/llms/litellm.py | 2 +- packages/ragbits-core/tests/unit/llms/test_litellm.py | 4 ++-- 3 files changed, 4 insertions(+), 4 deletions(-) diff --git a/packages/ragbits-core/src/ragbits/core/llms/base.py b/packages/ragbits-core/src/ragbits/core/llms/base.py index d2c4fdbfb..6cb714962 100644 --- a/packages/ragbits-core/src/ragbits/core/llms/base.py +++ b/packages/ragbits-core/src/ragbits/core/llms/base.py @@ -77,7 +77,7 @@ async def generate_raw( options: Options to use for the LLM client. Returns: - Raw text response from LLM. + Raw response from LLM. """ merged_options = (self.default_options | options) if options else self.default_options response = await self._call( diff --git a/packages/ragbits-core/src/ragbits/core/llms/litellm.py b/packages/ragbits-core/src/ragbits/core/llms/litellm.py index dfe6100b7..bbf331ba3 100644 --- a/packages/ragbits-core/src/ragbits/core/llms/litellm.py +++ b/packages/ragbits-core/src/ragbits/core/llms/litellm.py @@ -169,7 +169,7 @@ async def _call( outputs.prompt_tokens = response.usage.prompt_tokens # type: ignore outputs.total_tokens = response.usage.total_tokens # type: ignore if options.logprobs: - outputs.logprobs = response.choices[0].logprobs # type: ignore + outputs.logprobs = response.choices[0].logprobs # type: ignore return vars(outputs) # type: ignore async def _call_streaming( diff --git a/packages/ragbits-core/tests/unit/llms/test_litellm.py b/packages/ragbits-core/tests/unit/llms/test_litellm.py index 15397e24b..ae69aac0c 100644 --- a/packages/ragbits-core/tests/unit/llms/test_litellm.py +++ b/packages/ragbits-core/tests/unit/llms/test_litellm.py @@ -85,7 +85,7 @@ async def test_generation_with_parser(): output = await llm.generate(prompt, options=options) assert output == 42 raw_output = await llm.generate_raw(prompt, options=options) - assert raw_output == "I'm fine, thank you." + assert raw_output["response"] == "I'm fine, thank you." async def test_generation_with_static_prompt(): @@ -117,7 +117,7 @@ class StaticPromptWithParser(Prompt[None, int]): output = await llm.generate(prompt, options=options) assert output == 42 raw_output = await llm.generate_raw(prompt, options=options) - assert raw_output == "42" + assert raw_output["response"] == "42" async def test_generation_with_pydantic_output(): From 773bf6b89e0000509487fdb2e833a7cbd3e442a3 Mon Sep 17 00:00:00 2001 From: kdziedzic68 Date: Wed, 15 Jan 2025 11:58:28 +0100 Subject: [PATCH 3/8] rebase --- packages/ragbits-core/src/ragbits/core/llms/mock.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/packages/ragbits-core/src/ragbits/core/llms/mock.py b/packages/ragbits-core/src/ragbits/core/llms/mock.py index 9a82489fc..87a42f99c 100644 --- a/packages/ragbits-core/src/ragbits/core/llms/mock.py +++ b/packages/ragbits-core/src/ragbits/core/llms/mock.py @@ -1,4 +1,5 @@ from collections.abc import AsyncGenerator +from typing import Any from pydantic import BaseModel @@ -42,14 +43,14 @@ async def _call( # noqa: PLR6301 options: MockLLMOptions, json_mode: bool = False, output_schema: type[BaseModel] | dict | None = None, - ) -> str: + ) -> dict[str, Any]: """ Mocks the call to the LLM, using the response from the options if provided. """ self.calls.append(conversation) if not isinstance(options.response, NotGiven): - return options.response - return "mocked response" + return {"response": options.response} + return {"response": "mocked response"} async def _call_streaming( # noqa: PLR6301 self, From 731d9e41d33e2ae815ee529f07c186a3bed0b662 Mon Sep 17 00:00:00 2001 From: kdziedzic68 Date: Wed, 15 Jan 2025 12:59:04 +0100 Subject: [PATCH 4/8] fixing mypy --- .../src/ragbits/core/llms/base.py | 41 ++++++++++++++++--- 1 file changed, 36 insertions(+), 5 deletions(-) diff --git a/packages/ragbits-core/src/ragbits/core/llms/base.py b/packages/ragbits-core/src/ragbits/core/llms/base.py index 6cb714962..e7b994142 100644 --- a/packages/ragbits-core/src/ragbits/core/llms/base.py +++ b/packages/ragbits-core/src/ragbits/core/llms/base.py @@ -2,7 +2,7 @@ import warnings as wrngs from abc import ABC, abstractmethod from collections.abc import AsyncGenerator -from typing import Any, ClassVar, TypeVar, cast, overload +from typing import Any, ClassVar, Generic, TypeVar, cast, overload from pydantic import BaseModel @@ -24,6 +24,14 @@ class LLMType(enum.Enum): STRUCTURED_OUTPUT = "structured_output" +class LLMResponseWithMetadata(BaseModel, Generic[OutputT]): + """ + A schema of output with metadata + """ + llm_result: OutputT + metadata: dict[str, Any] + + class LLM(ConfigurableComponent[LLMClientOptionsT], ABC): """ Abstract class for interaction with Large Language Model. @@ -123,12 +131,28 @@ async def generate( Text response from LLM. """ response = await self.generate_raw(prompt, options=options) - response_str = cast(str, response["response"]) + return cast(OutputT, self._format_raw_llm_response(llm_response=response, prompt=prompt)) - if isinstance(prompt, BasePromptWithParser): - return prompt.parse_response(response_str) + async def generate_with_metadata( + self, + prompt: BasePrompt, + *, + options: LLMClientOptionsT | None = None, + ) -> LLMResponseWithMetadata[Any]: + """ + Prepares and sends a prompt to the LLM and returns response parsed to the + output type of the prompt (if available). + + Args: + prompt: Formatted prompt template with conversation and optional response parsing configuration. + options: Options to use for the LLM client. - return cast(OutputT, response_str) + Returns: + Text response from LLM. + """ + response = await self.generate_raw(prompt, options=options) + llm_result = self._format_raw_llm_response(llm_response=response, prompt=prompt) + return LLMResponseWithMetadata(llm_result=llm_result, metadata=response) async def generate_streaming( self, @@ -161,6 +185,13 @@ def _format_chat_for_llm(self, prompt: BasePrompt) -> ChatFormat: wrngs.warn(message=f"Image input not implemented for {self.__class__.__name__}") return prompt.chat + @staticmethod + def _format_raw_llm_response(llm_response: dict[str, Any], prompt: BasePrompt) -> Any: # noqa: ANN401 + response_str = llm_response.pop("response") + if isinstance(prompt, BasePromptWithParser): + return prompt.parse_response(response_str) + return response_str + @abstractmethod async def _call( self, From ef0503d411951249b96ae51ee45969a12e036331 Mon Sep 17 00:00:00 2001 From: kdziedzic68 Date: Wed, 15 Jan 2025 15:14:10 +0100 Subject: [PATCH 5/8] fix typing --- .../src/ragbits/core/llms/base.py | 43 +++++++++++++------ 1 file changed, 30 insertions(+), 13 deletions(-) diff --git a/packages/ragbits-core/src/ragbits/core/llms/base.py b/packages/ragbits-core/src/ragbits/core/llms/base.py index e7b994142..14fe6f74c 100644 --- a/packages/ragbits-core/src/ragbits/core/llms/base.py +++ b/packages/ragbits-core/src/ragbits/core/llms/base.py @@ -28,6 +28,7 @@ class LLMResponseWithMetadata(BaseModel, Generic[OutputT]): """ A schema of output with metadata """ + llm_result: OutputT metadata: dict[str, Any] @@ -130,15 +131,34 @@ async def generate( Returns: Text response from LLM. """ - response = await self.generate_raw(prompt, options=options) - return cast(OutputT, self._format_raw_llm_response(llm_response=response, prompt=prompt)) + raw_response = await self.generate_raw(prompt, options=options) + user_response = raw_response.pop("response") + if isinstance(prompt, BasePromptWithParser): + return prompt.parse_response(user_response) + return cast(OutputT, user_response) + + @overload + async def generate_with_metadata( + self, + prompt: BasePromptWithParser[OutputT], + *, + options: LLMClientOptionsT | None = None, + ) -> LLMResponseWithMetadata[OutputT]: ... + @overload async def generate_with_metadata( self, prompt: BasePrompt, *, options: LLMClientOptionsT | None = None, - ) -> LLMResponseWithMetadata[Any]: + ) -> LLMResponseWithMetadata[OutputT]: ... + + async def generate_with_metadata( + self, + prompt: BasePrompt, + *, + options: LLMClientOptionsT | None = None, + ) -> LLMResponseWithMetadata[OutputT]: """ Prepares and sends a prompt to the LLM and returns response parsed to the output type of the prompt (if available). @@ -150,9 +170,13 @@ async def generate_with_metadata( Returns: Text response from LLM. """ - response = await self.generate_raw(prompt, options=options) - llm_result = self._format_raw_llm_response(llm_response=response, prompt=prompt) - return LLMResponseWithMetadata(llm_result=llm_result, metadata=response) + raw_response = await self.generate_raw(prompt, options=options) + user_response = raw_response.pop("response") + if isinstance(prompt, BasePromptWithParser): + user_response = prompt.parse_response(user_response) + return LLMResponseWithMetadata(llm_result=user_response, metadata=raw_response) + user_response = cast(OutputT, user_response) + return LLMResponseWithMetadata(llm_result=user_response, metadata=raw_response) async def generate_streaming( self, @@ -185,13 +209,6 @@ def _format_chat_for_llm(self, prompt: BasePrompt) -> ChatFormat: wrngs.warn(message=f"Image input not implemented for {self.__class__.__name__}") return prompt.chat - @staticmethod - def _format_raw_llm_response(llm_response: dict[str, Any], prompt: BasePrompt) -> Any: # noqa: ANN401 - response_str = llm_response.pop("response") - if isinstance(prompt, BasePromptWithParser): - return prompt.parse_response(response_str) - return response_str - @abstractmethod async def _call( self, From 14681b98120d35602a29d760021005eaac980642 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Micha=C5=82=20Pstr=C4=85g?= Date: Fri, 17 Jan 2025 19:01:13 +0100 Subject: [PATCH 6/8] fix validation error + add test --- .../src/ragbits/core/llms/base.py | 22 ++++++-------- .../src/ragbits/core/llms/litellm.py | 30 ++++++++++++------- .../tests/unit/llms/test_litellm.py | 14 +++++++++ 3 files changed, 42 insertions(+), 24 deletions(-) diff --git a/packages/ragbits-core/src/ragbits/core/llms/base.py b/packages/ragbits-core/src/ragbits/core/llms/base.py index 14fe6f74c..36359395e 100644 --- a/packages/ragbits-core/src/ragbits/core/llms/base.py +++ b/packages/ragbits-core/src/ragbits/core/llms/base.py @@ -29,7 +29,7 @@ class LLMResponseWithMetadata(BaseModel, Generic[OutputT]): A schema of output with metadata """ - llm_result: OutputT + content: OutputT metadata: dict[str, Any] @@ -77,7 +77,7 @@ async def generate_raw( prompt: BasePrompt, *, options: LLMClientOptionsT | None = None, - ) -> dict[str, Any]: + ) -> dict: """ Prepares and sends a prompt to the LLM and returns the raw response (without parsing). @@ -89,15 +89,13 @@ async def generate_raw( Raw response from LLM. """ merged_options = (self.default_options | options) if options else self.default_options - response = await self._call( + return await self._call( conversation=self._format_chat_for_llm(prompt), options=merged_options, json_mode=prompt.json_mode, output_schema=prompt.output_schema(), ) - return response - @overload async def generate( self, @@ -168,15 +166,13 @@ async def generate_with_metadata( options: Options to use for the LLM client. Returns: - Text response from LLM. + Text response from LLM with metadata. """ - raw_response = await self.generate_raw(prompt, options=options) - user_response = raw_response.pop("response") + response = await self.generate_raw(prompt, options=options) + content = response.pop("response") if isinstance(prompt, BasePromptWithParser): - user_response = prompt.parse_response(user_response) - return LLMResponseWithMetadata(llm_result=user_response, metadata=raw_response) - user_response = cast(OutputT, user_response) - return LLMResponseWithMetadata(llm_result=user_response, metadata=raw_response) + content = prompt.parse_response(content) + return LLMResponseWithMetadata[type(content)](content=content, metadata=response) # type: ignore async def generate_streaming( self, @@ -216,7 +212,7 @@ async def _call( options: LLMClientOptionsT, json_mode: bool = False, output_schema: type[BaseModel] | dict | None = None, - ) -> dict[str, Any]: + ) -> dict: """ Calls LLM inference API. diff --git a/packages/ragbits-core/src/ragbits/core/llms/litellm.py b/packages/ragbits-core/src/ragbits/core/llms/litellm.py index bbf331ba3..8e2b7bd6c 100644 --- a/packages/ragbits-core/src/ragbits/core/llms/litellm.py +++ b/packages/ragbits-core/src/ragbits/core/llms/litellm.py @@ -1,13 +1,13 @@ import base64 import warnings from collections.abc import AsyncGenerator -from typing import Any import litellm from litellm.utils import CustomStreamWrapper, ModelResponse from pydantic import BaseModel from ragbits.core.audit import trace +from ragbits.core.llms.base import LLM from ragbits.core.llms.exceptions import ( LLMConnectionError, LLMEmptyResponseError, @@ -18,8 +18,6 @@ from ragbits.core.prompt.base import BasePrompt, ChatFormat from ragbits.core.types import NOT_GIVEN, NotGiven -from .base import LLM - class LiteLLMOptions(Options): """ @@ -35,9 +33,10 @@ class LiteLLMOptions(Options): stop: str | list[str] | None | NotGiven = NOT_GIVEN temperature: float | None | NotGiven = NOT_GIVEN top_p: float | None | NotGiven = NOT_GIVEN - mock_response: str | None | NotGiven = NOT_GIVEN logprobs: bool | None | NotGiven = NOT_GIVEN top_logprobs: int | None | NotGiven = NOT_GIVEN + logit_bias: dict | None | NotGiven = NOT_GIVEN + mock_response: str | None | NotGiven = NOT_GIVEN class LiteLLM(LLM[LiteLLMOptions]): @@ -128,7 +127,7 @@ async def _call( options: LiteLLMOptions, json_mode: bool = False, output_schema: type[BaseModel] | dict | None = None, - ) -> dict[str, Any]: + ) -> dict: """ Calls the appropriate LLM endpoint with the given prompt and options. @@ -148,6 +147,7 @@ async def _call( LLMResponseError: If the LLM API response is invalid. """ response_format = self._get_response_format(output_schema=output_schema, json_mode=json_mode) + with trace( messages=conversation, model=self.model_name, @@ -157,19 +157,23 @@ async def _call( options=options.dict(), ) as outputs: response = await self._get_litellm_response( - conversation=conversation, options=options, response_format=response_format + conversation=conversation, + options=options, + response_format=response_format, ) - if not response.choices: # type: ignore raise LLMEmptyResponseError() outputs.response = response.choices[0].message.content # type: ignore + if response.usage: # type: ignore outputs.completion_tokens = response.usage.completion_tokens # type: ignore outputs.prompt_tokens = response.usage.prompt_tokens # type: ignore outputs.total_tokens = response.usage.total_tokens # type: ignore - if options.logprobs: - outputs.logprobs = response.choices[0].logprobs # type: ignore + + if options.logprobs: + outputs.logprobs = response.choices[0].logprobs["content"] # type: ignore + return vars(outputs) # type: ignore async def _call_streaming( @@ -198,6 +202,7 @@ async def _call_streaming( LLMResponseError: If the LLM API response is invalid. """ response_format = self._get_response_format(output_schema=output_schema, json_mode=json_mode) + with trace( messages=conversation, model=self.model_name, @@ -207,9 +212,11 @@ async def _call_streaming( options=options.dict(), ) as outputs: response = await self._get_litellm_response( - conversation=conversation, options=options, response_format=response_format, stream=True + conversation=conversation, + options=options, + response_format=response_format, + stream=True, ) - if not response.completion_stream: # type: ignore raise LLMEmptyResponseError() @@ -218,6 +225,7 @@ async def response_to_async_generator(response: CustomStreamWrapper) -> AsyncGen yield item.choices[0].delta.content or "" outputs.response = response_to_async_generator(response) # type: ignore + return outputs.response # type: ignore async def _get_litellm_response( diff --git a/packages/ragbits-core/tests/unit/llms/test_litellm.py b/packages/ragbits-core/tests/unit/llms/test_litellm.py index ae69aac0c..906180773 100644 --- a/packages/ragbits-core/tests/unit/llms/test_litellm.py +++ b/packages/ragbits-core/tests/unit/llms/test_litellm.py @@ -140,3 +140,17 @@ class PydanticPrompt(Prompt[None, OutputModel]): output = await llm.generate(prompt, options=options) assert output.response == "I'm fine, thank you." assert output.happiness == 100 + + +async def test_generation_with_metadata(): + """Test generation of a response.""" + llm = LiteLLM(api_key="test_key") + prompt = MockPrompt("Hello, how are you?") + options = LiteLLMOptions(mock_response="I'm fine, thank you.") + output = await llm.generate_with_metadata(prompt, options=options) + assert output.content == "I'm fine, thank you." + assert output.metadata == { + 'completion_tokens': 20, + 'prompt_tokens': 10, + 'total_tokens': 30, + } From 45cdcfdabeba1c02b6b0a08f5b53cee2267ffc36 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Micha=C5=82=20Pstr=C4=85g?= Date: Fri, 17 Jan 2025 19:07:27 +0100 Subject: [PATCH 7/8] linters --- packages/ragbits-core/src/ragbits/core/llms/base.py | 12 ++++++------ packages/ragbits-core/src/ragbits/core/llms/local.py | 6 ++---- .../ragbits-core/tests/unit/llms/test_litellm.py | 6 +++--- 3 files changed, 11 insertions(+), 13 deletions(-) diff --git a/packages/ragbits-core/src/ragbits/core/llms/base.py b/packages/ragbits-core/src/ragbits/core/llms/base.py index 36359395e..8462ef9e1 100644 --- a/packages/ragbits-core/src/ragbits/core/llms/base.py +++ b/packages/ragbits-core/src/ragbits/core/llms/base.py @@ -2,7 +2,7 @@ import warnings as wrngs from abc import ABC, abstractmethod from collections.abc import AsyncGenerator -from typing import Any, ClassVar, Generic, TypeVar, cast, overload +from typing import ClassVar, Generic, TypeVar, cast, overload from pydantic import BaseModel @@ -30,7 +30,7 @@ class LLMResponseWithMetadata(BaseModel, Generic[OutputT]): """ content: OutputT - metadata: dict[str, Any] + metadata: dict class LLM(ConfigurableComponent[LLMClientOptionsT], ABC): @@ -129,11 +129,11 @@ async def generate( Returns: Text response from LLM. """ - raw_response = await self.generate_raw(prompt, options=options) - user_response = raw_response.pop("response") + response = await self.generate_raw(prompt, options=options) + content = response.pop("response") if isinstance(prompt, BasePromptWithParser): - return prompt.parse_response(user_response) - return cast(OutputT, user_response) + return prompt.parse_response(content) + return cast(OutputT, content) @overload async def generate_with_metadata( diff --git a/packages/ragbits-core/src/ragbits/core/llms/local.py b/packages/ragbits-core/src/ragbits/core/llms/local.py index 86f3b021b..f36c6b856 100644 --- a/packages/ragbits-core/src/ragbits/core/llms/local.py +++ b/packages/ragbits-core/src/ragbits/core/llms/local.py @@ -1,7 +1,6 @@ import asyncio import threading from collections.abc import AsyncGenerator -from typing import Any from pydantic import BaseModel @@ -14,13 +13,12 @@ except ImportError: HAS_LOCAL_LLM = False +from ragbits.core.llms.base import LLM from ragbits.core.options import Options from ragbits.core.prompt import ChatFormat from ragbits.core.prompt.base import BasePrompt from ragbits.core.types import NOT_GIVEN, NotGiven -from .base import LLM - class LocalLLMOptions(Options): """ @@ -94,7 +92,7 @@ async def _call( options: LocalLLMOptions, json_mode: bool = False, output_schema: type[BaseModel] | dict | None = None, - ) -> dict[str, Any]: + ) -> dict: """ Makes a call to the local LLM with the provided prompt and options. diff --git a/packages/ragbits-core/tests/unit/llms/test_litellm.py b/packages/ragbits-core/tests/unit/llms/test_litellm.py index 906180773..78c312a40 100644 --- a/packages/ragbits-core/tests/unit/llms/test_litellm.py +++ b/packages/ragbits-core/tests/unit/llms/test_litellm.py @@ -150,7 +150,7 @@ async def test_generation_with_metadata(): output = await llm.generate_with_metadata(prompt, options=options) assert output.content == "I'm fine, thank you." assert output.metadata == { - 'completion_tokens': 20, - 'prompt_tokens': 10, - 'total_tokens': 30, + "completion_tokens": 20, + "prompt_tokens": 10, + "total_tokens": 30, } From 1113b2bc7bb73faa872caea3d8ab707cd00a37f9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Micha=C5=82=20Pstr=C4=85g?= Date: Fri, 17 Jan 2025 19:10:37 +0100 Subject: [PATCH 8/8] types --- packages/ragbits-core/src/ragbits/core/llms/mock.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/packages/ragbits-core/src/ragbits/core/llms/mock.py b/packages/ragbits-core/src/ragbits/core/llms/mock.py index 87a42f99c..ffd6195f2 100644 --- a/packages/ragbits-core/src/ragbits/core/llms/mock.py +++ b/packages/ragbits-core/src/ragbits/core/llms/mock.py @@ -1,14 +1,12 @@ from collections.abc import AsyncGenerator -from typing import Any from pydantic import BaseModel +from ragbits.core.llms.base import LLM from ragbits.core.options import Options from ragbits.core.prompt import ChatFormat from ragbits.core.types import NOT_GIVEN, NotGiven -from .base import LLM - class MockLLMOptions(Options): """ @@ -43,7 +41,7 @@ async def _call( # noqa: PLR6301 options: MockLLMOptions, json_mode: bool = False, output_schema: type[BaseModel] | dict | None = None, - ) -> dict[str, Any]: + ) -> dict: """ Mocks the call to the LLM, using the response from the options if provided. """