Skip to content

Commit

Permalink
DH-4757 separarted user upload and verified query logic (#187)
Browse files Browse the repository at this point in the history
  • Loading branch information
DishenWang2023 committed May 7, 2024
1 parent f31a489 commit 3755652
Show file tree
Hide file tree
Showing 5 changed files with 65 additions and 35 deletions.
11 changes: 5 additions & 6 deletions apps/ai/server/modules/golden_sql/controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
from fastapi import APIRouter, Depends, status
from fastapi.security import HTTPBearer

from modules.golden_sql.models.entities import GoldenSQLSource
from modules.golden_sql.models.requests import GoldenSQLRequest
from modules.golden_sql.models.responses import GoldenSQLResponse
from modules.golden_sql.service import GoldenSQLService
Expand Down Expand Up @@ -43,13 +42,13 @@ async def get_golden_sql(
return golden_sql_service.get_golden_sql(id)


@router.post("", status_code=status.HTTP_201_CREATED)
async def add_golden_sql(
@router.post("/user-upload", status_code=status.HTTP_201_CREATED)
async def add_user_upload_golden_sql(
golden_sql_requests: List[GoldenSQLRequest], token: str = Depends(token_auth_scheme)
) -> GoldenSQLResponse:
) -> List[GoldenSQLResponse]:
org_id = authorize.user_and_get_org_id(VerifyToken(token.credentials).verify())
return await golden_sql_service.add_golden_sql(
golden_sql_requests, org_id, GoldenSQLSource.user_upload
return await golden_sql_service.add_user_upload_golden_sql(
golden_sql_requests, org_id
)


Expand Down
4 changes: 2 additions & 2 deletions apps/ai/server/modules/golden_sql/models/entities.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,8 @@ class BaseGoldenSQL(BaseModel):


class GoldenSQLSource(Enum):
user_upload = "USER_UPLOAD"
verified_query = "VERIFIED_QUERY"
USER_UPLOAD = "USER_UPLOAD"
VERIFIED_QUERY = "VERIFIED_QUERY"


class GoldenSQLRef(BaseModel):
Expand Down
75 changes: 54 additions & 21 deletions apps/ai/server/modules/golden_sql/service.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,43 +53,32 @@ def get_golden_sqls(
def get_verified_golden_sql_ref(self, query_response_id: str) -> GoldenSQLRef:
return self.repo.get_verified_golden_sql_ref(query_response_id)

async def add_golden_sql(
async def add_verified_query_golden_sql(
self,
golden_sql_requests: List[GoldenSQLRequest],
golden_sql_request: GoldenSQLRequest,
org_id: str,
source: GoldenSQLSource,
query_response_id: str = None,
query_response_id: str,
) -> GoldenSQLResponse:
golden_sql_ref = self.repo.get_verified_golden_sql_ref(query_response_id)
# if already exist, delete golden_sql_ref and call delete /golden-records
if golden_sql_ref:
await self.delete_golden_sql("", query_response_id)
async with httpx.AsyncClient() as client:
if query_response_id:
golden_sql_ref = self.repo.get_verified_golden_sql_ref(
query_response_id
)
# if already exist, delete golden_sql_ref and call delete /golden-records
if golden_sql_ref:
await self.delete_golden_sql("", query_response_id)

# add golden_sql using core
response = await client.post(
settings.k2_core_url + "/golden-records",
# core should have consistent request body
json=[
golden_sql_request.dict()
for golden_sql_request in golden_sql_requests
],
json=[golden_sql_request.dict()],
timeout=settings.default_k2_core_timeout,
)
raise_for_status(response.status_code, response.text)
response_json = response.json()[0]
golden_sql = GoldenSQL(**response_json)
golden_sql.id = ObjectId(response_json["id"])
golden_sql = GoldenSQL(_id=ObjectId(response_json["id"]), **response_json)

display_id = self.repo.get_next_display_id(org_id)

golden_sql_ref_data = GoldenSQLRef(
golden_sql_id=golden_sql.id,
organization_id=ObjectId(org_id),
source=source.value,
source=GoldenSQLSource.VERIFIED_QUERY.value,
query_response_id=ObjectId(query_response_id),
created_time=datetime.now(timezone.utc).strftime("%Y-%m-%d %H:%M:%S"),
display_id=display_id,
Expand All @@ -100,6 +89,50 @@ async def add_golden_sql(
golden_sql_ref = self.repo.get_golden_sql_ref(str(golden_sql.id))
return self._get_mapped_golden_sql_response(golden_sql, golden_sql_ref)

async def add_user_upload_golden_sql(
self, golden_sql_requests: List[GoldenSQLRequest], org_id: str
) -> List[GoldenSQLResponse]:
async with httpx.AsyncClient() as client:
response = await client.post(
settings.k2_core_url + "/golden-records",
json=[
golden_sql_request.dict()
for golden_sql_request in golden_sql_requests
],
timeout=settings.default_k2_core_timeout,
)
raise_for_status(response.status_code, response.text)

response_jsons = response.json()
golden_sqls = [
GoldenSQL(_id=ObjectId(response_json["id"]), **response_json)
for response_json in response_jsons
]

golden_sql_responses = []

for golden_sql in golden_sqls:
display_id = self.repo.get_next_display_id(org_id)

golden_sql_ref_data = GoldenSQLRef(
golden_sql_id=golden_sql.id,
organization_id=ObjectId(org_id),
source=GoldenSQLSource.USER_UPLOAD.value,
created_time=datetime.now(timezone.utc).strftime(
"%Y-%m-%d %H:%M:%S"
),
display_id=display_id,
)

# add golden_sql_ref
self.repo.add_golden_sql_ref(golden_sql_ref_data.dict(exclude={"id"}))
golden_sql_ref = self.repo.get_golden_sql_ref(str(golden_sql.id))
golden_sql_responses.append(
self._get_mapped_golden_sql_response(golden_sql, golden_sql_ref)
)

return golden_sql_responses

async def delete_golden_sql(
self, golden_id: str, query_response_id: str = None
) -> dict:
Expand Down
6 changes: 2 additions & 4 deletions apps/ai/server/modules/query/service.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
from bson import ObjectId

from config import settings
from modules.golden_sql.models.entities import GoldenSQLSource
from modules.golden_sql.models.requests import GoldenSQLRequest
from modules.golden_sql.service import GoldenSQLService
from modules.organization.models.responses import OrganizationResponse
Expand Down Expand Up @@ -198,11 +197,10 @@ async def patch_query(
sql_query=query_request.sql_query,
db_connection_id=organization.db_connection_id,
)
await self.golden_sql_service.add_golden_sql(
await self.golden_sql_service.add_verified_query_golden_sql(
golden_sql,
organization.id,
source=GoldenSQLSource.verified_query,
query_response_id=query_id,
query_id,
)
else:
golden_sql_ref = self.golden_sql_service.get_verified_golden_sql_ref(
Expand Down
4 changes: 2 additions & 2 deletions apps/ai/server/tests/golden_sql/test_golden_sql_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,7 @@ def test_get_golden_sql(self):
)
def test_add_golden_sql(self):
response = client.post(
self.url,
self.url + "/user-upload",
headers=self.test_header,
json=[
{
Expand All @@ -159,7 +159,7 @@ def test_add_golden_sql(self):
],
)
assert response.status_code == status.HTTP_201_CREATED
assert response.json() == self.test_response_1
assert response.json() == [self.test_response_1]

@patch(
"httpx.AsyncClient.delete",
Expand Down

0 comments on commit 3755652

Please sign in to comment.