Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: support metadata output from LLM #274

Merged
merged 9 commits into from
Jan 17, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
67 changes: 56 additions & 11 deletions packages/ragbits-core/src/ragbits/core/llms/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import warnings as wrngs
from abc import ABC, abstractmethod
from collections.abc import AsyncGenerator
from typing import ClassVar, TypeVar, cast, overload
from typing import ClassVar, Generic, TypeVar, cast, overload

from pydantic import BaseModel

Expand All @@ -24,6 +24,15 @@ class LLMType(enum.Enum):
STRUCTURED_OUTPUT = "structured_output"


class LLMResponseWithMetadata(BaseModel, Generic[OutputT]):
"""
A schema of output with metadata
"""

content: OutputT
metadata: dict


class LLM(ConfigurableComponent[LLMClientOptionsT], ABC):
"""
Abstract class for interaction with Large Language Model.
Expand Down Expand Up @@ -68,7 +77,7 @@ async def generate_raw(
prompt: BasePrompt,
*,
options: LLMClientOptionsT | None = None,
) -> str:
) -> dict:
"""
Prepares and sends a prompt to the LLM and returns the raw response (without parsing).

Expand All @@ -77,18 +86,16 @@ async def generate_raw(
options: Options to use for the LLM client.

Returns:
Raw text response from LLM.
Raw response from LLM.
"""
merged_options = (self.default_options | options) if options else self.default_options
response = await self._call(
return await self._call(
conversation=self._format_chat_for_llm(prompt),
options=merged_options,
json_mode=prompt.json_mode,
output_schema=prompt.output_schema(),
)

return response

@overload
async def generate(
self,
Expand Down Expand Up @@ -123,11 +130,49 @@ async def generate(
Text response from LLM.
"""
response = await self.generate_raw(prompt, options=options)

content = response.pop("response")
if isinstance(prompt, BasePromptWithParser):
return prompt.parse_response(response)
return prompt.parse_response(content)
return cast(OutputT, content)

@overload
async def generate_with_metadata(
self,
prompt: BasePromptWithParser[OutputT],
*,
options: LLMClientOptionsT | None = None,
) -> LLMResponseWithMetadata[OutputT]: ...

return cast(OutputT, response)
@overload
async def generate_with_metadata(
self,
prompt: BasePrompt,
*,
options: LLMClientOptionsT | None = None,
) -> LLMResponseWithMetadata[OutputT]: ...

async def generate_with_metadata(
self,
prompt: BasePrompt,
*,
options: LLMClientOptionsT | None = None,
) -> LLMResponseWithMetadata[OutputT]:
"""
Prepares and sends a prompt to the LLM and returns response parsed to the
output type of the prompt (if available).

Args:
prompt: Formatted prompt template with conversation and optional response parsing configuration.
options: Options to use for the LLM client.

Returns:
Text response from LLM with metadata.
"""
response = await self.generate_raw(prompt, options=options)
content = response.pop("response")
if isinstance(prompt, BasePromptWithParser):
content = prompt.parse_response(content)
return LLMResponseWithMetadata[type(content)](content=content, metadata=response) # type: ignore

async def generate_streaming(
self,
Expand Down Expand Up @@ -167,7 +212,7 @@ async def _call(
options: LLMClientOptionsT,
json_mode: bool = False,
output_schema: type[BaseModel] | dict | None = None,
) -> str:
) -> dict:
"""
Calls LLM inference API.

Expand All @@ -178,7 +223,7 @@ async def _call(
output_schema: Schema for structured response (either Pydantic model or a JSON schema).

Returns:
Response string from LLM.
Response dict from LLM.
"""

@abstractmethod
Expand Down
27 changes: 19 additions & 8 deletions packages/ragbits-core/src/ragbits/core/llms/litellm.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from pydantic import BaseModel

from ragbits.core.audit import trace
from ragbits.core.llms.base import LLM
from ragbits.core.llms.exceptions import (
LLMConnectionError,
LLMEmptyResponseError,
Expand All @@ -17,8 +18,6 @@
from ragbits.core.prompt.base import BasePrompt, ChatFormat
from ragbits.core.types import NOT_GIVEN, NotGiven

from .base import LLM


class LiteLLMOptions(Options):
"""
Expand All @@ -34,6 +33,9 @@ class LiteLLMOptions(Options):
stop: str | list[str] | None | NotGiven = NOT_GIVEN
temperature: float | None | NotGiven = NOT_GIVEN
top_p: float | None | NotGiven = NOT_GIVEN
logprobs: bool | None | NotGiven = NOT_GIVEN
top_logprobs: int | None | NotGiven = NOT_GIVEN
logit_bias: dict | None | NotGiven = NOT_GIVEN
mock_response: str | None | NotGiven = NOT_GIVEN


Expand Down Expand Up @@ -125,7 +127,7 @@ async def _call(
options: LiteLLMOptions,
json_mode: bool = False,
output_schema: type[BaseModel] | dict | None = None,
) -> str:
) -> dict:
"""
Calls the appropriate LLM endpoint with the given prompt and options.

Expand Down Expand Up @@ -155,19 +157,24 @@ async def _call(
options=options.dict(),
) as outputs:
response = await self._get_litellm_response(
conversation=conversation, options=options, response_format=response_format
conversation=conversation,
options=options,
response_format=response_format,
)

if not response.choices: # type: ignore
raise LLMEmptyResponseError()

outputs.response = response.choices[0].message.content # type: ignore

if response.usage: # type: ignore
outputs.completion_tokens = response.usage.completion_tokens # type: ignore
outputs.prompt_tokens = response.usage.prompt_tokens # type: ignore
outputs.total_tokens = response.usage.total_tokens # type: ignore

return outputs.response # type: ignore
if options.logprobs:
outputs.logprobs = response.choices[0].logprobs["content"] # type: ignore

return vars(outputs) # type: ignore

async def _call_streaming(
self,
Expand Down Expand Up @@ -195,6 +202,7 @@ async def _call_streaming(
LLMResponseError: If the LLM API response is invalid.
"""
response_format = self._get_response_format(output_schema=output_schema, json_mode=json_mode)

with trace(
messages=conversation,
model=self.model_name,
Expand All @@ -204,9 +212,11 @@ async def _call_streaming(
options=options.dict(),
) as outputs:
response = await self._get_litellm_response(
conversation=conversation, options=options, response_format=response_format, stream=True
conversation=conversation,
options=options,
response_format=response_format,
stream=True,
)

if not response.completion_stream: # type: ignore
raise LLMEmptyResponseError()

Expand All @@ -215,6 +225,7 @@ async def response_to_async_generator(response: CustomStreamWrapper) -> AsyncGen
yield item.choices[0].delta.content or ""

outputs.response = response_to_async_generator(response) # type: ignore

return outputs.response # type: ignore

async def _get_litellm_response(
Expand Down
7 changes: 3 additions & 4 deletions packages/ragbits-core/src/ragbits/core/llms/local.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,12 @@
except ImportError:
HAS_LOCAL_LLM = False

from ragbits.core.llms.base import LLM
from ragbits.core.options import Options
from ragbits.core.prompt import ChatFormat
from ragbits.core.prompt.base import BasePrompt
from ragbits.core.types import NOT_GIVEN, NotGiven

from .base import LLM


class LocalLLMOptions(Options):
"""
Expand Down Expand Up @@ -93,7 +92,7 @@ async def _call(
options: LocalLLMOptions,
json_mode: bool = False,
output_schema: type[BaseModel] | dict | None = None,
) -> str:
) -> dict:
"""
Makes a call to the local LLM with the provided prompt and options.

Expand All @@ -117,7 +116,7 @@ async def _call(
)
response = outputs[0][input_ids.shape[-1] :]
decoded_response = self.tokenizer.decode(response, skip_special_tokens=True)
return decoded_response
return {"response": decoded_response}

async def _call_streaming(
self,
Expand Down
9 changes: 4 additions & 5 deletions packages/ragbits-core/src/ragbits/core/llms/mock.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,11 @@

from pydantic import BaseModel

from ragbits.core.llms.base import LLM
from ragbits.core.options import Options
from ragbits.core.prompt import ChatFormat
from ragbits.core.types import NOT_GIVEN, NotGiven

from .base import LLM


class MockLLMOptions(Options):
"""
Expand Down Expand Up @@ -42,14 +41,14 @@ async def _call( # noqa: PLR6301
options: MockLLMOptions,
json_mode: bool = False,
output_schema: type[BaseModel] | dict | None = None,
) -> str:
) -> dict:
"""
Mocks the call to the LLM, using the response from the options if provided.
"""
self.calls.append(conversation)
if not isinstance(options.response, NotGiven):
return options.response
return "mocked response"
return {"response": options.response}
return {"response": "mocked response"}

async def _call_streaming( # noqa: PLR6301
self,
Expand Down
4 changes: 2 additions & 2 deletions packages/ragbits-core/src/ragbits/core/prompt/lab/app.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import asyncio
from dataclasses import dataclass, field, replace
from typing import Any
from typing import Any, cast

try:
import gradio as gr
Expand Down Expand Up @@ -107,7 +107,7 @@ def send_prompt_to_llm(state: PromptState) -> str:
raise ValueError("LLM model is not configured.")

try:
response = asyncio.run(state.llm.generate_raw(prompt=state.rendered_prompt))
response = cast(str, asyncio.run(state.llm.generate_raw(prompt=state.rendered_prompt))["response"])
except Exception as e: # pylint: disable=broad-except
response = str(e)

Expand Down
18 changes: 16 additions & 2 deletions packages/ragbits-core/tests/unit/llms/test_litellm.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ async def test_generation_with_parser():
output = await llm.generate(prompt, options=options)
assert output == 42
raw_output = await llm.generate_raw(prompt, options=options)
assert raw_output == "I'm fine, thank you."
assert raw_output["response"] == "I'm fine, thank you."


async def test_generation_with_static_prompt():
Expand Down Expand Up @@ -117,7 +117,7 @@ class StaticPromptWithParser(Prompt[None, int]):
output = await llm.generate(prompt, options=options)
assert output == 42
raw_output = await llm.generate_raw(prompt, options=options)
assert raw_output == "42"
assert raw_output["response"] == "42"


async def test_generation_with_pydantic_output():
Expand All @@ -140,3 +140,17 @@ class PydanticPrompt(Prompt[None, OutputModel]):
output = await llm.generate(prompt, options=options)
assert output.response == "I'm fine, thank you."
assert output.happiness == 100


async def test_generation_with_metadata():
"""Test generation of a response."""
llm = LiteLLM(api_key="test_key")
prompt = MockPrompt("Hello, how are you?")
options = LiteLLMOptions(mock_response="I'm fine, thank you.")
output = await llm.generate_with_metadata(prompt, options=options)
assert output.content == "I'm fine, thank you."
assert output.metadata == {
"completion_tokens": 20,
"prompt_tokens": 10,
"total_tokens": 30,
}
Loading