Skip to content
This repository was archived by the owner on Apr 26, 2024. It is now read-only.

Make search statement in List Room and User Admin API case-insensitive #8931

Merged
merged 3 commits into from
Dec 17, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changelog.d/8931.feature
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Make search statement in List Room and List User Admin API case-insensitive.
9 changes: 6 additions & 3 deletions docs/admin_api/user_admin_api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,12 @@ It returns a JSON body like the following:
],
"avatar_url": "<avatar_url>",
"admin": false,
"deactivated": false
"deactivated": false,
"password_hash": "$2b$12$p9B4GkqYdRTPGD",
"creation_ts": 1560432506,
"appservice_id": null,
"consent_server_notice_sent": null,
"consent_version": null
}

URL parameters:
Expand Down Expand Up @@ -139,7 +144,6 @@ A JSON body is returned with the following shape:
"users": [
{
"name": "<user_id1>",
"password_hash": "<password_hash1>",
"is_guest": 0,
"admin": 0,
"user_type": null,
Expand All @@ -148,7 +152,6 @@ A JSON body is returned with the following shape:
"avatar_url": null
}, {
"name": "<user_id2>",
"password_hash": "<password_hash2>",
"is_guest": 0,
"admin": 1,
"user_type": null,
Expand Down
7 changes: 4 additions & 3 deletions synapse/storage/databases/main/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -342,12 +342,13 @@ def get_users_paginate_txn(txn):
filters = []
args = [self.hs.config.server_name]

# `name` is in database already in lower case
if name:
filters.append("(name LIKE ? OR displayname LIKE ?)")
args.extend(["@%" + name + "%:%", "%" + name + "%"])
filters.append("(name LIKE ? OR LOWER(displayname) LIKE ?)")
args.extend(["@%" + name.lower() + "%:%", "%" + name.lower() + "%"])
elif user_id:
filters.append("name LIKE ?")
args.extend(["%" + user_id + "%"])
args.extend(["%" + user_id.lower() + "%"])

if not guests:
filters.append("is_guest = 0")
Expand Down
4 changes: 2 additions & 2 deletions synapse/storage/databases/main/room.py
Original file line number Diff line number Diff line change
Expand Up @@ -379,14 +379,14 @@ async def get_rooms_paginate(
# Filter room names by a string
where_statement = ""
if search_term:
where_statement = "WHERE state.name LIKE ?"
where_statement = "WHERE LOWER(state.name) LIKE ?"

# Our postgres db driver converts ? -> %s in SQL strings as that's the
# placeholder for postgres.
# HOWEVER, if you put a % into your SQL then everything goes wibbly.
# To get around this, we're going to surround search_term with %'s
# before giving it to the database in python instead
search_term = "%" + search_term + "%"
search_term = "%" + search_term.lower() + "%"

# Set ordering
if RoomSortOrder(order_by) == RoomSortOrder.SIZE:
Expand Down
7 changes: 7 additions & 0 deletions tests/rest/admin/test_room.py
Original file line number Diff line number Diff line change
Expand Up @@ -1050,6 +1050,13 @@ def _search_test(
_search_test(room_id_2, "else")
_search_test(room_id_2, "se")

# Test case insensitive
_search_test(room_id_1, "SOMETHING")
_search_test(room_id_1, "THING")

_search_test(room_id_2, "ELSE")
_search_test(room_id_2, "SE")

_search_test(None, "foo")
_search_test(None, "bar")
_search_test(None, "", expected_http_code=400)
Expand Down
101 changes: 98 additions & 3 deletions tests/rest/admin/test_user.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import json
import urllib.parse
from binascii import unhexlify
from typing import Optional

from mock import Mock

Expand Down Expand Up @@ -466,8 +467,12 @@ def prepare(self, reactor, clock, hs):
self.admin_user = self.register_user("admin", "pass", admin=True)
self.admin_user_tok = self.login("admin", "pass")

self.register_user("user1", "pass1", admin=False)
self.register_user("user2", "pass2", admin=False)
self.user1 = self.register_user(
"user1", "pass1", admin=False, displayname="Name 1"
)
self.user2 = self.register_user(
"user2", "pass2", admin=False, displayname="Name 2"
)

def test_no_auth(self):
"""
Expand All @@ -476,7 +481,20 @@ def test_no_auth(self):
request, channel = self.make_request("GET", self.url, b"{}")

self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual("M_MISSING_TOKEN", channel.json_body["errcode"])
self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"])

def test_requester_is_no_admin(self):
"""
If the user is not a server admin, an error is returned.
"""
other_user_token = self.login("user1", "pass1")

request, channel = self.make_request(
"GET", self.url, access_token=other_user_token,
)

self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])

def test_all_users(self):
"""
Expand All @@ -493,6 +511,83 @@ def test_all_users(self):
self.assertEqual(3, len(channel.json_body["users"]))
self.assertEqual(3, channel.json_body["total"])

# Check that all fields are available
for u in channel.json_body["users"]:
self.assertIn("name", u)
self.assertIn("is_guest", u)
self.assertIn("admin", u)
self.assertIn("user_type", u)
self.assertIn("deactivated", u)
self.assertIn("displayname", u)
self.assertIn("avatar_url", u)

def test_search_term(self):
"""Test that searching for a users works correctly"""

def _search_test(
expected_user_id: Optional[str],
search_term: str,
search_field: Optional[str] = "name",
expected_http_code: Optional[int] = 200,
):
"""Search for a user and check that the returned user's id is a match

Args:
expected_user_id: The user_id expected to be returned by the API. Set
to None to expect zero results for the search
search_term: The term to search for user names with
search_field: Field which is to request: `name` or `user_id`
expected_http_code: The expected http code for the request
"""
url = self.url + "?%s=%s" % (search_field, search_term,)
request, channel = self.make_request(
"GET", url.encode("ascii"), access_token=self.admin_user_tok,
)
self.assertEqual(expected_http_code, channel.code, msg=channel.json_body)

if expected_http_code != 200:
return

# Check that users were returned
self.assertTrue("users" in channel.json_body)
users = channel.json_body["users"]

# Check that the expected number of users were returned
expected_user_count = 1 if expected_user_id else 0
self.assertEqual(len(users), expected_user_count)
self.assertEqual(channel.json_body["total"], expected_user_count)

if expected_user_id:
# Check that the first returned user id is correct
u = users[0]
self.assertEqual(expected_user_id, u["name"])

# Perform search tests
_search_test(self.user1, "er1")
_search_test(self.user1, "me 1")

_search_test(self.user2, "er2")
_search_test(self.user2, "me 2")

_search_test(self.user1, "er1", "user_id")
_search_test(self.user2, "er2", "user_id")

# Test case insensitive
_search_test(self.user1, "ER1")
_search_test(self.user1, "NAME 1")

_search_test(self.user2, "ER2")
_search_test(self.user2, "NAME 2")

_search_test(self.user1, "ER1", "user_id")
_search_test(self.user2, "ER2", "user_id")

_search_test(None, "foo")
_search_test(None, "bar")

_search_test(None, "foo", "user_id")
_search_test(None, "bar", "user_id")


class UserRestTestCase(unittest.HomeserverTestCase):

Expand Down
7 changes: 7 additions & 0 deletions tests/storage/test_main.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,3 +48,10 @@ def test_get_users_paginate(self):

self.assertEquals(1, total)
self.assertEquals(self.displayname, users.pop()["displayname"])

users, total = yield defer.ensureDeferred(
self.store.get_users_paginate(0, 10, name="BC", guests=False)
)

self.assertEquals(1, total)
self.assertEquals(self.displayname, users.pop()["displayname"])