Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enhance Action Generalization by Fetching Relevant Tasks for Prompting #6

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
74 changes: 65 additions & 9 deletions agents.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,10 @@
from glob import glob
from typing import Callable, Dict, List, Optional, Union

import transformers
from langchain.embeddings.openai import OpenAIEmbeddings
from langchain.vectorstores import Chroma
from langchain_community.embeddings import OllamaEmbeddings
from langchain_openai import AzureOpenAIEmbeddings
from transformers.agents import Agent, ReactCodeAgent
from transformers.agents.agents import (
AgentExecutionError,
Expand All @@ -22,17 +25,28 @@
from scripts.llm_engines import AzureOpenAIEngine
from utils import GeneratedTool, add_parent_pointers, parse_generated_tools

OPENAI_API_KEY = os.getenv("OPENAI_API_KEY")
OPENAI_ORGANIZATION = os.getenv("OPENAI_ORGANIZATION")

EMBED_MODEL_TYPE = os.getenv("EMBED_MODEL_TYPE")
EMBED_MODEL_NAME = os.getenv("EMBED_MODEL_NAME")

AZURE_EMBED_MODEL_NAME = os.getenv("AZURE_EMBED_MODEL_NAME")
AZURE_EMBED_API_KEY = os.getenv("AZURE_EMBED_API_KEY")
AZURE_EMBED_ENDPOINT = os.getenv("AZURE_EMBED_ENDPOINT")
AZURE_EMBED_API_VERSION = os.getenv("AZURE_EMBED_API_VERSION")

# Define a timeout exception
class TimeoutException(Exception):
pass


def format_prompt_with_tools(toolbox: Toolbox, prompt_template: str) -> str:
def format_prompt_with_tools_and_tasks(toolbox: Toolbox, prompt_template: str, tasks: List[str], max_iterations: int) -> str:
tool_descriptions = "\n".join([f"{tool.name}: {tool.description}" for tool in toolbox._tools.values()])
if tool_descriptions == "":
tool_descriptions = "None"
example_tasks = "".join([f"- {task}\n" for task in tasks])
prompt = prompt_template.replace("<<tool_descriptions>>", tool_descriptions)
prompt = prompt.replace("<<example_tasks>>", example_tasks)
prompt = prompt.replace("<<max_iterations>>", str(max_iterations))
return prompt


Expand Down Expand Up @@ -186,10 +200,47 @@ def step(self):


class DynamicActionSpaceAgent(UnrestrictedReactCodeAgent):
def __init__(self, generated_tool_dir: str, disable_accum: bool = False, *args, **kwargs):
def __init__(self, dataset, generated_tool_dir: str, num_examples_tasks: int, disable_accum: bool = False, *args, **kwargs):
super().__init__(*args, **kwargs)
self.generated_tool_dir = generated_tool_dir
self.disable_accum = disable_accum
self.num_examples_tasks = num_examples_tasks
if not self.disable_accum:
# Utilized the vectordb for relevant task generation
self.vectordb_path = f"{self.generated_tool_dir}/vectordb"

# Utilize the Chroma database and employ OpenAI Embeddings for vectorization (default: text-embedding-ada-002)
if EMBED_MODEL_TYPE == "OpenAI":
embedding_function = OpenAIEmbeddings(
openai_api_key=OPENAI_API_KEY,
openai_organization=OPENAI_ORGANIZATION,
)
embed_model_name = "openai"
elif EMBED_MODEL_TYPE == "OLLAMA":
embedding_function = OllamaEmbeddings(model=EMBED_MODEL_NAME)
embed_model_name = "ollama"
elif EMBED_MODEL_TYPE == "AzureOpenAI":
embedding_function = AzureOpenAIEmbeddings(
api_key=AZURE_EMBED_API_KEY,
azure_endpoint=AZURE_EMBED_ENDPOINT,
azure_deployment=AZURE_EMBED_MODEL_NAME,
openai_api_version=AZURE_EMBED_API_VERSION,
)
embed_model_name = AZURE_EMBED_MODEL_NAME

self.task_db = Chroma(
collection_name="task_vectordb",
embedding_function=embedding_function,
persist_directory=self.vectordb_path,
)

for task in dataset:
self.task_db.add_texts(
texts=[task["question"]],
)

self.task_db.persist()


# Load generated tools from disk
generated_tools: list[GeneratedTool] = []
Expand Down Expand Up @@ -243,18 +294,23 @@ def initialize_for_run(self, task: str, **kwargs):
if len(kwargs) > 0:
self.task += f"\nYou have been provided with these initial arguments: {str(kwargs)}."
self.state = kwargs.copy()
self.system_prompt = transformers.agents.agents.format_prompt_with_tools(
# Sample relevant tasks
if self.disable_accum:
# If we disable accum, we don't need to provide example tasks
tasks = []
else:
tasks = [doc.page_content for doc in self.task_db.similarity_search(task, k=self.num_examples_tasks)]
self.system_prompt = format_prompt_with_tools_and_tasks(
self._toolbox,
self.system_prompt_template,
self.tool_description_template,
tasks,
max_iterations=5
)
generated_tool_descriptions = self.generated_toolbox.show_tool_descriptions(self.tool_description_template)
self.system_prompt = self.system_prompt.replace("<<generated_tool_descriptions>>", generated_tool_descriptions)
self.logs = [{"system_prompt": self.system_prompt, "task": self.task}]
self.logger.warn("\n" * 5)
self.logger.warn("======== New task ========")
# self.logger.log(33, self.task)
# self.logger.debug("System prompt is as follows:")
self.logger.warning("[SYSTEM_PROMPT]")
self.logger.debug(self.system_prompt)
self.logger.warning("[TASK]")
Expand Down
21 changes: 16 additions & 5 deletions dynasaur.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from agents import StructuredOutputDynamicActionSpaceAgent
from env import Env
from prompts import ACTION_DESCRIPTION_TEMPLATE, DYNASAUR_PROMPT
from scripts.llm_engines import AzureOpenAIEngine, StructuredOutputAzureOpenAIEngine
from scripts.llm_engines import AzureOpenAIEngine, StructuredOutputAzureOpenAIEngine, OpenAIEngine, StructuredOutputOpenAIEngine
from scripts.reformulator import prepare_response
from scripts.run_agents import answer_questions

Expand Down Expand Up @@ -46,8 +46,13 @@ def get_env(args):
return env


def get_agent(args, env):
llm_engine = StructuredOutputAzureOpenAIEngine(model_name=args.model_name, response_format="thought_code")
LLM_ENGINE_TYPE = os.getenv("LLM_ENGINE_TYPE", "AzureOpenAI")

def get_agent(dataset,args, env):
if LLM_ENGINE_TYPE == "OpenAI":
llm_engine = StructuredOutputOpenAIEngine(model_name=args.model_name, response_format="thought_code")
else:
llm_engine = StructuredOutputAzureOpenAIEngine(model_name=args.model_name, response_format="thought_code")

# Load initial actions
required_actions = list(actions.get_required_actions(args.generated_action_dir).values())
Expand All @@ -65,13 +70,18 @@ def get_agent(args, env):
generated_tool_dir=args.generated_action_dir,
disable_accum=disable_accum,
env=env,
dataset=dataset,
num_examples_tasks=args.num_examples_tasks,
)

return agent


def get_agent_call_function(args):
llm_engine = AzureOpenAIEngine(args.model_name)
if LLM_ENGINE_TYPE == "OpenAI":
llm_engine = OpenAIEngine(args.model_name)
else:
llm_engine = AzureOpenAIEngine(args.model_name)

def agent_call_function(agent, question: str, **kwargs) -> str:
result = agent.run(question, **kwargs)
Expand Down Expand Up @@ -109,6 +119,7 @@ def agent_call_function(agent, question: str, **kwargs) -> str:
parser.add_argument("--split", type=str, default="2023_level1")
parser.add_argument("--model_name", type=str, default="gpt-4o-2024-08-06")
parser.add_argument("--max_iterations", type=int, default=20)
parser.add_argument("--num_examples_tasks", type=int, default=5)
args = parser.parse_args()

agent_name = f"{args.model_name}-{args.split}"
Expand All @@ -119,7 +130,7 @@ def agent_call_function(agent, question: str, **kwargs) -> str:

dataset = get_dataset(args)
env = get_env(args)
agent = get_agent(args, env)
agent = get_agent(dataset, args, env)

agent_call_function = get_agent_call_function(args)

Expand Down
6 changes: 6 additions & 0 deletions prompts.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,12 @@
}
```

# Example tasks
You are provided with tasks which you might want to use the new function to solve. Generated new functions should be generalized enough to solve this.
```
<<example_tasks>>
```

# Available Functions
You are provided with several available functions. If you need to discover more relevant functions, use the `get_relevant_tools` function.
```
Expand Down
35 changes: 34 additions & 1 deletion scripts/llm_engines.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,16 @@


class OpenAIEngine:
def __init__(self, model_name="gpt-4o"):
def __init__(self, model_name="gpt-o1"):
self.model_name = model_name
self.client = OpenAI(
api_key=os.getenv("OPENAI_API_KEY"),
)
self.metrics = {
"num_calls": 0,
"prompt_tokens": 0,
"completion_tokens": 0,
}

def __call__(self, messages, stop_sequences=[]):
messages = get_clean_message_list(messages, role_conversions=openai_role_conversions)
Expand Down Expand Up @@ -166,3 +171,31 @@ def __call__(self, messages, temperature=0.5, stop_sequences=None, *args, **kwar
self.metrics["completion_tokens"] += response.usage.completion_tokens

return response.choices[0].message.parsed

class StructuredOutputOpenAIEngine(OpenAIEngine):
def __init__(self, response_format: str, *args, **kwargs):
super().__init__(*args, **kwargs)
self.response_format_str = response_format
if response_format == "thought_code":
self.response_format = ThoughtCodeFormat
elif response_format == "thought_action":
self.response_format = ThoughtActionFormat

def __call__(self, messages, temperature=0.5, stop_sequences=None, *args, **kwargs) -> dict:
messages = get_clean_message_list(messages, role_conversions=openai_role_conversions)

response = self.client.beta.chat.completions.parse(
model=self.model_name,
messages=messages,
response_format=self.response_format,
temperature=temperature,
*args,
**kwargs
)

# Update metrics
self.metrics["num_calls"] += 1
self.metrics["prompt_tokens"] += response.usage.prompt_tokens
self.metrics["completion_tokens"] += response.usage.completion_tokens

return response.choices[0].message.parsed