Skip to content

Commit d7f5bf0

Browse files
Exercise Chat: Implement native function calling agent (#154)
Co-authored-by: coderabbitai[bot] <136622811+coderabbitai[bot]@users.noreply.github.com>
1 parent 2cffe1c commit d7f5bf0

36 files changed

+1550
-218
lines changed

.gitignore

+2
Original file line numberDiff line numberDiff line change
@@ -175,3 +175,5 @@ cython_debug/
175175
# and can be added to the global gitignore or merged into this file. For a more nuclear
176176
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
177177
.idea/
178+
179+
.DS_Store

app/common/PipelineEnum.py

+1
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ class PipelineEnum(str, Enum):
55
IRIS_CODE_FEEDBACK = "IRIS_CODE_FEEDBACK"
66
IRIS_CHAT_COURSE_MESSAGE = "IRIS_CHAT_COURSE_MESSAGE"
77
IRIS_CHAT_EXERCISE_MESSAGE = "IRIS_CHAT_EXERCISE_MESSAGE"
8+
IRIS_CHAT_EXERCISE_AGENT_MESSAGE = "IRIS_CHAT_EXERCISE_AGENT_MESSAGE"
89
IRIS_INTERACTION_SUGGESTION = "IRIS_INTERACTION_SUGGESTION"
910
IRIS_CHAT_LECTURE_MESSAGE = "IRIS_CHAT_LECTURE_MESSAGE"
1011
IRIS_COMPETENCY_GENERATION = "IRIS_COMPETENCY_GENERATION"

app/common/__init__.py

-4
Original file line numberDiff line numberDiff line change
@@ -1,5 +1 @@
11
from app.common.singleton import Singleton
2-
from app.common.message_converters import (
3-
convert_iris_message_to_langchain_message,
4-
convert_langchain_message_to_iris_message,
5-
)

app/common/message_converters.py

+113-15
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,31 @@
1+
import json
12
from datetime import datetime
2-
from typing import Literal
3+
from typing import Literal, List
34

4-
from langchain_core.messages import BaseMessage, HumanMessage, AIMessage, SystemMessage
5+
from langchain_core.messages import (
6+
BaseMessage,
7+
HumanMessage,
8+
AIMessage,
9+
SystemMessage,
10+
ToolMessage,
11+
ToolCall,
12+
)
513

14+
from app.common.pyris_message import (
15+
PyrisMessage,
16+
PyrisAIMessage,
17+
IrisMessageRole,
18+
PyrisToolMessage,
19+
)
620
from app.domain.data.text_message_content_dto import TextMessageContentDTO
7-
from app.common.pyris_message import PyrisMessage, IrisMessageRole
21+
from app.domain.data.tool_call_dto import ToolCallDTO, FunctionDTO
22+
from app.domain.data.tool_message_content_dto import ToolMessageContentDTO
823

924

1025
def convert_iris_message_to_langchain_message(
1126
iris_message: PyrisMessage,
1227
) -> BaseMessage:
13-
if len(iris_message.contents) == 0:
28+
if iris_message is None or len(iris_message.contents) == 0:
1429
raise ValueError("IrisMessage contents must not be empty")
1530
message = iris_message.contents[0]
1631
# Check if the message is of type TextMessageContentDTO
@@ -20,41 +35,122 @@ def convert_iris_message_to_langchain_message(
2035
case IrisMessageRole.USER:
2136
return HumanMessage(content=message.text_content)
2237
case IrisMessageRole.ASSISTANT:
38+
if isinstance(iris_message, PyrisAIMessage):
39+
tool_calls = [
40+
ToolCall(
41+
name=tc.function.name,
42+
args=tc.function.arguments,
43+
id=tc.id,
44+
)
45+
for tc in iris_message.tool_calls
46+
]
47+
return AIMessage(content=message.text_content, tool_calls=tool_calls)
2348
return AIMessage(content=message.text_content)
2449
case IrisMessageRole.SYSTEM:
2550
return SystemMessage(content=message.text_content)
2651
case _:
2752
raise ValueError(f"Unknown message role: {iris_message.sender}")
2853

2954

55+
def convert_iris_message_to_langchain_human_message(
56+
iris_message: PyrisMessage,
57+
) -> HumanMessage:
58+
if len(iris_message.contents) == 0:
59+
raise ValueError("IrisMessage contents must not be empty")
60+
message = iris_message.contents[0]
61+
# Check if the message is of type TextMessageContentDTO
62+
if not isinstance(message, TextMessageContentDTO):
63+
raise ValueError("Message must be of type TextMessageContentDTO")
64+
return HumanMessage(content=message.text_content)
65+
66+
67+
def extract_text_from_iris_message(iris_message: PyrisMessage) -> str:
68+
if len(iris_message.contents) == 0:
69+
raise ValueError("IrisMessage contents must not be empty")
70+
message = iris_message.contents[0]
71+
# Check if the message is of type TextMessageContentDTO
72+
if not isinstance(message, TextMessageContentDTO):
73+
raise ValueError("Message must be of type TextMessageContentDTO")
74+
return message.text_content
75+
76+
77+
def convert_langchain_tool_calls_to_iris_tool_calls(
78+
tool_calls: List[ToolCall],
79+
) -> List[ToolCallDTO]:
80+
return [
81+
ToolCallDTO(
82+
function=FunctionDTO(
83+
name=tc["name"],
84+
arguments=json.dumps(tc["args"]),
85+
),
86+
id=tc["id"],
87+
)
88+
for tc in tool_calls
89+
]
90+
91+
3092
def convert_langchain_message_to_iris_message(
3193
base_message: BaseMessage,
3294
) -> PyrisMessage:
33-
match base_message.type:
34-
case "human":
35-
role = IrisMessageRole.USER
36-
case "ai":
37-
role = IrisMessageRole.ASSISTANT
38-
case "system":
39-
role = IrisMessageRole.SYSTEM
40-
case _:
41-
raise ValueError(f"Unknown message type: {base_message.type}")
42-
contents = [TextMessageContentDTO(textContent=base_message.content)]
95+
type_to_role = {
96+
"human": IrisMessageRole.USER,
97+
"ai": IrisMessageRole.ASSISTANT,
98+
"system": IrisMessageRole.SYSTEM,
99+
"tool": IrisMessageRole.TOOL,
100+
}
101+
102+
role = type_to_role.get(base_message.type)
103+
if role is None:
104+
raise ValueError(f"Unknown message type: {base_message.type}")
105+
106+
if isinstance(base_message, (HumanMessage, SystemMessage)):
107+
contents = [TextMessageContentDTO(textContent=base_message.content)]
108+
elif isinstance(base_message, AIMessage):
109+
if base_message.tool_calls:
110+
contents = [TextMessageContentDTO(textContent=base_message.content)]
111+
tool_calls = convert_langchain_tool_calls_to_iris_tool_calls(
112+
base_message.tool_calls
113+
)
114+
return PyrisAIMessage(
115+
contents=contents,
116+
tool_calls=tool_calls,
117+
send_at=datetime.now(),
118+
)
119+
else:
120+
contents = [TextMessageContentDTO(textContent=base_message.content)]
121+
elif isinstance(base_message, ToolMessage):
122+
contents = [
123+
ToolMessageContentDTO(
124+
toolContent=base_message.content,
125+
toolName=base_message.additional_kwargs["name"],
126+
toolCallId=base_message.tool_call_id,
127+
)
128+
]
129+
return PyrisToolMessage(
130+
contents=contents,
131+
send_at=datetime.now(),
132+
)
133+
else:
134+
raise ValueError(f"Unknown message type: {type(base_message)}")
43135
return PyrisMessage(
44136
contents=contents,
45137
sender=role,
46138
send_at=datetime.now(),
47139
)
48140

49141

50-
def map_role_to_str(role: IrisMessageRole) -> Literal["user", "assistant", "system"]:
142+
def map_role_to_str(
143+
role: IrisMessageRole,
144+
) -> Literal["user", "assistant", "system", "tool"]:
51145
match role:
52146
case IrisMessageRole.USER:
53147
return "user"
54148
case IrisMessageRole.ASSISTANT:
55149
return "assistant"
56150
case IrisMessageRole.SYSTEM:
57151
return "system"
152+
case IrisMessageRole.TOOL:
153+
return "tool"
58154
case _:
59155
raise ValueError(f"Unknown message role: {role}")
60156

@@ -67,5 +163,7 @@ def map_str_to_role(role: str) -> IrisMessageRole:
67163
return IrisMessageRole.ASSISTANT
68164
case "system":
69165
return IrisMessageRole.SYSTEM
166+
case "tool":
167+
return IrisMessageRole.TOOL
70168
case _:
71169
raise ValueError(f"Unknown message role: {role}")

app/common/pyris_message.py

+18-2
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,20 @@
11
from datetime import datetime
22
from enum import Enum
3-
from typing import List
3+
from typing import List, Optional
44

55
from pydantic import BaseModel, ConfigDict, Field
66

77
from app.domain.data.message_content_dto import MessageContentDTO
88
from app.common.token_usage_dto import TokenUsageDTO
9+
from app.domain.data.tool_call_dto import ToolCallDTO
10+
from app.domain.data.tool_message_content_dto import ToolMessageContentDTO
911

1012

1113
class IrisMessageRole(str, Enum):
1214
USER = "USER"
1315
ASSISTANT = "LLM"
1416
SYSTEM = "SYSTEM"
17+
TOOL = "TOOL"
1518

1619

1720
class PyrisMessage(BaseModel):
@@ -21,7 +24,20 @@ class PyrisMessage(BaseModel):
2124

2225
sent_at: datetime | None = Field(alias="sentAt", default=None)
2326
sender: IrisMessageRole
24-
contents: List[MessageContentDTO] = []
27+
28+
contents: List[MessageContentDTO] = Field(default=[])
2529

2630
def __str__(self):
2731
return f"{self.sender.lower()}: {self.contents}"
32+
33+
34+
class PyrisAIMessage(PyrisMessage):
35+
model_config = ConfigDict(populate_by_name=True)
36+
sender: IrisMessageRole = IrisMessageRole.ASSISTANT
37+
tool_calls: Optional[List[ToolCallDTO]] = Field(alias="toolCalls")
38+
39+
40+
class PyrisToolMessage(PyrisMessage):
41+
model_config = ConfigDict(populate_by_name=True)
42+
sender: IrisMessageRole = IrisMessageRole.TOOL
43+
contents: List[ToolMessageContentDTO] = Field(default=[])

app/config.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ def get_settings(cls):
3232
try:
3333
with open(file_path, "r") as file:
3434
settings_file = yaml.safe_load(file)
35-
return cls.parse_obj(settings_file)
35+
return cls.model_validate(settings_file)
3636
except FileNotFoundError as e:
3737
raise FileNotFoundError(
3838
f"Configuration file not found at {file_path}."
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,14 @@
1-
from typing import Optional
1+
from typing import Optional, Any
22

33
from pydantic import Field
44

55
from ..chat_pipeline_execution_dto import ChatPipelineExecutionDTO
66
from ...data.extended_course_dto import ExtendedCourseDTO
7-
from ...data.metrics.competency_jol_dto import CompetencyJolDTO
87
from ...data.metrics.student_metrics_dto import StudentMetricsDTO
8+
from ...event.pyris_event_dto import PyrisEventDTO
99

1010

1111
class CourseChatPipelineExecutionDTO(ChatPipelineExecutionDTO):
1212
course: ExtendedCourseDTO
1313
metrics: Optional[StudentMetricsDTO]
14-
competency_jol: Optional[CompetencyJolDTO] = Field(None, alias="competencyJol")
14+
event_payload: Optional[PyrisEventDTO[Any]] = Field(None, alias="eventPayload")
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,16 @@
1-
from typing import Optional
1+
from typing import Optional, Any
2+
3+
from pydantic import Field
24

35
from app.domain.chat.chat_pipeline_execution_dto import ChatPipelineExecutionDTO
46
from app.domain.data.course_dto import CourseDTO
57
from app.domain.data.programming_exercise_dto import ProgrammingExerciseDTO
68
from app.domain.data.programming_submission_dto import ProgrammingSubmissionDTO
9+
from app.domain.event.pyris_event_dto import PyrisEventDTO
710

811

912
class ExerciseChatPipelineExecutionDTO(ChatPipelineExecutionDTO):
1013
submission: Optional[ProgrammingSubmissionDTO] = None
1114
exercise: ProgrammingExerciseDTO
1215
course: CourseDTO
16+
event_payload: Optional[PyrisEventDTO[Any]] = Field(None, alias="eventPayload")

app/domain/data/competency_dto.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from datetime import datetime
22
from enum import Enum
3-
from typing import Optional
3+
from typing import Optional, List
44

55
from pydantic import BaseModel, Field
66
from pydantic.v1 import validator
@@ -22,6 +22,7 @@ class CompetencyDTO(BaseModel):
2222
taxonomy: Optional[CompetencyTaxonomy] = None
2323
soft_due_date: Optional[datetime] = Field(default=None, alias="softDueDate")
2424
optional: Optional[bool] = None
25+
exercise_list: Optional[List[int]] = Field(default=[], alias="exerciseList")
2526

2627

2728
class Competency(BaseModel):

app/domain/data/exercise_with_submissions_dto.py

+1
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@ class IncludedInOverallScore(str, Enum):
3434

3535
class ExerciseWithSubmissionsDTO(BaseModel):
3636
id: int = Field(alias="id")
37+
url: Optional[str] = Field(alias="url", default=None)
3738
title: str = Field(alias="title")
3839
type: ExerciseType = Field(alias="type")
3940
mode: ExerciseMode = Field(alias="mode")
+5-1
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,13 @@
11
from typing import Union
22

3+
from .tool_message_content_dto import ToolMessageContentDTO
34
from ...domain.data.image_message_content_dto import ImageMessageContentDTO
45
from ...domain.data.json_message_content_dto import JsonMessageContentDTO
56
from ...domain.data.text_message_content_dto import TextMessageContentDTO
67

78
MessageContentDTO = Union[
8-
TextMessageContentDTO, ImageMessageContentDTO, JsonMessageContentDTO
9+
TextMessageContentDTO,
10+
ImageMessageContentDTO,
11+
JsonMessageContentDTO,
12+
ToolMessageContentDTO,
913
]

app/domain/data/programming_exercise_dto.py

+6
Original file line numberDiff line numberDiff line change
@@ -30,3 +30,9 @@ class ProgrammingExerciseDTO(BaseModel):
3030
problem_statement: str = Field(alias="problemStatement", default=None)
3131
start_date: Optional[datetime] = Field(alias="startDate", default=None)
3232
end_date: Optional[datetime] = Field(alias="endDate", default=None)
33+
max_points: Optional[float] = Field(alias="maxPoints", default=None)
34+
recent_changes: Optional[str] = Field(
35+
alias="recentChanges",
36+
default=None,
37+
description="Git diff of the recent changes",
38+
)

app/domain/data/tool_call_dto.py

+16
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
from typing import Literal, Any
2+
3+
from pydantic import BaseModel, ConfigDict, Field, Json
4+
5+
6+
class FunctionDTO(BaseModel):
7+
name: str = Field(..., alias="name")
8+
arguments: Json[Any] = Field(..., alias="arguments")
9+
10+
11+
class ToolCallDTO(BaseModel):
12+
13+
model_config = ConfigDict(populate_by_name=True)
14+
id: str = Field(alias="id")
15+
type: Literal["function"] = "function"
16+
function: FunctionDTO = Field(alias="function")
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
from typing import Optional
2+
3+
from pydantic import BaseModel, ConfigDict, Field
4+
5+
6+
class ToolMessageContentDTO(BaseModel):
7+
8+
model_config = ConfigDict(populate_by_name=True)
9+
name: Optional[str] = Field(alias="toolName", default="")
10+
tool_content: str = Field(alias="toolContent")
11+
tool_call_id: str = Field(alias="toolCallId")

app/domain/event/pyris_event_dto.py

+10
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
from typing import TypeVar, Generic, Optional
2+
3+
from pydantic import Field, BaseModel
4+
5+
T = TypeVar("T")
6+
7+
8+
class PyrisEventDTO(BaseModel, Generic[T]):
9+
event_type: Optional[str] = Field(default=None, alias="eventType")
10+
event: Optional[T] = Field(default=None, alias="event")

0 commit comments

Comments
 (0)