Skip to content

Commit d71ba51

Browse files
author
JulianMaurin
committed
feat(api_jws): typing
1 parent c6b6214 commit d71ba51

File tree

1 file changed

+31
-30
lines changed

1 file changed

+31
-30
lines changed

jwt/api_jws.py

+31-30
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
1+
from __future__ import annotations
2+
13
import binascii
24
import json
35
import warnings
4-
from collections.abc import Mapping
5-
from typing import Any, Dict, List, Optional, Type
6+
from typing import Any, Type
67

78
from .algorithms import (
89
Algorithm,
@@ -23,7 +24,7 @@
2324
class PyJWS:
2425
header_typ = "JWT"
2526

26-
def __init__(self, algorithms=None, options=None):
27+
def __init__(self, algorithms=None, options=None) -> None:
2728
self._algorithms = get_default_algorithms()
2829
self._valid_algs = (
2930
set(algorithms) if algorithms is not None else set(self._algorithms)
@@ -39,10 +40,10 @@ def __init__(self, algorithms=None, options=None):
3940
self.options = {**self._get_default_options(), **options}
4041

4142
@staticmethod
42-
def _get_default_options():
43+
def _get_default_options() -> dict[str, bool]:
4344
return {"verify_signature": True}
4445

45-
def register_algorithm(self, alg_id, alg_obj):
46+
def register_algorithm(self, alg_id: str, alg_obj: Algorithm) -> None:
4647
"""
4748
Registers a new Algorithm for use when creating and verifying tokens.
4849
"""
@@ -55,7 +56,7 @@ def register_algorithm(self, alg_id, alg_obj):
5556
self._algorithms[alg_id] = alg_obj
5657
self._valid_algs.add(alg_id)
5758

58-
def unregister_algorithm(self, alg_id):
59+
def unregister_algorithm(self, alg_id: str) -> None:
5960
"""
6061
Unregisters an Algorithm for use when creating and verifying tokens
6162
Throws KeyError if algorithm is not registered.
@@ -69,7 +70,7 @@ def unregister_algorithm(self, alg_id):
6970
del self._algorithms[alg_id]
7071
self._valid_algs.remove(alg_id)
7172

72-
def get_algorithms(self):
73+
def get_algorithms(self) -> list[str]:
7374
"""
7475
Returns a list of supported values for the 'alg' parameter.
7576
"""
@@ -96,9 +97,9 @@ def encode(
9697
self,
9798
payload: bytes,
9899
key: str,
99-
algorithm: Optional[str] = "HS256",
100-
headers: Optional[Dict[str, Any]] = None,
101-
json_encoder: Optional[Type[json.JSONEncoder]] = None,
100+
algorithm: str | None = "HS256",
101+
headers: dict[str, Any] | None = None,
102+
json_encoder: Type[json.JSONEncoder] | None = None,
102103
is_payload_detached: bool = False,
103104
) -> str:
104105
segments = []
@@ -117,7 +118,7 @@ def encode(
117118
is_payload_detached = True
118119

119120
# Header
120-
header = {"typ": self.header_typ, "alg": algorithm_} # type: Dict[str, Any]
121+
header: dict[str, Any] = {"typ": self.header_typ, "alg": algorithm_}
121122

122123
if headers:
123124
self._validate_headers(headers)
@@ -165,11 +166,11 @@ def decode_complete(
165166
self,
166167
jwt: str,
167168
key: str = "",
168-
algorithms: Optional[List[str]] = None,
169-
options: Optional[Dict[str, Any]] = None,
170-
detached_payload: Optional[bytes] = None,
169+
algorithms: list[str] | None = None,
170+
options: dict[str, Any] | None = None,
171+
detached_payload: bytes | None = None,
171172
**kwargs,
172-
) -> Dict[str, Any]:
173+
) -> dict[str, Any]:
173174
if kwargs:
174175
warnings.warn(
175176
"passing additional kwargs to decode_complete() is deprecated "
@@ -210,9 +211,9 @@ def decode(
210211
self,
211212
jwt: str,
212213
key: str = "",
213-
algorithms: Optional[List[str]] = None,
214-
options: Optional[Dict[str, Any]] = None,
215-
detached_payload: Optional[bytes] = None,
214+
algorithms: list[str] | None = None,
215+
options: dict[str, Any] | None = None,
216+
detached_payload: bytes | None = None,
216217
**kwargs,
217218
) -> str:
218219
if kwargs:
@@ -227,7 +228,7 @@ def decode(
227228
)
228229
return decoded["payload"]
229230

230-
def get_unverified_header(self, jwt):
231+
def get_unverified_header(self, jwt: str | bytes) -> dict:
231232
"""Returns back the JWT header parameters as a dict()
232233
233234
Note: The signature is not verified so the header parameters
@@ -238,7 +239,7 @@ def get_unverified_header(self, jwt):
238239

239240
return headers
240241

241-
def _load(self, jwt):
242+
def _load(self, jwt: str | bytes) -> tuple[bytes, bytes, dict, bytes]:
242243
if isinstance(jwt, str):
243244
jwt = jwt.encode("utf-8")
244245

@@ -261,7 +262,7 @@ def _load(self, jwt):
261262
except ValueError as e:
262263
raise DecodeError(f"Invalid header string: {e}") from e
263264

264-
if not isinstance(header, Mapping):
265+
if not isinstance(header, dict):
265266
raise DecodeError("Invalid header string: must be a json object")
266267

267268
try:
@@ -278,16 +279,16 @@ def _load(self, jwt):
278279

279280
def _verify_signature(
280281
self,
281-
signing_input,
282-
header,
283-
signature,
284-
key="",
285-
algorithms=None,
286-
):
282+
signing_input: bytes,
283+
header: dict,
284+
signature: bytes,
285+
key: str = "",
286+
algorithms: list[str] | None = None,
287+
) -> None:
287288

288289
alg = header.get("alg")
289290

290-
if algorithms is not None and alg not in algorithms:
291+
if not alg or (algorithms is not None and alg not in algorithms):
291292
raise InvalidAlgorithmError("The specified alg value is not allowed")
292293

293294
try:
@@ -299,11 +300,11 @@ def _verify_signature(
299300
if not alg_obj.verify(signing_input, key, signature):
300301
raise InvalidSignatureError("Signature verification failed")
301302

302-
def _validate_headers(self, headers):
303+
def _validate_headers(self, headers: dict[str, Any]) -> None:
303304
if "kid" in headers:
304305
self._validate_kid(headers["kid"])
305306

306-
def _validate_kid(self, kid):
307+
def _validate_kid(self, kid: str) -> None:
307308
if not isinstance(kid, str):
308309
raise InvalidTokenError("Key ID header parameter must be a string")
309310

0 commit comments

Comments
 (0)