@@ -76,9 +76,9 @@ def _build_service_param(self, args: Dict[str, str]) -> str:
76
76
77
77
async def _validate_ticket (
78
78
self , ticket : str , service_args : Dict [str , str ]
79
- ) -> Tuple [str , Optional [str ]]:
79
+ ) -> Tuple [str , Dict [ str , Optional [str ] ]]:
80
80
"""
81
- Validate a CAS ticket with the server, parse the response, and return the user and display name .
81
+ Validate a CAS ticket with the server, parse the response, and return the user and other attributes .
82
82
83
83
Args:
84
84
ticket: The CAS ticket from the client.
@@ -97,22 +97,7 @@ async def _validate_ticket(
97
97
# even if that's being used old-http style to signal end-of-data
98
98
body = pde .response
99
99
100
- user , attributes = self ._parse_cas_response (body )
101
- displayname = attributes .pop (self ._cas_displayname_attribute , None )
102
-
103
- for required_attribute , required_value in self ._cas_required_attributes .items ():
104
- # If required attribute was not in CAS Response - Forbidden
105
- if required_attribute not in attributes :
106
- raise LoginError (401 , "Unauthorized" , errcode = Codes .UNAUTHORIZED )
107
-
108
- # Also need to check value
109
- if required_value is not None :
110
- actual_value = attributes [required_attribute ]
111
- # If required attribute value does not match expected - Forbidden
112
- if required_value != actual_value :
113
- raise LoginError (401 , "Unauthorized" , errcode = Codes .UNAUTHORIZED )
114
-
115
- return user , displayname
100
+ return self ._parse_cas_response (body )
116
101
117
102
def _parse_cas_response (
118
103
self , cas_response_body : bytes
@@ -208,7 +193,7 @@ async def handle_ticket(
208
193
args ["redirectUrl" ] = client_redirect_url
209
194
if session :
210
195
args ["session" ] = session
211
- username , user_display_name = await self ._validate_ticket (ticket , args )
196
+ username , attributes = await self ._validate_ticket (ticket , args )
212
197
213
198
# first check if we're doing a UIA
214
199
if session :
@@ -218,14 +203,36 @@ async def handle_ticket(
218
203
219
204
# otherwise, we're handling a login request.
220
205
206
+ # Ensure that the attributes of the logged in user meet the required
207
+ # attributes.
208
+ for required_attribute , required_value in self ._cas_required_attributes .items ():
209
+ # If required attribute was not in CAS Response - Forbidden
210
+ if required_attribute not in attributes :
211
+ self ._sso_handler .render_error (
212
+ request , "unauthorised" , "You are not authorised to log in here."
213
+ )
214
+ return
215
+
216
+ # Also need to check value
217
+ if required_value is not None :
218
+ actual_value = attributes [required_attribute ]
219
+ # If required attribute value does not match expected - Forbidden
220
+ if required_value != actual_value :
221
+ self ._sso_handler .render_error (
222
+ request ,
223
+ "unauthorised" ,
224
+ "You are not authorised to log in here." ,
225
+ )
226
+ return
227
+
221
228
# Pull out the user-agent and IP from the request.
222
229
user_agent = request .get_user_agent ("" )
223
230
ip_address = self .hs .get_ip_from_request (request )
224
231
225
232
# Get the matrix ID from the CAS username.
226
233
try :
227
234
user_id = await self ._map_cas_user_to_matrix_user (
228
- username , user_display_name , user_agent , ip_address
235
+ username , attributes , user_agent , ip_address
229
236
)
230
237
except MappingException as e :
231
238
logger .exception ("Could not map user" )
@@ -242,7 +249,7 @@ async def handle_ticket(
242
249
async def _map_cas_user_to_matrix_user (
243
250
self ,
244
251
remote_user_id : str ,
245
- display_name : Optional [str ],
252
+ attributes : Dict [ str , Optional [str ] ],
246
253
user_agent : str ,
247
254
ip_address : str ,
248
255
) -> str :
@@ -251,7 +258,7 @@ async def _map_cas_user_to_matrix_user(
251
258
252
259
Args:
253
260
remote_user_id: The username from the CAS response.
254
- display_name: The display name from the CAS response.
261
+ attributes: Additional attributes from the CAS response.
255
262
user_agent: The user agent of the client making the request.
256
263
ip_address: The IP address of the client making the request.
257
264
@@ -262,12 +269,14 @@ async def _map_cas_user_to_matrix_user(
262
269
The user ID associated with this response.
263
270
"""
264
271
272
+ # Note that CAS does not support a mapping provider, so the logic is hard-coded.
273
+ localpart = map_username_to_mxid_localpart (remote_user_id )
274
+ display_name = attributes .pop (self ._cas_displayname_attribute , None )
275
+
265
276
async def cas_response_to_user_attributes (failures : int ) -> UserAttributes :
266
277
"""
267
278
Map from CAS attributes to user attributes.
268
279
"""
269
- localpart = map_username_to_mxid_localpart (remote_user_id )
270
-
271
280
# Due to the grandfathering logic matching any previously registered
272
281
# mxids it isn't expected for there to be any failures.
273
282
if failures :
@@ -278,9 +287,7 @@ async def cas_response_to_user_attributes(failures: int) -> UserAttributes:
278
287
async def grandfather_existing_users () -> Optional [str ]:
279
288
# Since CAS did not used to support storing data into the user_external_ids
280
289
# tables, we need to attempt to map to existing users.
281
- user_id = UserID (
282
- map_username_to_mxid_localpart (remote_user_id ), self ._hostname
283
- ).to_string ()
290
+ user_id = UserID (localpart , self ._hostname ).to_string ()
284
291
285
292
logger .debug (
286
293
"Looking for existing account based on mapped %s" , user_id ,
0 commit comments