Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

dynamic client support (supports OpenAI API) #4

Open
wants to merge 7 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 40 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
# Operating System Files
.DS_Store
Thumbs.db

# IDE
.idea/
.vscode/

# Python
__pycache__/
*.py[cod]
*$py.class
.pytest_cache/
.coverage
htmlcov/

# Virtual Environment
venv/
env/
ENV/
.env
.venv

# Distribution / packaging
dist/
build/
*.egg-info/
*.egg

# Jupyter Notebook
.ipynb_checkpoints

# Logs
*.log


# Generated Directories
data/
outputs/
generated_actions/
9 changes: 7 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,15 @@ Empirically, DynaSaur exhibits remarkable versatility, recovering automatically
### 1. Create a `.env` file and add your keys:
```bash

# Required: Main keys for the agent
# Required: Azure or OpenAI main keys for the agent
# Azure Option
AZURE_API_KEY=""
AZURE_ENDPOINT=""
AZURE_API_VERSION=""
AZURE_MODEL_NAME=""
# OpenAI Option
OPENAI_API_KEY=""
OPENAI_MODEL_NAME=""

# Required: Keys for embeddings used in action retrieval
EMBED_MODEL_TYPE="AzureOpenAI"
Expand Down Expand Up @@ -55,7 +60,7 @@ python dynasaur.py
```

# TODOs
- [ ] Add support for the OpenAI API
- [x] Add support for the OpenAI API

# Citation
If you find this work useful, please cite our [paper](https://arxiv.org/pdf/2411.01747):
Expand Down
65 changes: 60 additions & 5 deletions dynasaur.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,45 @@
from agents import StructuredOutputDynamicActionSpaceAgent
from env import Env
from prompts import ACTION_DESCRIPTION_TEMPLATE, DYNASAUR_PROMPT
from scripts.llm_engines import AzureOpenAIEngine, StructuredOutputAzureOpenAIEngine
from scripts.llm_engines import AzureOpenAIEngine, OpenAIEngine, StructuredOutputEngine
from scripts.reformulator import prepare_response
from scripts.run_agents import answer_questions


def add_model_args(parser):
model_group = parser.add_argument_group('Model Configuration')
model_group.add_argument("--client", type=str, choices=['azure', 'openai', 'anthropic'], default='azure',
help="Which client to use (azure, openai, or anthropic)")
model_group.add_argument("--model_name", type=str,
help="Model name. If not provided, will use default for chosen client")
return parser

def get_default_model_name(client):
defaults = {
'azure': os.getenv("AZURE_MODEL_NAME"),
'openai': os.getenv("OPENAI_MODEL_NAME"),
'anthropic': "claude-3-sonnet-20240229"
}
name = defaults.get(client)
if name is None:
raise ValueError(f"No default model name for client: {client}")
return name

def get_llm_engine(client, model_name=None):
"""Factory function to create appropriate LLM engine based on client choice"""
if model_name is None:
model_name = get_default_model_name(client)

if client == 'azure':
return AzureOpenAIEngine(model_name=model_name)
elif client == 'openai':
return OpenAIEngine(model_name=model_name)
elif client == 'anthropic':
# return AnthropicEngine(model_name=model_name)
raise ValueError("Anthropic client not supported yet")
else:
raise ValueError(f"Unknown client: {client}")

def get_dataset(args):
dataset = datasets.load_dataset("gaia-benchmark/GAIA", args.split)[args.set]
dataset = dataset.rename_columns({"Question": "question", "Final answer": "true_answer", "Level": "level", "Annotator Metadata": "annotations"})
Expand Down Expand Up @@ -47,7 +81,18 @@ def get_env(args):


def get_agent(args, env):
llm_engine = StructuredOutputAzureOpenAIEngine(model_name=args.model_name, response_format="thought_code")
base_engine = get_llm_engine(args.client, args.model_name)

if isinstance(base_engine, (OpenAIEngine, AzureOpenAIEngine)):
llm_engine = StructuredOutputEngine(
model_name=base_engine.model_name,
client=base_engine.client,
response_format="thought_code"
)
else:
# # For non-OpenAI engines that don't support structured output
# llm_engine = base_engine
raise ValueError("Non-OpenAI engines not supported yet")

# Load initial actions
required_actions = list(actions.get_required_actions(args.generated_action_dir).values())
Expand All @@ -71,8 +116,8 @@ def get_agent(args, env):


def get_agent_call_function(args):
llm_engine = AzureOpenAIEngine(args.model_name)

llm_engine = get_llm_engine(args.client, args.model_name)
def agent_call_function(agent, question: str, **kwargs) -> str:
result = agent.run(question, **kwargs)

Expand Down Expand Up @@ -107,10 +152,20 @@ def agent_call_function(agent, question: str, **kwargs) -> str:
parser.add_argument("--generated_action_dir", type=str, default="generated_actions")
parser.add_argument("--set", type=str, default="validation")
parser.add_argument("--split", type=str, default="2023_level1")
parser.add_argument("--model_name", type=str, default="gpt-4o-2024-08-06")
parser.add_argument("--max_iterations", type=int, default=20)

# Add model configuration arguments
parser = add_model_args(parser)
args = parser.parse_args()

if args.model_name is None:
args.model_name = get_default_model_name(args.client)
if args.model_name is None and args.client == 'azure':
raise ValueError("model_name is required for Azure client. Set it via --model_name or AZURE_MODEL_NAME environment variable.")

# print client and model_name
print(f"Using {args.client} client with model: {args.model_name}")

agent_name = f"{args.model_name}-{args.split}"
generated_action_dir = os.path.join(args.generated_action_dir, agent_name)
args.agent_name = agent_name
Expand Down
13 changes: 12 additions & 1 deletion env.py
Original file line number Diff line number Diff line change
Expand Up @@ -812,7 +812,7 @@ def step(self, code, language="Python", stream=False, display=False):
# else:
# state.result += content
state.pwd = self.working_dir
state.ls = subprocess.run(['ls'], cwd=self.working_dir, capture_output=True, text=True).stdout
state.ls = list_directory(self.working_dir)
return state

# if (
Expand Down Expand Up @@ -918,3 +918,14 @@ def terminate(self):
): # Not sure why this is None sometimes. We should look into this
language.terminate()
del self._active_languages[language_name]


def list_directory(directory):
try:
# Get list of files and directories
items = os.listdir(directory)
# Format similar to ls/dir output
print("items:\n\n", '\n'.join(items))
return '\n'.join(items)
except Exception as e:
return str(e)
132 changes: 64 additions & 68 deletions scripts/llm_engines.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import os
from abc import ABC
from time import sleep

import openai
Expand All @@ -10,24 +11,74 @@
openai_role_conversions = {MessageRole.TOOL_RESPONSE: MessageRole.USER}


class OpenAIEngine:
def __init__(self, model_name="gpt-4o"):
class BaseOpenAIEngine(ABC):
def __init__(self, model_name: str, client):
self.model_name = model_name
self.client = OpenAI(
api_key=os.getenv("OPENAI_API_KEY"),
)
self.client = client
self.metrics = {
"num_calls": 0,
"prompt_tokens": 0,
"completion_tokens": 0,
}

def __call__(self, messages, stop_sequences=[]):
def __call__(self, messages, stop_sequences=[], temperature=0.5, *args, **kwargs):
messages = get_clean_message_list(messages, role_conversions=openai_role_conversions)

response = self.client.chat.completions.create(
model=self.model_name,
messages=messages,
stop=stop_sequences,
temperature=0.5,
)
success = False
wait_time = 1
while not success:
try:
response = self.client.chat.completions.create(
model=self.model_name,
messages=messages,
temperature=temperature,
*args,
**kwargs
)
success = True
except openai.InternalServerError:
sleep(wait_time)
wait_time += 1

# Update metrics
self.metrics["num_calls"] += 1
self.metrics["prompt_tokens"] += response.usage.prompt_tokens
self.metrics["completion_tokens"] += response.usage.completion_tokens

return response.choices[0].message.content

def reset(self):
self.metrics = {
"num_calls": 0,
"prompt_tokens": 0,
"completion_tokens": 0,
}


class OpenAIEngine(BaseOpenAIEngine):
def __init__(self, model_name="gpt-4", api_key=None):
client = OpenAI(
api_key=api_key or os.getenv("OPENAI_API_KEY"),
)
super().__init__(model_name=model_name, client=client)


class AzureOpenAIEngine(BaseOpenAIEngine):
def __init__(
self,
model_name: str = None,
api_key: str = None,
azure_endpoint: str = None,
api_version: str = None,
):
model_name = model_name or os.getenv("AZURE_MODEL_NAME")
client = AzureOpenAI(
api_key=api_key or os.getenv("AZURE_API_KEY"),
azure_endpoint=azure_endpoint or os.getenv("AZURE_ENDPOINT"),
api_version=api_version or os.getenv("AZURE_API_VERSION"),
)
super().__init__(model_name=model_name, client=client)


class AnthropicEngine:
def __init__(self, model_name="claude-3-5-sonnet-20240620", use_bedrock=False):
Expand Down Expand Up @@ -74,61 +125,6 @@ def __call__(self, messages, stop_sequences=[]):
return full_response_text


class AzureOpenAIEngine:
def __init__(
self,
model_name: str = None,
api_key: str = None,
azure_endpoint: str = None,
api_version: str = None,
):
self.model_name = model_name or os.getenv("AZURE_MODEL_NAME")
self.client = AzureOpenAI(
api_key=api_key or os.getenv("AZURE_API_KEY"),
azure_endpoint=azure_endpoint or os.getenv("AZURE_ENDPOINT"),
api_version=api_version or os.getenv("AZURE_API_VERSION"),
)
self.metrics = {
"num_calls": 0,
"prompt_tokens": 0,
"completion_tokens": 0,
}

def __call__(self, messages, stop_sequences=[], temperature=0.5, *args, **kwargs):
messages = get_clean_message_list(messages, role_conversions=openai_role_conversions)

success = False
wait_time = 1
while not success:
try:
response = self.client.chat.completions.create(
model=self.model_name,
messages=messages,
# stop=stop_sequences,
temperature=temperature,
*args,
**kwargs
)
success = True
except openai.InternalServerError:
sleep(wait_time)
wait_time += 1

# Update metrics
self.metrics["num_calls"] += 1
self.metrics["prompt_tokens"] += response.usage.prompt_tokens
self.metrics["completion_tokens"] += response.usage.completion_tokens

return response.choices[0].message.content

def reset(self):
self.metrics = {
"num_calls": 0,
"prompt_tokens": 0,
"completion_tokens": 0,
}


class ThoughtCodeFormat(BaseModel):
thought: str
code: str
Expand All @@ -139,7 +135,7 @@ class ThoughtActionFormat(BaseModel):
action: str


class StructuredOutputAzureOpenAIEngine(AzureOpenAIEngine):
class StructuredOutputEngine(BaseOpenAIEngine):
def __init__(self, response_format: str, *args, **kwargs):
super().__init__(*args, **kwargs)
self.response_format_str = response_format
Expand Down