From 7e8ebbc5b2ab33b100a17550cc5ee20f634b4831 Mon Sep 17 00:00:00 2001 From: dklimpel <5740567+dklimpel@users.noreply.github.com> Date: Fri, 15 Jul 2022 20:47:38 +0200 Subject: [PATCH 1/4] Use `HTTPStatus` constants in place of literals in tests. --- tests/rest/client/test_login.py | 81 ++++++++++--------- tests/rest/client/test_redactions.py | 31 +++++-- tests/rest/client/test_register.py | 117 ++++++++++++++------------- 3 files changed, 131 insertions(+), 98 deletions(-) diff --git a/tests/rest/client/test_login.py b/tests/rest/client/test_login.py index f6efa5fe37f5..b90430e5d4a8 100644 --- a/tests/rest/client/test_login.py +++ b/tests/rest/client/test_login.py @@ -14,6 +14,7 @@ import json import time import urllib.parse +from http import HTTPStatus from typing import Any, Dict, List, Optional from unittest.mock import Mock from urllib.parse import urlencode @@ -134,10 +135,12 @@ def test_POST_ratelimiting_per_address(self) -> None: channel = self.make_request(b"POST", LOGIN_URL, params) if i == 5: - self.assertEqual(channel.result["code"], b"429", channel.result) + self.assertEqual( + channel.code, HTTPStatus.TOO_MANY_REQUESTS, msg=channel.result + ) retry_after_ms = int(channel.json_body["retry_after_ms"]) else: - self.assertEqual(channel.result["code"], b"200", channel.result) + self.assertEqual(channel.code, HTTPStatus.OK, msg=channel.result) # Since we're ratelimiting at 1 request/min, retry_after_ms should be lower # than 1min. @@ -152,7 +155,7 @@ def test_POST_ratelimiting_per_address(self) -> None: } channel = self.make_request(b"POST", LOGIN_URL, params) - self.assertEqual(channel.result["code"], b"200", channel.result) + self.assertEqual(channel.code, HTTPStatus.OK, msg=channel.result) @override_config( { @@ -179,10 +182,12 @@ def test_POST_ratelimiting_per_account(self) -> None: channel = self.make_request(b"POST", LOGIN_URL, params) if i == 5: - self.assertEqual(channel.result["code"], b"429", channel.result) + self.assertEqual( + channel.code, HTTPStatus.TOO_MANY_REQUESTS, msg=channel.result + ) retry_after_ms = int(channel.json_body["retry_after_ms"]) else: - self.assertEqual(channel.result["code"], b"200", channel.result) + self.assertEqual(channel.code, HTTPStatus.OK, msg=channel.result) # Since we're ratelimiting at 1 request/min, retry_after_ms should be lower # than 1min. @@ -197,7 +202,7 @@ def test_POST_ratelimiting_per_account(self) -> None: } channel = self.make_request(b"POST", LOGIN_URL, params) - self.assertEqual(channel.result["code"], b"200", channel.result) + self.assertEqual(channel.code, HTTPStatus.OK, msg=channel.result) @override_config( { @@ -224,10 +229,14 @@ def test_POST_ratelimiting_per_account_failed_attempts(self) -> None: channel = self.make_request(b"POST", LOGIN_URL, params) if i == 5: - self.assertEqual(channel.result["code"], b"429", channel.result) + self.assertEqual( + channel.code, HTTPStatus.TOO_MANY_REQUESTS, msg=channel.result + ) retry_after_ms = int(channel.json_body["retry_after_ms"]) else: - self.assertEqual(channel.result["code"], b"403", channel.result) + self.assertEqual( + channel.code, HTTPStatus.FORBIDDEN, msg=channel.result + ) # Since we're ratelimiting at 1 request/min, retry_after_ms should be lower # than 1min. @@ -242,7 +251,7 @@ def test_POST_ratelimiting_per_account_failed_attempts(self) -> None: } channel = self.make_request(b"POST", LOGIN_URL, params) - self.assertEqual(channel.result["code"], b"403", channel.result) + self.assertEqual(channel.code, HTTPStatus.FORBIDDEN, msg=channel.result) @override_config({"session_lifetime": "24h"}) def test_soft_logout(self) -> None: @@ -250,7 +259,7 @@ def test_soft_logout(self) -> None: # we shouldn't be able to make requests without an access token channel = self.make_request(b"GET", TEST_URL) - self.assertEqual(channel.result["code"], b"401", channel.result) + self.assertEqual(channel.code, HTTPStatus.UNAUTHORIZED, msg=channel.result) self.assertEqual(channel.json_body["errcode"], "M_MISSING_TOKEN") # log in as normal @@ -354,7 +363,7 @@ def test_session_can_hard_logout_after_being_soft_logged_out(self) -> None: # Now try to hard logout this session channel = self.make_request(b"POST", "/logout", access_token=access_token) - self.assertEqual(channel.result["code"], b"200", channel.result) + self.assertEqual(channel.code, HTTPStatus.OK, msg=channel.result) @override_config({"session_lifetime": "24h"}) def test_session_can_hard_logout_all_sessions_after_being_soft_logged_out( @@ -380,7 +389,7 @@ def test_session_can_hard_logout_all_sessions_after_being_soft_logged_out( # Now try to hard log out all of the user's sessions channel = self.make_request(b"POST", "/logout/all", access_token=access_token) - self.assertEqual(channel.result["code"], b"200", channel.result) + self.assertEqual(channel.code, HTTPStatus.OK, msg=channel.result) def test_login_with_overly_long_device_id_fails(self) -> None: self.register_user("mickey", "cheese") @@ -878,17 +887,17 @@ def jwt_login(self, *args: Any) -> FakeChannel: def test_login_jwt_valid_registered(self) -> None: self.register_user("kermit", "monkey") channel = self.jwt_login({"sub": "kermit"}) - self.assertEqual(channel.result["code"], b"200", channel.result) + self.assertEqual(channel.code, HTTPStatus.OK, msg=channel.result) self.assertEqual(channel.json_body["user_id"], "@kermit:test") def test_login_jwt_valid_unregistered(self) -> None: channel = self.jwt_login({"sub": "frog"}) - self.assertEqual(channel.result["code"], b"200", channel.result) + self.assertEqual(channel.code, HTTPStatus.OK, msg=channel.result) self.assertEqual(channel.json_body["user_id"], "@frog:test") def test_login_jwt_invalid_signature(self) -> None: channel = self.jwt_login({"sub": "frog"}, "notsecret") - self.assertEqual(channel.result["code"], b"403", channel.result) + self.assertEqual(channel.code, HTTPStatus.FORBIDDEN, msg=channel.result) self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN") self.assertEqual( channel.json_body["error"], @@ -897,7 +906,7 @@ def test_login_jwt_invalid_signature(self) -> None: def test_login_jwt_expired(self) -> None: channel = self.jwt_login({"sub": "frog", "exp": 864000}) - self.assertEqual(channel.result["code"], b"403", channel.result) + self.assertEqual(channel.code, HTTPStatus.FORBIDDEN, msg=channel.result) self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN") self.assertEqual( channel.json_body["error"], @@ -907,7 +916,7 @@ def test_login_jwt_expired(self) -> None: def test_login_jwt_not_before(self) -> None: now = int(time.time()) channel = self.jwt_login({"sub": "frog", "nbf": now + 3600}) - self.assertEqual(channel.result["code"], b"403", channel.result) + self.assertEqual(channel.code, HTTPStatus.FORBIDDEN, msg=channel.result) self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN") self.assertEqual( channel.json_body["error"], @@ -916,7 +925,7 @@ def test_login_jwt_not_before(self) -> None: def test_login_no_sub(self) -> None: channel = self.jwt_login({"username": "root"}) - self.assertEqual(channel.result["code"], b"403", channel.result) + self.assertEqual(channel.code, HTTPStatus.FORBIDDEN, msg=channel.result) self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN") self.assertEqual(channel.json_body["error"], "Invalid JWT") @@ -925,12 +934,12 @@ def test_login_iss(self) -> None: """Test validating the issuer claim.""" # A valid issuer. channel = self.jwt_login({"sub": "kermit", "iss": "test-issuer"}) - self.assertEqual(channel.result["code"], b"200", channel.result) + self.assertEqual(channel.code, HTTPStatus.OK, msg=channel.result) self.assertEqual(channel.json_body["user_id"], "@kermit:test") # An invalid issuer. channel = self.jwt_login({"sub": "kermit", "iss": "invalid"}) - self.assertEqual(channel.result["code"], b"403", channel.result) + self.assertEqual(channel.code, HTTPStatus.FORBIDDEN, msg=channel.result) self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN") self.assertEqual( channel.json_body["error"], @@ -939,7 +948,7 @@ def test_login_iss(self) -> None: # Not providing an issuer. channel = self.jwt_login({"sub": "kermit"}) - self.assertEqual(channel.result["code"], b"403", channel.result) + self.assertEqual(channel.code, HTTPStatus.FORBIDDEN, msg=channel.result) self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN") self.assertEqual( channel.json_body["error"], @@ -949,7 +958,7 @@ def test_login_iss(self) -> None: def test_login_iss_no_config(self) -> None: """Test providing an issuer claim without requiring it in the configuration.""" channel = self.jwt_login({"sub": "kermit", "iss": "invalid"}) - self.assertEqual(channel.result["code"], b"200", channel.result) + self.assertEqual(channel.code, HTTPStatus.OK, msg=channel.result) self.assertEqual(channel.json_body["user_id"], "@kermit:test") @override_config({"jwt_config": {**base_config, "audiences": ["test-audience"]}}) @@ -957,12 +966,12 @@ def test_login_aud(self) -> None: """Test validating the audience claim.""" # A valid audience. channel = self.jwt_login({"sub": "kermit", "aud": "test-audience"}) - self.assertEqual(channel.result["code"], b"200", channel.result) + self.assertEqual(channel.code, HTTPStatus.OK, msg=channel.result) self.assertEqual(channel.json_body["user_id"], "@kermit:test") # An invalid audience. channel = self.jwt_login({"sub": "kermit", "aud": "invalid"}) - self.assertEqual(channel.result["code"], b"403", channel.result) + self.assertEqual(channel.code, HTTPStatus.FORBIDDEN, msg=channel.result) self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN") self.assertEqual( channel.json_body["error"], @@ -971,7 +980,7 @@ def test_login_aud(self) -> None: # Not providing an audience. channel = self.jwt_login({"sub": "kermit"}) - self.assertEqual(channel.result["code"], b"403", channel.result) + self.assertEqual(channel.code, HTTPStatus.FORBIDDEN, msg=channel.result) self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN") self.assertEqual( channel.json_body["error"], @@ -981,7 +990,7 @@ def test_login_aud(self) -> None: def test_login_aud_no_config(self) -> None: """Test providing an audience without requiring it in the configuration.""" channel = self.jwt_login({"sub": "kermit", "aud": "invalid"}) - self.assertEqual(channel.result["code"], b"403", channel.result) + self.assertEqual(channel.code, HTTPStatus.FORBIDDEN, msg=channel.result) self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN") self.assertEqual( channel.json_body["error"], @@ -991,20 +1000,20 @@ def test_login_aud_no_config(self) -> None: def test_login_default_sub(self) -> None: """Test reading user ID from the default subject claim.""" channel = self.jwt_login({"sub": "kermit"}) - self.assertEqual(channel.result["code"], b"200", channel.result) + self.assertEqual(channel.code, HTTPStatus.OK, msg=channel.result) self.assertEqual(channel.json_body["user_id"], "@kermit:test") @override_config({"jwt_config": {**base_config, "subject_claim": "username"}}) def test_login_custom_sub(self) -> None: """Test reading user ID from a custom subject claim.""" channel = self.jwt_login({"username": "frog"}) - self.assertEqual(channel.result["code"], b"200", channel.result) + self.assertEqual(channel.code, HTTPStatus.OK, msg=channel.result) self.assertEqual(channel.json_body["user_id"], "@frog:test") def test_login_no_token(self) -> None: params = {"type": "org.matrix.login.jwt"} channel = self.make_request(b"POST", LOGIN_URL, params) - self.assertEqual(channel.result["code"], b"403", channel.result) + self.assertEqual(channel.code, HTTPStatus.FORBIDDEN, msg=channel.result) self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN") self.assertEqual(channel.json_body["error"], "Token field for JWT is missing") @@ -1086,12 +1095,12 @@ def jwt_login(self, *args: Any) -> FakeChannel: def test_login_jwt_valid(self) -> None: channel = self.jwt_login({"sub": "kermit"}) - self.assertEqual(channel.result["code"], b"200", channel.result) + self.assertEqual(channel.code, HTTPStatus.OK, msg=channel.result) self.assertEqual(channel.json_body["user_id"], "@kermit:test") def test_login_jwt_invalid_signature(self) -> None: channel = self.jwt_login({"sub": "frog"}, self.bad_privatekey) - self.assertEqual(channel.result["code"], b"403", channel.result) + self.assertEqual(channel.code, HTTPStatus.FORBIDDEN, msg=channel.result) self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN") self.assertEqual( channel.json_body["error"], @@ -1152,7 +1161,7 @@ def test_login_appservice_user(self) -> None: b"POST", LOGIN_URL, params, access_token=self.service.token ) - self.assertEqual(channel.result["code"], b"200", channel.result) + self.assertEqual(channel.code, HTTPStatus.OK, msg=channel.result) def test_login_appservice_user_bot(self) -> None: """Test that the appservice bot can use /login""" @@ -1166,7 +1175,7 @@ def test_login_appservice_user_bot(self) -> None: b"POST", LOGIN_URL, params, access_token=self.service.token ) - self.assertEqual(channel.result["code"], b"200", channel.result) + self.assertEqual(channel.code, HTTPStatus.OK, msg=channel.result) def test_login_appservice_wrong_user(self) -> None: """Test that non-as users cannot login with the as token""" @@ -1180,7 +1189,7 @@ def test_login_appservice_wrong_user(self) -> None: b"POST", LOGIN_URL, params, access_token=self.service.token ) - self.assertEqual(channel.result["code"], b"403", channel.result) + self.assertEqual(channel.code, HTTPStatus.FORBIDDEN, msg=channel.result) def test_login_appservice_wrong_as(self) -> None: """Test that as users cannot login with wrong as token""" @@ -1194,7 +1203,7 @@ def test_login_appservice_wrong_as(self) -> None: b"POST", LOGIN_URL, params, access_token=self.another_service.token ) - self.assertEqual(channel.result["code"], b"403", channel.result) + self.assertEqual(channel.code, HTTPStatus.FORBIDDEN, msg=channel.result) def test_login_appservice_no_token(self) -> None: """Test that users must provide a token when using the appservice @@ -1208,7 +1217,7 @@ def test_login_appservice_no_token(self) -> None: } channel = self.make_request(b"POST", LOGIN_URL, params) - self.assertEqual(channel.result["code"], b"401", channel.result) + self.assertEqual(channel.code, HTTPStatus.UNAUTHORIZED, msg=channel.result) @skip_unless(HAS_OIDC, "requires OIDC") diff --git a/tests/rest/client/test_redactions.py b/tests/rest/client/test_redactions.py index 7401b5e0c0fa..c4da6303a52f 100644 --- a/tests/rest/client/test_redactions.py +++ b/tests/rest/client/test_redactions.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from http import HTTPStatus from typing import List from twisted.test.proto_helpers import MemoryReactor @@ -67,7 +68,11 @@ def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: ) def _redact_event( - self, access_token: str, room_id: str, event_id: str, expect_code: int = 200 + self, + access_token: str, + room_id: str, + event_id: str, + expect_code: int = HTTPStatus.OK, ) -> JsonDict: """Helper function to send a redaction event. @@ -76,12 +81,12 @@ def _redact_event( path = "/_matrix/client/r0/rooms/%s/redact/%s" % (room_id, event_id) channel = self.make_request("POST", path, content={}, access_token=access_token) - self.assertEqual(int(channel.result["code"]), expect_code) + self.assertEqual(channel.code, expect_code) return channel.json_body def _sync_room_timeline(self, access_token: str, room_id: str) -> List[JsonDict]: channel = self.make_request("GET", "sync", access_token=self.mod_access_token) - self.assertEqual(channel.result["code"], b"200") + self.assertEqual(channel.code, HTTPStatus.OK) room_sync = channel.json_body["rooms"]["join"][room_id] return room_sync["timeline"]["events"] @@ -117,7 +122,10 @@ def test_redact_event_as_normal(self) -> None: # as a normal, try to redact the admin's event self._redact_event( - self.other_access_token, self.room_id, admin_msg_id, expect_code=403 + self.other_access_token, + self.room_id, + admin_msg_id, + expect_code=HTTPStatus.FORBIDDEN ) # now try to redact our own event @@ -153,7 +161,10 @@ def test_redact_nonexistent_event(self) -> None: # ... but normals cannot self._redact_event( - self.other_access_token, self.room_id, "$zzz", expect_code=404 + self.other_access_token, + self.room_id, + "$zzz", + expect_code=HTTPStatus.NOT_FOUND, ) # when we sync, we should see only the valid redaction @@ -178,12 +189,18 @@ def test_redact_create_event(self) -> None: # room moderators cannot send redactions for create events self._redact_event( - self.mod_access_token, self.room_id, create_event_id, expect_code=403 + self.mod_access_token, + self.room_id, + create_event_id, + expect_code=HTTPStatus.FORBIDDEN, ) # and nor can normals self._redact_event( - self.other_access_token, self.room_id, create_event_id, expect_code=403 + self.other_access_token, + self.room_id, + create_event_id, + expect_code=HTTPStatus.FORBIDDEN, ) def test_redact_event_as_moderator_ratelimit(self) -> None: diff --git a/tests/rest/client/test_register.py b/tests/rest/client/test_register.py index cb274587465f..d0c511296130 100644 --- a/tests/rest/client/test_register.py +++ b/tests/rest/client/test_register.py @@ -16,6 +16,7 @@ import datetime import json import os +from http import HTTPStatus from typing import Any, Dict, List, Tuple import pkg_resources @@ -70,7 +71,7 @@ def test_POST_appservice_registration_valid(self) -> None: b"POST", self.url + b"?access_token=i_am_an_app_service", request_data ) - self.assertEqual(channel.result["code"], b"200", channel.result) + self.assertEqual(channel.code, HTTPStatus.OK, msg=channel.result) det_data = {"user_id": user_id, "home_server": self.hs.hostname} self.assertDictContainsSubset(det_data, channel.json_body) @@ -91,7 +92,7 @@ def test_POST_appservice_registration_no_type(self) -> None: b"POST", self.url + b"?access_token=i_am_an_app_service", request_data ) - self.assertEqual(channel.result["code"], b"400", channel.result) + self.assertEqual(channel.code, HTTPStatus.BAD_REQUEST, msg=channel.result) def test_POST_appservice_registration_invalid(self) -> None: self.appservice = None # no application service exists @@ -102,20 +103,20 @@ def test_POST_appservice_registration_invalid(self) -> None: b"POST", self.url + b"?access_token=i_am_an_app_service", request_data ) - self.assertEqual(channel.result["code"], b"401", channel.result) + self.assertEqual(channel.code, HTTPStatus.UNAUTHORIZED, msg=channel.result) def test_POST_bad_password(self) -> None: request_data = json.dumps({"username": "kermit", "password": 666}) channel = self.make_request(b"POST", self.url, request_data) - self.assertEqual(channel.result["code"], b"400", channel.result) + self.assertEqual(channel.code, HTTPStatus.BAD_REQUEST, msg=channel.result) self.assertEqual(channel.json_body["error"], "Invalid password") def test_POST_bad_username(self) -> None: request_data = json.dumps({"username": 777, "password": "monkey"}) channel = self.make_request(b"POST", self.url, request_data) - self.assertEqual(channel.result["code"], b"400", channel.result) + self.assertEqual(channel.code, HTTPStatus.BAD_REQUEST, msg=channel.result) self.assertEqual(channel.json_body["error"], "Invalid username") def test_POST_user_valid(self) -> None: @@ -135,7 +136,7 @@ def test_POST_user_valid(self) -> None: "home_server": self.hs.hostname, "device_id": device_id, } - self.assertEqual(channel.result["code"], b"200", channel.result) + self.assertEqual(channel.code, HTTPStatus.OK, msg=channel.result) self.assertDictContainsSubset(det_data, channel.json_body) @override_config({"enable_registration": False}) @@ -145,7 +146,7 @@ def test_POST_disabled_registration(self) -> None: channel = self.make_request(b"POST", self.url, request_data) - self.assertEqual(channel.result["code"], b"403", channel.result) + self.assertEqual(channel.code, HTTPStatus.FORBIDDEN, msg=channel.result) self.assertEqual(channel.json_body["error"], "Registration has been disabled") self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN") @@ -156,7 +157,7 @@ def test_POST_guest_registration(self) -> None: channel = self.make_request(b"POST", self.url + b"?kind=guest", b"{}") det_data = {"home_server": self.hs.hostname, "device_id": "guest_device"} - self.assertEqual(channel.result["code"], b"200", channel.result) + self.assertEqual(channel.code, HTTPStatus.OK, msg=channel.result) self.assertDictContainsSubset(det_data, channel.json_body) def test_POST_disabled_guest_registration(self) -> None: @@ -164,7 +165,7 @@ def test_POST_disabled_guest_registration(self) -> None: channel = self.make_request(b"POST", self.url + b"?kind=guest", b"{}") - self.assertEqual(channel.result["code"], b"403", channel.result) + self.assertEqual(channel.code, HTTPStatus.FORBIDDEN, msg=channel.result) self.assertEqual(channel.json_body["error"], "Guest access is disabled") @override_config({"rc_registration": {"per_second": 0.17, "burst_count": 5}}) @@ -174,16 +175,18 @@ def test_POST_ratelimiting_guest(self) -> None: channel = self.make_request(b"POST", url, b"{}") if i == 5: - self.assertEqual(channel.result["code"], b"429", channel.result) + self.assertEqual( + channel.code, HTTPStatus.TOO_MANY_REQUESTS, msg=channel.result + ) retry_after_ms = int(channel.json_body["retry_after_ms"]) else: - self.assertEqual(channel.result["code"], b"200", channel.result) + self.assertEqual(channel.code, HTTPStatus.OK, msg=channel.result) self.reactor.advance(retry_after_ms / 1000.0 + 1.0) channel = self.make_request(b"POST", self.url + b"?kind=guest", b"{}") - self.assertEqual(channel.result["code"], b"200", channel.result) + self.assertEqual(channel.code, HTTPStatus.OK, msg=channel.result) @override_config({"rc_registration": {"per_second": 0.17, "burst_count": 5}}) def test_POST_ratelimiting(self) -> None: @@ -198,16 +201,18 @@ def test_POST_ratelimiting(self) -> None: channel = self.make_request(b"POST", self.url, request_data) if i == 5: - self.assertEqual(channel.result["code"], b"429", channel.result) + self.assertEqual( + channel.code, HTTPStatus.TOO_MANY_REQUESTS, msg=channel.result + ) retry_after_ms = int(channel.json_body["retry_after_ms"]) else: - self.assertEqual(channel.result["code"], b"200", channel.result) + self.assertEqual(channel.code, HTTPStatus.OK, msg=channel.result) self.reactor.advance(retry_after_ms / 1000.0 + 1.0) channel = self.make_request(b"POST", self.url + b"?kind=guest", b"{}") - self.assertEqual(channel.result["code"], b"200", channel.result) + self.assertEqual(channel.code, HTTPStatus.OK, msg=channel.result) @override_config({"registration_requires_token": True}) def test_POST_registration_requires_token(self) -> None: @@ -235,7 +240,7 @@ def test_POST_registration_requires_token(self) -> None: # Request without auth to get flows and session channel = self.make_request(b"POST", self.url, json.dumps(params)) - self.assertEqual(channel.result["code"], b"401", channel.result) + self.assertEqual(channel.code, HTTPStatus.UNAUTHORIZED, msg=channel.result) flows = channel.json_body["flows"] # Synapse adds a dummy stage to differentiate flows where otherwise one # flow would be a subset of another flow. @@ -253,7 +258,7 @@ def test_POST_registration_requires_token(self) -> None: } request_data = json.dumps(params) channel = self.make_request(b"POST", self.url, request_data) - self.assertEqual(channel.result["code"], b"401", channel.result) + self.assertEqual(channel.code, HTTPStatus.UNAUTHORIZED, msg=channel.result) completed = channel.json_body["completed"] self.assertCountEqual([LoginType.REGISTRATION_TOKEN], completed) @@ -269,7 +274,7 @@ def test_POST_registration_requires_token(self) -> None: "home_server": self.hs.hostname, "device_id": device_id, } - self.assertEqual(channel.result["code"], b"200", channel.result) + self.assertEqual(channel.code, HTTPStatus.OK, msg=channel.result) self.assertDictContainsSubset(det_data, channel.json_body) # Check the `completed` counter has been incremented and pending is 0 @@ -299,21 +304,21 @@ def test_POST_registration_token_invalid(self) -> None: "session": session, } channel = self.make_request(b"POST", self.url, json.dumps(params)) - self.assertEqual(channel.result["code"], b"401", channel.result) + self.assertEqual(channel.code, HTTPStatus.UNAUTHORIZED, msg=channel.result) self.assertEqual(channel.json_body["errcode"], Codes.MISSING_PARAM) self.assertEqual(channel.json_body["completed"], []) # Test with non-string (invalid) params["auth"]["token"] = 1234 channel = self.make_request(b"POST", self.url, json.dumps(params)) - self.assertEqual(channel.result["code"], b"401", channel.result) + self.assertEqual(channel.code, HTTPStatus.UNAUTHORIZED, msg=channel.result) self.assertEqual(channel.json_body["errcode"], Codes.INVALID_PARAM) self.assertEqual(channel.json_body["completed"], []) # Test with unknown token (invalid) params["auth"]["token"] = "1234" channel = self.make_request(b"POST", self.url, json.dumps(params)) - self.assertEqual(channel.result["code"], b"401", channel.result) + self.assertEqual(channel.code, HTTPStatus.UNAUTHORIZED, msg=channel.result) self.assertEqual(channel.json_body["errcode"], Codes.UNAUTHORIZED) self.assertEqual(channel.json_body["completed"], []) @@ -367,7 +372,7 @@ def test_POST_registration_token_limit_uses(self) -> None: "session": session2, } channel = self.make_request(b"POST", self.url, json.dumps(params2)) - self.assertEqual(channel.result["code"], b"401", channel.result) + self.assertEqual(channel.code, HTTPStatus.UNAUTHORIZED, msg=channel.result) self.assertEqual(channel.json_body["errcode"], Codes.UNAUTHORIZED) self.assertEqual(channel.json_body["completed"], []) @@ -387,7 +392,7 @@ def test_POST_registration_token_limit_uses(self) -> None: # Check auth still fails when using token with session2 channel = self.make_request(b"POST", self.url, json.dumps(params2)) - self.assertEqual(channel.result["code"], b"401", channel.result) + self.assertEqual(channel.code, HTTPStatus.UNAUTHORIZED, msg=channel.result) self.assertEqual(channel.json_body["errcode"], Codes.UNAUTHORIZED) self.assertEqual(channel.json_body["completed"], []) @@ -421,7 +426,7 @@ def test_POST_registration_token_expiry(self) -> None: "session": session, } channel = self.make_request(b"POST", self.url, json.dumps(params)) - self.assertEqual(channel.result["code"], b"401", channel.result) + self.assertEqual(channel.code, HTTPStatus.UNAUTHORIZED, msg=channel.result) self.assertEqual(channel.json_body["errcode"], Codes.UNAUTHORIZED) self.assertEqual(channel.json_body["completed"], []) @@ -576,7 +581,7 @@ def test_POST_registration_token_session_expiry_deleted_token(self) -> None: def test_advertised_flows(self) -> None: channel = self.make_request(b"POST", self.url, b"{}") - self.assertEqual(channel.result["code"], b"401", channel.result) + self.assertEqual(channel.code, HTTPStatus.UNAUTHORIZED, msg=channel.result) flows = channel.json_body["flows"] # with the stock config, we only expect the dummy flow @@ -599,7 +604,7 @@ def test_advertised_flows(self) -> None: ) def test_advertised_flows_captcha_and_terms_and_3pids(self) -> None: channel = self.make_request(b"POST", self.url, b"{}") - self.assertEqual(channel.result["code"], b"401", channel.result) + self.assertEqual(channel.code, HTTPStatus.UNAUTHORIZED, msg=channel.result) flows = channel.json_body["flows"] self.assertCountEqual( @@ -631,7 +636,7 @@ def test_advertised_flows_captcha_and_terms_and_3pids(self) -> None: ) def test_advertised_flows_no_msisdn_email_required(self) -> None: channel = self.make_request(b"POST", self.url, b"{}") - self.assertEqual(channel.result["code"], b"401", channel.result) + self.assertEqual(channel.code, HTTPStatus.UNAUTHORIZED, msg=channel.result) flows = channel.json_body["flows"] # with the stock config, we expect all four combinations of 3pid @@ -675,7 +680,7 @@ def test_request_token_existing_email_inhibit_error(self) -> None: b"register/email/requestToken", {"client_secret": "foobar", "email": email, "send_attempt": 1}, ) - self.assertEqual(200, channel.code, channel.result) + self.assertEqual(HTTPStatus.OK, channel.code, channel.result) self.assertIsNotNone(channel.json_body.get("sid")) @@ -698,7 +703,7 @@ def test_reject_invalid_email(self) -> None: b"register/email/requestToken", {"client_secret": "foobar", "email": "email@@email", "send_attempt": 1}, ) - self.assertEqual(400, channel.code, channel.result) + self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, channel.result) # Check error to ensure that we're not erroring due to a bug in the test. self.assertEqual( channel.json_body, @@ -711,7 +716,7 @@ def test_reject_invalid_email(self) -> None: b"register/email/requestToken", {"client_secret": "foobar", "email": "email", "send_attempt": 1}, ) - self.assertEqual(400, channel.code, channel.result) + self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, channel.result) self.assertEqual( channel.json_body, {"errcode": "M_UNKNOWN", "error": "Unable to parse email address"}, @@ -724,7 +729,7 @@ def test_reject_invalid_email(self) -> None: b"register/email/requestToken", {"client_secret": "foobar", "email": email, "send_attempt": 1}, ) - self.assertEqual(400, channel.code, channel.result) + self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, channel.result) self.assertEqual( channel.json_body, {"errcode": "M_UNKNOWN", "error": "Unable to parse email address"}, @@ -749,7 +754,7 @@ def test_inhibit_user_in_use_error(self) -> None: # Check that /available correctly ignores the username provided despite the # username being already registered. channel = self.make_request("GET", "register/available?username=" + username) - self.assertEqual(200, channel.code, channel.result) + self.assertEqual(HTTPStatus.OK, channel.code, channel.result) # Test that when starting a UIA registration flow the request doesn't fail because # of a conflicting username @@ -758,7 +763,7 @@ def test_inhibit_user_in_use_error(self) -> None: "register", {"username": username, "type": "m.login.password", "password": "foo"}, ) - self.assertEqual(channel.code, 401) + self.assertEqual(channel.code, HTTPStatus.UNAUTHORIZED) self.assertIn("session", channel.json_body) # Test that finishing the registration fails because of a conflicting username. @@ -768,7 +773,7 @@ def test_inhibit_user_in_use_error(self) -> None: "register", {"auth": {"session": session, "type": LoginType.DUMMY}}, ) - self.assertEqual(channel.code, 400, channel.json_body) + self.assertEqual(channel.code, HTTPStatus.UNAUTHORIZED, channel.json_body) self.assertEqual(channel.json_body["errcode"], Codes.USER_IN_USE) @@ -803,13 +808,13 @@ def test_validity_period(self) -> None: # endpoint. channel = self.make_request(b"GET", "/sync", access_token=tok) - self.assertEqual(channel.result["code"], b"200", channel.result) + self.assertEqual(channel.code, HTTPStatus.OK, msg=channel.result) self.reactor.advance(datetime.timedelta(weeks=1).total_seconds()) channel = self.make_request(b"GET", "/sync", access_token=tok) - self.assertEqual(channel.result["code"], b"403", channel.result) + self.assertEqual(channel.code, HTTPStatus.FORBIDDEN, msg=channel.result) self.assertEqual( channel.json_body["errcode"], Codes.EXPIRED_ACCOUNT, channel.result ) @@ -830,12 +835,12 @@ def test_manual_renewal(self) -> None: params = {"user_id": user_id} request_data = json.dumps(params) channel = self.make_request(b"POST", url, request_data, access_token=admin_tok) - self.assertEqual(channel.result["code"], b"200", channel.result) + self.assertEqual(channel.code, HTTPStatus.OK, msg=channel.result) # The specific endpoint doesn't matter, all we need is an authenticated # endpoint. channel = self.make_request(b"GET", "/sync", access_token=tok) - self.assertEqual(channel.result["code"], b"200", channel.result) + self.assertEqual(channel.code, HTTPStatus.OK, msg=channel.result) def test_manual_expire(self) -> None: user_id = self.register_user("kermit", "monkey") @@ -852,12 +857,12 @@ def test_manual_expire(self) -> None: } request_data = json.dumps(params) channel = self.make_request(b"POST", url, request_data, access_token=admin_tok) - self.assertEqual(channel.result["code"], b"200", channel.result) + self.assertEqual(channel.code, HTTPStatus.OK, msg=channel.result) # The specific endpoint doesn't matter, all we need is an authenticated # endpoint. channel = self.make_request(b"GET", "/sync", access_token=tok) - self.assertEqual(channel.result["code"], b"403", channel.result) + self.assertEqual(channel.code, HTTPStatus.FORBIDDEN, msg=channel.result) self.assertEqual( channel.json_body["errcode"], Codes.EXPIRED_ACCOUNT, channel.result ) @@ -877,18 +882,18 @@ def test_logging_out_expired_user(self) -> None: } request_data = json.dumps(params) channel = self.make_request(b"POST", url, request_data, access_token=admin_tok) - self.assertEqual(channel.result["code"], b"200", channel.result) + self.assertEqual(channel.code, HTTPStatus.OK, msg=channel.result) # Try to log the user out channel = self.make_request(b"POST", "/logout", access_token=tok) - self.assertEqual(channel.result["code"], b"200", channel.result) + self.assertEqual(channel.code, HTTPStatus.OK, msg=channel.result) # Log the user in again (allowed for expired accounts) tok = self.login("kermit", "monkey") # Try to log out all of the user's sessions channel = self.make_request(b"POST", "/logout/all", access_token=tok) - self.assertEqual(channel.result["code"], b"200", channel.result) + self.assertEqual(channel.code, HTTPStatus.OK, msg=channel.result) class AccountValidityRenewalByEmailTestCase(unittest.HomeserverTestCase): @@ -963,7 +968,7 @@ def test_renewal_email(self) -> None: renewal_token = self.get_success(self.store.get_renewal_token_for_user(user_id)) url = "/_matrix/client/unstable/account_validity/renew?token=%s" % renewal_token channel = self.make_request(b"GET", url) - self.assertEqual(channel.result["code"], b"200", channel.result) + self.assertEqual(channel.code, HTTPStatus.OK, msg=channel.result) # Check that we're getting HTML back. content_type = channel.headers.getRawHeaders(b"Content-Type") @@ -981,7 +986,7 @@ def test_renewal_email(self) -> None: # Move 1 day forward. Try to renew with the same token again. url = "/_matrix/client/unstable/account_validity/renew?token=%s" % renewal_token channel = self.make_request(b"GET", url) - self.assertEqual(channel.result["code"], b"200", channel.result) + self.assertEqual(channel.code, HTTPStatus.OK, msg=channel.result) # Check that we're getting HTML back. content_type = channel.headers.getRawHeaders(b"Content-Type") @@ -1001,14 +1006,14 @@ def test_renewal_email(self) -> None: # succeed. self.reactor.advance(datetime.timedelta(days=3).total_seconds()) channel = self.make_request(b"GET", "/sync", access_token=tok) - self.assertEqual(channel.result["code"], b"200", channel.result) + self.assertEqual(channel.code, HTTPStatus.OK, msg=channel.result) def test_renewal_invalid_token(self) -> None: # Hit the renewal endpoint with an invalid token and check that it behaves as # expected, i.e. that it responds with 404 Not Found and the correct HTML. url = "/_matrix/client/unstable/account_validity/renew?token=123" channel = self.make_request(b"GET", url) - self.assertEqual(channel.result["code"], b"404", channel.result) + self.assertEqual(channel.code, HTTPStatus.NOT_FOUND, msg=channel.result) # Check that we're getting HTML back. content_type = channel.headers.getRawHeaders(b"Content-Type") @@ -1032,7 +1037,7 @@ def test_manual_email_send(self) -> None: "/_matrix/client/unstable/account_validity/send_mail", access_token=tok, ) - self.assertEqual(channel.result["code"], b"200", channel.result) + self.assertEqual(channel.code, HTTPStatus.OK, msg=channel.result) self.assertEqual(len(self.email_attempts), 1) @@ -1054,7 +1059,7 @@ def test_deactivated_user(self) -> None: channel = self.make_request( "POST", "account/deactivate", request_data, access_token=tok ) - self.assertEqual(channel.code, 200) + self.assertEqual(channel.code, HTTPStatus.OK) self.reactor.advance(datetime.timedelta(days=8).total_seconds()) @@ -1107,7 +1112,7 @@ def test_manual_email_send_expired_account(self) -> None: "/_matrix/client/unstable/account_validity/send_mail", access_token=tok, ) - self.assertEqual(channel.result["code"], b"200", channel.result) + self.assertEqual(channel.code, HTTPStatus.OK, msg=channel.result) self.assertEqual(len(self.email_attempts), 1) @@ -1187,7 +1192,7 @@ def test_GET_token_valid(self) -> None: b"GET", f"{self.url}?token={token}", ) - self.assertEqual(channel.result["code"], b"200", channel.result) + self.assertEqual(channel.code, HTTPStatus.OK, msg=channel.result) self.assertEqual(channel.json_body["valid"], True) def test_GET_token_invalid(self) -> None: @@ -1196,7 +1201,7 @@ def test_GET_token_invalid(self) -> None: b"GET", f"{self.url}?token={token}", ) - self.assertEqual(channel.result["code"], b"200", channel.result) + self.assertEqual(channel.code, HTTPStatus.OK, msg=channel.result) self.assertEqual(channel.json_body["valid"], False) @override_config( @@ -1212,10 +1217,12 @@ def test_GET_ratelimiting(self) -> None: ) if i == 5: - self.assertEqual(channel.result["code"], b"429", channel.result) + self.assertEqual( + channel.code, HTTPStatus.TOO_MANY_REQUESTS, msg=channel.result + ) retry_after_ms = int(channel.json_body["retry_after_ms"]) else: - self.assertEqual(channel.result["code"], b"200", channel.result) + self.assertEqual(channel.code, HTTPStatus.OK, msg=channel.result) self.reactor.advance(retry_after_ms / 1000.0 + 1.0) @@ -1223,4 +1230,4 @@ def test_GET_ratelimiting(self) -> None: b"GET", f"{self.url}?token={token}", ) - self.assertEqual(channel.result["code"], b"200", channel.result) + self.assertEqual(channel.code, HTTPStatus.OK, msg=channel.result) From 77bacb830c0d36a5303c05b36214e1d29fb4f07d Mon Sep 17 00:00:00 2001 From: dklimpel <5740567+dklimpel@users.noreply.github.com> Date: Fri, 15 Jul 2022 20:49:34 +0200 Subject: [PATCH 2/4] newsfile --- changelog.d/13298.misc | 1 + 1 file changed, 1 insertion(+) create mode 100644 changelog.d/13298.misc diff --git a/changelog.d/13298.misc b/changelog.d/13298.misc new file mode 100644 index 000000000000..545a62369f43 --- /dev/null +++ b/changelog.d/13298.misc @@ -0,0 +1 @@ +Use `HTTPStatus` constants in place of literals in tests. \ No newline at end of file From 1a3e478974ef74a1ffa8bd3acf63bce519c8ca5a Mon Sep 17 00:00:00 2001 From: dklimpel <5740567+dklimpel@users.noreply.github.com> Date: Fri, 15 Jul 2022 20:54:38 +0200 Subject: [PATCH 3/4] codestyle --- tests/rest/client/test_login.py | 4 +--- tests/rest/client/test_redactions.py | 2 +- 2 files changed, 2 insertions(+), 4 deletions(-) diff --git a/tests/rest/client/test_login.py b/tests/rest/client/test_login.py index b90430e5d4a8..d0733616b118 100644 --- a/tests/rest/client/test_login.py +++ b/tests/rest/client/test_login.py @@ -234,9 +234,7 @@ def test_POST_ratelimiting_per_account_failed_attempts(self) -> None: ) retry_after_ms = int(channel.json_body["retry_after_ms"]) else: - self.assertEqual( - channel.code, HTTPStatus.FORBIDDEN, msg=channel.result - ) + self.assertEqual(channel.code, HTTPStatus.FORBIDDEN, msg=channel.result) # Since we're ratelimiting at 1 request/min, retry_after_ms should be lower # than 1min. diff --git a/tests/rest/client/test_redactions.py b/tests/rest/client/test_redactions.py index c4da6303a52f..909c017e8840 100644 --- a/tests/rest/client/test_redactions.py +++ b/tests/rest/client/test_redactions.py @@ -125,7 +125,7 @@ def test_redact_event_as_normal(self) -> None: self.other_access_token, self.room_id, admin_msg_id, - expect_code=HTTPStatus.FORBIDDEN + expect_code=HTTPStatus.FORBIDDEN, ) # now try to redact our own event From 7750a44aa56d4b002bf99578c184f147db855e0c Mon Sep 17 00:00:00 2001 From: dklimpel <5740567+dklimpel@users.noreply.github.com> Date: Fri, 15 Jul 2022 21:14:05 +0200 Subject: [PATCH 4/4] fix test `test_inhibit_user_in_use_error` --- tests/rest/client/test_register.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/rest/client/test_register.py b/tests/rest/client/test_register.py index d0c511296130..df5de27df7db 100644 --- a/tests/rest/client/test_register.py +++ b/tests/rest/client/test_register.py @@ -773,7 +773,7 @@ def test_inhibit_user_in_use_error(self) -> None: "register", {"auth": {"session": session, "type": LoginType.DUMMY}}, ) - self.assertEqual(channel.code, HTTPStatus.UNAUTHORIZED, channel.json_body) + self.assertEqual(channel.code, HTTPStatus.BAD_REQUEST, channel.json_body) self.assertEqual(channel.json_body["errcode"], Codes.USER_IN_USE)