diff --git a/apps/ai/server/modules/golden_sql/controller.py b/apps/ai/server/modules/golden_sql/controller.py index a990367c..223664a2 100644 --- a/apps/ai/server/modules/golden_sql/controller.py +++ b/apps/ai/server/modules/golden_sql/controller.py @@ -3,7 +3,6 @@ from fastapi import APIRouter, Depends, status from fastapi.security import HTTPBearer -from modules.golden_sql.models.entities import GoldenSQLSource from modules.golden_sql.models.requests import GoldenSQLRequest from modules.golden_sql.models.responses import GoldenSQLResponse from modules.golden_sql.service import GoldenSQLService @@ -43,13 +42,13 @@ async def get_golden_sql( return golden_sql_service.get_golden_sql(id) -@router.post("", status_code=status.HTTP_201_CREATED) -async def add_golden_sql( +@router.post("/user-upload", status_code=status.HTTP_201_CREATED) +async def add_user_upload_golden_sql( golden_sql_requests: List[GoldenSQLRequest], token: str = Depends(token_auth_scheme) -) -> GoldenSQLResponse: +) -> List[GoldenSQLResponse]: org_id = authorize.user_and_get_org_id(VerifyToken(token.credentials).verify()) - return await golden_sql_service.add_golden_sql( - golden_sql_requests, org_id, GoldenSQLSource.user_upload + return await golden_sql_service.add_user_upload_golden_sql( + golden_sql_requests, org_id ) diff --git a/apps/ai/server/modules/golden_sql/models/entities.py b/apps/ai/server/modules/golden_sql/models/entities.py index 47816936..f2ab2969 100644 --- a/apps/ai/server/modules/golden_sql/models/entities.py +++ b/apps/ai/server/modules/golden_sql/models/entities.py @@ -10,8 +10,8 @@ class BaseGoldenSQL(BaseModel): class GoldenSQLSource(Enum): - user_upload = "USER_UPLOAD" - verified_query = "VERIFIED_QUERY" + USER_UPLOAD = "USER_UPLOAD" + VERIFIED_QUERY = "VERIFIED_QUERY" class GoldenSQLRef(BaseModel): diff --git a/apps/ai/server/modules/golden_sql/service.py b/apps/ai/server/modules/golden_sql/service.py index be43797e..d9120bff 100644 --- a/apps/ai/server/modules/golden_sql/service.py +++ b/apps/ai/server/modules/golden_sql/service.py @@ -53,43 +53,32 @@ def get_golden_sqls( def get_verified_golden_sql_ref(self, query_response_id: str) -> GoldenSQLRef: return self.repo.get_verified_golden_sql_ref(query_response_id) - async def add_golden_sql( + async def add_verified_query_golden_sql( self, - golden_sql_requests: List[GoldenSQLRequest], + golden_sql_request: GoldenSQLRequest, org_id: str, - source: GoldenSQLSource, - query_response_id: str = None, + query_response_id: str, ) -> GoldenSQLResponse: + golden_sql_ref = self.repo.get_verified_golden_sql_ref(query_response_id) + # if already exist, delete golden_sql_ref and call delete /golden-records + if golden_sql_ref: + await self.delete_golden_sql("", query_response_id) async with httpx.AsyncClient() as client: - if query_response_id: - golden_sql_ref = self.repo.get_verified_golden_sql_ref( - query_response_id - ) - # if already exist, delete golden_sql_ref and call delete /golden-records - if golden_sql_ref: - await self.delete_golden_sql("", query_response_id) - - # add golden_sql using core response = await client.post( settings.k2_core_url + "/golden-records", - # core should have consistent request body - json=[ - golden_sql_request.dict() - for golden_sql_request in golden_sql_requests - ], + json=[golden_sql_request.dict()], timeout=settings.default_k2_core_timeout, ) raise_for_status(response.status_code, response.text) response_json = response.json()[0] - golden_sql = GoldenSQL(**response_json) - golden_sql.id = ObjectId(response_json["id"]) + golden_sql = GoldenSQL(_id=ObjectId(response_json["id"]), **response_json) display_id = self.repo.get_next_display_id(org_id) golden_sql_ref_data = GoldenSQLRef( golden_sql_id=golden_sql.id, organization_id=ObjectId(org_id), - source=source.value, + source=GoldenSQLSource.VERIFIED_QUERY.value, query_response_id=ObjectId(query_response_id), created_time=datetime.now(timezone.utc).strftime("%Y-%m-%d %H:%M:%S"), display_id=display_id, @@ -100,6 +89,50 @@ async def add_golden_sql( golden_sql_ref = self.repo.get_golden_sql_ref(str(golden_sql.id)) return self._get_mapped_golden_sql_response(golden_sql, golden_sql_ref) + async def add_user_upload_golden_sql( + self, golden_sql_requests: List[GoldenSQLRequest], org_id: str + ) -> List[GoldenSQLResponse]: + async with httpx.AsyncClient() as client: + response = await client.post( + settings.k2_core_url + "/golden-records", + json=[ + golden_sql_request.dict() + for golden_sql_request in golden_sql_requests + ], + timeout=settings.default_k2_core_timeout, + ) + raise_for_status(response.status_code, response.text) + + response_jsons = response.json() + golden_sqls = [ + GoldenSQL(_id=ObjectId(response_json["id"]), **response_json) + for response_json in response_jsons + ] + + golden_sql_responses = [] + + for golden_sql in golden_sqls: + display_id = self.repo.get_next_display_id(org_id) + + golden_sql_ref_data = GoldenSQLRef( + golden_sql_id=golden_sql.id, + organization_id=ObjectId(org_id), + source=GoldenSQLSource.USER_UPLOAD.value, + created_time=datetime.now(timezone.utc).strftime( + "%Y-%m-%d %H:%M:%S" + ), + display_id=display_id, + ) + + # add golden_sql_ref + self.repo.add_golden_sql_ref(golden_sql_ref_data.dict(exclude={"id"})) + golden_sql_ref = self.repo.get_golden_sql_ref(str(golden_sql.id)) + golden_sql_responses.append( + self._get_mapped_golden_sql_response(golden_sql, golden_sql_ref) + ) + + return golden_sql_responses + async def delete_golden_sql( self, golden_id: str, query_response_id: str = None ) -> dict: diff --git a/apps/ai/server/modules/query/service.py b/apps/ai/server/modules/query/service.py index c54787cc..670df9a7 100644 --- a/apps/ai/server/modules/query/service.py +++ b/apps/ai/server/modules/query/service.py @@ -5,7 +5,6 @@ from bson import ObjectId from config import settings -from modules.golden_sql.models.entities import GoldenSQLSource from modules.golden_sql.models.requests import GoldenSQLRequest from modules.golden_sql.service import GoldenSQLService from modules.organization.models.responses import OrganizationResponse @@ -198,11 +197,10 @@ async def patch_query( sql_query=query_request.sql_query, db_connection_id=organization.db_connection_id, ) - await self.golden_sql_service.add_golden_sql( + await self.golden_sql_service.add_verified_query_golden_sql( golden_sql, organization.id, - source=GoldenSQLSource.verified_query, - query_response_id=query_id, + query_id, ) else: golden_sql_ref = self.golden_sql_service.get_verified_golden_sql_ref( diff --git a/apps/ai/server/tests/golden_sql/test_golden_sql_api.py b/apps/ai/server/tests/golden_sql/test_golden_sql_api.py index dcd23d8b..e38ea5b7 100644 --- a/apps/ai/server/tests/golden_sql/test_golden_sql_api.py +++ b/apps/ai/server/tests/golden_sql/test_golden_sql_api.py @@ -148,7 +148,7 @@ def test_get_golden_sql(self): ) def test_add_golden_sql(self): response = client.post( - self.url, + self.url + "/user-upload", headers=self.test_header, json=[ { @@ -159,7 +159,7 @@ def test_add_golden_sql(self): ], ) assert response.status_code == status.HTTP_201_CREATED - assert response.json() == self.test_response_1 + assert response.json() == [self.test_response_1] @patch( "httpx.AsyncClient.delete",