Skip to content

Commit 4b515ed

Browse files
committed
added cross encoder serving
1 parent f1a3f37 commit 4b515ed

File tree

5 files changed

+157
-7
lines changed

5 files changed

+157
-7
lines changed

lib/demo/application.ex

+12
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ defmodule Demo.Application do
99
def start(_type, _args) do
1010
children = [
1111
DemoWeb.Telemetry,
12+
{Nx.Serving, serving: cross(), name: CrossEncoder},
1213
{Nx.Serving, serving: serving(), name: SentenceTransformer},
1314
Demo.Repo,
1415
{DNSCluster, query: Application.get_env(:demo, :dns_cluster_query) || :ignore},
@@ -48,4 +49,15 @@ defmodule Demo.Application do
4849
defn_options: [compiler: EXLA]
4950
)
5051
end
52+
53+
def cross() do
54+
repo = "cross-encoder/ms-marco-MiniLM-L-6-v2"
55+
{:ok, model_info} = Bumblebee.load_model({:hf, repo})
56+
{:ok, tokenizer} = Bumblebee.load_tokenizer({:hf, "bert-base-uncased"})
57+
58+
Demo.Encoder.cross_encoder(model_info, tokenizer,
59+
compile: [batch_size: 32, sequence_length: [512]],
60+
defn_options: [compiler: EXLA]
61+
)
62+
end
5163
end

lib/demo/encoder.ex

+81
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,81 @@
1+
defmodule Demo.Encoder do
2+
@moduledoc false
3+
4+
alias Bumblebee.Shared
5+
6+
def cross_encoder(model_info, tokenizer, opts \\ []) do
7+
%{model: model, params: params, spec: _spec} = model_info
8+
9+
opts =
10+
Keyword.validate!(opts, [
11+
:compile,
12+
defn_options: [],
13+
preallocate_params: false
14+
])
15+
16+
preallocate_params = opts[:preallocate_params]
17+
defn_options = opts[:defn_options]
18+
19+
compile =
20+
if compile = opts[:compile] do
21+
compile
22+
|> Keyword.validate!([:batch_size, :sequence_length])
23+
|> Shared.require_options!([:batch_size, :sequence_length])
24+
end
25+
26+
batch_size = compile[:batch_size]
27+
sequence_length = compile[:sequence_length]
28+
29+
scores_fun = fn params, inputs ->
30+
Axon.predict(model, params, inputs)
31+
end
32+
33+
batch_keys = Shared.sequence_batch_keys(sequence_length)
34+
35+
Nx.Serving.new(
36+
fn batch_key, defn_options ->
37+
params = Shared.maybe_preallocate(params, preallocate_params, defn_options)
38+
39+
scores_fun =
40+
Shared.compile_or_jit(scores_fun, defn_options, compile != nil, fn ->
41+
{:sequence_length, sequence_length} = batch_key
42+
43+
inputs = %{
44+
"input_ids" => Nx.template({batch_size, sequence_length}, :u32),
45+
"attention_mask" => Nx.template({batch_size, sequence_length}, :u32)
46+
}
47+
48+
[params, inputs]
49+
end)
50+
51+
fn inputs ->
52+
inputs = Shared.maybe_pad(inputs, batch_size)
53+
scores_fun.(params, inputs)
54+
end
55+
end,
56+
defn_options
57+
)
58+
|> Nx.Serving.batch_size(batch_size)
59+
|> Nx.Serving.process_options(batch_keys: batch_keys)
60+
|> Nx.Serving.client_preprocessing(fn raw_input ->
61+
multi? = Enum.count(raw_input) > 1
62+
63+
inputs =
64+
Nx.with_default_backend(Nx.BinaryBackend, fn ->
65+
Bumblebee.apply_tokenizer(tokenizer, raw_input,
66+
length: sequence_length,
67+
return_token_type_ids: false
68+
)
69+
end)
70+
71+
batch_key = Shared.sequence_batch_key_for_inputs(inputs, sequence_length)
72+
batch = [inputs] |> Nx.Batch.concatenate() |> Nx.Batch.key(batch_key)
73+
74+
{batch, multi?}
75+
end)
76+
|> Nx.Serving.client_postprocessing(fn {scores, _metadata}, multi? ->
77+
%{results: scores.logits |> Nx.to_flat_list()}
78+
|> Shared.normalize_output(multi?)
79+
end)
80+
end
81+
end

lib/demo/section.ex

+22-3
Original file line numberDiff line numberDiff line change
@@ -27,13 +27,32 @@ defmodule Demo.Section do
2727
|> validate_required(@required_attrs)
2828
end
2929

30-
def search_document(document_id, embedding) do
30+
def search_document_embedding(document_id, embedding) do
3131
from(s in Section,
32+
select: {s.id, s.page, s.text, s.document_id},
3233
where: s.document_id == ^document_id,
3334
order_by: max_inner_product(s.embedding, ^embedding),
34-
limit: 1
35+
limit: 4
36+
)
37+
|> Demo.Repo.all()
38+
end
39+
40+
def search_document_text(document_id, search) do
41+
from(s in Section,
42+
select: {s.id, s.page, s.text, s.document_id},
43+
where:
44+
s.document_id == ^document_id and
45+
fragment("to_tsvector('english', ?) @@ plainto_tsquery('english', ?)", s.text, ^search),
46+
order_by: [
47+
desc:
48+
fragment(
49+
"ts_rank_cd(to_tsvector('english', ?), plainto_tsquery('english', ?))",
50+
s.text,
51+
^search
52+
)
53+
],
54+
limit: 4
3555
)
3656
|> Demo.Repo.all()
37-
|> List.first()
3857
end
3958
end

lib/demo_web/live/page_live.ex

+33-4
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ defmodule DemoWeb.PageLive do
1212

1313
socket =
1414
socket
15-
|> assign(task: nil, lookup: nil, filename: nil, messages: messages, version: version, documents: documents, result: nil, text: nil, loading: false, selected: nil, query: nil, transformer: nil, llama: nil, path: nil, focused: false, loadingpdf: false)
15+
|> assign(encoder: nil, task: nil, lookup: nil, filename: nil, messages: messages, version: version, documents: documents, result: nil, text: nil, loading: false, selected: nil, query: nil, transformer: nil, llama: nil, path: nil, focused: false, loadingpdf: false)
1616
|> allow_upload(:document, accept: ~w(.pdf), progress: &handle_progress/3, auto_upload: true, max_entries: 1)
1717

1818
{:ok, socket}
@@ -71,9 +71,36 @@ defmodule DemoWeb.PageLive do
7171

7272
@impl true
7373
def handle_info({ref, {selected, question, %{embedding: embedding}}}, socket) when socket.assigns.lookup.ref == ref do
74-
version = socket.assigns.version
74+
sections = Demo.Section.search_document_embedding(selected.id, embedding)
75+
others = Demo.Section.search_document_text(selected.id, question)
76+
deduplicated = sections ++ others |> Enum.uniq_by(fn {id, _, _, _} -> id end)
77+
78+
data =
79+
deduplicated
80+
|> Enum.map(fn {_id, _page, text, _} -> {question, text} end)
81+
82+
encoder =
83+
Task.async(fn ->
84+
section =
85+
CrossEncoder
86+
|> Nx.Serving.batched_run(data)
87+
|> results()
88+
|> Enum.zip(deduplicated)
89+
|> Enum.map(fn {score, {id, page, text, document_id}} ->
90+
%{id: id, page: page, text: text, document_id: document_id, score: score}
91+
end)
92+
|> Enum.sort(fn x, y -> x.score > y.score end)
93+
|> List.first()
94+
95+
{question, section}
96+
end)
7597

76-
section = Demo.Section.search_document(selected.id, embedding)
98+
{:noreply, assign(socket, lookup: nil, encoder: encoder)}
99+
end
100+
101+
@impl true
102+
def handle_info({ref, {question, section}}, socket) when socket.assigns.encoder.ref == ref do
103+
version = socket.assigns.version
77104
document = socket.assigns.documents |> Enum.find(&(&1.id == section.document_id))
78105

79106
prompt = """
@@ -91,7 +118,7 @@ defmodule DemoWeb.PageLive do
91118
{section, Replicate.Predictions.wait(prediction)}
92119
end)
93120

94-
{:noreply, assign(socket, lookup: nil, llama: llama, selected: document)}
121+
{:noreply, assign(socket, encoder: nil, llama: llama, selected: document)}
95122
end
96123

97124
@impl true
@@ -215,6 +242,8 @@ defmodule DemoWeb.PageLive do
215242
end)
216243
end
217244

245+
def results(%{results: results}), do: results
246+
218247
@impl true
219248
def render(assigns) do
220249
~H"""
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
defmodule Demo.Repo.Migrations.AddGinIndex do
2+
use Ecto.Migration
3+
4+
def change do
5+
execute """
6+
CREATE INDEX sections_text_search_idx ON sections USING GIN (to_tsvector('english', text));
7+
"""
8+
end
9+
end

0 commit comments

Comments
 (0)