From 1bef49484977b2bab1487574fa6eab6780f8c8a1 Mon Sep 17 00:00:00 2001 From: Fabrice Normandin Date: Mon, 25 Sep 2023 19:40:57 -0400 Subject: [PATCH 1/2] Start adding a simple LLM inference example Signed-off-by: Fabrice Normandin --- .../distributed/LLM_inference/README.md | 96 +++++++ .../distributed/LLM_inference/client.py | 36 +++ .../distributed/LLM_inference/server.py | 242 ++++++++++++++++++ .../distributed/LLM_inference/setup.py | 0 4 files changed, 374 insertions(+) create mode 100644 docs/examples/distributed/LLM_inference/README.md create mode 100644 docs/examples/distributed/LLM_inference/client.py create mode 100644 docs/examples/distributed/LLM_inference/server.py create mode 100644 docs/examples/distributed/LLM_inference/setup.py diff --git a/docs/examples/distributed/LLM_inference/README.md b/docs/examples/distributed/LLM_inference/README.md new file mode 100644 index 00000000..69d24a76 --- /dev/null +++ b/docs/examples/distributed/LLM_inference/README.md @@ -0,0 +1,96 @@ +# FastAPI + HuggingFace + SLURM + +Proof-of-concept for an API that performs inference with a Large Language Model (LLM) on the Mila cluster. + +![LLM_api](https://user-images.githubusercontent.com/13387299/184188304-3ce82a7f-29a6-49ed-86ba-4842db4e207e.png) + +## The goal: + +- One ML researcher/student can submit this as a job on a SLURM cluster, and other users can use a single shared model instance via HTTP or a simple python client. + +## Installation: + +To run the server locally: + +```console +> conda env create -n llm python=3.10 +> conda activate llm +> pip install git+https://www.github.com/lebrice/LLM_api.git +``` + +(WIP) To connect to a running LLM server: + +(Requires python >= 3.7) +```console +> pip install git+https://www.github.com/lebrice/LLM_api.git +``` + + +## Usage: + +Available options: +```console +$ python app/server.py --help +usage: server.py [-h] [--model str] [--hf_cache_dir Path] [--port int] + [--reload bool] [--offload_folder Path] [--use_public_ip bool] + + API for querying a large language model. + +options: + -h, --help show this help message and exit + +Settings ['settings']: + Configuration settings for the API. + + --model str HuggingFace model to use. Examples: facebook/opt-13b, + facebook/opt-30b, facebook/opt-66b, bigscience/bloom, + etc. (default: facebook/opt-13b) + --hf_cache_dir Path (default: $SCRATCH/cache/huggingface) + --port int The port to run the server on. (default: 12345) + --reload bool Whether to restart the server (and reload the model) when + the source code changes. (default: False) + --offload_folder Path + Folder where the model weights will be offloaded if the + entire model doesn't fit in memory. (default: + $SLURM_TMPDIR) + --use_public_ip bool Set to True to make the server available on the node's + public IP, rather than localhost. Setting this to False + is useful when using VSCode to debug the server, since + the port forwarding is done automatically for you. + Setting this to True makes it so many users on the + cluster can share the same server. However, at the + moment, you would still need to do the port forwarding + setup yourself, if you want to access the server from + outside the cluster. (default: False) +``` + +Spinning up the server: +```console +> python app/server.py +HF_HOME='/home/mila/n/normandf/scratch/cache/huggingface' +TRANSFORMERS_CACHE='/home/mila/n/normandf/scratch/cache/huggingface/transformers' +Running the server with the following settings: {"model_capacity": "13b", "hf_cache_dir": "~/scratch/cache/huggingface", "port": 12345, "reload": false, "offload_folder": "/Tmp/slurm.1968686.0"} +INFO: Started server process [25042] +INFO: Waiting for application startup. +Writing address_string='cn-b003:8000' to server.txt +INFO: Application startup complete. +INFO: Uvicorn running on http://127.0.0.1:12345 (Press CTRL+C to quit) +``` + +(WIP) Run as a slurm job: + +```console +> sbatch run_server.sh +``` + +(WIP) Using the python client to Connect to a running server: + +```python +import time +from app.client import server_is_up, get_completion_text +while not server_is_up(): + print("Waiting for the server to be online...") + time.sleep(10) +print("server is up!") +rest_of_story = get_completion_text("Once upon a time, there lived a great wizard.") +``` diff --git a/docs/examples/distributed/LLM_inference/client.py b/docs/examples/distributed/LLM_inference/client.py new file mode 100644 index 00000000..1c6ee6c9 --- /dev/null +++ b/docs/examples/distributed/LLM_inference/client.py @@ -0,0 +1,36 @@ +""" TODO: Client-side code to communicate with the server that is running somewhere on the cluster. + +IDEAS: +- Could look for slurm jobs that have a given name, like `deploy.sh` and extract the port from the + job's command-line ags! +""" +from pathlib import Path +import requests +import time + + +def get_server_url_and_port() -> tuple[str, int]: + with open("server.txt") as f: + server_url_and_port = f.read().strip() + server_url, _, port = server_url_and_port.partition(":") + return server_url, int(port) + + +def debug(): + # WIP: Not working yet. + while not Path("server.txt").exists(): + time.sleep(1) + print(f"Waiting for server to start...") + server_url, port = get_server_url_and_port() + print(f"Found server at {server_url}:{port}") + response = requests.get( + f"http://{server_url}:{port}/complete/", + params={ + "prompt": "Hello, my name is Bob. I love fishing, hunting, and my favorite food is", + }, + ) + print(response) + + +if __name__ == "__main__": + debug() diff --git a/docs/examples/distributed/LLM_inference/server.py b/docs/examples/distributed/LLM_inference/server.py new file mode 100644 index 00000000..766c0064 --- /dev/null +++ b/docs/examples/distributed/LLM_inference/server.py @@ -0,0 +1,242 @@ +""" API for querying a large language model. """ +from __future__ import annotations + +import functools +import logging +import os +import socket +from dataclasses import asdict, dataclass +from logging import getLogger as get_logger +from pathlib import Path + +import torch +import uvicorn +from fastapi import Depends, FastAPI, Request +from fastapi.responses import RedirectResponse +from pydantic import BaseSettings +from simple_parsing import ArgumentParser +from transformers import AutoModelForCausalLM, AutoTokenizer, BloomForCausalLM +from transformers.models.gpt2.tokenization_gpt2 import GPT2Tokenizer +from transformers.models.opt.modeling_opt import OPTForCausalLM + +# TODO: Setup logging correctly with FastAPI. +logger = get_logger(__name__) +logger.setLevel(logging.DEBUG) + +# SCRATCH = Path(os.environ["SCRATCH"]) +SCRATCH = Path("/data/fake_scratch") +# SLURM_TMPDIR = Path(os.environ.get("SLURM_TMPDIR", f"/Tmp/slurm.{os.environ['SLURM_JOB_ID']}.0")) +SLURM_TMPDIR = None + + +@dataclass(init=False) +class Settings(BaseSettings): + """Configuration settings for the API.""" + + model: str = "meta-llama/Llama-2-7b-chat-hf" + """ HuggingFace model to use. + Examples: facebook/opt-13b, facebook/opt-30b, facebook/opt-66b, bigscience/bloom, etc. + """ + + hf_cache_dir: Path = SCRATCH / "cache" / "huggingface" + + port: int = 12345 + """ The port to run the server on.""" + + reload: bool = False + """ Whether to restart the server (and reload the model) when the source code changes. """ + + offload_folder: Path = Path(SLURM_TMPDIR or "model_offload") + """ + Folder where the model weights will be offloaded if the entire model doesn't fit in memory. + """ + + use_public_ip: bool = False + """ Set to True to make the server available on the node's public IP, rather than localhost. + + Setting this to False is useful when using VSCode to debug the server, since the port + forwarding is done automatically for you. + + Setting this to True makes it so many users on the cluster can share the same server. However, + at the moment, you would still need to do the port forwarding setup yourself, if you want to + access the server from outside the cluster. + """ + + +def write_server_address_to_file(port: int = 12345): + node_hostname = socket.gethostname() + with open("server.txt", "w") as f: + address_string = f"{node_hostname}:{port}" + print(f"Writing {address_string=} to server.txt") + f.write(address_string) + + +app = FastAPI( + on_startup=[ + write_server_address_to_file, + ], + title="SLURM + FastAPI + HuggingFace", + dependencies=[], +) + + +@functools.cache +def get_settings() -> Settings: + # Creates a Settings object from the environment variables. + return Settings() + + +@app.get("/") +def root(request: Request): + return RedirectResponse(url=f"{request.base_url}docs") + + +@dataclass +class CompletionResponse: + prompt: str + response: str + model: str + + +def preload_components(settings: Settings = Depends(get_settings)): + print(f"Preloading components: {settings=}") + load_completion_model(capacity=settings.model, offload_folder=settings.offload_folder) + load_tokenizer(capacity=settings.model) + + +@app.get("/complete/") +async def get_completion( + prompt: str, + max_response_length: int = 30, + settings: Settings = Depends(get_settings), +) -> CompletionResponse: + """Returns the completion of the given prompt by a language model with the given capacity.""" + model_name = settings.model + offload_folder = settings.offload_folder + print(f"Completion request: {prompt=}, model: {model_name}") + + model = load_completion_model(model=model_name, offload_folder=offload_folder) + tokenizer = load_tokenizer(model=model_name) + + response_text = get_response_text( + model=model, + tokenizer=tokenizer, + prompt=prompt, + max_response_length=max_response_length, + ) + + print(f"Completion response: {response_text}") + return CompletionResponse( + prompt=prompt, + response=response_text, + model=model_name, + ) + + +@functools.cache +def load_completion_model(model: str, offload_folder: Path) -> OPTForCausalLM | BloomForCausalLM: + print(f"Loading model: {model}...") + extra_kwargs = {} + if model.startswith("bigscience/bloom"): + extra_kwargs.update(load_in_8bit=True) + pretrained_causal_lm_model = AutoModelForCausalLM.from_pretrained( + model, + device_map="auto", + torch_dtype=torch.float16, + offload_folder=offload_folder, + use_auth_token=True, + **extra_kwargs, + ) + print("Done.") + return pretrained_causal_lm_model + + +@functools.cache +def load_tokenizer(model: str) -> GPT2Tokenizer: + print(f"Loading Tokenizer for model {model}...") + # NOTE: See https://github.com/huggingface/tokenizers/pull/1005 + kwargs = {} + if model.startswith("facebook/opt"): + kwargs.update(use_fast=False) + pretrained_tokenizer = AutoTokenizer.from_pretrained( + model, + device_map="auto", + torch_dtype=torch.float16, + **kwargs, + ) + return pretrained_tokenizer + + +@torch.no_grad() +def get_response_text( + model: OPTForCausalLM | BloomForCausalLM, + tokenizer: GPT2Tokenizer, + prompt: str, + max_response_length: int = 30, +) -> str: + inputs = tokenizer(prompt, return_tensors="pt") + print(f"Generating based on {prompt=}...") + generate_ids = model.generate( + inputs.input_ids.to(model.device), max_length=max_response_length + ) + prompt_and_response = tokenizer.batch_decode( + generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False + )[0] + assert isinstance(prompt_and_response, str) + model_response = prompt_and_response.replace(prompt, "").lstrip() + return model_response + + +# TODOs: +# - Check with students what kind of functionality they want, e.g. extracting representations: +# @torch.no_grad() +# def get_hidden_state(prompt: str, capacity: Capacity = DEFAULT_CAPACITY) -> Tensor: +# inputs = tokenize(prompt) +# model = load_embedding_model() +# outputs = model(**inputs.to(model.device)) + +# last_hidden_states = outputs.last_hidden_state +# return last_hidden_states +# - Add a training example! +# - Create a slurm sbatch script to run this. + + +def main(): + parser = ArgumentParser(description=__doc__) + parser.add_arguments(Settings, "settings", default=Settings()) + args = parser.parse_args() + settings: Settings = args.settings + + HF_HOME = os.environ.setdefault("HF_HOME", str(settings.hf_cache_dir)) + TRANSFORMERS_CACHE = os.environ.setdefault( + "TRANSFORMERS_CACHE", str(settings.hf_cache_dir / "transformers") + ) + print(f"{HF_HOME=}") + print(f"{TRANSFORMERS_CACHE=}") + + print(f"Running the server with the following settings: {settings.json()}") + + # NOTE: Can't use `reload` or `workers` when passing the app by value. + if not settings.reload: + app.dependency_overrides[get_settings] = lambda: settings + else: + # NOTE: If we we want to use `reload=True`, we set the environment variables, so they are + # used when that module gets imported. + for k, v in asdict(settings).items(): + os.environ[k.upper()] = str(v) + + write_server_address_to_file(port=settings.port) + + uvicorn.run( + (app if not settings.reload else "app.server:app"), # type: ignore + port=settings.port, + # Using the public IP makes the server publicly available, but a bit harder to debug (no + # automatic port forwarding in VSCode for example). + host=socket.gethostname() if settings.use_public_ip else "127.0.0.1", + log_level="debug", + reload=settings.reload, + ) + + +if __name__ == "__main__": + main() diff --git a/docs/examples/distributed/LLM_inference/setup.py b/docs/examples/distributed/LLM_inference/setup.py new file mode 100644 index 00000000..e69de29b From c0478b42a85f5c631c6e3d16c8a18e97eec44a78 Mon Sep 17 00:00:00 2001 From: Setepenre Date: Mon, 2 Oct 2023 19:41:38 +0000 Subject: [PATCH 2/2] Add Server lookup --- .../distributed/LLM_inference/client.py | 97 ++++++++++++++++++- 1 file changed, 93 insertions(+), 4 deletions(-) diff --git a/docs/examples/distributed/LLM_inference/client.py b/docs/examples/distributed/LLM_inference/client.py index 1c6ee6c9..fc07b8d2 100644 --- a/docs/examples/distributed/LLM_inference/client.py +++ b/docs/examples/distributed/LLM_inference/client.py @@ -9,11 +9,99 @@ import time +def _fetch_job_info(name): + # Mock this for testing + command = ["squeue", "-h", f"--name={name}", "--format=\"%A %j %T %P %U %k %N\""] + return subprocess.check_output(command, text=True) + + +def get_slurm_job_by_name(name): + """Retrieve a list of jobs that match a given job name""" + + output =_fetch_job_info(name) + jobs = [] + + def parse_meta(comment): + data = dict() + if comment != "(null)": + items = comment.split('|') + for kv in items: + try: + k, v = kv.split('=', maxsplit=1) + data[k] = v + except: + pass + + return data + + for line in output.splitlines(): + job_id, job_name, status, partition, user, comment, nodes = line.split(' ') + + jobs.append({ + "job_id":job_id, + "job_name":job_name, + "status":status, + "partition":partition, + "user":user, + "comment": parse_meta(comment), + "nodes": nodes + }) + + return jobs + + +def find_suitable_inference_server(jobs, model): + """Select suitable jobs from a list, looking for a specific model""" + selected = [] + + def is_shared(job): + return job["comment"].get("shared", 'y') == 'y' + + def is_running(job): + return job['status'] == "RUNNING" + + def has_model(job, model): + if model is None: + return True + + # FIXME: + # /network/weights/llama.var/llama2/Llama-2-7b-hf != meta-llama/Llama-2-7b-hf + # + return job['comment']['model'] == model + + def select(job): + selected.append({ + "model": job['comment']["model"], + "host": job["comment"]["host"], + "port": job["comment"]["port"], + }) + + for job in jobs: + if is_shared(job) and is_running(job) and has_model(job, model): + select(job) + + return selected + + +def get_inference_server(model=None): + """Retrieve an inference server from slurm jobs""" + jobs = get_slurm_job_by_name('inference_server_SHARED.sh') + + servers = find_suitable_inference_server(jobs, model) + + try: + return random.choice(servers) + except IndexError: + return None + + def get_server_url_and_port() -> tuple[str, int]: - with open("server.txt") as f: - server_url_and_port = f.read().strip() - server_url, _, port = server_url_and_port.partition(":") - return server_url, int(port) + server = get_inference_server(model) + + if server is None: + return None + + return server['host'], int(server['port']) def debug(): @@ -21,6 +109,7 @@ def debug(): while not Path("server.txt").exists(): time.sleep(1) print(f"Waiting for server to start...") + server_url, port = get_server_url_and_port() print(f"Found server at {server_url}:{port}") response = requests.get(