From 912a1831a7965b4e8218fef4af2a320558e8a5d5 Mon Sep 17 00:00:00 2001 From: Costa Huang Date: Wed, 13 Sep 2023 13:34:00 -0400 Subject: [PATCH] Benchmark CI (actual) (#754) * refactor and benchmark * update code * Add accelerate logging * logs * quick fix * update config * precommit * modify training example * fix multi-gpu all_reduce error `Tensors must be CUDA and dense` * support more models and benchmark * update * add changes * upload benchmark * precommit * add tyro as a dependency * add tyro * pre-commit * precommit * weird... * lol typo * precommit * sigh * push changes * Update benchmark/README.md Co-authored-by: Leandro von Werra * Add experiments * upload image to tag specific folder * add openrlbenchmark documentation * rename * remove unused field * precommit * update slurm template * add dependency * update dependency * .. * . * quick change * push changes * update * update * remove wandb tag code * quick change * precommit * update test * update dependency * update test * update benchmark dependency --------- Co-authored-by: Leandro von Werra --- .github/workflows/benchmark.yml | 2 +- benchmark/benchmark.py | 64 +++++++++++++++++++++++++--- benchmark/benchmark_and_report.sh | 30 +++++++++++++ benchmark/benchmark_core.sh | 24 +++++------ benchmark/plot_core.sh | 23 ++++++++++ benchmark/post_github_comment.py | 26 +++++++++++ benchmark/post_github_comment.sbatch | 9 ++++ benchmark/trl.slurm_template | 8 ++-- setup.py | 1 + trl/trainer/ppo_config.py | 43 ------------------- 10 files changed, 165 insertions(+), 65 deletions(-) create mode 100644 benchmark/benchmark_and_report.sh create mode 100644 benchmark/plot_core.sh create mode 100644 benchmark/post_github_comment.py create mode 100644 benchmark/post_github_comment.sbatch diff --git a/.github/workflows/benchmark.yml b/.github/workflows/benchmark.yml index 1360b3babcc..2ed588eedc7 100644 --- a/.github/workflows/benchmark.yml +++ b/.github/workflows/benchmark.yml @@ -38,7 +38,7 @@ jobs: python-version: ${{ matrix.python-version }} - name: Install dependencies run: | - pip install .[test] + pip install .[test,benchmark] - name: Login run: wandb login ${{ secrets.WANDB_API_KEY }} && huggingface-cli login --token ${{ secrets.HUGGING_FACE_HUB_TOKEN }} diff --git a/benchmark/benchmark.py b/benchmark/benchmark.py index ffc4bbfafb3..895000f24e3 100644 --- a/benchmark/benchmark.py +++ b/benchmark/benchmark.py @@ -6,6 +6,8 @@ import uuid from distutils.util import strtobool +import requests + def parse_args(): # fmt: off @@ -38,14 +40,65 @@ def parse_args(): def run_experiment(command: str): command_list = shlex.split(command) print(f"running {command}") - fd = subprocess.Popen(command_list) - return_code = fd.wait() - assert return_code == 0 + + # Use subprocess.PIPE to capture the output + fd = subprocess.Popen(command_list, stdout=subprocess.PIPE, stderr=subprocess.PIPE) + output, errors = fd.communicate() + + return_code = fd.returncode + assert return_code == 0, f"Command failed with error: {errors.decode('utf-8')}" + + # Convert bytes to string and strip leading/trailing whitespaces + return output.decode("utf-8").strip() + + +def autotag() -> str: + wandb_tag = "" + print("autotag feature is enabled") + git_tag = "" + try: + git_tag = subprocess.check_output(["git", "describe", "--tags"]).decode("ascii").strip() + print(f"identified git tag: {git_tag}") + except subprocess.CalledProcessError as e: + print(e) + if len(git_tag) == 0: + try: + count = int(subprocess.check_output(["git", "rev-list", "--count", "HEAD"]).decode("ascii").strip()) + hash = subprocess.check_output(["git", "rev-parse", "--short", "HEAD"]).decode("ascii").strip() + git_tag = f"no-tag-{count}-g{hash}" + print(f"identified git tag: {git_tag}") + except subprocess.CalledProcessError as e: + print(e) + wandb_tag = f"{git_tag}" + + git_commit = subprocess.check_output(["git", "rev-parse", "--verify", "HEAD"]).decode("ascii").strip() + try: + # try finding the pull request number on github + prs = requests.get(f"https://api.github.com/search/issues?q=repo:huggingface/trl+is:pr+{git_commit}") + if prs.status_code == 200: + prs = prs.json() + if len(prs["items"]) > 0: + pr = prs["items"][0] + pr_number = pr["number"] + wandb_tag += f",pr-{pr_number}" + print(f"identified github pull request: {pr_number}") + except Exception as e: + print(e) + + return wandb_tag if __name__ == "__main__": args = parse_args() - + if args.auto_tag: + existing_wandb_tag = os.environ.get("WANDB_TAGS", "") + wandb_tag = autotag() + if len(wandb_tag) > 0: + if len(existing_wandb_tag) > 0: + os.environ["WANDB_TAGS"] = ",".join([existing_wandb_tag, wandb_tag]) + else: + os.environ["WANDB_TAGS"] = wandb_tag + print("WANDB_TAGS: ", os.environ.get("WANDB_TAGS", "")) commands = [] for seed in range(0, args.num_seeds): commands += [" ".join([args.command, "--seed", str(args.start_seed + seed)])] @@ -93,4 +146,5 @@ def run_experiment(command: str): slurm_path = os.path.join("slurm", f"{filename}.slurm") print(f"saving command in {slurm_path}") if args.workers > 0: - run_experiment(f"sbatch {slurm_path}") + job_id = run_experiment(f"sbatch --parsable {slurm_path}") + print(f"Job ID: {job_id}") diff --git a/benchmark/benchmark_and_report.sh b/benchmark/benchmark_and_report.sh new file mode 100644 index 00000000000..26738de90f1 --- /dev/null +++ b/benchmark/benchmark_and_report.sh @@ -0,0 +1,30 @@ +echo PATH is $PATH +echo PYTHONPATH is $PYTHONPATH +echo whcih python is $(which python) + +export WANDB_ENTITY=huggingface + +bash benchmark/benchmark_core.sh > output.txt + +# Extract Job IDs into an array +job_ids=($(grep "Job ID:" output.txt | awk '{print $3}')) + +# Extract WANDB_TAGS into an array +WANDB_TAGS=($(grep "WANDB_TAGS:" output.txt | awk '{print $2}')) +WANDB_TAGS=($(echo $WANDB_TAGS | tr "," "\n")) + +# Print to verify +echo "Job IDs: ${job_ids[@]}" +echo "WANDB_TAGS: ${WANDB_TAGS[@]}" + +TAGS_STRING="?tag=${WANDB_TAGS[0]}" +FOLDER_STRING="${WANDB_TAGS[0]}" +for tag in "${WANDB_TAGS[@]:1}"; do + TAGS_STRING+="&tag=$tag" + FOLDER_STRING+="_$tag" +done + +echo "TAGS_STRING: $TAGS_STRING" +echo "FOLDER_STRING: $FOLDER_STRING" + +TAGS_STRING=$TAGS_STRING FOLDER_STRING=$FOLDER_STRING sbatch --dependency=afterany:$job_ids benchmark/post_github_comment.sbatch diff --git a/benchmark/benchmark_core.sh b/benchmark/benchmark_core.sh index fed0511f011..81c3676ac5a 100644 --- a/benchmark/benchmark_core.sh +++ b/benchmark/benchmark_core.sh @@ -2,7 +2,7 @@ # hello world experiment python benchmark/benchmark.py \ --command "python examples/scripts/sentiment_tuning.py --ppo_config.log_with wandb" \ - --num-seeds 5 \ + --num-seeds 3 \ --start-seed 1 \ --workers 10 \ --slurm-nodes 1 \ @@ -11,14 +11,14 @@ python benchmark/benchmark.py \ --slurm-total-cpus 12 \ --slurm-template-path benchmark/trl.slurm_template -# compound -python benchmark/benchmark.py \ - --command "python examples/scripts/sentiment_tuning.py --ppo_config.exp_name sentiment_tuning_gpt2xl_grad_accu --ppo_config.model_name gpt2-xl --ppo_config.mini_batch_size 16 --ppo_config.gradient_accumulation_steps 8 --ppo_config.log_with wandb" \ - --num-seeds 5 \ - --start-seed 1 \ - --workers 10 \ - --slurm-nodes 1 \ - --slurm-gpus-per-task 1 \ - --slurm-ntasks 1 \ - --slurm-total-cpus 12 \ - --slurm-template-path benchmark/trl.slurm_template \ No newline at end of file +# # compound +# python benchmark/benchmark.py \ +# --command "python examples/scripts/sentiment_tuning.py --ppo_config.exp_name sentiment_tuning_gpt2xl_grad_accu --ppo_config.model_name gpt2-xl --ppo_config.mini_batch_size 16 --ppo_config.gradient_accumulation_steps 8 --ppo_config.log_with wandb" \ +# --num-seeds 3 \ +# --start-seed 1 \ +# --workers 10 \ +# --slurm-nodes 1 \ +# --slurm-gpus-per-task 1 \ +# --slurm-ntasks 1 \ +# --slurm-total-cpus 12 \ +# --slurm-template-path benchmark/trl.slurm_template diff --git a/benchmark/plot_core.sh b/benchmark/plot_core.sh new file mode 100644 index 00000000000..7cf77cfe2af --- /dev/null +++ b/benchmark/plot_core.sh @@ -0,0 +1,23 @@ +# pip install openrlbenchmark==0.2.1a5 +# see https://github.com/openrlbenchmark/openrlbenchmark#get-started for documentation +echo "we deal with $TAGS_STRING" + + # "sentiment_tuning_gpt2xl_grad_accu$TAGS_STRING" \ + +python -m openrlbenchmark.rlops_multi_metrics \ + --filters '?we=huggingface&wpn=trl&xaxis=_step&ceik=trl_ppo_trainer_config.value.reward_model&cen=trl_ppo_trainer_config.value.exp_name&metrics=env/reward_mean&metrics=objective/kl' \ + "sentiment_tuning$TAGS_STRING" \ + --env-ids sentiment-analysis:lvwerra/distilbert-imdb \ + --no-check-empty-runs \ + --pc.ncols 2 \ + --pc.ncols-legend 1 \ + --output-filename benchmark/trl/$FOLDER_STRING/different_models \ + --scan-history + + +python benchmark/upload_benchmark.py \ + --folder_path="benchmark/trl/$FOLDER_STRING" \ + --path_in_repo="images/benchmark/$FOLDER_STRING" \ + --repo_id="trl-internal-testing/example-images" \ + --repo_type="dataset" + diff --git a/benchmark/post_github_comment.py b/benchmark/post_github_comment.py new file mode 100644 index 00000000000..70241ef1319 --- /dev/null +++ b/benchmark/post_github_comment.py @@ -0,0 +1,26 @@ +import json +import os + +from ghapi.all import GhApi + + +FOLDER_STRING = os.environ.get("FOLDER_STRING", "") +folder = f"benchmark/trl/{FOLDER_STRING}" +host_url = f"https://huggingface.co/datasets/trl-internal-testing/example-images/resolve/main/images/benchmark/{FOLDER_STRING}" + +# Create a GitHub API instance +github_context = json.loads(os.environ["GITHUB_CONTEXT"]) +token = os.environ["PERSONAL_ACCESS_TOKEN_GITHUB"] # this needs to refreshed every 12 months +status_message = "**[COSTA BENCHMARK BOT]**: Here are the results" +body = status_message +repo = github_context["repository"] +owner, repo = repo.split("/") +api = GhApi(owner=owner, repo=repo, token=token) + +# for each `.png` file in the folder, add it to the comment +for file in os.listdir(folder): + if file.endswith(".png"): + body += f"\n![{file}]({host_url}/{file})" + +# Create a comment on the issue +api.issues.create_comment(issue_number=github_context["event"]["issue"]["number"], body=body) diff --git a/benchmark/post_github_comment.sbatch b/benchmark/post_github_comment.sbatch new file mode 100644 index 00000000000..3bc602252cc --- /dev/null +++ b/benchmark/post_github_comment.sbatch @@ -0,0 +1,9 @@ +#!/bin/bash +#SBATCH --job-name=trl +#SBATCH --partition=production-cluster +#SBATCH --ntasks=1 +#SBATCH --output=slurm/logs/%x_%j.out + +sleep 2m +bash benchmark/plot_core.sh +srun python benchmark/post_github_comment.py diff --git a/benchmark/trl.slurm_template b/benchmark/trl.slurm_template index 7bfc7bc3302..3de9eb0babe 100644 --- a/benchmark/trl.slurm_template +++ b/benchmark/trl.slurm_template @@ -1,16 +1,16 @@ #!/bin/bash -#SBATCH --partition=dev-cluster +#SBATCH --job-name=trl +#SBATCH --partition=production-cluster #SBATCH --gpus-per-task={{gpus_per_task}} #SBATCH --cpus-per-gpu={{cpus_per_gpu}} #SBATCH --ntasks={{ntasks}} -#SBATCH --mem-per-cpu=11G #SBATCH --output=slurm/logs/%x_%j.out #SBATCH --array={{array}} - +#SBATCH --exclude=ip-26-0-156-239,ip-26-0-148-151,ip-26-0-146-212,ip-26-0-145-137,ip-26-0-146-249,ip-26-0-146-149,ip-26-0-147-233,ip-26-0-145-154,ip-26-0-144-35,ip-26-0-144-189,ip-26-0-146-183,ip-26-0-147-120,ip-26-0-144-95,ip-26-0-145-193 {{nodes}} seeds={{seeds}} seed=${seeds[$SLURM_ARRAY_TASK_ID % {{len_seeds}}]} echo "Running task $SLURM_ARRAY_TASK_ID with seed: $seed" -srun {{command}} --seed $seed +srun {{command}} --ppo_config.seed $seed diff --git a/setup.py b/setup.py index 518e0603d11..2aaad9c72bd 100644 --- a/setup.py +++ b/setup.py @@ -73,6 +73,7 @@ "diffusers": ["diffusers>=0.18.0"], "deepspeed": ["deepspeed>=0.9.5"], "dev": ["parameterized", "pytest", "pytest-xdist", "pre-commit", "peft>=0.4.0", "diffusers>=0.18.0"], + "benchmark": ["wandb", "ghapi", "openrlbenchmark==0.2.1a5"], } setup( diff --git a/trl/trainer/ppo_config.py b/trl/trainer/ppo_config.py index 2acd8fffb81..cdcaf1b9946 100644 --- a/trl/trainer/ppo_config.py +++ b/trl/trainer/ppo_config.py @@ -11,16 +11,13 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import logging import os -import subprocess import sys import warnings from dataclasses import dataclass, field from typing import Literal, Optional import numpy as np -import requests import tyro from trl.trainer.utils import exact_div @@ -28,38 +25,6 @@ from ..core import flatten_dict -def autotag() -> str: - wandb_tag = "" - logging.info("autotag feature is enabled") - try: - git_tag = subprocess.check_output(["git", "describe", "--tags"]).decode("ascii").strip() - wandb_tag = f"{git_tag}" - logging.info(f"identified git tag: {git_tag}") - except subprocess.CalledProcessError: - return wandb_tag - - git_commit = subprocess.check_output(["git", "rev-parse", "--verify", "HEAD"]).decode("ascii").strip() - try: - # if the current branch is not main, try find the PR number - git_branch = subprocess.check_output(["git", "rev-parse", "--abbrev-ref", "HEAD"]).decode("ascii").strip() - if git_branch != "main": - # try finding the pull request number on github - prs = requests.get(f"https://api.github.com/search/issues?q=repo:huggingface/trl+is:pr+{git_commit}") - if prs.status_code == 200: - prs = prs.json() - if len(prs["items"]) > 0: - pr = prs["items"][0] - pr_number = pr["number"] - wandb_tag += f",pr-{pr_number}" - logging.info(f"identified github pull request: {pr_number}") - else: - logging.info("current branch is main, not searching for pull request") - except Exception as e: - logging.warning(f"Automatic autotag failed with the following error: {e}") - - return wandb_tag - - @dataclass class PPOConfig: """ @@ -181,14 +146,6 @@ def __post_init__(self): try: import wandb # noqa: F401 - existing_wandb_tag = os.environ.get("WANDB_TAGS", "") - wandb_tag = autotag() - if len(wandb_tag) > 0: - if len(existing_wandb_tag) > 0: - os.environ["WANDB_TAGS"] = ",".join([existing_wandb_tag, wandb_tag]) - else: - os.environ["WANDB_TAGS"] = wandb_tag - logging.info(f"the following tags will be used for wandb logging: {os.environ['WANDB_TAGS']}") except ImportError: raise ImportError( "Please install wandb to use wandb logging. You can do this by running `pip install wandb`."