Skip to content

Commit

Permalink
DH-4725 add crud endpoints for instruction (#184)
Browse files Browse the repository at this point in the history
  • Loading branch information
DishenWang2023 committed May 7, 2024
1 parent 9f28194 commit f31a489
Show file tree
Hide file tree
Showing 13 changed files with 264 additions and 0 deletions.
2 changes: 2 additions & 0 deletions apps/ai/server/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from modules.auth import controller as auth_controller
from modules.db_connection import controller as db_connection_controller
from modules.golden_sql import controller as golden_sql_controller
from modules.instruction import controller as instruction_controller
from modules.organization import controller as organization_controller
from modules.query import controller as query_controller
from modules.table_description import controller as table_description_controller
Expand Down Expand Up @@ -43,6 +44,7 @@
app.include_router(auth_controller.router, tags=["Authentication"])
app.include_router(db_connection_controller.router, tags=["Database Connection"])
app.include_router(golden_sql_controller.router, tags=["Golden SQL"])
app.include_router(instruction_controller.router, tags=["Instruction"])
app.include_router(organization_controller.router, tags=["Organization"])
app.include_router(query_controller.router, tags=["Query"])
app.include_router(table_description_controller.router, tags=["Table Description"])
Expand Down
1 change: 1 addition & 0 deletions apps/ai/server/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
QUERY_RESPONSE_COL = "nl_query_responses"
GOLDEN_SQL_COL = "golden_records"
TABLE_DESCRIPTION_COL = "table_descriptions"
INSTRUCTION_COL = "instructions"

USER_COL = "users"
ORGANIZATION_COL = "organizations"
Expand Down
Empty file.
64 changes: 64 additions & 0 deletions apps/ai/server/modules/instruction/controller.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
from fastapi import APIRouter, Depends, status
from fastapi.security import HTTPBearer

from modules.instruction.models.requests import InstructionRequest
from modules.instruction.models.responses import InstructionResponse
from modules.instruction.service import InstructionService
from utils.auth import Authorize, VerifyToken

router = APIRouter(
prefix="/instruction",
responses={404: {"description": "Not found"}},
)

token_auth_scheme = HTTPBearer()
authorize = Authorize()
table_description_service = InstructionService()


@router.get("/list", status_code=status.HTTP_200_OK)
async def get_instructions(
token: str = Depends(token_auth_scheme),
) -> list[InstructionResponse]:
user = authorize.user(VerifyToken(token.credentials).verify())
organization = authorize.get_organization_by_user(user)
return await table_description_service.get_instructions(
organization.db_connection_id
)


@router.post("", status_code=status.HTTP_201_CREATED)
async def add_instructions(
instruction_request: InstructionRequest,
token: str = Depends(token_auth_scheme),
) -> InstructionResponse:
user = authorize.user(VerifyToken(token.credentials).verify())
organization = authorize.get_organization_by_user(user)
return await table_description_service.add_instruction(
instruction_request, organization.db_connection_id
)


@router.put("/{id}")
async def update_instruction(
id: str,
instruction_request: InstructionRequest,
token: str = Depends(token_auth_scheme),
) -> InstructionResponse:
user = authorize.user(VerifyToken(token.credentials).verify())
organization = authorize.get_organization_by_user(user)
authorize.instruction_in_organization(id, organization)
return await table_description_service.update_instruction(
id, instruction_request, organization.db_connection_id
)


@router.delete("/{id}")
async def delete_instruction(
id: str,
token: str = Depends(token_auth_scheme),
):
user = authorize.user(VerifyToken(token.credentials).verify())
organization = authorize.get_organization_by_user(user)
authorize.instruction_in_organization(id, organization)
return await table_description_service.delete_instruction(id)
Empty file.
12 changes: 12 additions & 0 deletions apps/ai/server/modules/instruction/models/entities.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
from typing import Any

from pydantic import BaseModel, Field


class BaseInstruction(BaseModel):
instruction: str


class Instruction(BaseInstruction):
id: Any = Field(alias="_id")
db_connection_id: str
5 changes: 5 additions & 0 deletions apps/ai/server/modules/instruction/models/requests.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
from modules.instruction.models.entities import BaseInstruction


class InstructionRequest(BaseInstruction):
pass
6 changes: 6 additions & 0 deletions apps/ai/server/modules/instruction/models/responses.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
from modules.instruction.models.entities import BaseInstruction


class InstructionResponse(BaseInstruction):
id: str
db_connection_id: str
Empty file.
58 changes: 58 additions & 0 deletions apps/ai/server/modules/instruction/service.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
import httpx

from config import settings
from modules.instruction.models.requests import InstructionRequest
from modules.instruction.models.responses import InstructionResponse
from utils.exception import raise_for_status


class InstructionService:
async def get_instructions(
self, db_connection_id: str
) -> list[InstructionResponse]:
async with httpx.AsyncClient() as client:
response = await client.get(
settings.k2_core_url + "/instructions",
params={"db_connection_id": db_connection_id},
)
raise_for_status(response.status_code, response.text)
return [InstructionResponse(**td) for td in response.json()]

async def add_instruction(
self, instruction_request: InstructionRequest, db_connection_id: str
) -> InstructionResponse:
async with httpx.AsyncClient() as client:
response = await client.post(
settings.k2_core_url + "/instructions",
json={
"db_connection_id": db_connection_id,
**instruction_request.dict(),
},
)
raise_for_status(response.status_code, response.text)
return InstructionResponse(**response.json())

async def update_instruction(
self,
instruction_id,
instruction_request: InstructionRequest,
db_connection_id: str,
) -> InstructionResponse:
async with httpx.AsyncClient() as client:
response = await client.put(
settings.k2_core_url + f"/instructions/{instruction_id}",
json={
"db_connection_id": db_connection_id,
**instruction_request.dict(),
},
)
raise_for_status(response.status_code, response.text)
return InstructionResponse(**response.json())

async def delete_instruction(self, instruction_id):
async with httpx.AsyncClient() as client:
response = await client.delete(
settings.k2_core_url + f"/instructions/{instruction_id}",
)
raise_for_status(response.status_code, response.text)
return {"id": instruction_id}
Empty file.
99 changes: 99 additions & 0 deletions apps/ai/server/tests/instruction/test_instruction_api.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
from unittest import TestCase
from unittest.mock import AsyncMock, Mock, patch

from bson import ObjectId
from fastapi import status
from fastapi.testclient import TestClient
from httpx import Response

from app import app
from modules.organization.models.entities import Organization
from modules.user.models.entities import User

client = TestClient(app)


@patch("utils.auth.VerifyToken.verify", Mock(return_value={"email": ""}))
@patch.multiple(
"utils.auth.Authorize",
user=Mock(
return_value=User(
id="123",
email="[email protected]",
username="test_user",
organization_id="123",
)
),
get_organization_by_user=Mock(
return_value=Organization(
id="123", name="test_org", db_connection_id="0123456789ab0123456789ab"
)
),
instruction_in_organization=Mock(return_value=None),
)
class TestInstructionAPI(TestCase):
url = "/instruction"
test_header = {"Authorization": "Bearer some-token"}
test_1 = {
"id": ObjectId(b"foo-bar-quux"),
"instruction": "test_instruction",
"db_connection_id": "0123456789ab0123456789ab",
}

test_response_0 = {
"id": str(test_1["id"]),
"instruction": "test_instruction",
"db_connection_id": "0123456789ab0123456789ab",
}

test_response_1 = test_response_0.copy()

@patch(
"httpx.AsyncClient.get",
AsyncMock(return_value=Response(status_code=200, json=[test_response_0])),
)
def test_get_instructions(self):
response = client.get(self.url + "/list", headers=self.test_header)
assert response.status_code == status.HTTP_200_OK
assert response.json() == [self.test_response_1]

@patch(
"httpx.AsyncClient.post",
AsyncMock(return_value=Response(status_code=201, json=test_response_0)),
)
def test_add_instruction(self):
response = client.post(
self.url,
headers=self.test_header,
json={
"instruction": "test_description",
},
)
assert response.status_code == status.HTTP_201_CREATED
assert response.json() == self.test_response_1

@patch(
"httpx.AsyncClient.put",
AsyncMock(return_value=Response(status_code=200, json=test_response_0)),
)
def test_update_instruction(self):
response = client.put(
self.url + "/666f6f2d6261722d71757578",
headers=self.test_header,
json={
"instruction": "test_description",
},
)
assert response.status_code == status.HTTP_200_OK
assert response.json() == self.test_response_1

@patch(
"modules.instruction.service.InstructionService.delete_instruction",
AsyncMock(return_value={"status": True}),
)
def test_delete_instruction(self):
response = client.delete(
self.url + "/666f6f2d6261722d71757578",
headers=self.test_header,
)
assert response.status_code == status.HTTP_200_OK
17 changes: 17 additions & 0 deletions apps/ai/server/utils/auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from config import (
DATABASE_CONNECTION_REF_COL,
GOLDEN_SQL_REF_COL,
INSTRUCTION_COL,
QUERY_RESPONSE_REF_COL,
TABLE_DESCRIPTION_COL,
USER_COL,
Expand Down Expand Up @@ -131,6 +132,22 @@ def user(self, payload: dict) -> User:
)
return user

def instruction_in_organization(
self, instruction_id: str, organization: OrganizationResponse
):
instruction = MongoDB.find_one(
INSTRUCTION_COL,
{
"_id": ObjectId(instruction_id),
"db_connection_id": organization.db_connection_id,
},
)

if not instruction:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND, detail="Instruction not found"
)

def db_connection_in_organization(self, db_connection_id: str, org_id: str):
self._item_in_organization(
DATABASE_CONNECTION_REF_COL,
Expand Down

0 comments on commit f31a489

Please sign in to comment.