Skip to content

Commit

Permalink
Benchmark CI (actual) (huggingface#754)
Browse files Browse the repository at this point in the history
* refactor and benchmark

* update code

* Add accelerate logging

* logs

* quick fix

* update config

* precommit

* modify training example

* fix multi-gpu all_reduce error `Tensors must be CUDA and dense`

* support more models and benchmark

* update

* add changes

* upload benchmark

* precommit

* add tyro as a dependency

* add tyro

* pre-commit

* precommit

* weird...

* lol typo

* precommit

* sigh

* push changes

* Update benchmark/README.md

Co-authored-by: Leandro von Werra <[email protected]>

* Add experiments

* upload image to tag specific folder

* add openrlbenchmark documentation

* rename

* remove unused field

* precommit

* update slurm template

* add dependency

* update dependency

* ..

* .

* quick change

* push changes

* update

* update

* remove wandb tag code

* quick change

* precommit

* update test

* update dependency

* update test

* update benchmark dependency

---------

Co-authored-by: Leandro von Werra <[email protected]>
  • Loading branch information
2 people authored and kushal-tri committed Sep 19, 2023
1 parent c55f2a0 commit 912a183
Show file tree
Hide file tree
Showing 10 changed files with 165 additions and 65 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/benchmark.yml
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ jobs:
python-version: ${{ matrix.python-version }}
- name: Install dependencies
run: |
pip install .[test]
pip install .[test,benchmark]
- name: Login
run: wandb login ${{ secrets.WANDB_API_KEY }} && huggingface-cli login --token ${{ secrets.HUGGING_FACE_HUB_TOKEN }}
Expand Down
64 changes: 59 additions & 5 deletions benchmark/benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
import uuid
from distutils.util import strtobool

import requests


def parse_args():
# fmt: off
Expand Down Expand Up @@ -38,14 +40,65 @@ def parse_args():
def run_experiment(command: str):
command_list = shlex.split(command)
print(f"running {command}")
fd = subprocess.Popen(command_list)
return_code = fd.wait()
assert return_code == 0

# Use subprocess.PIPE to capture the output
fd = subprocess.Popen(command_list, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
output, errors = fd.communicate()

return_code = fd.returncode
assert return_code == 0, f"Command failed with error: {errors.decode('utf-8')}"

# Convert bytes to string and strip leading/trailing whitespaces
return output.decode("utf-8").strip()


def autotag() -> str:
wandb_tag = ""
print("autotag feature is enabled")
git_tag = ""
try:
git_tag = subprocess.check_output(["git", "describe", "--tags"]).decode("ascii").strip()
print(f"identified git tag: {git_tag}")
except subprocess.CalledProcessError as e:
print(e)
if len(git_tag) == 0:
try:
count = int(subprocess.check_output(["git", "rev-list", "--count", "HEAD"]).decode("ascii").strip())
hash = subprocess.check_output(["git", "rev-parse", "--short", "HEAD"]).decode("ascii").strip()
git_tag = f"no-tag-{count}-g{hash}"
print(f"identified git tag: {git_tag}")
except subprocess.CalledProcessError as e:
print(e)
wandb_tag = f"{git_tag}"

git_commit = subprocess.check_output(["git", "rev-parse", "--verify", "HEAD"]).decode("ascii").strip()
try:
# try finding the pull request number on github
prs = requests.get(f"https://api.github.com/search/issues?q=repo:huggingface/trl+is:pr+{git_commit}")
if prs.status_code == 200:
prs = prs.json()
if len(prs["items"]) > 0:
pr = prs["items"][0]
pr_number = pr["number"]
wandb_tag += f",pr-{pr_number}"
print(f"identified github pull request: {pr_number}")
except Exception as e:
print(e)

return wandb_tag


if __name__ == "__main__":
args = parse_args()

if args.auto_tag:
existing_wandb_tag = os.environ.get("WANDB_TAGS", "")
wandb_tag = autotag()
if len(wandb_tag) > 0:
if len(existing_wandb_tag) > 0:
os.environ["WANDB_TAGS"] = ",".join([existing_wandb_tag, wandb_tag])
else:
os.environ["WANDB_TAGS"] = wandb_tag
print("WANDB_TAGS: ", os.environ.get("WANDB_TAGS", ""))
commands = []
for seed in range(0, args.num_seeds):
commands += [" ".join([args.command, "--seed", str(args.start_seed + seed)])]
Expand Down Expand Up @@ -93,4 +146,5 @@ def run_experiment(command: str):
slurm_path = os.path.join("slurm", f"{filename}.slurm")
print(f"saving command in {slurm_path}")
if args.workers > 0:
run_experiment(f"sbatch {slurm_path}")
job_id = run_experiment(f"sbatch --parsable {slurm_path}")
print(f"Job ID: {job_id}")
30 changes: 30 additions & 0 deletions benchmark/benchmark_and_report.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
echo PATH is $PATH
echo PYTHONPATH is $PYTHONPATH
echo whcih python is $(which python)

export WANDB_ENTITY=huggingface

bash benchmark/benchmark_core.sh > output.txt

# Extract Job IDs into an array
job_ids=($(grep "Job ID:" output.txt | awk '{print $3}'))

# Extract WANDB_TAGS into an array
WANDB_TAGS=($(grep "WANDB_TAGS:" output.txt | awk '{print $2}'))
WANDB_TAGS=($(echo $WANDB_TAGS | tr "," "\n"))

# Print to verify
echo "Job IDs: ${job_ids[@]}"
echo "WANDB_TAGS: ${WANDB_TAGS[@]}"

TAGS_STRING="?tag=${WANDB_TAGS[0]}"
FOLDER_STRING="${WANDB_TAGS[0]}"
for tag in "${WANDB_TAGS[@]:1}"; do
TAGS_STRING+="&tag=$tag"
FOLDER_STRING+="_$tag"
done

echo "TAGS_STRING: $TAGS_STRING"
echo "FOLDER_STRING: $FOLDER_STRING"

TAGS_STRING=$TAGS_STRING FOLDER_STRING=$FOLDER_STRING sbatch --dependency=afterany:$job_ids benchmark/post_github.ghproxy.topment.sbatch
24 changes: 12 additions & 12 deletions benchmark/benchmark_core.sh
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
# hello world experiment
python benchmark/benchmark.py \
--command "python examples/scripts/sentiment_tuning.py --ppo_config.log_with wandb" \
--num-seeds 5 \
--num-seeds 3 \
--start-seed 1 \
--workers 10 \
--slurm-nodes 1 \
Expand All @@ -11,14 +11,14 @@ python benchmark/benchmark.py \
--slurm-total-cpus 12 \
--slurm-template-path benchmark/trl.slurm_template

# compound
python benchmark/benchmark.py \
--command "python examples/scripts/sentiment_tuning.py --ppo_config.exp_name sentiment_tuning_gpt2xl_grad_accu --ppo_config.model_name gpt2-xl --ppo_config.mini_batch_size 16 --ppo_config.gradient_accumulation_steps 8 --ppo_config.log_with wandb" \
--num-seeds 5 \
--start-seed 1 \
--workers 10 \
--slurm-nodes 1 \
--slurm-gpus-per-task 1 \
--slurm-ntasks 1 \
--slurm-total-cpus 12 \
--slurm-template-path benchmark/trl.slurm_template
# # compound
# python benchmark/benchmark.py \
# --command "python examples/scripts/sentiment_tuning.py --ppo_config.exp_name sentiment_tuning_gpt2xl_grad_accu --ppo_config.model_name gpt2-xl --ppo_config.mini_batch_size 16 --ppo_config.gradient_accumulation_steps 8 --ppo_config.log_with wandb" \
# --num-seeds 3 \
# --start-seed 1 \
# --workers 10 \
# --slurm-nodes 1 \
# --slurm-gpus-per-task 1 \
# --slurm-ntasks 1 \
# --slurm-total-cpus 12 \
# --slurm-template-path benchmark/trl.slurm_template
23 changes: 23 additions & 0 deletions benchmark/plot_core.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
# pip install openrlbenchmark==0.2.1a5
# see https://github.com/openrlbenchmark/openrlbenchmark#get-started for documentation
echo "we deal with $TAGS_STRING"

# "sentiment_tuning_gpt2xl_grad_accu$TAGS_STRING" \

python -m openrlbenchmark.rlops_multi_metrics \
--filters '?we=huggingface&wpn=trl&xaxis=_step&ceik=trl_ppo_trainer_config.value.reward_model&cen=trl_ppo_trainer_config.value.exp_name&metrics=env/reward_mean&metrics=objective/kl' \
"sentiment_tuning$TAGS_STRING" \
--env-ids sentiment-analysis:lvwerra/distilbert-imdb \
--no-check-empty-runs \
--pc.ncols 2 \
--pc.ncols-legend 1 \
--output-filename benchmark/trl/$FOLDER_STRING/different_models \
--scan-history


python benchmark/upload_benchmark.py \
--folder_path="benchmark/trl/$FOLDER_STRING" \
--path_in_repo="images/benchmark/$FOLDER_STRING" \
--repo_id="trl-internal-testing/example-images" \
--repo_type="dataset"

26 changes: 26 additions & 0 deletions benchmark/post_github.ghproxy.topment.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
import json
import os

from ghapi.all import GhApi


FOLDER_STRING = os.environ.get("FOLDER_STRING", "")
folder = f"benchmark/trl/{FOLDER_STRING}"
host_url = f"https://huggingface.co/datasets/trl-internal-testing/example-images/resolve/main/images/benchmark/{FOLDER_STRING}"

# Create a GitHub API instance
github_context = json.loads(os.environ["GITHUB_CONTEXT"])
token = os.environ["PERSONAL_ACCESS_TOKEN_GITHUB"] # this needs to refreshed every 12 months
status_message = "**[COSTA BENCHMARK BOT]**: Here are the results"
body = status_message
repo = github_context["repository"]
owner, repo = repo.split("/")
api = GhApi(owner=owner, repo=repo, token=token)

# for each `.png` file in the folder, add it to the comment
for file in os.listdir(folder):
if file.endswith(".png"):
body += f"\n![{file}]({host_url}/{file})"

# Create a comment on the issue
api.issues.create_comment(issue_number=github_context["event"]["issue"]["number"], body=body)
9 changes: 9 additions & 0 deletions benchmark/post_github.ghproxy.topment.sbatch
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
#!/bin/bash
#SBATCH --job-name=trl
#SBATCH --partition=production-cluster
#SBATCH --ntasks=1
#SBATCH --output=slurm/logs/%x_%j.out

sleep 2m
bash benchmark/plot_core.sh
srun python benchmark/post_github.ghproxy.topment.py
8 changes: 4 additions & 4 deletions benchmark/trl.slurm_template
Original file line number Diff line number Diff line change
@@ -1,16 +1,16 @@
#!/bin/bash
#SBATCH --partition=dev-cluster
#SBATCH --job-name=trl
#SBATCH --partition=production-cluster
#SBATCH --gpus-per-task={{gpus_per_task}}
#SBATCH --cpus-per-gpu={{cpus_per_gpu}}
#SBATCH --ntasks={{ntasks}}
#SBATCH --mem-per-cpu=11G
#SBATCH --output=slurm/logs/%x_%j.out
#SBATCH --array={{array}}

#SBATCH --exclude=ip-26-0-156-239,ip-26-0-148-151,ip-26-0-146-212,ip-26-0-145-137,ip-26-0-146-249,ip-26-0-146-149,ip-26-0-147-233,ip-26-0-145-154,ip-26-0-144-35,ip-26-0-144-189,ip-26-0-146-183,ip-26-0-147-120,ip-26-0-144-95,ip-26-0-145-193
{{nodes}}

seeds={{seeds}}
seed=${seeds[$SLURM_ARRAY_TASK_ID % {{len_seeds}}]}

echo "Running task $SLURM_ARRAY_TASK_ID with seed: $seed"
srun {{command}} --seed $seed
srun {{command}} --ppo_config.seed $seed
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@
"diffusers": ["diffusers>=0.18.0"],
"deepspeed": ["deepspeed>=0.9.5"],
"dev": ["parameterized", "pytest", "pytest-xdist", "pre-commit", "peft>=0.4.0", "diffusers>=0.18.0"],
"benchmark": ["wandb", "ghapi", "openrlbenchmark==0.2.1a5"],
}

setup(
Expand Down
43 changes: 0 additions & 43 deletions trl/trainer/ppo_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,55 +11,20 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import os
import subprocess
import sys
import warnings
from dataclasses import dataclass, field
from typing import Literal, Optional

import numpy as np
import requests
import tyro

from trl.trainer.utils import exact_div

from ..core import flatten_dict


def autotag() -> str:
wandb_tag = ""
logging.info("autotag feature is enabled")
try:
git_tag = subprocess.check_output(["git", "describe", "--tags"]).decode("ascii").strip()
wandb_tag = f"{git_tag}"
logging.info(f"identified git tag: {git_tag}")
except subprocess.CalledProcessError:
return wandb_tag

git_commit = subprocess.check_output(["git", "rev-parse", "--verify", "HEAD"]).decode("ascii").strip()
try:
# if the current branch is not main, try find the PR number
git_branch = subprocess.check_output(["git", "rev-parse", "--abbrev-ref", "HEAD"]).decode("ascii").strip()
if git_branch != "main":
# try finding the pull request number on github
prs = requests.get(f"https://api.github.com/search/issues?q=repo:huggingface/trl+is:pr+{git_commit}")
if prs.status_code == 200:
prs = prs.json()
if len(prs["items"]) > 0:
pr = prs["items"][0]
pr_number = pr["number"]
wandb_tag += f",pr-{pr_number}"
logging.info(f"identified github pull request: {pr_number}")
else:
logging.info("current branch is main, not searching for pull request")
except Exception as e:
logging.warning(f"Automatic autotag failed with the following error: {e}")

return wandb_tag


@dataclass
class PPOConfig:
"""
Expand Down Expand Up @@ -181,14 +146,6 @@ def __post_init__(self):
try:
import wandb # noqa: F401

existing_wandb_tag = os.environ.get("WANDB_TAGS", "")
wandb_tag = autotag()
if len(wandb_tag) > 0:
if len(existing_wandb_tag) > 0:
os.environ["WANDB_TAGS"] = ",".join([existing_wandb_tag, wandb_tag])
else:
os.environ["WANDB_TAGS"] = wandb_tag
logging.info(f"the following tags will be used for wandb logging: {os.environ['WANDB_TAGS']}")
except ImportError:
raise ImportError(
"Please install wandb to use wandb logging. You can do this by running `pip install wandb`."
Expand Down

0 comments on commit 912a183

Please sign in to comment.