-
-
Notifications
You must be signed in to change notification settings - Fork 6.2k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Frontend] Rerank API (Jina- and Cohere-compatible API) #12376
Merged
Merged
Changes from all commits
Commits
Show all changes
26 commits
Select commit
Hold shift + click to select a range
b6610fb
feat: serving_rerank implementation
K-Mistele a82b4bb
fix: imports
K-Mistele 99acff6
doc: add example requests and scripts
K-Mistele 31b5137
test: rerank
K-Mistele 485e328
feat: serving_rerank implementation
K-Mistele 8922f81
fix: imports
K-Mistele dc0d158
doc: add example requests and scripts
K-Mistele 4ed459b
test: rerank
K-Mistele 676eea0
added /v2/rerank route
K-Mistele b66bcc2
fix(docs): extra spaces
K-Mistele c44dee4
fix(docs): cross-reference target for rerank API
K-Mistele cce2873
fix(tests): needed to break up model quotes
K-Mistele a38060f
doc(example): update jina example to reflect lack of SDK, add cohere …
K-Mistele 901021f
fix: remove logger warnings and make the linter happy
K-Mistele 4849575
fix: file name
K-Mistele 36e85a5
fix(nit): ordering on assertions
K-Mistele 4adb94b
fix(tests): was using score instead of rerank
K-Mistele dc92240
fix(api): use rereank as the default API for scoring
K-Mistele 330aa22
fix(merge)
K-Mistele ce85821
Merge branch 'vllm-project:main' into main
K-Mistele 29a0366
doc: v2 rerank endpoint
K-Mistele 844d39a
fix: remove duplicate file and fix vllm start command in examples
K-Mistele af83c25
fix: only load serving rerank if model supports score
K-Mistele a53b59c
Merge branch 'vllm-project:main' into main
K-Mistele 17441f5
merge
K-Mistele 974c0be
fix(tests): use correct API for rerank tests
K-Mistele File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,32 @@ | ||
""" | ||
Example of using the OpenAI entrypoint's rerank API which is compatible with | ||
the Cohere SDK: https://github.com/cohere-ai/cohere-python | ||
|
||
run: vllm serve BAAI/bge-reranker-base | ||
""" | ||
import cohere | ||
|
||
# cohere v1 client | ||
co = cohere.Client(base_url="http://localhost:8000", api_key="sk-fake-key") | ||
rerank_v1_result = co.rerank( | ||
model="BAAI/bge-reranker-base", | ||
query="What is the capital of France?", | ||
documents=[ | ||
"The capital of France is Paris", "Reranking is fun!", | ||
"vLLM is an open-source framework for fast AI serving" | ||
]) | ||
|
||
print(rerank_v1_result) | ||
|
||
# or the v2 | ||
co2 = cohere.ClientV2("sk-fake-key", base_url="http://localhost:8000") | ||
|
||
v2_rerank_result = co2.rerank( | ||
model="BAAI/bge-reranker-base", | ||
query="What is the capital of France?", | ||
documents=[ | ||
"The capital of France is Paris", "Reranking is fun!", | ||
"vLLM is an open-source framework for fast AI serving" | ||
]) | ||
|
||
print(v2_rerank_result) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,33 @@ | ||
""" | ||
Example of using the OpenAI entrypoint's rerank API which is compatible with | ||
Jina and Cohere https://jina.ai/reranker | ||
|
||
run: vllm serve BAAI/bge-reranker-base | ||
""" | ||
import json | ||
|
||
import requests | ||
|
||
url = "http://127.0.0.1:8000/rerank" | ||
|
||
headers = {"accept": "application/json", "Content-Type": "application/json"} | ||
|
||
data = { | ||
"model": | ||
"BAAI/bge-reranker-base", | ||
"query": | ||
"What is the capital of France?", | ||
"documents": [ | ||
"The capital of Brazil is Brasilia.", | ||
"The capital of France is Paris.", "Horses and cows are both animals" | ||
] | ||
} | ||
response = requests.post(url, headers=headers, json=data) | ||
|
||
# Check the response | ||
if response.status_code == 200: | ||
print("Request successful!") | ||
print(json.dumps(response.json(), indent=2)) | ||
else: | ||
print(f"Request failed with status code: {response.status_code}") | ||
print(response.text) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,87 @@ | ||
import pytest | ||
import requests | ||
|
||
from vllm.entrypoints.openai.protocol import RerankResponse | ||
|
||
from ...utils import RemoteOpenAIServer | ||
|
||
MODEL_NAME = "BAAI/bge-reranker-base" | ||
|
||
|
||
@pytest.fixture(scope="module") | ||
def server(): | ||
args = ["--enforce-eager", "--max-model-len", "100"] | ||
|
||
with RemoteOpenAIServer(MODEL_NAME, args) as remote_server: | ||
yield remote_server | ||
|
||
|
||
@pytest.mark.asyncio | ||
@pytest.mark.parametrize("model_name", [MODEL_NAME]) | ||
def test_rerank_texts(server: RemoteOpenAIServer, model_name: str): | ||
query = "What is the capital of France?" | ||
documents = [ | ||
"The capital of Brazil is Brasilia.", "The capital of France is Paris." | ||
] | ||
|
||
rerank_response = requests.post(server.url_for("rerank"), | ||
json={ | ||
"model": model_name, | ||
"query": query, | ||
"documents": documents, | ||
}) | ||
rerank_response.raise_for_status() | ||
rerank = RerankResponse.model_validate(rerank_response.json()) | ||
|
||
assert rerank.id is not None | ||
assert rerank.results is not None | ||
assert len(rerank.results) == 2 | ||
assert rerank.results[0].relevance_score >= 0.9 | ||
assert rerank.results[1].relevance_score <= 0.01 | ||
|
||
|
||
@pytest.mark.asyncio | ||
@pytest.mark.parametrize("model_name", [MODEL_NAME]) | ||
def test_top_n(server: RemoteOpenAIServer, model_name: str): | ||
query = "What is the capital of France?" | ||
documents = [ | ||
"The capital of Brazil is Brasilia.", | ||
"The capital of France is Paris.", "Cross-encoder models are neat" | ||
] | ||
|
||
rerank_response = requests.post(server.url_for("rerank"), | ||
json={ | ||
"model": model_name, | ||
"query": query, | ||
"documents": documents, | ||
"top_n": 2 | ||
}) | ||
rerank_response.raise_for_status() | ||
rerank = RerankResponse.model_validate(rerank_response.json()) | ||
|
||
assert rerank.id is not None | ||
assert rerank.results is not None | ||
assert len(rerank.results) == 2 | ||
assert rerank.results[0].relevance_score >= 0.9 | ||
assert rerank.results[1].relevance_score <= 0.01 | ||
|
||
|
||
@pytest.mark.asyncio | ||
@pytest.mark.parametrize("model_name", [MODEL_NAME]) | ||
def test_rerank_max_model_len(server: RemoteOpenAIServer, model_name: str): | ||
|
||
query = "What is the capital of France?" * 100 | ||
documents = [ | ||
"The capital of Brazil is Brasilia.", "The capital of France is Paris." | ||
] | ||
|
||
rerank_response = requests.post(server.url_for("rerank"), | ||
json={ | ||
"model": model_name, | ||
"query": query, | ||
"documents": documents | ||
}) | ||
assert rerank_response.status_code == 400 | ||
# Assert just a small fragments of the response | ||
assert "Please reduce the length of the input." in \ | ||
rerank_response.text |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we should remove these warnings as the Cohere Python client will access this URL by default. Unless there's a way to change the URL in the client?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
there's a way to change the base URL, but that's just the server or hostname. unlike OpenAI which expects you to include the
/v1
in thebase_url
if you change it, cohere doesn't want you to set it, it just wants the host and automatically sets/v1
or/v2
depending on if you use the v1 client or v2 client.I will remove the logger warnings
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Remember to remove or switch to warning_once