From a0a438d65804e2a1660c76e5b32bbf9b0cff1a0f Mon Sep 17 00:00:00 2001 From: hagen-danswer Date: Thu, 16 Jan 2025 10:20:24 -0800 Subject: [PATCH 1/6] Combined Persona and Prompt API --- backend/ee/onyx/server/seeding.py | 6 +- backend/onyx/chat/chat_utils.py | 2 +- .../chat/prompt_builder/citations_prompt.py | 2 +- backend/onyx/chat/prompt_builder/utils.py | 20 -- backend/onyx/db/persona.py | 232 +++------------- backend/onyx/db/prompts.py | 132 +++++++++ backend/onyx/db/slack_channel_config.py | 2 +- backend/onyx/main.py | 2 - backend/onyx/prompts/direct_qa_prompts.py | 26 -- backend/onyx/seeding/load_yamls.py | 8 +- backend/onyx/server/features/persona/api.py | 59 ++-- .../onyx/server/features/persona/models.py | 22 +- backend/onyx/server/features/prompt/api.py | 152 ----------- backend/onyx/server/features/prompt/models.py | 10 - .../openai_assistants_api/asssistants_api.py | 2 +- .../app/admin/assistants/AssistantEditor.tsx | 44 +-- web/src/app/admin/assistants/lib.ts | 251 +++++------------- 17 files changed, 310 insertions(+), 662 deletions(-) create mode 100644 backend/onyx/db/prompts.py delete mode 100644 backend/onyx/server/features/prompt/api.py diff --git a/backend/ee/onyx/server/seeding.py b/backend/ee/onyx/server/seeding.py index 1481920890d..49efec72774 100644 --- a/backend/ee/onyx/server/seeding.py +++ b/backend/ee/onyx/server/seeding.py @@ -24,7 +24,7 @@ from onyx.db.llm import upsert_llm_provider from onyx.db.models import Tool from onyx.db.persona import upsert_persona -from onyx.server.features.persona.models import CreatePersonaRequest +from onyx.server.features.persona.models import PersonaUpsertRequest from onyx.server.manage.llm.models import LLMProviderUpsertRequest from onyx.server.settings.models import Settings from onyx.server.settings.store import store_settings as store_base_settings @@ -57,7 +57,7 @@ class SeedConfiguration(BaseModel): llms: list[LLMProviderUpsertRequest] | None = None admin_user_emails: list[str] | None = None seeded_logo_path: str | None = None - personas: list[CreatePersonaRequest] | None = None + personas: list[PersonaUpsertRequest] | None = None settings: Settings | None = None enterprise_settings: EnterpriseSettings | None = None @@ -128,7 +128,7 @@ def _seed_llms( ) -def _seed_personas(db_session: Session, personas: list[CreatePersonaRequest]) -> None: +def _seed_personas(db_session: Session, personas: list[PersonaUpsertRequest]) -> None: if personas: logger.notice("Seeding Personas") for persona in personas: diff --git a/backend/onyx/chat/chat_utils.py b/backend/onyx/chat/chat_utils.py index 70083a0da24..b14a005f386 100644 --- a/backend/onyx/chat/chat_utils.py +++ b/backend/onyx/chat/chat_utils.py @@ -25,7 +25,7 @@ from onyx.db.models import Prompt from onyx.db.models import Tool from onyx.db.models import User -from onyx.db.persona import get_prompts_by_ids +from onyx.db.prompts import get_prompts_by_ids from onyx.llm.models import PreviousMessage from onyx.natural_language_processing.utils import BaseTokenizer from onyx.server.query_and_chat.models import CreateChatMessageRequest diff --git a/backend/onyx/chat/prompt_builder/citations_prompt.py b/backend/onyx/chat/prompt_builder/citations_prompt.py index f2d88cc1283..52043abdf14 100644 --- a/backend/onyx/chat/prompt_builder/citations_prompt.py +++ b/backend/onyx/chat/prompt_builder/citations_prompt.py @@ -6,7 +6,7 @@ from onyx.configs.model_configs import GEN_AI_SINGLE_USER_MESSAGE_EXPECTED_MAX_TOKENS from onyx.context.search.models import InferenceChunk from onyx.db.models import Persona -from onyx.db.persona import get_default_prompt__read_only +from onyx.db.prompts import get_default_prompt__read_only from onyx.db.search_settings import get_multilingual_expansion from onyx.llm.factory import get_llms_for_persona from onyx.llm.factory import get_main_llm_from_tuple diff --git a/backend/onyx/chat/prompt_builder/utils.py b/backend/onyx/chat/prompt_builder/utils.py index 13084fcd188..7c231234c34 100644 --- a/backend/onyx/chat/prompt_builder/utils.py +++ b/backend/onyx/chat/prompt_builder/utils.py @@ -7,26 +7,6 @@ from onyx.file_store.models import InMemoryChatFile from onyx.llm.models import PreviousMessage from onyx.llm.utils import build_content_with_imgs -from onyx.prompts.direct_qa_prompts import PARAMATERIZED_PROMPT -from onyx.prompts.direct_qa_prompts import PARAMATERIZED_PROMPT_WITHOUT_CONTEXT - - -def build_dummy_prompt( - system_prompt: str, task_prompt: str, retrieval_disabled: bool -) -> str: - if retrieval_disabled: - return PARAMATERIZED_PROMPT_WITHOUT_CONTEXT.format( - user_query="", - system_prompt=system_prompt, - task_prompt=task_prompt, - ).strip() - - return PARAMATERIZED_PROMPT.format( - context_docs_str="", - user_query="", - system_prompt=system_prompt, - task_prompt=task_prompt, - ).strip() def translate_onyx_msg_to_langchain( diff --git a/backend/onyx/db/persona.py b/backend/onyx/db/persona.py index de277c32f86..97d2410598d 100644 --- a/backend/onyx/db/persona.py +++ b/backend/onyx/db/persona.py @@ -1,6 +1,5 @@ from collections.abc import Sequence from datetime import datetime -from functools import lru_cache from uuid import UUID from fastapi import HTTPException @@ -8,7 +7,6 @@ from sqlalchemy import exists from sqlalchemy import func from sqlalchemy import not_ -from sqlalchemy import or_ from sqlalchemy import Select from sqlalchemy import select from sqlalchemy import update @@ -23,7 +21,6 @@ from onyx.configs.chat_configs import CONTEXT_CHUNKS_BELOW from onyx.context.search.enums import RecencyBiasSetting from onyx.db.constants import SLACK_BOT_PERSONA_PREFIX -from onyx.db.engine import get_sqlalchemy_engine from onyx.db.models import DocumentSet from onyx.db.models import Persona from onyx.db.models import Persona__User @@ -35,8 +32,8 @@ from onyx.db.models import User from onyx.db.models import User__UserGroup from onyx.db.models import UserGroup -from onyx.server.features.persona.models import CreatePersonaRequest from onyx.server.features.persona.models import PersonaSnapshot +from onyx.server.features.persona.models import PersonaUpsertRequest from onyx.utils.logger import setup_logger from onyx.utils.variable_functionality import fetch_versioned_implementation @@ -107,9 +104,6 @@ def _add_user_filters( return stmt.where(where_clause) -# fetch_persona_by_id is used to fetch a persona by its ID. It is used to fetch a persona by its ID. - - def fetch_persona_by_id_for_user( db_session: Session, persona_id: int, user: User | None, get_editable: bool = True ) -> Persona: @@ -184,7 +178,7 @@ def make_persona_private( def create_update_persona( persona_id: int | None, - create_persona_request: CreatePersonaRequest, + create_persona_request: PersonaUpsertRequest, user: User | None, db_session: Session, ) -> PersonaSnapshot: @@ -192,14 +186,36 @@ def create_update_persona( # Permission to actually use these is checked later try: - persona_data = { - "persona_id": persona_id, - "user": user, - "db_session": db_session, - **create_persona_request.model_dump(exclude={"users", "groups"}), - } - - persona = upsert_persona(**persona_data) + all_prompt_ids = create_persona_request.prompt_ids or [] + if create_persona_request.existing_prompt_id: + all_prompt_ids.append(create_persona_request.existing_prompt_id) + if not all_prompt_ids: + raise ValueError("No prompt IDs provided") + persona = upsert_persona( + persona_id=persona_id, + user=user, + db_session=db_session, + description=create_persona_request.description, + name=create_persona_request.name, + prompt_ids=all_prompt_ids, + document_set_ids=create_persona_request.document_set_ids, + tool_ids=create_persona_request.tool_ids, + is_public=create_persona_request.is_public, + recency_bias=create_persona_request.recency_bias, + llm_model_provider_override=create_persona_request.llm_model_provider_override, + llm_model_version_override=create_persona_request.llm_model_version_override, + starter_messages=create_persona_request.starter_messages, + icon_color=create_persona_request.icon_color, + icon_shape=create_persona_request.icon_shape, + uploaded_image_id=create_persona_request.uploaded_image_id, + display_priority=create_persona_request.display_priority, + remove_image=create_persona_request.remove_image, + search_start_date=create_persona_request.search_start_date, + label_ids=create_persona_request.label_ids, + num_chunks=create_persona_request.num_chunks, + llm_relevance_filter=create_persona_request.llm_relevance_filter, + llm_filter_extraction=create_persona_request.llm_filter_extraction, + ) versioned_make_persona_private = fetch_versioned_implementation( "onyx.db.persona", "make_persona_private" @@ -265,24 +281,6 @@ def update_persona_public_status( db_session.commit() -def get_prompts( - user_id: UUID | None, - db_session: Session, - include_default: bool = True, - include_deleted: bool = False, -) -> Sequence[Prompt]: - stmt = select(Prompt).where( - or_(Prompt.user_id == user_id, Prompt.user_id.is_(None)) - ) - - if not include_default: - stmt = stmt.where(Prompt.default_prompt.is_(False)) - if not include_deleted: - stmt = stmt.where(Prompt.deleted.is_(False)) - - return db_session.scalars(stmt).all() - - def get_personas_for_user( # if user is `None` assume the user is an admin or auth is disabled user: User | None, @@ -374,65 +372,6 @@ def update_all_personas_display_priority( db_session.commit() -def upsert_prompt( - user: User | None, - name: str, - description: str, - system_prompt: str, - task_prompt: str, - include_citations: bool, - datetime_aware: bool, - personas: list[Persona] | None, - db_session: Session, - prompt_id: int | None = None, - default_prompt: bool = True, - commit: bool = True, -) -> Prompt: - if prompt_id is not None: - prompt = db_session.query(Prompt).filter_by(id=prompt_id).first() - else: - prompt = get_prompt_by_name(prompt_name=name, user=user, db_session=db_session) - - if prompt: - if not default_prompt and prompt.default_prompt: - raise ValueError("Cannot update default prompt with non-default.") - - prompt.name = name - prompt.description = description - prompt.system_prompt = system_prompt - prompt.task_prompt = task_prompt - prompt.include_citations = include_citations - prompt.datetime_aware = datetime_aware - prompt.default_prompt = default_prompt - - if personas is not None: - prompt.personas.clear() - prompt.personas = personas - - else: - prompt = Prompt( - id=prompt_id, - user_id=user.id if user else None, - name=name, - description=description, - system_prompt=system_prompt, - task_prompt=task_prompt, - include_citations=include_citations, - datetime_aware=datetime_aware, - default_prompt=default_prompt, - personas=personas or [], - ) - db_session.add(prompt) - - if commit: - db_session.commit() - else: - # Flush the session so that the Prompt has an ID - db_session.flush() - - return prompt - - def upsert_persona( user: User | None, name: str, @@ -477,6 +416,15 @@ def upsert_persona( persona_name=name, user=user, db_session=db_session ) + if existing_persona: + # this checks if the user has permission to edit the persona + # will raise an Exception if the user does not have permission + existing_persona = fetch_persona_by_id_for_user( + db_session=db_session, + persona_id=existing_persona.id, + user=user, + get_editable=True, + ) # Fetch and attach tools by IDs tools = None if tool_ids is not None: @@ -522,15 +470,6 @@ def upsert_persona( if existing_persona.builtin_persona and not builtin_persona: raise ValueError("Cannot update builtin persona with non-builtin.") - # this checks if the user has permission to edit the persona - # will raise an Exception if the user does not have permission - existing_persona = fetch_persona_by_id_for_user( - db_session=db_session, - persona_id=existing_persona.id, - user=user, - get_editable=True, - ) - # The following update excludes `default`, `built-in`, and display priority. # Display priority is handled separately in the `display-priority` endpoint. # `default` and `built-in` properties can only be set when creating a persona. @@ -619,16 +558,6 @@ def upsert_persona( return persona -def mark_prompt_as_deleted( - prompt_id: int, - user: User | None, - db_session: Session, -) -> None: - prompt = get_prompt_by_id(prompt_id=prompt_id, user=user, db_session=db_session) - prompt.deleted = True - db_session.commit() - - def delete_old_default_personas( db_session: Session, ) -> None: @@ -666,69 +595,6 @@ def validate_persona_tools(tools: list[Tool]) -> None: ) -def get_prompts_by_ids(prompt_ids: list[int], db_session: Session) -> list[Prompt]: - """Unsafe, can fetch prompts from all users""" - if not prompt_ids: - return [] - prompts = db_session.scalars( - select(Prompt).where(Prompt.id.in_(prompt_ids)).where(Prompt.deleted.is_(False)) - ).all() - - return list(prompts) - - -def get_prompt_by_id( - prompt_id: int, - user: User | None, - db_session: Session, - include_deleted: bool = False, -) -> Prompt: - stmt = select(Prompt).where(Prompt.id == prompt_id) - - # if user is not specified OR they are an admin, they should - # have access to all prompts, so this where clause is not needed - if user and user.role != UserRole.ADMIN: - stmt = stmt.where(or_(Prompt.user_id == user.id, Prompt.user_id.is_(None))) - - if not include_deleted: - stmt = stmt.where(Prompt.deleted.is_(False)) - - result = db_session.execute(stmt) - prompt = result.scalar_one_or_none() - - if prompt is None: - raise ValueError( - f"Prompt with ID {prompt_id} does not exist or does not belong to user" - ) - - return prompt - - -def _get_default_prompt(db_session: Session) -> Prompt: - stmt = select(Prompt).where(Prompt.id == 0) - result = db_session.execute(stmt) - prompt = result.scalar_one_or_none() - - if prompt is None: - raise RuntimeError("Default Prompt not found") - - return prompt - - -def get_default_prompt(db_session: Session) -> Prompt: - return _get_default_prompt(db_session) - - -@lru_cache() -def get_default_prompt__read_only() -> Prompt: - """Due to the way lru_cache / SQLAlchemy works, this can cause issues - when trying to attach the returned `Prompt` object to a `Persona`. If you are - doing anything other than reading, you should use the `get_default_prompt` - method instead.""" - with Session(get_sqlalchemy_engine()) as db_session: - return _get_default_prompt(db_session) - - # TODO: since this gets called with every chat message, could it be more efficient to pregenerate # a direct mapping indicating whether a user has access to a specific persona? def get_persona_by_id( @@ -800,22 +666,6 @@ def get_personas_by_ids( return personas -def get_prompt_by_name( - prompt_name: str, user: User | None, db_session: Session -) -> Prompt | None: - stmt = select(Prompt).where(Prompt.name == prompt_name) - - # if user is not specified OR they are an admin, they should - # have access to all prompts, so this where clause is not needed - if user and user.role != UserRole.ADMIN: - stmt = stmt.where(Prompt.user_id == user.id) - - # Order by ID to ensure consistent result when multiple prompts exist - stmt = stmt.order_by(Prompt.id).limit(1) - result = db_session.execute(stmt).scalar_one_or_none() - return result - - def delete_persona_by_name( persona_name: str, db_session: Session, is_default: bool = True ) -> None: diff --git a/backend/onyx/db/prompts.py b/backend/onyx/db/prompts.py new file mode 100644 index 00000000000..0dacad9879c --- /dev/null +++ b/backend/onyx/db/prompts.py @@ -0,0 +1,132 @@ +from functools import lru_cache + +from sqlalchemy import select +from sqlalchemy.orm import Session + +from onyx.auth.schemas import UserRole +from onyx.db.engine import get_sqlalchemy_engine +from onyx.db.models import Persona +from onyx.db.models import Prompt +from onyx.db.models import User +from onyx.utils.logger import setup_logger + + +# Note: As prompts are fairly innocuous/harmless, there are no protections +# to prevent users from messing with prompts of other users. + +logger = setup_logger() + + +def _get_default_prompt(db_session: Session) -> Prompt: + stmt = select(Prompt).where(Prompt.id == 0) + result = db_session.execute(stmt) + prompt = result.scalar_one_or_none() + + if prompt is None: + raise RuntimeError("Default Prompt not found") + + return prompt + + +def get_default_prompt(db_session: Session) -> Prompt: + return _get_default_prompt(db_session) + + +@lru_cache() +def get_default_prompt__read_only() -> Prompt: + """Due to the way lru_cache / SQLAlchemy works, this can cause issues + when trying to attach the returned `Prompt` object to a `Persona`. If you are + doing anything other than reading, you should use the `get_default_prompt` + method instead.""" + with Session(get_sqlalchemy_engine()) as db_session: + return _get_default_prompt(db_session) + + +def get_prompts_by_ids(prompt_ids: list[int], db_session: Session) -> list[Prompt]: + """Unsafe, can fetch prompts from all users""" + if not prompt_ids: + return [] + prompts = db_session.scalars( + select(Prompt).where(Prompt.id.in_(prompt_ids)).where(Prompt.deleted.is_(False)) + ).all() + + return list(prompts) + + +def get_prompt_by_name( + prompt_name: str, user: User | None, db_session: Session +) -> Prompt | None: + stmt = select(Prompt).where(Prompt.name == prompt_name) + + # if user is not specified OR they are an admin, they should + # have access to all prompts, so this where clause is not needed + if user and user.role != UserRole.ADMIN: + stmt = stmt.where(Prompt.user_id == user.id) + + # Order by ID to ensure consistent result when multiple prompts exist + stmt = stmt.order_by(Prompt.id).limit(1) + result = db_session.execute(stmt).scalar_one_or_none() + return result + + +def build_prompt_name_from_persona_name(persona_name: str) -> str: + return f"default-prompt__{persona_name}" + + +def upsert_prompt( + db_session: Session, + user: User | None, + name: str, + system_prompt: str, + task_prompt: str, + prompt_id: int | None = None, + personas: list[Persona] | None = None, + datetime_aware: bool = True, + include_citations: bool = False, + default_prompt: bool = True, + # Support backwards compatibility + description: str | None = None, +) -> Prompt: + if description is None: + description = f"Default prompt for {name}" + + if prompt_id is not None: + prompt = db_session.query(Prompt).filter_by(id=prompt_id).first() + else: + prompt = get_prompt_by_name(prompt_name=name, user=user, db_session=db_session) + + if prompt: + if not default_prompt and prompt.default_prompt: + raise ValueError("Cannot update default prompt with non-default.") + + prompt.name = name + prompt.description = description + prompt.system_prompt = system_prompt + prompt.task_prompt = task_prompt + prompt.include_citations = include_citations + prompt.datetime_aware = datetime_aware + prompt.default_prompt = default_prompt + + if personas is not None: + prompt.personas.clear() + prompt.personas = personas + + else: + prompt = Prompt( + id=prompt_id, + user_id=user.id if user else None, + name=name, + description=description, + system_prompt=system_prompt, + task_prompt=task_prompt, + include_citations=include_citations, + datetime_aware=datetime_aware, + default_prompt=default_prompt, + personas=personas or [], + ) + db_session.add(prompt) + + # Flush the session so that the Prompt has an ID + db_session.flush() + + return prompt diff --git a/backend/onyx/db/slack_channel_config.py b/backend/onyx/db/slack_channel_config.py index ee6dcc2767a..a3b84533e6a 100644 --- a/backend/onyx/db/slack_channel_config.py +++ b/backend/onyx/db/slack_channel_config.py @@ -12,9 +12,9 @@ from onyx.db.models import Persona__DocumentSet from onyx.db.models import SlackChannelConfig from onyx.db.models import User -from onyx.db.persona import get_default_prompt from onyx.db.persona import mark_persona_as_deleted from onyx.db.persona import upsert_persona +from onyx.db.prompts import get_default_prompt from onyx.utils.errors import EERequiredError from onyx.utils.variable_functionality import ( fetch_versioned_implementation_with_fallback, diff --git a/backend/onyx/main.py b/backend/onyx/main.py index 382d101b845..c2917c0e41e 100644 --- a/backend/onyx/main.py +++ b/backend/onyx/main.py @@ -64,7 +64,6 @@ from onyx.server.features.notifications.api import router as notification_router from onyx.server.features.persona.api import admin_router as admin_persona_router from onyx.server.features.persona.api import basic_router as persona_router -from onyx.server.features.prompt.api import basic_router as prompt_router from onyx.server.features.tool.api import admin_router as admin_tool_router from onyx.server.features.tool.api import router as tool_router from onyx.server.gpts.api import router as gpts_router @@ -296,7 +295,6 @@ def get_application() -> FastAPI: include_router_with_global_prefix_prepended(application, persona_router) include_router_with_global_prefix_prepended(application, admin_persona_router) include_router_with_global_prefix_prepended(application, notification_router) - include_router_with_global_prefix_prepended(application, prompt_router) include_router_with_global_prefix_prepended(application, tool_router) include_router_with_global_prefix_prepended(application, admin_tool_router) include_router_with_global_prefix_prepended(application, state_router) diff --git a/backend/onyx/prompts/direct_qa_prompts.py b/backend/onyx/prompts/direct_qa_prompts.py index 2b98c486405..133205dfd38 100644 --- a/backend/onyx/prompts/direct_qa_prompts.py +++ b/backend/onyx/prompts/direct_qa_prompts.py @@ -118,32 +118,6 @@ """ -# This is only for visualization for the users to specify their own prompts -# The actual flow does not work like this -PARAMATERIZED_PROMPT = f""" -{{system_prompt}} - -CONTEXT: -{GENERAL_SEP_PAT} -{{context_docs_str}} -{GENERAL_SEP_PAT} - -{{task_prompt}} - -{QUESTION_PAT.upper()} {{user_query}} -RESPONSE: -""".strip() - -PARAMATERIZED_PROMPT_WITHOUT_CONTEXT = f""" -{{system_prompt}} - -{{task_prompt}} - -{QUESTION_PAT.upper()} {{user_query}} -RESPONSE: -""".strip() - - # CURRENTLY DISABLED, CANNOT USE THIS ONE # Default chain-of-thought style json prompt which uses multiple docs # This one has a section for the LLM to output some non-answer "thoughts" diff --git a/backend/onyx/seeding/load_yamls.py b/backend/onyx/seeding/load_yamls.py index 8d06ec181ac..d00af08dfb2 100644 --- a/backend/onyx/seeding/load_yamls.py +++ b/backend/onyx/seeding/load_yamls.py @@ -12,9 +12,9 @@ from onyx.db.models import Persona from onyx.db.models import Prompt as PromptDBModel from onyx.db.models import Tool as ToolDBModel -from onyx.db.persona import get_prompt_by_name from onyx.db.persona import upsert_persona -from onyx.db.persona import upsert_prompt +from onyx.db.prompts import get_prompt_by_name +from onyx.db.prompts import upsert_prompt def load_prompts_from_yaml( @@ -26,6 +26,7 @@ def load_prompts_from_yaml( all_prompts = data.get("prompts", []) for prompt in all_prompts: upsert_prompt( + db_session=db_session, user=None, prompt_id=prompt.get("id"), name=prompt["name"], @@ -36,9 +37,8 @@ def load_prompts_from_yaml( datetime_aware=prompt.get("datetime_aware", True), default_prompt=True, personas=None, - db_session=db_session, - commit=True, ) + db_session.commit() def load_input_prompts_from_yaml( diff --git a/backend/onyx/server/features/persona/api.py b/backend/onyx/server/features/persona/api.py index a318ef7b5c6..a95736cf5b4 100644 --- a/backend/onyx/server/features/persona/api.py +++ b/backend/onyx/server/features/persona/api.py @@ -14,7 +14,6 @@ from onyx.auth.users import current_curator_or_admin_user from onyx.auth.users import current_limited_user from onyx.auth.users import current_user -from onyx.chat.prompt_builder.utils import build_dummy_prompt from onyx.configs.constants import FileOrigin from onyx.configs.constants import MilestoneRecordType from onyx.configs.constants import NotificationType @@ -36,19 +35,21 @@ from onyx.db.persona import update_persona_public_status from onyx.db.persona import update_persona_shared_users from onyx.db.persona import update_persona_visibility +from onyx.db.prompts import build_prompt_name_from_persona_name +from onyx.db.prompts import upsert_prompt from onyx.file_store.file_store import get_default_file_store from onyx.file_store.models import ChatFileType from onyx.secondary_llm_flows.starter_message_creation import ( generate_starter_messages, ) -from onyx.server.features.persona.models import CreatePersonaRequest from onyx.server.features.persona.models import GenerateStarterMessageRequest from onyx.server.features.persona.models import ImageGenerationToolStatus from onyx.server.features.persona.models import PersonaLabelCreate from onyx.server.features.persona.models import PersonaLabelResponse from onyx.server.features.persona.models import PersonaSharedNotificationData from onyx.server.features.persona.models import PersonaSnapshot -from onyx.server.features.persona.models import PromptTemplateResponse +from onyx.server.features.persona.models import PersonaUpsertRequest +from onyx.server.features.prompt.models import PromptSnapshot from onyx.server.models import DisplayPriorityRequest from onyx.tools.utils import is_image_generation_available from onyx.utils.logger import setup_logger @@ -173,18 +174,29 @@ def upload_file( @basic_router.post("") def create_persona( - create_persona_request: CreatePersonaRequest, + persona_upsert_request: PersonaUpsertRequest, user: User | None = Depends(current_user), db_session: Session = Depends(get_session), tenant_id: str | None = Depends(get_current_tenant_id), ) -> PersonaSnapshot: + prompt = upsert_prompt( + db_session=db_session, + user=user, + name=build_prompt_name_from_persona_name(persona_upsert_request.name), + system_prompt=persona_upsert_request.system_prompt, + task_prompt=persona_upsert_request.task_prompt, + include_citations=persona_upsert_request.include_citations, + prompt_id=persona_upsert_request.existing_prompt_id, + ) + prompt_snapshot = PromptSnapshot.from_model(prompt) + persona_upsert_request.existing_prompt_id = prompt.id persona_snapshot = create_update_persona( persona_id=None, - create_persona_request=create_persona_request, + create_persona_request=persona_upsert_request, user=user, db_session=db_session, ) - + persona_snapshot.prompts = [prompt_snapshot] create_milestone_and_report( user=user, distinct_id=tenant_id or "N/A", @@ -202,16 +214,29 @@ def create_persona( @basic_router.patch("/{persona_id}") def update_persona( persona_id: int, - update_persona_request: CreatePersonaRequest, + persona_upsert_request: PersonaUpsertRequest, user: User | None = Depends(current_user), db_session: Session = Depends(get_session), ) -> PersonaSnapshot: - return create_update_persona( + prompt = upsert_prompt( + db_session=db_session, + user=user, + name=build_prompt_name_from_persona_name(persona_upsert_request.name), + system_prompt=persona_upsert_request.system_prompt, + task_prompt=persona_upsert_request.task_prompt, + include_citations=persona_upsert_request.include_citations, + prompt_id=persona_upsert_request.existing_prompt_id, + ) + prompt_snapshot = PromptSnapshot.from_model(prompt) + persona_upsert_request.existing_prompt_id = prompt.id + persona_snapshot = create_update_persona( persona_id=persona_id, - create_persona_request=update_persona_request, + create_persona_request=persona_upsert_request, user=user, db_session=db_session, ) + persona_snapshot.prompts = [prompt_snapshot] + return persona_snapshot class PersonaLabelPatchRequest(BaseModel): @@ -365,22 +390,6 @@ def get_persona( ) -@basic_router.get("/utils/prompt-explorer") -def build_final_template_prompt( - system_prompt: str, - task_prompt: str, - retrieval_disabled: bool = False, - _: User | None = Depends(current_user), -) -> PromptTemplateResponse: - return PromptTemplateResponse( - final_prompt_template=build_dummy_prompt( - system_prompt=system_prompt, - task_prompt=task_prompt, - retrieval_disabled=retrieval_disabled, - ) - ) - - @basic_router.post("/assistant-prompt-refresh") def build_assistant_prompts( generate_persona_prompt_request: GenerateStarterMessageRequest, diff --git a/backend/onyx/server/features/persona/models.py b/backend/onyx/server/features/persona/models.py index fd41f43bdd7..5431f8688f1 100644 --- a/backend/onyx/server/features/persona/models.py +++ b/backend/onyx/server/features/persona/models.py @@ -27,32 +27,36 @@ class GenerateStarterMessageRequest(BaseModel): generation_count: int -class CreatePersonaRequest(BaseModel): +class PersonaUpsertRequest(BaseModel): name: str description: str + existing_prompt_id: int | None = None + system_prompt: str + task_prompt: str + document_set_ids: list[int] num_chunks: float - llm_relevance_filter: bool + include_citations: bool is_public: bool - llm_filter_extraction: bool recency_bias: RecencyBiasSetting prompt_ids: list[int] - document_set_ids: list[int] - # e.g. ID of SearchTool or ImageGenerationTool or - tool_ids: list[int] + llm_filter_extraction: bool + llm_relevance_filter: bool llm_model_provider_override: str | None = None llm_model_version_override: str | None = None starter_messages: list[StarterMessage] | None = None # For Private Personas, who should be able to access these users: list[UUID] = Field(default_factory=list) groups: list[int] = Field(default_factory=list) + # e.g. ID of SearchTool or ImageGenerationTool or + tool_ids: list[int] icon_color: str | None = None icon_shape: int | None = None - uploaded_image_id: str | None = None # New field for uploaded image remove_image: bool | None = None - is_default_persona: bool = False - display_priority: int | None = None + uploaded_image_id: str | None = None # New field for uploaded image search_start_date: datetime | None = None label_ids: list[int] | None = None + is_default_persona: bool = False + display_priority: int | None = None class PersonaSnapshot(BaseModel): diff --git a/backend/onyx/server/features/prompt/api.py b/backend/onyx/server/features/prompt/api.py deleted file mode 100644 index 5432fa96100..00000000000 --- a/backend/onyx/server/features/prompt/api.py +++ /dev/null @@ -1,152 +0,0 @@ -from fastapi import APIRouter -from fastapi import Depends -from fastapi import HTTPException -from sqlalchemy.orm import Session -from starlette import status - -from onyx.auth.users import current_user -from onyx.db.engine import get_session -from onyx.db.models import User -from onyx.db.persona import get_personas_by_ids -from onyx.db.persona import get_prompt_by_id -from onyx.db.persona import get_prompts -from onyx.db.persona import mark_prompt_as_deleted -from onyx.db.persona import upsert_prompt -from onyx.server.features.prompt.models import CreatePromptRequest -from onyx.server.features.prompt.models import PromptSnapshot -from onyx.utils.logger import setup_logger - - -# Note: As prompts are fairly innocuous/harmless, there are no protections -# to prevent users from messing with prompts of other users. - -logger = setup_logger() - -basic_router = APIRouter(prefix="/prompt") - - -def create_update_prompt( - prompt_id: int | None, - create_prompt_request: CreatePromptRequest, - user: User | None, - db_session: Session, -) -> PromptSnapshot: - personas = ( - list( - get_personas_by_ids( - persona_ids=create_prompt_request.persona_ids, - db_session=db_session, - ) - ) - if create_prompt_request.persona_ids - else [] - ) - - prompt = upsert_prompt( - prompt_id=prompt_id, - user=user, - name=create_prompt_request.name, - description=create_prompt_request.description, - system_prompt=create_prompt_request.system_prompt, - task_prompt=create_prompt_request.task_prompt, - include_citations=create_prompt_request.include_citations, - datetime_aware=create_prompt_request.datetime_aware, - personas=personas, - db_session=db_session, - ) - return PromptSnapshot.from_model(prompt) - - -@basic_router.post("") -def create_prompt( - create_prompt_request: CreatePromptRequest, - user: User | None = Depends(current_user), - db_session: Session = Depends(get_session), -) -> PromptSnapshot: - try: - return create_update_prompt( - prompt_id=None, - create_prompt_request=create_prompt_request, - user=user, - db_session=db_session, - ) - except ValueError as ve: - logger.exception(ve) - raise HTTPException( - status_code=status.HTTP_400_BAD_REQUEST, - detail="Failed to create Persona, invalid info.", - ) - except Exception as e: - logger.exception(e) - raise HTTPException( - status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, - detail="An unexpected error occurred. Please try again later.", - ) - - -@basic_router.patch("/{prompt_id}") -def update_prompt( - prompt_id: int, - update_prompt_request: CreatePromptRequest, - user: User | None = Depends(current_user), - db_session: Session = Depends(get_session), -) -> PromptSnapshot: - try: - return create_update_prompt( - prompt_id=prompt_id, - create_prompt_request=update_prompt_request, - user=user, - db_session=db_session, - ) - except ValueError as ve: - logger.exception(ve) - raise HTTPException( - status_code=status.HTTP_400_BAD_REQUEST, - detail="Failed to create Persona, invalid info.", - ) - except Exception as e: - logger.exception(e) - raise HTTPException( - status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, - detail="An unexpected error occurred. Please try again later.", - ) - - -@basic_router.delete("/{prompt_id}") -def delete_prompt( - prompt_id: int, - user: User | None = Depends(current_user), - db_session: Session = Depends(get_session), -) -> None: - mark_prompt_as_deleted( - prompt_id=prompt_id, - user=user, - db_session=db_session, - ) - - -@basic_router.get("") -def list_prompts( - user: User | None = Depends(current_user), - db_session: Session = Depends(get_session), -) -> list[PromptSnapshot]: - user_id = user.id if user is not None else None - return [ - PromptSnapshot.from_model(prompt) - for prompt in get_prompts(user_id=user_id, db_session=db_session) - ] - - -@basic_router.get("/{prompt_id}") -def get_prompt( - prompt_id: int, - user: User | None = Depends(current_user), - db_session: Session = Depends(get_session), -) -> PromptSnapshot: - return PromptSnapshot.from_model( - get_prompt_by_id( - prompt_id=prompt_id, - user=user, - db_session=db_session, - ) - ) diff --git a/backend/onyx/server/features/prompt/models.py b/backend/onyx/server/features/prompt/models.py index c15127d5ed8..fb7a160166a 100644 --- a/backend/onyx/server/features/prompt/models.py +++ b/backend/onyx/server/features/prompt/models.py @@ -3,16 +3,6 @@ from onyx.db.models import Prompt -class CreatePromptRequest(BaseModel): - name: str - description: str - system_prompt: str - task_prompt: str - include_citations: bool = False - datetime_aware: bool = False - persona_ids: list[int] | None = None - - class PromptSnapshot(BaseModel): id: int name: str diff --git a/backend/onyx/server/openai_assistants_api/asssistants_api.py b/backend/onyx/server/openai_assistants_api/asssistants_api.py index 78ef8d45f8a..2e8ecc1e354 100644 --- a/backend/onyx/server/openai_assistants_api/asssistants_api.py +++ b/backend/onyx/server/openai_assistants_api/asssistants_api.py @@ -18,7 +18,7 @@ from onyx.db.persona import get_personas_for_user from onyx.db.persona import mark_persona_as_deleted from onyx.db.persona import upsert_persona -from onyx.db.persona import upsert_prompt +from onyx.db.prompts import upsert_prompt from onyx.db.tools import get_tool_by_name from onyx.utils.logger import setup_logger diff --git a/web/src/app/admin/assistants/AssistantEditor.tsx b/web/src/app/admin/assistants/AssistantEditor.tsx index 0e3943bd646..ea22ad4f23a 100644 --- a/web/src/app/admin/assistants/AssistantEditor.tsx +++ b/web/src/app/admin/assistants/AssistantEditor.tsx @@ -6,14 +6,11 @@ import { generateRandomIconShape } from "@/lib/assistantIconUtils"; import { CCPairBasicInfo, DocumentSet, User, UserGroup } from "@/lib/types"; import { Separator } from "@/components/ui/separator"; import { Button } from "@/components/ui/button"; -import { Textarea } from "@/components/ui/textarea"; -import { IsPublicGroupSelector } from "@/components/IsPublicGroupSelector"; import { ArrayHelpers, FieldArray, Form, Formik, FormikProps } from "formik"; import { BooleanFormField, Label, - SelectorFormField, TextFormField, } from "@/components/admin/connectors/Field"; @@ -46,10 +43,10 @@ import { Persona, PersonaLabel, StarterMessage, - StarterMessageBase, } from "./interfaces"; import { createPersonaLabel, + PersonaUpsertParameters, createPersona, deletePersonaLabel, updatePersonaLabel, @@ -67,30 +64,19 @@ import { useAssistants } from "@/components/context/AssistantsContext"; import { debounce } from "lodash"; import { FullLLMProvider } from "../configuration/llm/interfaces"; import StarterMessagesList from "./StarterMessageList"; -import { LabelCard } from "./LabelCard"; import { Switch } from "@/components/ui/switch"; import { generateIdenticon } from "@/components/assistants/AssistantIcon"; import { BackButton } from "@/components/BackButton"; import { Checkbox } from "@/components/ui/checkbox"; -import { Input } from "@/components/ui/input"; import { AdvancedOptionsToggle } from "@/components/AdvancedOptionsToggle"; -import { AssistantVisibilityPopover } from "@/app/assistants/mine/AssistantVisibilityPopover"; import { MinimalUserSnapshot } from "@/lib/types"; import { useUserGroups } from "@/lib/hooks"; -import { useUsers } from "@/lib/hooks"; -import { AllUsersResponse } from "@/lib/types"; -// import { Badge } from "@/components/ui/Badge"; -// import { -// addUsersToAssistantSharedList, -// shareAssistantWithGroups, -// } from "@/lib/assistants/shareAssistant"; import { SearchMultiSelectDropdown, Option as DropdownOption, } from "@/components/Dropdown"; -import { Badge } from "@/components/ui/badge"; import { SourceChip } from "@/app/chat/input/ChatInputBar"; -import { GroupIcon, TagIcon, UserIcon } from "lucide-react"; +import { TagIcon, UserIcon } from "lucide-react"; import { LLMSelector } from "@/components/llm/LLMSelector"; import useSWR from "swr"; import { errorHandlingFetcher } from "@/lib/fetcher"; @@ -268,7 +254,6 @@ export function AssistantEditor({ labels: existingPersona?.labels ?? null, // EE Only - groups: existingPersona?.groups ?? [], label_ids: existingPersona?.labels?.map((label) => label.id) ?? [], selectedUsers: existingPersona?.users?.filter( @@ -418,7 +403,6 @@ export function AssistantEditor({ icon_shape: Yup.number(), uploaded_image: Yup.mixed().nullable(), // EE Only - groups: Yup.array().of(Yup.number()), label_ids: Yup.array().of(Yup.number()), selectedUsers: Yup.array().of(Yup.object()), selectedGroups: Yup.array().of(Yup.number()), @@ -494,12 +478,13 @@ export function AssistantEditor({ })); // don't set groups if marked as public - const groups = values.is_public ? [] : values.groups; - - const submissionData = { + const groups = values.is_public ? [] : values.selectedGroups; + const submissionData: PersonaUpsertParameters = { ...values, + existing_prompt_id: existingPrompt?.id ?? null, + is_default_persona: admin!, starter_messages: starterMessages, - groups: values.is_public ? [] : values.selectedGroups, + groups: groups, users: values.is_public ? undefined : [ @@ -514,25 +499,14 @@ export function AssistantEditor({ num_chunks: numChunks, }; - let promptResponse; let personaResponse; if (isUpdate) { - [promptResponse, personaResponse] = await updatePersona({ - id: existingPersona.id, - existingPromptId: existingPrompt?.id, - ...submissionData, - }); + personaResponse = await updatePersona(existingPersona.id, submissionData); } else { - [promptResponse, personaResponse] = await createPersona({ - ...submissionData, - is_default_persona: admin!, - }); + personaResponse = await createPersona(submissionData); } let error = null; - if (!promptResponse.ok) { - error = await promptResponse.text(); - } if (!personaResponse) { error = "Failed to create Assistant - no response received"; diff --git a/web/src/app/admin/assistants/lib.ts b/web/src/app/admin/assistants/lib.ts index f631665df8a..377f09657f9 100644 --- a/web/src/app/admin/assistants/lib.ts +++ b/web/src/app/admin/assistants/lib.ts @@ -1,15 +1,19 @@ import { FullLLMProvider } from "../configuration/llm/interfaces"; import { Persona, StarterMessage } from "./interfaces"; -interface PersonaCreationRequest { +interface PersonaUpsertRequest { name: string; description: string; + existing_prompt_id: number | null; system_prompt: string; task_prompt: string; document_set_ids: number[]; num_chunks: number | null; include_citations: boolean; is_public: boolean; + recency_bias: string; + prompt_ids: number[]; + llm_filter_extraction: boolean; llm_relevance_filter: boolean | null; llm_model_provider_override: string | null; llm_model_version_override: string | null; @@ -20,18 +24,18 @@ interface PersonaCreationRequest { icon_color: string | null; icon_shape: number | null; remove_image?: boolean; - uploaded_image: File | null; + uploaded_image_id: string | null; search_start_date: Date | null; is_default_persona: boolean; - label_ids?: number[]; + display_priority: number | null; + label_ids: number[] | null; } -interface PersonaUpdateRequest { - id: number; - existingPromptId: number | undefined; +export interface PersonaUpsertParameters { name: string; description: string; system_prompt: string; + existing_prompt_id: number | null; task_prompt: string; document_set_ids: number[]; num_chunks: number | null; @@ -46,68 +50,11 @@ interface PersonaUpdateRequest { tool_ids: number[]; icon_color: string | null; icon_shape: number | null; - remove_image: boolean; - uploaded_image: File | null; + remove_image?: boolean; search_start_date: Date | null; - label_ids?: number[]; -} - -function promptNameFromPersonaName(personaName: string) { - return `default-prompt__${personaName}`; -} - -function createPrompt({ - personaName, - systemPrompt, - taskPrompt, - includeCitations, -}: { - personaName: string; - systemPrompt: string; - taskPrompt: string; - includeCitations: boolean; -}) { - return fetch("/api/prompt", { - method: "POST", - headers: { - "Content-Type": "application/json", - }, - body: JSON.stringify({ - name: promptNameFromPersonaName(personaName), - description: `Default prompt for persona ${personaName}`, - system_prompt: systemPrompt, - task_prompt: taskPrompt, - include_citations: includeCitations, - }), - }); -} - -function updatePrompt({ - promptId, - personaName, - systemPrompt, - taskPrompt, - includeCitations, -}: { - promptId: number; - personaName: string; - systemPrompt: string; - taskPrompt: string; - includeCitations: boolean; -}) { - return fetch(`/api/prompt/${promptId}`, { - method: "PATCH", - headers: { - "Content-Type": "application/json", - }, - body: JSON.stringify({ - name: promptNameFromPersonaName(personaName), - description: `Default prompt for persona ${personaName}`, - system_prompt: systemPrompt, - task_prompt: taskPrompt, - include_citations: includeCitations, - }), - }); + uploaded_image: File | null; + is_default_persona: boolean; + label_ids: number[] | null; } export const createPersonaLabel = (name: string) => { @@ -144,17 +91,18 @@ export const updatePersonaLabel = ( }); }; -function buildPersonaAPIBody( - creationRequest: PersonaCreationRequest | PersonaUpdateRequest, - promptId: number, +function buildPersonaUpsertRequest( + creationRequest: PersonaUpsertParameters, uploaded_image_id: string | null -) { +): PersonaUpsertRequest { const { name, description, + system_prompt, + task_prompt, document_set_ids, num_chunks, - llm_relevance_filter, + include_citations, is_public, groups, users, @@ -163,37 +111,37 @@ function buildPersonaAPIBody( icon_shape, remove_image, search_start_date, - label_ids, } = creationRequest; - - const is_default_persona = - "is_default_persona" in creationRequest - ? creationRequest.is_default_persona - : false; - return { name, description, + system_prompt, + task_prompt, + document_set_ids, num_chunks, - llm_relevance_filter, - llm_filter_extraction: false, + include_citations, is_public, - recency_bias: "base_decay", - prompt_ids: [promptId], - document_set_ids, - llm_model_provider_override: creationRequest.llm_model_provider_override, - llm_model_version_override: creationRequest.llm_model_version_override, - starter_messages: creationRequest.starter_messages, - users, + uploaded_image_id, groups, + users, tool_ids, icon_color, icon_shape, - uploaded_image_id, remove_image, search_start_date, - is_default_persona, - label_ids, + is_default_persona: creationRequest.is_default_persona ?? false, + existing_prompt_id: null, + recency_bias: "base_decay", + prompt_ids: [], + llm_filter_extraction: false, + llm_relevance_filter: creationRequest.llm_relevance_filter ?? null, + llm_model_provider_override: + creationRequest.llm_model_provider_override ?? null, + llm_model_version_override: + creationRequest.llm_model_version_override ?? null, + starter_messages: creationRequest.starter_messages ?? null, + display_priority: null, + label_ids: creationRequest.label_ids ?? null, }; } @@ -215,92 +163,52 @@ export async function uploadFile(file: File): Promise { } export async function createPersona( - personaCreationRequest: PersonaCreationRequest -): Promise<[Response, Response | null]> { - // first create prompt - const createPromptResponse = await createPrompt({ - personaName: personaCreationRequest.name, - systemPrompt: personaCreationRequest.system_prompt, - taskPrompt: personaCreationRequest.task_prompt, - includeCitations: personaCreationRequest.include_citations, - }); - const promptId = createPromptResponse.ok - ? (await createPromptResponse.json()).id - : null; - + personaUpsertParams: PersonaUpsertParameters +): Promise { let fileId = null; - if (personaCreationRequest.uploaded_image) { - fileId = await uploadFile(personaCreationRequest.uploaded_image); + if (personaUpsertParams.uploaded_image) { + fileId = await uploadFile(personaUpsertParams.uploaded_image); if (!fileId) { - return [createPromptResponse, null]; + return null; } } - const createPersonaResponse = - promptId !== null - ? await fetch("/api/persona", { - method: "POST", - headers: { - "Content-Type": "application/json", - }, - body: JSON.stringify( - buildPersonaAPIBody(personaCreationRequest, promptId, fileId) - ), - }) - : null; + const createPersonaResponse = await fetch("/api/persona", { + method: "POST", + headers: { + "Content-Type": "application/json", + }, + body: JSON.stringify( + buildPersonaUpsertRequest(personaUpsertParams, fileId) + ), + }); - return [createPromptResponse, createPersonaResponse]; + return createPersonaResponse; } export async function updatePersona( - personaUpdateRequest: PersonaUpdateRequest -): Promise<[Response, Response | null]> { - const { id, existingPromptId } = personaUpdateRequest; - - let promptResponse; - let promptId: number | null = null; - if (existingPromptId !== undefined) { - promptResponse = await updatePrompt({ - promptId: existingPromptId, - personaName: personaUpdateRequest.name, - systemPrompt: personaUpdateRequest.system_prompt, - taskPrompt: personaUpdateRequest.task_prompt, - includeCitations: personaUpdateRequest.include_citations, - }); - promptId = existingPromptId; - } else { - promptResponse = await createPrompt({ - personaName: personaUpdateRequest.name, - systemPrompt: personaUpdateRequest.system_prompt, - taskPrompt: personaUpdateRequest.task_prompt, - includeCitations: personaUpdateRequest.include_citations, - }); - promptId = promptResponse.ok - ? ((await promptResponse.json()).id as number) - : null; - } + id: number, + personaUpsertParams: PersonaUpsertParameters +): Promise { let fileId = null; - if (personaUpdateRequest.uploaded_image) { - fileId = await uploadFile(personaUpdateRequest.uploaded_image); + if (personaUpsertParams.uploaded_image) { + fileId = await uploadFile(personaUpsertParams.uploaded_image); if (!fileId) { - return [promptResponse, null]; + return null; } } - const updatePersonaResponse = - promptResponse.ok && promptId !== null - ? await fetch(`/api/persona/${id}`, { - method: "PATCH", - headers: { - "Content-Type": "application/json", - }, - body: JSON.stringify( - buildPersonaAPIBody(personaUpdateRequest, promptId, fileId) - ), - }) - : null; + const updatePersonaResponse = await fetch(`/api/persona/${id}`, { + method: "PATCH", + headers: { + "Content-Type": "application/json", + }, + body: JSON.stringify( + buildPersonaUpsertRequest(personaUpsertParams, fileId) + ), + }); - return [promptResponse, updatePersonaResponse]; + return updatePersonaResponse; } export function deletePersona(personaId: number) { @@ -309,25 +217,6 @@ export function deletePersona(personaId: number) { }); } -export function buildFinalPrompt( - systemPrompt: string, - taskPrompt: string, - retrievalDisabled: boolean -) { - const queryString = Object.entries({ - system_prompt: systemPrompt, - task_prompt: taskPrompt, - retrieval_disabled: retrievalDisabled, - }) - .map( - ([key, value]) => - `${encodeURIComponent(key)}=${encodeURIComponent(value)}` - ) - .join("&"); - - return fetch(`/api/persona/utils/prompt-explorer?${queryString}`); -} - function smallerNumberFirstComparator(a: number, b: number) { return a > b ? 1 : -1; } From bed53ab3d3ec6670a7de6ebd9b058274836b2174 Mon Sep 17 00:00:00 2001 From: hagen-danswer Date: Thu, 16 Jan 2025 11:42:20 -0800 Subject: [PATCH 2/6] quality --- web/src/app/admin/assistants/AssistantEditor.tsx | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/web/src/app/admin/assistants/AssistantEditor.tsx b/web/src/app/admin/assistants/AssistantEditor.tsx index ea22ad4f23a..268008d27ab 100644 --- a/web/src/app/admin/assistants/AssistantEditor.tsx +++ b/web/src/app/admin/assistants/AssistantEditor.tsx @@ -39,11 +39,7 @@ import { FiInfo, FiRefreshCcw, FiUsers } from "react-icons/fi"; import * as Yup from "yup"; import CollapsibleSection from "./CollapsibleSection"; import { SuccessfulPersonaUpdateRedirectType } from "./enums"; -import { - Persona, - PersonaLabel, - StarterMessage, -} from "./interfaces"; +import { Persona, PersonaLabel, StarterMessage } from "./interfaces"; import { createPersonaLabel, PersonaUpsertParameters, @@ -501,7 +497,10 @@ export function AssistantEditor({ let personaResponse; if (isUpdate) { - personaResponse = await updatePersona(existingPersona.id, submissionData); + personaResponse = await updatePersona( + existingPersona.id, + submissionData + ); } else { personaResponse = await createPersona(submissionData); } From 87fd7b0c3ee9050ac5b9757ba1c41b58ee73e845 Mon Sep 17 00:00:00 2001 From: hagen-danswer Date: Thu, 16 Jan 2025 13:18:41 -0800 Subject: [PATCH 3/6] added tests --- .../common_utils/managers/persona.py | 118 ++++++++---- .../integration/common_utils/test_models.py | 2 +- .../permissions/test_persona_permissions.py | 175 ++++++++++++++++++ 3 files changed, 255 insertions(+), 40 deletions(-) create mode 100644 backend/tests/integration/tests/permissions/test_persona_permissions.py diff --git a/backend/tests/integration/common_utils/managers/persona.py b/backend/tests/integration/common_utils/managers/persona.py index e2e22981afa..de2164e1e60 100644 --- a/backend/tests/integration/common_utils/managers/persona.py +++ b/backend/tests/integration/common_utils/managers/persona.py @@ -1,9 +1,11 @@ +from uuid import UUID from uuid import uuid4 import requests from onyx.context.search.enums import RecencyBiasSetting from onyx.server.features.persona.models import PersonaSnapshot +from onyx.server.features.persona.models import PersonaUpsertRequest from tests.integration.common_utils.constants import API_SERVER_URL from tests.integration.common_utils.constants import GENERAL_HEADERS from tests.integration.common_utils.test_models import DATestPersona @@ -16,6 +18,9 @@ class PersonaManager: def create( name: str | None = None, description: str | None = None, + system_prompt: str | None = None, + task_prompt: str | None = None, + include_citations: bool = False, num_chunks: float = 5, llm_relevance_filter: bool = True, is_public: bool = True, @@ -28,32 +33,38 @@ def create( llm_model_version_override: str | None = None, users: list[str] | None = None, groups: list[int] | None = None, - category_id: int | None = None, + label_ids: list[int] | None = None, user_performing_action: DATestUser | None = None, ) -> DATestPersona: name = name or f"test-persona-{uuid4()}" description = description or f"Description for {name}" + system_prompt = system_prompt or f"System prompt for {name}" + task_prompt = task_prompt or f"Task prompt for {name}" - persona_creation_request = { - "name": name, - "description": description, - "num_chunks": num_chunks, - "llm_relevance_filter": llm_relevance_filter, - "is_public": is_public, - "llm_filter_extraction": llm_filter_extraction, - "recency_bias": recency_bias, - "prompt_ids": prompt_ids or [0], - "document_set_ids": document_set_ids or [], - "tool_ids": tool_ids or [], - "llm_model_provider_override": llm_model_provider_override, - "llm_model_version_override": llm_model_version_override, - "users": users or [], - "groups": groups or [], - } + persona_creation_request = PersonaUpsertRequest( + name=name, + description=description, + system_prompt=system_prompt, + task_prompt=task_prompt, + include_citations=include_citations, + num_chunks=num_chunks, + llm_relevance_filter=llm_relevance_filter, + is_public=is_public, + llm_filter_extraction=llm_filter_extraction, + recency_bias=recency_bias, + prompt_ids=prompt_ids or [0], + document_set_ids=document_set_ids or [], + tool_ids=tool_ids or [], + llm_model_provider_override=llm_model_provider_override, + llm_model_version_override=llm_model_version_override, + users=[UUID(user) for user in (users or [])], + groups=groups or [], + label_ids=label_ids or [], + ) response = requests.post( f"{API_SERVER_URL}/persona", - json=persona_creation_request, + json=persona_creation_request.model_dump(), headers=user_performing_action.headers if user_performing_action else GENERAL_HEADERS, @@ -77,6 +88,7 @@ def create( llm_model_version_override=llm_model_version_override, users=users or [], groups=groups or [], + label_ids=label_ids or [], ) @staticmethod @@ -84,6 +96,9 @@ def edit( persona: DATestPersona, name: str | None = None, description: str | None = None, + system_prompt: str | None = None, + task_prompt: str | None = None, + include_citations: bool = False, num_chunks: float | None = None, llm_relevance_filter: bool | None = None, is_public: bool | None = None, @@ -96,32 +111,38 @@ def edit( llm_model_version_override: str | None = None, users: list[str] | None = None, groups: list[int] | None = None, + label_ids: list[int] | None = None, user_performing_action: DATestUser | None = None, ) -> DATestPersona: - persona_update_request = { - "name": name or persona.name, - "description": description or persona.description, - "num_chunks": num_chunks or persona.num_chunks, - "llm_relevance_filter": llm_relevance_filter - or persona.llm_relevance_filter, - "is_public": is_public or persona.is_public, - "llm_filter_extraction": llm_filter_extraction + system_prompt = system_prompt or f"System prompt for {persona.name}" + task_prompt = task_prompt or f"Task prompt for {persona.name}" + persona_update_request = PersonaUpsertRequest( + name=name or persona.name, + description=description or persona.description, + system_prompt=system_prompt, + task_prompt=task_prompt, + include_citations=include_citations, + num_chunks=num_chunks or persona.num_chunks, + llm_relevance_filter=llm_relevance_filter or persona.llm_relevance_filter, + is_public=is_public or persona.is_public, + llm_filter_extraction=llm_filter_extraction or persona.llm_filter_extraction, - "recency_bias": recency_bias or persona.recency_bias, - "prompt_ids": prompt_ids or persona.prompt_ids, - "document_set_ids": document_set_ids or persona.document_set_ids, - "tool_ids": tool_ids or persona.tool_ids, - "llm_model_provider_override": llm_model_provider_override + recency_bias=recency_bias or persona.recency_bias, + prompt_ids=prompt_ids or persona.prompt_ids, + document_set_ids=document_set_ids or persona.document_set_ids, + tool_ids=tool_ids or persona.tool_ids, + llm_model_provider_override=llm_model_provider_override or persona.llm_model_provider_override, - "llm_model_version_override": llm_model_version_override + llm_model_version_override=llm_model_version_override or persona.llm_model_version_override, - "users": users or persona.users, - "groups": groups or persona.groups, - } + users=[UUID(user) for user in (users or persona.users)], + groups=groups or persona.groups, + label_ids=label_ids or persona.label_ids, + ) response = requests.patch( f"{API_SERVER_URL}/persona/{persona.id}", - json=persona_update_request, + json=persona_update_request.model_dump(), headers=user_performing_action.headers if user_performing_action else GENERAL_HEADERS, @@ -137,8 +158,8 @@ def edit( llm_relevance_filter=updated_persona_data["llm_relevance_filter"], is_public=updated_persona_data["is_public"], llm_filter_extraction=updated_persona_data["llm_filter_extraction"], - recency_bias=updated_persona_data["recency_bias"], - prompt_ids=updated_persona_data["prompts"], + recency_bias=recency_bias or persona.recency_bias, + prompt_ids=[prompt["id"] for prompt in updated_persona_data["prompts"]], document_set_ids=updated_persona_data["document_sets"], tool_ids=updated_persona_data["tools"], llm_model_provider_override=updated_persona_data[ @@ -149,6 +170,7 @@ def edit( ], users=[user["email"] for user in updated_persona_data["users"]], groups=updated_persona_data["groups"], + label_ids=updated_persona_data["labels"], ) @staticmethod @@ -164,12 +186,29 @@ def get_all( response.raise_for_status() return [PersonaSnapshot(**persona) for persona in response.json()] + @staticmethod + def get_one( + persona_id: int, + user_performing_action: DATestUser | None = None, + ) -> list[PersonaSnapshot]: + response = requests.get( + f"{API_SERVER_URL}/persona/{persona_id}", + headers=user_performing_action.headers + if user_performing_action + else GENERAL_HEADERS, + ) + response.raise_for_status() + return [PersonaSnapshot(**response.json())] + @staticmethod def verify( persona: DATestPersona, user_performing_action: DATestUser | None = None, ) -> bool: - all_personas = PersonaManager.get_all(user_performing_action) + all_personas = PersonaManager.get_one( + persona_id=persona.id, + user_performing_action=user_performing_action, + ) for fetched_persona in all_personas: if fetched_persona.id == persona.id: return ( @@ -199,6 +238,7 @@ def verify( and set(user.email for user in fetched_persona.users) == set(persona.users) and set(fetched_persona.groups) == set(persona.groups) + and set(fetched_persona.labels) == set(persona.label_ids) ) return False diff --git a/backend/tests/integration/common_utils/test_models.py b/backend/tests/integration/common_utils/test_models.py index 659dfffdda1..be8e7ec085d 100644 --- a/backend/tests/integration/common_utils/test_models.py +++ b/backend/tests/integration/common_utils/test_models.py @@ -127,7 +127,7 @@ class DATestPersona(BaseModel): llm_model_version_override: str | None users: list[str] groups: list[int] - category_id: int | None = None + label_ids: list[int] # diff --git a/backend/tests/integration/tests/permissions/test_persona_permissions.py b/backend/tests/integration/tests/permissions/test_persona_permissions.py new file mode 100644 index 00000000000..5f63e5fc21b --- /dev/null +++ b/backend/tests/integration/tests/permissions/test_persona_permissions.py @@ -0,0 +1,175 @@ +""" +This file tests the permissions for creating and editing personas for different user roles: +- Basic users can create personas and edit their own +- Curators can edit personas that belong exclusively to groups they curate +- Admins can edit all personas +""" +import pytest +from requests.exceptions import HTTPError + +from tests.integration.common_utils.managers.persona import PersonaManager +from tests.integration.common_utils.managers.user import DATestUser +from tests.integration.common_utils.managers.user import UserManager +from tests.integration.common_utils.managers.user_group import UserGroupManager + + +def test_persona_permissions(reset: None) -> None: + # Creating an admin user (first user created is automatically an admin) + admin_user: DATestUser = UserManager.create(name="admin_user") + + # Creating a curator user + curator: DATestUser = UserManager.create(name="curator") + + # Creating a basic user + basic_user: DATestUser = UserManager.create(name="basic_user") + + # Creating user groups + user_group_1 = UserGroupManager.create( + name="curated_user_group", + user_ids=[curator.id], + cc_pair_ids=[], + user_performing_action=admin_user, + ) + UserGroupManager.wait_for_sync( + user_groups_to_check=[user_group_1], user_performing_action=admin_user + ) + # Setting the user as a curator for the user group + UserGroupManager.set_curator_status( + test_user_group=user_group_1, + user_to_set_as_curator=curator, + user_performing_action=admin_user, + ) + + # Creating another user group that the user is not a curator of + user_group_2 = UserGroupManager.create( + name="uncurated_user_group", + user_ids=[curator.id], + cc_pair_ids=[], + user_performing_action=admin_user, + ) + UserGroupManager.wait_for_sync( + user_groups_to_check=[user_group_2], user_performing_action=admin_user + ) + + """Test that any user can create a persona""" + # Basic user creates a persona + basic_user_persona = PersonaManager.create( + name="basic_user_persona", + description="A persona created by basic user", + is_public=False, + groups=[], + user_performing_action=basic_user, + ) + PersonaManager.verify(basic_user_persona, user_performing_action=basic_user) + + # Curator creates a persona + curator_persona = PersonaManager.create( + name="curator_persona", + description="A persona created by curator", + is_public=False, + groups=[], + user_performing_action=curator, + ) + PersonaManager.verify(curator_persona, user_performing_action=curator) + + # Admin creates personas for different groups + admin_persona_group_1 = PersonaManager.create( + name="admin_persona_group_1", + description="A persona for group 1", + is_public=False, + groups=[user_group_1.id], + user_performing_action=admin_user, + ) + admin_persona_group_2 = PersonaManager.create( + name="admin_persona_group_2", + description="A persona for group 2", + is_public=False, + groups=[user_group_2.id], + user_performing_action=admin_user, + ) + admin_persona_both_groups = PersonaManager.create( + name="admin_persona_both_groups", + description="A persona for both groups", + is_public=False, + groups=[user_group_1.id, user_group_2.id], + user_performing_action=admin_user, + ) + + """Test that users can edit their own personas""" + # Basic user can edit their own persona + PersonaManager.edit( + persona=basic_user_persona, + description="Updated description by basic user", + user_performing_action=basic_user, + ) + PersonaManager.verify(basic_user_persona, user_performing_action=basic_user) + + # Basic user cannot edit other's personas + with pytest.raises(HTTPError): + PersonaManager.edit( + persona=curator_persona, + description="Invalid edit by basic user", + user_performing_action=basic_user, + ) + + """Test curator permissions""" + # Curator can edit personas that belong exclusively to groups they curate + PersonaManager.edit( + persona=admin_persona_group_1, + description="Updated by curator", + user_performing_action=curator, + ) + PersonaManager.verify(admin_persona_group_1, user_performing_action=curator) + + # Curator cannot edit personas in groups they don't curate + with pytest.raises(HTTPError): + PersonaManager.edit( + persona=admin_persona_group_2, + description="Invalid edit by curator", + user_performing_action=curator, + ) + + # Curator cannot edit personas that belong to multiple groups, even if they curate one + with pytest.raises(HTTPError): + PersonaManager.edit( + persona=admin_persona_both_groups, + description="Invalid edit by curator", + user_performing_action=curator, + ) + + """Test admin permissions""" + # Admin can edit any persona + PersonaManager.edit( + persona=basic_user_persona, + description="Updated by admin", + user_performing_action=admin_user, + ) + PersonaManager.verify(basic_user_persona, user_performing_action=admin_user) + + PersonaManager.edit( + persona=curator_persona, + description="Updated by admin", + user_performing_action=admin_user, + ) + PersonaManager.verify(curator_persona, user_performing_action=admin_user) + + PersonaManager.edit( + persona=admin_persona_group_1, + description="Updated by admin", + user_performing_action=admin_user, + ) + PersonaManager.verify(admin_persona_group_1, user_performing_action=admin_user) + + PersonaManager.edit( + persona=admin_persona_group_2, + description="Updated by admin", + user_performing_action=admin_user, + ) + PersonaManager.verify(admin_persona_group_2, user_performing_action=admin_user) + + PersonaManager.edit( + persona=admin_persona_both_groups, + description="Updated by admin", + user_performing_action=admin_user, + ) + PersonaManager.verify(admin_persona_both_groups, user_performing_action=admin_user) From 5ad1cdcee3ccd979371777ced5e9e7f44de559e3 Mon Sep 17 00:00:00 2001 From: hagen-danswer Date: Fri, 17 Jan 2025 07:55:01 -0800 Subject: [PATCH 4/6] consolidated models and got rid of redundant fields --- backend/onyx/db/persona.py | 6 ++-- backend/onyx/server/features/persona/api.py | 22 ++++++++++--- .../onyx/server/features/persona/models.py | 31 +++++++++++++++++-- .../onyx/server/features/prompt/__init__.py | 0 backend/onyx/server/features/prompt/models.py | 31 ------------------- web/src/app/admin/assistants/lib.ts | 5 ++- 6 files changed, 51 insertions(+), 44 deletions(-) delete mode 100644 backend/onyx/server/features/prompt/__init__.py delete mode 100644 backend/onyx/server/features/prompt/models.py diff --git a/backend/onyx/db/persona.py b/backend/onyx/db/persona.py index 97d2410598d..ffa9b831920 100644 --- a/backend/onyx/db/persona.py +++ b/backend/onyx/db/persona.py @@ -186,11 +186,11 @@ def create_update_persona( # Permission to actually use these is checked later try: - all_prompt_ids = create_persona_request.prompt_ids or [] - if create_persona_request.existing_prompt_id: - all_prompt_ids.append(create_persona_request.existing_prompt_id) + all_prompt_ids = create_persona_request.prompt_ids + if not all_prompt_ids: raise ValueError("No prompt IDs provided") + persona = upsert_persona( persona_id=persona_id, user=user, diff --git a/backend/onyx/server/features/persona/api.py b/backend/onyx/server/features/persona/api.py index a95736cf5b4..feb5a63ad4b 100644 --- a/backend/onyx/server/features/persona/api.py +++ b/backend/onyx/server/features/persona/api.py @@ -49,7 +49,7 @@ from onyx.server.features.persona.models import PersonaSharedNotificationData from onyx.server.features.persona.models import PersonaSnapshot from onyx.server.features.persona.models import PersonaUpsertRequest -from onyx.server.features.prompt.models import PromptSnapshot +from onyx.server.features.persona.models import PromptSnapshot from onyx.server.models import DisplayPriorityRequest from onyx.tools.utils import is_image_generation_available from onyx.utils.logger import setup_logger @@ -179,6 +179,12 @@ def create_persona( db_session: Session = Depends(get_session), tenant_id: str | None = Depends(get_current_tenant_id), ) -> PersonaSnapshot: + prompt_id = ( + persona_upsert_request.prompt_ids[0] + if persona_upsert_request.prompt_ids + and len(persona_upsert_request.prompt_ids) > 0 + else None + ) prompt = upsert_prompt( db_session=db_session, user=user, @@ -186,10 +192,10 @@ def create_persona( system_prompt=persona_upsert_request.system_prompt, task_prompt=persona_upsert_request.task_prompt, include_citations=persona_upsert_request.include_citations, - prompt_id=persona_upsert_request.existing_prompt_id, + prompt_id=prompt_id, ) prompt_snapshot = PromptSnapshot.from_model(prompt) - persona_upsert_request.existing_prompt_id = prompt.id + persona_upsert_request.prompt_ids = [prompt.id] persona_snapshot = create_update_persona( persona_id=None, create_persona_request=persona_upsert_request, @@ -218,6 +224,12 @@ def update_persona( user: User | None = Depends(current_user), db_session: Session = Depends(get_session), ) -> PersonaSnapshot: + prompt_id = ( + persona_upsert_request.prompt_ids[0] + if persona_upsert_request.prompt_ids + and len(persona_upsert_request.prompt_ids) > 0 + else None + ) prompt = upsert_prompt( db_session=db_session, user=user, @@ -225,10 +237,10 @@ def update_persona( system_prompt=persona_upsert_request.system_prompt, task_prompt=persona_upsert_request.task_prompt, include_citations=persona_upsert_request.include_citations, - prompt_id=persona_upsert_request.existing_prompt_id, + prompt_id=prompt_id, ) prompt_snapshot = PromptSnapshot.from_model(prompt) - persona_upsert_request.existing_prompt_id = prompt.id + persona_upsert_request.prompt_ids = [prompt.id] persona_snapshot = create_update_persona( persona_id=persona_id, create_persona_request=persona_upsert_request, diff --git a/backend/onyx/server/features/persona/models.py b/backend/onyx/server/features/persona/models.py index 5431f8688f1..74f7f09a7e1 100644 --- a/backend/onyx/server/features/persona/models.py +++ b/backend/onyx/server/features/persona/models.py @@ -7,9 +7,9 @@ from onyx.context.search.enums import RecencyBiasSetting from onyx.db.models import Persona from onyx.db.models import PersonaLabel +from onyx.db.models import Prompt from onyx.db.models import StarterMessage from onyx.server.features.document_set.models import DocumentSet -from onyx.server.features.prompt.models import PromptSnapshot from onyx.server.features.tool.models import ToolSnapshot from onyx.server.models import MinimalUserSnapshot from onyx.utils.logger import setup_logger @@ -18,6 +18,34 @@ logger = setup_logger() +class PromptSnapshot(BaseModel): + id: int + name: str + description: str + system_prompt: str + task_prompt: str + include_citations: bool + datetime_aware: bool + default_prompt: bool + # Not including persona info, not needed + + @classmethod + def from_model(cls, prompt: Prompt) -> "PromptSnapshot": + if prompt.deleted: + raise ValueError("Prompt has been deleted") + + return PromptSnapshot( + id=prompt.id, + name=prompt.name, + description=prompt.description, + system_prompt=prompt.system_prompt, + task_prompt=prompt.task_prompt, + include_citations=prompt.include_citations, + datetime_aware=prompt.datetime_aware, + default_prompt=prompt.default_prompt, + ) + + # More minimal request for generating a persona prompt class GenerateStarterMessageRequest(BaseModel): name: str @@ -30,7 +58,6 @@ class GenerateStarterMessageRequest(BaseModel): class PersonaUpsertRequest(BaseModel): name: str description: str - existing_prompt_id: int | None = None system_prompt: str task_prompt: str document_set_ids: list[int] diff --git a/backend/onyx/server/features/prompt/__init__.py b/backend/onyx/server/features/prompt/__init__.py deleted file mode 100644 index e69de29bb2d..00000000000 diff --git a/backend/onyx/server/features/prompt/models.py b/backend/onyx/server/features/prompt/models.py deleted file mode 100644 index fb7a160166a..00000000000 --- a/backend/onyx/server/features/prompt/models.py +++ /dev/null @@ -1,31 +0,0 @@ -from pydantic import BaseModel - -from onyx.db.models import Prompt - - -class PromptSnapshot(BaseModel): - id: int - name: str - description: str - system_prompt: str - task_prompt: str - include_citations: bool - datetime_aware: bool - default_prompt: bool - # Not including persona info, not needed - - @classmethod - def from_model(cls, prompt: Prompt) -> "PromptSnapshot": - if prompt.deleted: - raise ValueError("Prompt has been deleted") - - return PromptSnapshot( - id=prompt.id, - name=prompt.name, - description=prompt.description, - system_prompt=prompt.system_prompt, - task_prompt=prompt.task_prompt, - include_citations=prompt.include_citations, - datetime_aware=prompt.datetime_aware, - default_prompt=prompt.default_prompt, - ) diff --git a/web/src/app/admin/assistants/lib.ts b/web/src/app/admin/assistants/lib.ts index 377f09657f9..052b33b266f 100644 --- a/web/src/app/admin/assistants/lib.ts +++ b/web/src/app/admin/assistants/lib.ts @@ -4,7 +4,6 @@ import { Persona, StarterMessage } from "./interfaces"; interface PersonaUpsertRequest { name: string; description: string; - existing_prompt_id: number | null; system_prompt: string; task_prompt: string; document_set_ids: number[]; @@ -105,6 +104,7 @@ function buildPersonaUpsertRequest( include_citations, is_public, groups, + existing_prompt_id, users, tool_ids, icon_color, @@ -130,9 +130,8 @@ function buildPersonaUpsertRequest( remove_image, search_start_date, is_default_persona: creationRequest.is_default_persona ?? false, - existing_prompt_id: null, recency_bias: "base_decay", - prompt_ids: [], + prompt_ids: existing_prompt_id ? [existing_prompt_id] : [], llm_filter_extraction: false, llm_relevance_filter: creationRequest.llm_relevance_filter ?? null, llm_model_provider_override: From 6f37e79a9d86c199271607ed8e6371c62661c218 Mon Sep 17 00:00:00 2001 From: hagen-danswer Date: Fri, 17 Jan 2025 10:21:33 -0800 Subject: [PATCH 5/6] tenant appreciation day --- .../onyx/chat/prompt_builder/citations_prompt.py | 6 ++++-- backend/onyx/db/prompts.py | 13 ------------- backend/onyx/server/query_and_chat/chat_backend.py | 5 ++++- 3 files changed, 8 insertions(+), 16 deletions(-) diff --git a/backend/onyx/chat/prompt_builder/citations_prompt.py b/backend/onyx/chat/prompt_builder/citations_prompt.py index 52043abdf14..51138f0d2a5 100644 --- a/backend/onyx/chat/prompt_builder/citations_prompt.py +++ b/backend/onyx/chat/prompt_builder/citations_prompt.py @@ -1,12 +1,13 @@ from langchain.schema.messages import HumanMessage from langchain.schema.messages import SystemMessage +from sqlalchemy.orm import Session from onyx.chat.models import LlmDoc from onyx.chat.models import PromptConfig from onyx.configs.model_configs import GEN_AI_SINGLE_USER_MESSAGE_EXPECTED_MAX_TOKENS from onyx.context.search.models import InferenceChunk from onyx.db.models import Persona -from onyx.db.prompts import get_default_prompt__read_only +from onyx.db.prompts import get_default_prompt from onyx.db.search_settings import get_multilingual_expansion from onyx.llm.factory import get_llms_for_persona from onyx.llm.factory import get_main_llm_from_tuple @@ -97,11 +98,12 @@ def compute_max_document_tokens( def compute_max_document_tokens_for_persona( + db_session: Session, persona: Persona, actual_user_input: str | None = None, max_llm_token_override: int | None = None, ) -> int: - prompt = persona.prompts[0] if persona.prompts else get_default_prompt__read_only() + prompt = persona.prompts[0] if persona.prompts else get_default_prompt(db_session) return compute_max_document_tokens( prompt_config=PromptConfig.from_model(prompt), llm_config=get_main_llm_from_tuple(get_llms_for_persona(persona)).config, diff --git a/backend/onyx/db/prompts.py b/backend/onyx/db/prompts.py index 0dacad9879c..8f677222baf 100644 --- a/backend/onyx/db/prompts.py +++ b/backend/onyx/db/prompts.py @@ -1,10 +1,7 @@ -from functools import lru_cache - from sqlalchemy import select from sqlalchemy.orm import Session from onyx.auth.schemas import UserRole -from onyx.db.engine import get_sqlalchemy_engine from onyx.db.models import Persona from onyx.db.models import Prompt from onyx.db.models import User @@ -32,16 +29,6 @@ def get_default_prompt(db_session: Session) -> Prompt: return _get_default_prompt(db_session) -@lru_cache() -def get_default_prompt__read_only() -> Prompt: - """Due to the way lru_cache / SQLAlchemy works, this can cause issues - when trying to attach the returned `Prompt` object to a `Persona`. If you are - doing anything other than reading, you should use the `get_default_prompt` - method instead.""" - with Session(get_sqlalchemy_engine()) as db_session: - return _get_default_prompt(db_session) - - def get_prompts_by_ids(prompt_ids: list[int], db_session: Session) -> list[Prompt]: """Unsafe, can fetch prompts from all users""" if not prompt_ids: diff --git a/backend/onyx/server/query_and_chat/chat_backend.py b/backend/onyx/server/query_and_chat/chat_backend.py index 0613556576a..669a416322f 100644 --- a/backend/onyx/server/query_and_chat/chat_backend.py +++ b/backend/onyx/server/query_and_chat/chat_backend.py @@ -479,7 +479,10 @@ def get_max_document_tokens( raise HTTPException(status_code=404, detail="Persona not found") return MaxSelectedDocumentTokens( - max_tokens=compute_max_document_tokens_for_persona(persona), + max_tokens=compute_max_document_tokens_for_persona( + db_session=db_session, + persona=persona, + ), ) From c012e94837e640996714a3beaf82dbc6cd35af1b Mon Sep 17 00:00:00 2001 From: hagen-danswer Date: Fri, 17 Jan 2025 11:46:55 -0800 Subject: [PATCH 6/6] reverted default --- backend/onyx/db/prompts.py | 2 +- backend/onyx/server/features/persona/api.py | 4 ++++ 2 files changed, 5 insertions(+), 1 deletion(-) diff --git a/backend/onyx/db/prompts.py b/backend/onyx/db/prompts.py index 8f677222baf..df0d9feac5c 100644 --- a/backend/onyx/db/prompts.py +++ b/backend/onyx/db/prompts.py @@ -66,9 +66,9 @@ def upsert_prompt( name: str, system_prompt: str, task_prompt: str, + datetime_aware: bool, prompt_id: int | None = None, personas: list[Persona] | None = None, - datetime_aware: bool = True, include_citations: bool = False, default_prompt: bool = True, # Support backwards compatibility diff --git a/backend/onyx/server/features/persona/api.py b/backend/onyx/server/features/persona/api.py index feb5a63ad4b..684311ea7a8 100644 --- a/backend/onyx/server/features/persona/api.py +++ b/backend/onyx/server/features/persona/api.py @@ -191,6 +191,8 @@ def create_persona( name=build_prompt_name_from_persona_name(persona_upsert_request.name), system_prompt=persona_upsert_request.system_prompt, task_prompt=persona_upsert_request.task_prompt, + # TODO: The PersonaUpsertRequest should provide the value for datetime_aware + datetime_aware=False, include_citations=persona_upsert_request.include_citations, prompt_id=prompt_id, ) @@ -234,6 +236,8 @@ def update_persona( db_session=db_session, user=user, name=build_prompt_name_from_persona_name(persona_upsert_request.name), + # TODO: The PersonaUpsertRequest should provide the value for datetime_aware + datetime_aware=False, system_prompt=persona_upsert_request.system_prompt, task_prompt=persona_upsert_request.task_prompt, include_citations=persona_upsert_request.include_citations,