Skip to content
This repository was archived by the owner on Apr 26, 2024. It is now read-only.

Commit cde552e

Browse files
committed
Restructure the CAS code to be more like SAML/OIDC.
1 parent a9e5a2a commit cde552e

File tree

2 files changed

+41
-34
lines changed

2 files changed

+41
-34
lines changed

synapse/handlers/cas_handler.py

+34-27
Original file line numberDiff line numberDiff line change
@@ -76,9 +76,9 @@ def _build_service_param(self, args: Dict[str, str]) -> str:
7676

7777
async def _validate_ticket(
7878
self, ticket: str, service_args: Dict[str, str]
79-
) -> Tuple[str, Optional[str]]:
79+
) -> Tuple[str, Dict[str, Optional[str]]]:
8080
"""
81-
Validate a CAS ticket with the server, parse the response, and return the user and display name.
81+
Validate a CAS ticket with the server, parse the response, and return the user and other attributes.
8282
8383
Args:
8484
ticket: The CAS ticket from the client.
@@ -97,22 +97,7 @@ async def _validate_ticket(
9797
# even if that's being used old-http style to signal end-of-data
9898
body = pde.response
9999

100-
user, attributes = self._parse_cas_response(body)
101-
displayname = attributes.pop(self._cas_displayname_attribute, None)
102-
103-
for required_attribute, required_value in self._cas_required_attributes.items():
104-
# If required attribute was not in CAS Response - Forbidden
105-
if required_attribute not in attributes:
106-
raise LoginError(401, "Unauthorized", errcode=Codes.UNAUTHORIZED)
107-
108-
# Also need to check value
109-
if required_value is not None:
110-
actual_value = attributes[required_attribute]
111-
# If required attribute value does not match expected - Forbidden
112-
if required_value != actual_value:
113-
raise LoginError(401, "Unauthorized", errcode=Codes.UNAUTHORIZED)
114-
115-
return user, displayname
100+
return self._parse_cas_response(body)
116101

117102
def _parse_cas_response(
118103
self, cas_response_body: bytes
@@ -208,7 +193,7 @@ async def handle_ticket(
208193
args["redirectUrl"] = client_redirect_url
209194
if session:
210195
args["session"] = session
211-
username, user_display_name = await self._validate_ticket(ticket, args)
196+
username, attributes = await self._validate_ticket(ticket, args)
212197

213198
# first check if we're doing a UIA
214199
if session:
@@ -218,14 +203,36 @@ async def handle_ticket(
218203

219204
# otherwise, we're handling a login request.
220205

206+
# Ensure that the attributes of the logged in user meet the required
207+
# attributes.
208+
for required_attribute, required_value in self._cas_required_attributes.items():
209+
# If required attribute was not in CAS Response - Forbidden
210+
if required_attribute not in attributes:
211+
self._sso_handler.render_error(
212+
request, "unauthorised", "You are not authorised to log in here."
213+
)
214+
return
215+
216+
# Also need to check value
217+
if required_value is not None:
218+
actual_value = attributes[required_attribute]
219+
# If required attribute value does not match expected - Forbidden
220+
if required_value != actual_value:
221+
self._sso_handler.render_error(
222+
request,
223+
"unauthorised",
224+
"You are not authorised to log in here.",
225+
)
226+
return
227+
221228
# Pull out the user-agent and IP from the request.
222229
user_agent = request.get_user_agent("")
223230
ip_address = self.hs.get_ip_from_request(request)
224231

225232
# Get the matrix ID from the CAS username.
226233
try:
227234
user_id = await self._map_cas_user_to_matrix_user(
228-
username, user_display_name, user_agent, ip_address
235+
username, attributes, user_agent, ip_address
229236
)
230237
except MappingException as e:
231238
logger.exception("Could not map user")
@@ -242,7 +249,7 @@ async def handle_ticket(
242249
async def _map_cas_user_to_matrix_user(
243250
self,
244251
remote_user_id: str,
245-
display_name: Optional[str],
252+
attributes: Dict[str, Optional[str]],
246253
user_agent: str,
247254
ip_address: str,
248255
) -> str:
@@ -251,7 +258,7 @@ async def _map_cas_user_to_matrix_user(
251258
252259
Args:
253260
remote_user_id: The username from the CAS response.
254-
display_name: The display name from the CAS response.
261+
attributes: Additional attributes from the CAS response.
255262
user_agent: The user agent of the client making the request.
256263
ip_address: The IP address of the client making the request.
257264
@@ -262,12 +269,14 @@ async def _map_cas_user_to_matrix_user(
262269
The user ID associated with this response.
263270
"""
264271

272+
# Note that CAS does not support a mapping provider, so the logic is hard-coded.
273+
localpart = map_username_to_mxid_localpart(remote_user_id)
274+
display_name = attributes.pop(self._cas_displayname_attribute, None)
275+
265276
async def cas_response_to_user_attributes(failures: int) -> UserAttributes:
266277
"""
267278
Map from CAS attributes to user attributes.
268279
"""
269-
localpart = map_username_to_mxid_localpart(remote_user_id)
270-
271280
# Due to the grandfathering logic matching any previously registered
272281
# mxids it isn't expected for there to be any failures.
273282
if failures:
@@ -278,9 +287,7 @@ async def cas_response_to_user_attributes(failures: int) -> UserAttributes:
278287
async def grandfather_existing_users() -> Optional[str]:
279288
# Since CAS did not used to support storing data into the user_external_ids
280289
# tables, we need to attempt to map to existing users.
281-
user_id = UserID(
282-
map_username_to_mxid_localpart(remote_user_id), self._hostname
283-
).to_string()
290+
user_id = UserID(localpart, self._hostname).to_string()
284291

285292
logger.debug(
286293
"Looking for existing account based on mapped %s", user_id,

tests/handlers/test_cas.py

+7-7
Original file line numberDiff line numberDiff line change
@@ -46,10 +46,10 @@ def make_homeserver(self, reactor, clock):
4646
def test_map_cas_user_to_user(self):
4747
"""Ensure that mapping the CAS user returned from a provider to an MXID works properly."""
4848
cas_user_id = "test_user"
49-
display_name = ""
49+
attributes = {}
5050
mxid = self.get_success(
5151
self.handler._map_cas_user_to_matrix_user(
52-
cas_user_id, display_name, "user-agent", "10.10.10.10"
52+
cas_user_id, attributes, "user-agent", "10.10.10.10"
5353
)
5454
)
5555
self.assertEqual(mxid, "@test_user:test")
@@ -63,29 +63,29 @@ def test_map_cas_user_to_existing_user(self):
6363

6464
# Map a user via SSO.
6565
cas_user_id = "test_user"
66-
display_name = ""
66+
attributes = {}
6767
mxid = self.get_success(
6868
self.handler._map_cas_user_to_matrix_user(
69-
cas_user_id, display_name, "user-agent", "10.10.10.10"
69+
cas_user_id, attributes, "user-agent", "10.10.10.10"
7070
)
7171
)
7272
self.assertEqual(mxid, "@test_user:test")
7373

7474
# Subsequent calls should map to the same mxid.
7575
mxid = self.get_success(
7676
self.handler._map_cas_user_to_matrix_user(
77-
cas_user_id, display_name, "user-agent", "10.10.10.10"
77+
cas_user_id, attributes, "user-agent", "10.10.10.10"
7878
)
7979
)
8080
self.assertEqual(mxid, "@test_user:test")
8181

8282
def test_map_cas_user_to_invalid_localpart(self):
8383
"""CAS automaps invalid characters to base-64 encoding."""
8484
cas_user_id = "föö"
85-
display_name = ""
85+
attributes = {}
8686
mxid = self.get_success(
8787
self.handler._map_cas_user_to_matrix_user(
88-
cas_user_id, display_name, "user-agent", "10.10.10.10"
88+
cas_user_id, attributes, "user-agent", "10.10.10.10"
8989
)
9090
)
9191
self.assertEqual(mxid, "@f=c3=b6=c3=b6:test")

0 commit comments

Comments
 (0)