Skip to content

Commit

Permalink
mcp adapter
Browse files Browse the repository at this point in the history
  • Loading branch information
phact committed Feb 13, 2025
1 parent 2fda450 commit 8addfa6
Show file tree
Hide file tree
Showing 5 changed files with 515 additions and 103 deletions.
82 changes: 60 additions & 22 deletions client/astra_assistants/astra_assistants_manager.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,41 @@
import logging
import os
from typing import List
import uuid
from typing import List, Any

from litellm import get_llm_provider

from astra_assistants import patch, OpenAIWithDefaultKey
from astra_assistants.astra_assistants_event_handler import AstraEventHandler
from astra_assistants.tools.tool_interface import ToolInterface
from astra_assistants.utils import env_var_is_missing, get_env_vars_for_provider
from astra_assistants.mcp_openai_adapter import MCPOpenAIAAdapter

logger = logging.getLogger(__name__)

class AssistantManager:
def __init__(self, instructions: str = None, model: str = "gpt-4o", name: str = "managed_assistant", tools: List[ToolInterface] = None, thread_id: str = None, thread: str = None, assistant_id: str = None, client = None, tool_resources = None):
def __init__(self,
instructions: str = None,
model: str = "gpt-4o",
name: str = "managed_assistant",
tools: List[ToolInterface] = None,
thread_id: str = None,
thread: str = None,
assistant_id: str = None,
client = None,
tool_resources = None,
mcp_represenations = None
):

if instructions is None and assistant_id is None:
raise Exception("Instructions must be provided if assistant_id is not provided")
if tools is None:
tools = []
# Only patch if astra token is provided


self.tools = tools

# Initialize client using the provided client or the default based on environment tokens.
if client is not None:
self.client = client
else:
Expand All @@ -31,7 +49,6 @@ def __init__(self, instructions: str = None, model: str = "gpt-4o", name: str =
self.client = OpenAIWithDefaultKey()
self.model = model
self.instructions = instructions
self.tools = tools
self.tool_resources = tool_resources
self.name = name
self.tool_call_arguments = None
Expand All @@ -48,9 +65,25 @@ def __init__(self, instructions: str = None, model: str = "gpt-4o", name: str =
elif thread_id is not None:
self.thread = self.client.beta.threads.retrieve(thread_id)


self.mcp_adapter = None
self.register_mcp(mcp_represenations)

logger.info(f'assistant {self.assistant}')
logger.info(f'thread {self.thread}')

def register_mcp(self, mcp_representations):
# If MCP representations are provided, convert them to tools using the adapter.
if mcp_representations is not None:
self.mcp_adapter = MCPOpenAIAAdapter(mcp_representations)

mcp_tools = self.mcp_adapter.get_tools()
self.tools.extend(mcp_tools)

schemas = self.mcp_adapter.get_json_schema_for_tools()
assistant = self.client.beta.assistants.update(assistant_id=self.assistant.id, tools=schemas)
self.assistant = assistant

def get_client(self):
return self.client

Expand All @@ -65,25 +98,24 @@ def create_assistant(self):
for tool in self.tools:
if hasattr(tool, 'to_function'):
tool_holder.append(tool.to_function())

if len(tool_holder) == 0:
tool_holder = self.tools

# Create and return the assistant
# Create and return the assistant with the combined tool definitions.
self.assistant = self.client.beta.assistants.create(
name=self.name,
instructions=self.instructions,
model=self.model,
tools=tool_holder,
tool_resources=self.tool_resources
)
logger.debug("Assistant created:", self.assistant)
logger.debug("Assistant created: %s", self.assistant)
return self.assistant

def create_thread(self):
# Create and return a new thread
# Create and return a new thread.
thread = self.client.beta.threads.create()
logger.debug("Thread generated:", thread)
logger.debug("Thread generated: %s", thread)
return thread

def stream_thread(self, content, tool_choice = None, thread_id: str = None, thread = None, additional_instructions = None):
Expand Down Expand Up @@ -112,7 +144,6 @@ def stream_thread(self, content, tool_choice = None, thread_id: str = None, thre
"event_handler": event_handler,
"additional_instructions": additional_instructions
}
# Conditionally add 'tool_choice' if it's not None
if tool_choice is not None:
args["tool_choice"] = tool_choice

Expand All @@ -121,8 +152,6 @@ def stream_thread(self, content, tool_choice = None, thread_id: str = None, thre
for text in stream.text_deltas:
yield text

tool_call_results = None
tool_call_arguments = None
self.tool_call_arguments = event_handler.arguments
if event_handler.stream is not None:
if event_handler.tool_call_results is not None:
Expand All @@ -133,7 +162,7 @@ def stream_thread(self, content, tool_choice = None, thread_id: str = None, thre
except Exception as e:
logger.error(e)
raise e

async def run_thread(self, content, tool = None, thread_id: str = None, thread = None, additional_instructions = None):
if thread_id is not None:
thread = self.client.beta.threads.retrieve(thread_id)
Expand All @@ -142,10 +171,15 @@ async def run_thread(self, content, tool = None, thread_id: str = None, thread =

assistant = self.assistant
event_handler = AstraEventHandler(self.client)

tool_choice = None
if tool is not None:
event_handler.register_tool(tool)
tool_choice = tool.tool_choice_object()

for tool in self.tools:
event_handler.register_tool(tool)

try:
self.client.beta.threads.messages.create(
thread_id=thread.id, role="user", content=content
Expand All @@ -156,33 +190,37 @@ async def run_thread(self, content, tool = None, thread_id: str = None, thread =
"event_handler": event_handler,
"additional_instructions": additional_instructions
}
# Conditionally add 'tool_choice' if it's not None
if tool_choice is not None:
args["tool_choice"] = tool_choice

text = ""
with self.client.beta.threads.runs.create_and_stream(**args) as stream:
with self.client.beta.threads.runs.stream(**args) as stream:
for part in stream.text_deltas:
text += part

tool_call_results = None
if event_handler.stream is not None:
with event_handler.stream as stream:
for part in stream.text_deltas:
text += part

tool_call_results = event_handler.tool_call_results
file_search = event_handler.file_search
if tool_call_results is not None:
file_search = event_handler.file_search

tool_call_results['file_search'] = file_search
tool_call_results['text'] = text
tool_call_results['arguments'] = event_handler.arguments
tool_call_results['file_search'] = file_search
tool_call_results['text'] = text
tool_call_results['arguments'] = event_handler.arguments
else:
print("event_handler.stream is not None but tool_call_results is None, bug?")

logger.info(tool_call_results)
tool_call_results
if tool_call_results is not None:
return tool_call_results
return {"text": text, "file_search": event_handler.file_search}
except Exception as e:
logger.error(e)
raise e
raise e

def shutdown(self):
self.mcp_adapter.shutdown()
191 changes: 191 additions & 0 deletions client/astra_assistants/mcp_openai_adapter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,191 @@
import asyncio
import os
import threading
import re
from abc import ABC
from contextlib import AsyncExitStack
from typing import List, Union, Optional, Literal, Dict, Any, Type

from mcp.types import CallToolResult
from pydantic import BaseModel, Field, create_model

# Import the high‐level MCP client interfaces from the official SDK.
from mcp import ClientSession, StdioServerParameters
from mcp.client.stdio import stdio_client

from astra_assistants.tools.tool_interface import ToolInterface, ToolResult

# --- MCP Representation Models ---

class MCPRepresentationBase(BaseModel):
type: str

class MCPRepresentationStdio(MCPRepresentationBase):
type: str = Literal["stdio"]
command: str
arguments: Optional[List[str]] = None
env_vars: Optional[List[str]] = None

class MCPRepresentationSSE(MCPRepresentationBase):
type: str = Literal["sse"]
sse_url: str

MCPRepresentation = Union[MCPRepresentationStdio, MCPRepresentationSSE]

# --- Helper functions ---

def generate_pydantic_model_from_schema(schema: Dict[str, Any], model_name: str = "DynamicModel") -> Type[BaseModel]:
fields = {}
properties = schema.get("properties", {})
required_fields = set(schema.get("required", []))
type_mapping = {
"string": str,
"integer": int,
"number": float,
"boolean": bool,
"array": list,
"object": dict
}
for field_name, field_schema in properties.items():
field_type = type_mapping.get(field_schema.get("type"), Any)
if field_name in required_fields:
fields[field_name] = (field_type, ...)
else:
fields[field_name] = (field_type, None)
return create_model(model_name, **fields)

def to_camel_case(s: str) -> str:
return ''.join(word.capitalize() for word in re.split(r'[_-]', s))

# --- MCP Tool Adapter (implements ToolInterface) ---

class MCPToolAdapter(ToolInterface, ABC):
def __init__(self, representation: MCPRepresentation, mcp_session: ClientSession, mcp_tool):
self.representation = representation
self.mcp_session = mcp_session
self.mcp_tool = mcp_tool

def get_model(self):
return generate_pydantic_model_from_schema(
self.mcp_tool.inputSchema,
to_camel_case(self.mcp_tool.name)
)

def to_function(self) -> dict:
return {
"type": "function",
"function": {
"name": self.mcp_tool.name,
"description": self.mcp_tool.description,
"parameters": self.mcp_tool.inputSchema
}
}

def call(self, arguments: BaseModel) -> CallToolResult:
# Use the background loop to run the async call synchronously.
future = asyncio.run_coroutine_threadsafe(
self.mcp_session.call_tool(
self.mcp_tool.name,
arguments=arguments.model_dump()
),
self.mcp_session_loop # set below when session is created
)
return {"output": future.result().content[0].text}

# --- MCP OpenAI Adapter ---

class MCPOpenAIAAdapter:
"""
This adapter connects to an MCP server using the official Python SDK (via stdio transport)
on a dedicated background thread. This allows synchronous methods (like call) to schedule
async work via asyncio.run_coroutine_threadsafe.
"""
def __init__(
self,
mcp_representations: List[MCPRepresentation] = None,
):
self.exit_stack = AsyncExitStack()
self.mcp_representations = mcp_representations or []
self.server_params = []
for rep in self.mcp_representations:
if rep.type == 'stdio':
env_vars = {"PATH": os.environ["PATH"]}
if rep.env_vars is not None:
# Assume env_vars are provided as a dict-like mapping or as "KEY=VALUE" strings.
for var in rep.env_vars:
if "=" in var:
key, value = var.split("=", 1)
env_vars[key] = value
# Split command into executable and arguments.
parts = rep.command.split()
executable = parts[0]
initial_args = parts[1:]
combined_args = initial_args + (rep.arguments or [])
server_param = StdioServerParameters(
command=executable,
args=combined_args,
env=env_vars,
)
self.server_params.append(server_param)
elif rep.type == 'sse':
self.server_params.append(rep.sse_url)
self.session: Optional[ClientSession] = None
self.tools: List[MCPToolAdapter] = []
self._bg_loop = asyncio.new_event_loop()
self._bg_thread = threading.Thread(target=self._run_bg_loop, daemon=True)
self._bg_thread.start()

def _run_bg_loop(self):
asyncio.set_event_loop(self._bg_loop)
self._bg_loop.run_forever()

def sync_connect(self):
"""
Synchronously connect to the MCP server using the background loop.
This schedules the async connect() coroutine on the background loop.
"""
for server_param in self.server_params:
asyncio.run_coroutine_threadsafe(self._connect(server_param), self._bg_loop).result()

async def _connect(self, server_param):
transport = await self.exit_stack.enter_async_context(stdio_client(server_param))
self.stdio, self.write = transport
# Create the session on the background loop.
self.session = await self.exit_stack.enter_async_context(ClientSession(self.stdio, self.write))
await self.session.initialize()
# Attach the background loop reference to each tool adapter.
result = await self.session.list_tools()
for rep in self.mcp_representations:
for tool in result.tools:
adapter = MCPToolAdapter(representation=rep, mcp_session=self.session, mcp_tool=tool)
# Set the event loop used by this session (i.e. the background loop)
adapter.mcp_session_loop = self._bg_loop
self.tools.append(adapter)

def get_tools(self) -> List[MCPToolAdapter]:
if self.session is None:
self.sync_connect()
return self.tools

def get_json_schema_for_tools(self) -> List[dict]:
# Since to_function() is synchronous, simply return the schemas.
return [tool_adapter.to_function() for tool_adapter in self.tools]

def shutdown(self):
"""
Cleanly shuts down the background loop and thread.
"""
# First, if session exists, schedule exit of the exit stack.
if self.session is not None:
future = asyncio.run_coroutine_threadsafe(self.exit_stack.aclose(), self._bg_loop)
try:
future.result(timeout=5)
except Exception as e:
print("Error during exit_stack.aclose():", e)
self.session = None
# Signal the background loop to stop.
self._bg_loop.call_soon_threadsafe(self._bg_loop.stop)
# Wait for the background thread to finish.
self._bg_thread.join(timeout=5)
# Close the loop.
self._bg_loop.close()
Loading

0 comments on commit 8addfa6

Please sign in to comment.