Skip to content

Commit 8b0ac12

Browse files
authored
langchain[patch]: Adds MMR to memory vector store (#6481)
* Adds MMR to memory vector store * Adds memory MMR docs
1 parent 8640a66 commit 8b0ac12

File tree

3 files changed

+143
-20
lines changed

3 files changed

+143
-20
lines changed

docs/core_docs/docs/integrations/vectorstores/memory.ipynb

+52-3
Original file line numberDiff line numberDiff line change
@@ -164,7 +164,7 @@
164164
},
165165
{
166166
"cell_type": "code",
167-
"execution_count": 4,
167+
"execution_count": 3,
168168
"id": "aa0a16fa",
169169
"metadata": {},
170170
"outputs": [
@@ -199,7 +199,7 @@
199199
},
200200
{
201201
"cell_type": "code",
202-
"execution_count": 5,
202+
"execution_count": 4,
203203
"id": "5efd2eaa",
204204
"metadata": {},
205205
"outputs": [
@@ -232,7 +232,7 @@
232232
},
233233
{
234234
"cell_type": "code",
235-
"execution_count": 6,
235+
"execution_count": 5,
236236
"id": "f3460093",
237237
"metadata": {},
238238
"outputs": [
@@ -265,6 +265,55 @@
265265
"await retriever.invoke(\"biology\");"
266266
]
267267
},
268+
{
269+
"cell_type": "markdown",
270+
"id": "423d779a",
271+
"metadata": {},
272+
"source": [
273+
"### Maximal marginal relevance\n",
274+
"\n",
275+
"This vector store also supports maximal marginal relevance (MMR), a technique that first fetches a larger number of results (given by `searchKwargs.fetchK`), with classic similarity search, then reranks for diversity and returns the top `k` results. This helps guard against redundant information:"
276+
]
277+
},
278+
{
279+
"cell_type": "code",
280+
"execution_count": 6,
281+
"id": "56817a1c",
282+
"metadata": {},
283+
"outputs": [
284+
{
285+
"name": "stdout",
286+
"output_type": "stream",
287+
"text": [
288+
"[\n",
289+
" Document {\n",
290+
" pageContent: 'The powerhouse of the cell is the mitochondria',\n",
291+
" metadata: { source: 'https://example.com' },\n",
292+
" id: undefined\n",
293+
" },\n",
294+
" Document {\n",
295+
" pageContent: 'Buildings are made out of brick',\n",
296+
" metadata: { source: 'https://example.com' },\n",
297+
" id: undefined\n",
298+
" }\n",
299+
"]\n"
300+
]
301+
}
302+
],
303+
"source": [
304+
"const mmrRetriever = vectorStore.asRetriever({\n",
305+
" searchType: \"mmr\",\n",
306+
" searchKwargs: {\n",
307+
" fetchK: 10,\n",
308+
" },\n",
309+
" // Optional filter\n",
310+
" filter: filter,\n",
311+
" k: 2,\n",
312+
"});\n",
313+
"\n",
314+
"await mmrRetriever.invoke(\"biology\");"
315+
]
316+
},
268317
{
269318
"cell_type": "markdown",
270319
"id": "e2e0a211",

langchain/src/vectorstores/memory.ts

+61-17
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,11 @@
1-
import { VectorStore } from "@langchain/core/vectorstores";
1+
import {
2+
MaxMarginalRelevanceSearchOptions,
3+
VectorStore,
4+
} from "@langchain/core/vectorstores";
25
import type { EmbeddingsInterface } from "@langchain/core/embeddings";
3-
import { Document } from "@langchain/core/documents";
6+
import { Document, DocumentInterface } from "@langchain/core/documents";
47
import { cosine } from "../util/ml-distance/similarities.js";
8+
import { maximalMarginalRelevance } from "../util/math.js";
59

610
/**
711
* Interface representing a vector in memory. It includes the content
@@ -82,21 +86,11 @@ export class MemoryVectorStore extends VectorStore {
8286
this.memoryVectors = this.memoryVectors.concat(memoryVectors);
8387
}
8488

85-
/**
86-
* Method to perform a similarity search in the memory vector store. It
87-
* calculates the similarity between the query vector and each vector in
88-
* the store, sorts the results by similarity, and returns the top `k`
89-
* results along with their scores.
90-
* @param query Query vector to compare against the vectors in the store.
91-
* @param k Number of top results to return.
92-
* @param filter Optional filter function to apply to the vectors before performing the search.
93-
* @returns Promise that resolves with an array of tuples, each containing a `Document` and its similarity score.
94-
*/
95-
async similaritySearchVectorWithScore(
89+
protected async _queryVectors(
9690
query: number[],
9791
k: number,
9892
filter?: this["FilterType"]
99-
): Promise<[Document, number][]> {
93+
) {
10094
const filterFunction = (memoryVector: MemoryVector) => {
10195
if (!filter) {
10296
return true;
@@ -109,25 +103,75 @@ export class MemoryVectorStore extends VectorStore {
109103
return filter(doc);
110104
};
111105
const filteredMemoryVectors = this.memoryVectors.filter(filterFunction);
112-
const searches = filteredMemoryVectors
106+
return filteredMemoryVectors
113107
.map((vector, index) => ({
114108
similarity: this.similarity(query, vector.embedding),
115109
index,
110+
metadata: vector.metadata,
111+
content: vector.content,
112+
embedding: vector.embedding,
116113
}))
117114
.sort((a, b) => (a.similarity > b.similarity ? -1 : 0))
118115
.slice(0, k);
116+
}
119117

118+
/**
119+
* Method to perform a similarity search in the memory vector store. It
120+
* calculates the similarity between the query vector and each vector in
121+
* the store, sorts the results by similarity, and returns the top `k`
122+
* results along with their scores.
123+
* @param query Query vector to compare against the vectors in the store.
124+
* @param k Number of top results to return.
125+
* @param filter Optional filter function to apply to the vectors before performing the search.
126+
* @returns Promise that resolves with an array of tuples, each containing a `Document` and its similarity score.
127+
*/
128+
async similaritySearchVectorWithScore(
129+
query: number[],
130+
k: number,
131+
filter?: this["FilterType"]
132+
): Promise<[Document, number][]> {
133+
const searches = await this._queryVectors(query, k, filter);
120134
const result: [Document, number][] = searches.map((search) => [
121135
new Document({
122-
metadata: filteredMemoryVectors[search.index].metadata,
123-
pageContent: filteredMemoryVectors[search.index].content,
136+
metadata: search.metadata,
137+
pageContent: search.content,
124138
}),
125139
search.similarity,
126140
]);
127141

128142
return result;
129143
}
130144

145+
async maxMarginalRelevanceSearch(
146+
query: string,
147+
options: MaxMarginalRelevanceSearchOptions<this["FilterType"]>
148+
): Promise<DocumentInterface[]> {
149+
const queryEmbedding = await this.embeddings.embedQuery(query);
150+
151+
const searches = await this._queryVectors(
152+
queryEmbedding,
153+
options.fetchK ?? 20,
154+
options.filter
155+
);
156+
157+
const embeddingList = searches.map((searchResp) => searchResp.embedding);
158+
159+
const mmrIndexes = maximalMarginalRelevance(
160+
queryEmbedding,
161+
embeddingList,
162+
options.lambda,
163+
options.k
164+
);
165+
166+
return mmrIndexes.map(
167+
(idx) =>
168+
new Document({
169+
metadata: searches[idx].metadata,
170+
pageContent: searches[idx].content,
171+
})
172+
);
173+
}
174+
131175
/**
132176
* Static method to create a `MemoryVectorStore` instance from an array of
133177
* texts. It creates a `Document` for each text and metadata pair, and

langchain/src/vectorstores/tests/memory.test.ts

+30
Original file line numberDiff line numberDiff line change
@@ -97,3 +97,33 @@ test("MemoryVectorStore with custom similarity", async () => {
9797
expect(similarityCalledCount).toBe(4);
9898
expect(results).toHaveLength(3);
9999
});
100+
101+
test("MemoryVectorStore with max marginal relevance", async () => {
102+
const embeddings = new SyntheticEmbeddings({
103+
vectorSize: 1536,
104+
});
105+
let similarityCalled = false;
106+
let similarityCalledCount = 0;
107+
const store = new MemoryVectorStore(embeddings, {
108+
similarity: (a: number[], b: number[]) => {
109+
similarityCalledCount += 1;
110+
similarityCalled = true;
111+
return cosine(a, b);
112+
},
113+
});
114+
115+
expect(store).toBeDefined();
116+
117+
await store.addDocuments([
118+
{ pageContent: "hello", metadata: { a: 1 } },
119+
{ pageContent: "hi", metadata: { a: 1 } },
120+
{ pageContent: "bye", metadata: { a: 1 } },
121+
{ pageContent: "what's this", metadata: { a: 1 } },
122+
]);
123+
124+
const results = await store.maxMarginalRelevanceSearch("hello", { k: 3 });
125+
126+
expect(similarityCalled).toBe(true);
127+
expect(similarityCalledCount).toBe(4);
128+
expect(results).toHaveLength(3);
129+
});

0 commit comments

Comments
 (0)