Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add 1 memory endpoint for Retrieving Points #892

Merged
merged 4 commits into from
Sep 25, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions core/cat/memory/vector_memory_collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -268,6 +268,20 @@ def get_all_points(self):
)

return all_points

# Retrieve a set of points with an optional offset and limit.
def get_all_points_with_offset(self, limit:int=10000, offset:str=None):
# Retrieve the points and the next offset.
# To retrieve the first page set offset equal to None

all_points, next_page_offset = self.client.scroll(
collection_name=self.collection_name,
with_vectors=True,
offset=offset, # Start from the given offset, or the beginning if None.
limit=limit # Limit the number of points retrieved to the specified limit.
)

return (all_points, next_page_offset)

def db_is_remote(self):
return isinstance(self.client._client, QdrantRemote)
Expand Down
82 changes: 82 additions & 0 deletions core/cat/routes/memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -277,3 +277,85 @@ async def get_conversation_history(
"""Get the specified user's conversation history from working memory"""

return {"history": stray.working_memory.history}

# GET all the points from a single collection
@router.get("/collections/{collection_id}/points")
async def get_collections_points(
request: Request,
collection_id: str,
limit:int=Query(
default=100,
description="How many points to return"
),
offset:str = Query(
default=None,
description="If provided (or not empty string) - skip points with ids less than given `offset`"
),
stray=Depends(HTTPAuth(AuthResource.MEMORY, AuthPermission.READ)),
) -> Dict:
"""Retrieve all the points from a single collection


Example
----------
```
collection = "declarative"
res = requests.get(
f"http://localhost:1865/memory/collections/{collection}/points",
)
json = res.json()
points = json["points"]

for point in points:
payload = point["payload"]
vector = point["vector"]
print(payload)
print(vector)
```

Example using offset
----------
```
# get all the points with limit 10
limit = 10
next_offset = ""
collection = "declarative"

while True:
res = requests.get(
f"http://localhost:1865/memory/collections/{collection}/points?limit={limit}&offset={next_offset}",
)
json = res.json()
points = json["points"]
next_offset = json["next_offset"]

for point in points:
payload = point["payload"]
vector = point["vector"]
print(payload)
print(vector)

if next_offset is None:
break
```
"""

# check if collection exists
collections = list(stray.memory.vectors.collections.keys())
if collection_id not in collections:
raise HTTPException(
status_code=400, detail={"error": f"Collection does not exist. Avaliable collections: {collections}"}
)

# if offset is empty string set to null
if offset == "":
offset = None

memory_collection = stray.memory.vectors.collections[collection_id]
points, next_offset = memory_collection.get_all_points_with_offset(limit=limit,offset=offset)

return {
"points":points,
"next_offset":next_offset
}

102 changes: 102 additions & 0 deletions core/tests/routes/memory/test_memory_points.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,3 +162,105 @@ def test_create_memory_point(client, patch_time_now, collection):
assert memory["metadata"] == expected_metadata



@pytest.mark.parametrize("collection", ["episodic", "declarative"])
def test_get_collection_points(client, patch_time_now, collection):
# create 100 points
n_points = 100
new_points = [{"content": f"MIAO {i}!","metadata": {"custom_key": f"custom_key_{i}"}} for i in range(n_points) ]

# Add points
for req_json in new_points:
res = client.post(
f"/memory/collections/{collection}/points", json=req_json
)
assert res.status_code == 200

# get all the points no limit, by default is 100
res = client.get(
f"/memory/collections/{collection}/points",
)
assert res.status_code == 200
json = res.json()

points = json["points"]
offset = json["next_offset"]

assert offset is None # the result should contains all the points so no offset

expected_payloads = [
{"page_content":p["content"],
"metadata":{"when":FAKE_TIMESTAMP,"source": "user", **p["metadata"]}
} for p in new_points
]

assert len(points) == len(new_points)
# check all the points contains id and vector
for point in points:
assert "id" in point
assert "vector" in point

# check points payload
points_payloads = [p["payload"] for p in points]
# sort the list and compare payload
points_payloads.sort(key=lambda p:p["page_content"])
expected_payloads.sort(key=lambda p:p["page_content"])
assert points_payloads == expected_payloads



@pytest.mark.parametrize("collection", ["episodic", "declarative"])
def test_get_collection_points_offset(client, patch_time_now, collection):
# create 200 points
n_points = 200
new_points = [{"content": f"MIAO {i}!","metadata": {"custom_key": f"custom_key_{i}"}} for i in range(n_points) ]

# Add points
for req_json in new_points:
res = client.post(
f"/memory/collections/{collection}/points", json=req_json
)
assert res.status_code == 200

# get all the points with limit 10
limit = 10
next_offset = ""
all_points = []

while True:
res = client.get(
f"/memory/collections/{collection}/points?limit={limit}&offset={next_offset}",
)
assert res.status_code == 200
json = res.json()
points = json["points"]
next_offset = json["next_offset"]
assert len(points) == limit

for point in points:
all_points.append(point)

if next_offset is None: # break if no new data
break

# create the expected payloads for all the points
expected_payloads = [
{"page_content":p["content"],
"metadata":{"when":FAKE_TIMESTAMP,"source": "user", **p["metadata"]}
} for p in new_points
]

assert len(all_points) == len(new_points)
# check all the points contains id and vector
for point in all_points:
assert "id" in point
assert "vector" in point

# check points payload
points_payloads = [p["payload"] for p in all_points]
# sort the list and compare payload
points_payloads.sort(key=lambda p:p["page_content"])
expected_payloads.sort(key=lambda p:p["page_content"])
assert points_payloads == expected_payloads


Loading