1
+ from __future__ import annotations
2
+
1
3
import binascii
2
4
import json
3
5
import warnings
4
- from collections .abc import Mapping
5
- from typing import Any , Dict , List , Optional , Type
6
+ from typing import Any , Type
6
7
7
8
from .algorithms import (
8
9
Algorithm ,
23
24
class PyJWS :
24
25
header_typ = "JWT"
25
26
26
- def __init__ (self , algorithms = None , options = None ):
27
+ def __init__ (self , algorithms = None , options = None ) -> None :
27
28
self ._algorithms = get_default_algorithms ()
28
29
self ._valid_algs = (
29
30
set (algorithms ) if algorithms is not None else set (self ._algorithms )
@@ -39,10 +40,10 @@ def __init__(self, algorithms=None, options=None):
39
40
self .options = {** self ._get_default_options (), ** options }
40
41
41
42
@staticmethod
42
- def _get_default_options ():
43
+ def _get_default_options () -> dict [ str , bool ] :
43
44
return {"verify_signature" : True }
44
45
45
- def register_algorithm (self , alg_id , alg_obj ) :
46
+ def register_algorithm (self , alg_id : str , alg_obj : Algorithm ) -> None :
46
47
"""
47
48
Registers a new Algorithm for use when creating and verifying tokens.
48
49
"""
@@ -55,7 +56,7 @@ def register_algorithm(self, alg_id, alg_obj):
55
56
self ._algorithms [alg_id ] = alg_obj
56
57
self ._valid_algs .add (alg_id )
57
58
58
- def unregister_algorithm (self , alg_id ) :
59
+ def unregister_algorithm (self , alg_id : str ) -> None :
59
60
"""
60
61
Unregisters an Algorithm for use when creating and verifying tokens
61
62
Throws KeyError if algorithm is not registered.
@@ -69,7 +70,7 @@ def unregister_algorithm(self, alg_id):
69
70
del self ._algorithms [alg_id ]
70
71
self ._valid_algs .remove (alg_id )
71
72
72
- def get_algorithms (self ):
73
+ def get_algorithms (self ) -> list [ str ] :
73
74
"""
74
75
Returns a list of supported values for the 'alg' parameter.
75
76
"""
@@ -96,9 +97,9 @@ def encode(
96
97
self ,
97
98
payload : bytes ,
98
99
key : str ,
99
- algorithm : Optional [ str ] = "HS256" ,
100
- headers : Optional [ Dict [ str , Any ]] = None ,
101
- json_encoder : Optional [ Type [json .JSONEncoder ]] = None ,
100
+ algorithm : str | None = "HS256" ,
101
+ headers : dict [ str , Any ] | None = None ,
102
+ json_encoder : Type [json .JSONEncoder ] | None = None ,
102
103
is_payload_detached : bool = False ,
103
104
) -> str :
104
105
segments = []
@@ -117,7 +118,7 @@ def encode(
117
118
is_payload_detached = True
118
119
119
120
# Header
120
- header = {"typ" : self .header_typ , "alg" : algorithm_ } # type: Dict[str, Any]
121
+ header : dict [ str , Any ] = {"typ" : self .header_typ , "alg" : algorithm_ }
121
122
122
123
if headers :
123
124
self ._validate_headers (headers )
@@ -165,11 +166,11 @@ def decode_complete(
165
166
self ,
166
167
jwt : str ,
167
168
key : str = "" ,
168
- algorithms : Optional [ List [ str ]] = None ,
169
- options : Optional [ Dict [ str , Any ]] = None ,
170
- detached_payload : Optional [ bytes ] = None ,
169
+ algorithms : list [ str ] | None = None ,
170
+ options : dict [ str , Any ] | None = None ,
171
+ detached_payload : bytes | None = None ,
171
172
** kwargs ,
172
- ) -> Dict [str , Any ]:
173
+ ) -> dict [str , Any ]:
173
174
if kwargs :
174
175
warnings .warn (
175
176
"passing additional kwargs to decode_complete() is deprecated "
@@ -210,9 +211,9 @@ def decode(
210
211
self ,
211
212
jwt : str ,
212
213
key : str = "" ,
213
- algorithms : Optional [ List [ str ]] = None ,
214
- options : Optional [ Dict [ str , Any ]] = None ,
215
- detached_payload : Optional [ bytes ] = None ,
214
+ algorithms : list [ str ] | None = None ,
215
+ options : dict [ str , Any ] | None = None ,
216
+ detached_payload : bytes | None = None ,
216
217
** kwargs ,
217
218
) -> str :
218
219
if kwargs :
@@ -227,7 +228,7 @@ def decode(
227
228
)
228
229
return decoded ["payload" ]
229
230
230
- def get_unverified_header (self , jwt ) :
231
+ def get_unverified_header (self , jwt : str | bytes ) -> dict :
231
232
"""Returns back the JWT header parameters as a dict()
232
233
233
234
Note: The signature is not verified so the header parameters
@@ -238,7 +239,7 @@ def get_unverified_header(self, jwt):
238
239
239
240
return headers
240
241
241
- def _load (self , jwt ) :
242
+ def _load (self , jwt : str | bytes ) -> tuple [ bytes , bytes , dict , bytes ] :
242
243
if isinstance (jwt , str ):
243
244
jwt = jwt .encode ("utf-8" )
244
245
@@ -261,7 +262,7 @@ def _load(self, jwt):
261
262
except ValueError as e :
262
263
raise DecodeError (f"Invalid header string: { e } " ) from e
263
264
264
- if not isinstance (header , Mapping ):
265
+ if not isinstance (header , dict ):
265
266
raise DecodeError ("Invalid header string: must be a json object" )
266
267
267
268
try :
@@ -278,16 +279,16 @@ def _load(self, jwt):
278
279
279
280
def _verify_signature (
280
281
self ,
281
- signing_input ,
282
- header ,
283
- signature ,
284
- key = "" ,
285
- algorithms = None ,
286
- ):
282
+ signing_input : bytes ,
283
+ header : dict ,
284
+ signature : bytes ,
285
+ key : str = "" ,
286
+ algorithms : list [ str ] | None = None ,
287
+ ) -> None :
287
288
288
289
alg = header .get ("alg" )
289
290
290
- if algorithms is not None and alg not in algorithms :
291
+ if not alg or ( algorithms is not None and alg not in algorithms ) :
291
292
raise InvalidAlgorithmError ("The specified alg value is not allowed" )
292
293
293
294
try :
@@ -299,11 +300,11 @@ def _verify_signature(
299
300
if not alg_obj .verify (signing_input , key , signature ):
300
301
raise InvalidSignatureError ("Signature verification failed" )
301
302
302
- def _validate_headers (self , headers ) :
303
+ def _validate_headers (self , headers : dict [ str , Any ]) -> None :
303
304
if "kid" in headers :
304
305
self ._validate_kid (headers ["kid" ])
305
306
306
- def _validate_kid (self , kid ) :
307
+ def _validate_kid (self , kid : str ) -> None :
307
308
if not isinstance (kid , str ):
308
309
raise InvalidTokenError ("Key ID header parameter must be a string" )
309
310
0 commit comments