Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Trusted publishing: prevent OIDC credential re-use #16254

Merged
merged 24 commits into from
Aug 12, 2024
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
747cbb2
Update documentation
DarkaMaul Jun 18, 2024
bccf3b8
Remove outdated sentence
DarkaMaul Jun 18, 2024
905b10b
Create OIDCJtiTokens table
DarkaMaul Jul 9, 2024
9841636
Implement JWT token reuse detection
DarkaMaul Jul 10, 2024
00837af
Revert "Create OIDCJtiTokens table"
DarkaMaul Jul 10, 2024
9f0aa32
Merge remote-tracking branch 'trailofforks/main' into dm/jti
DarkaMaul Jul 10, 2024
d446fd0
Linting
DarkaMaul Jul 10, 2024
5a65892
Move jti to preverified claims
DarkaMaul Jul 11, 2024
1d81567
Update lock period to 5s.
DarkaMaul Jul 11, 2024
d20fc13
Merge branch 'main' into dm/jti
woodruffw Jul 11, 2024
776238c
Add a prefix for the Redis key
DarkaMaul Jul 11, 2024
3f56126
Merge branch 'main' into dm/jti
DarkaMaul Jul 12, 2024
d326cf4
Remove duplication in Github/Gitlab
DarkaMaul Jul 15, 2024
799b322
Rephrase token error message
DarkaMaul Jul 15, 2024
62f3f19
Merge branch 'main' into dm/jti
DarkaMaul Jul 15, 2024
a088b79
Fix test
DarkaMaul Jul 16, 2024
af58daf
Merge branch 'main' into dm/jti
woodruffw Jul 16, 2024
3845bfe
Merge branch 'main' into dm/jti
DarkaMaul Jul 19, 2024
28cfcdb
Merge branch 'main' into dm/jti
woodruffw Jul 22, 2024
31f4417
Check JTI in Publishers with check_existing_jti
DarkaMaul Aug 9, 2024
35e50e2
Renmame token_identifier_exists and move jti token check
DarkaMaul Aug 12, 2024
99b90e0
Renmame jwt_token_identifier
DarkaMaul Aug 12, 2024
90638ae
Merge branch 'main' into dm/jti
DarkaMaul Aug 12, 2024
ba9bbfd
Merge branch 'main' into dm/jti
di Aug 12, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion tests/unit/oidc/models/test_github.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,10 +59,10 @@ def test_github_publisher_all_known_claims(self):
"nbf",
"exp",
"aud",
"jti",
# unchecked claims
"actor",
"actor_id",
"jti",
"run_id",
"run_number",
"run_attempt",
Expand Down
2 changes: 1 addition & 1 deletion tests/unit/oidc/models/test_gitlab.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ def test_gitlab_publisher_all_known_claims(self):
"nbf",
"exp",
"aud",
"jti",
# unchecked claims
"project_id",
"namespace_id",
Expand All @@ -78,7 +79,6 @@ def test_gitlab_publisher_all_known_claims(self):
"runner_environment",
"ci_config_sha",
"project_visibility",
"jti",
"user_access_level",
"groups_direct",
}
Expand Down
126 changes: 124 additions & 2 deletions tests/unit/oidc/test_services.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import datetime

import jwt
import pretend
Expand All @@ -22,6 +23,7 @@

from tests.common.db.oidc import GitHubPublisherFactory, PendingGitHubPublisherFactory
from warehouse.oidc import errors, interfaces, services
from warehouse.oidc.interfaces import SignedClaims


def test_oidc_publisher_service_factory(metrics):
Expand Down Expand Up @@ -184,7 +186,7 @@ def test_find_publisher(self, metrics, monkeypatch):
metrics=metrics,
)

token = pretend.stub()
token = SignedClaims({})

publisher = pretend.stub(verify_claims=pretend.call_recorder(lambda c: True))
find_publisher_by_issuer = pretend.call_recorder(lambda *a, **kw: publisher)
Expand Down Expand Up @@ -253,7 +255,7 @@ def test_find_publisher_verify_claims_fails(self, metrics, monkeypatch):
services, "find_publisher_by_issuer", find_publisher_by_issuer
)

claims = pretend.stub()
claims = SignedClaims({})
with pytest.raises(errors.InvalidPublisherError):
service.find_publisher(claims)
assert service.metrics.increment.calls == [
Expand All @@ -268,6 +270,125 @@ def test_find_publisher_verify_claims_fails(self, metrics, monkeypatch):
]
assert publisher.verify_claims.calls == [pretend.call(claims)]

def test_find_publisher_prevent_reuse_token(self, monkeypatch, mockredis, metrics):
service = services.OIDCPublisherService(
session=pretend.stub(),
publisher="fakepublisher",
issuer_url=pretend.stub(),
audience="fakeaudience",
cache_url="redis://fake.example.com",
metrics=metrics,
)

monkeypatch.setattr(services.redis, "StrictRedis", mockredis)

publisher = pretend.stub(verify_claims=pretend.call_recorder(lambda c: True))
find_publisher_by_issuer = pretend.call_recorder(lambda *a, **kw: publisher)
monkeypatch.setattr(
services, "find_publisher_by_issuer", find_publisher_by_issuer
)

expiration = int(
(
datetime.datetime.now(tz=datetime.UTC) + datetime.timedelta(minutes=15)
).timestamp()
)
jwt_token_identifier = "6e67b1cb-2b8d-4be5-91cb-757edb2ec970"
service.store_jwt_identifier(jwt_token_identifier, expiration=expiration)

claims = SignedClaims(
{
"iss": "foo",
"iat": 1516239022,
"nbf": 1516239022,
"exp": expiration,
"aud": "pypi",
"jti": jwt_token_identifier,
}
)

with pytest.raises(errors.ReusedTokenError):
service.find_publisher(claims, pending=False)

assert (
pretend.call(
"warehouse.oidc.reused_token", tags=["publisher:fakepublisher"]
)
in metrics.increment.calls
)

def test_find_publisher_store_jti(self, monkeypatch, mockredis, metrics):
service = services.OIDCPublisherService(
session=pretend.stub(),
publisher="fakepublisher",
issuer_url=pretend.stub(),
audience="fakeaudience",
cache_url="redis://fake.example.com",
metrics=metrics,
)

monkeypatch.setattr(services.redis, "StrictRedis", mockredis)

publisher = pretend.stub(verify_claims=pretend.call_recorder(lambda c: True))
find_publisher_by_issuer = pretend.call_recorder(lambda *a, **kw: publisher)
monkeypatch.setattr(
services, "find_publisher_by_issuer", find_publisher_by_issuer
)

expiration = int(
(
datetime.datetime.now(tz=datetime.UTC) + datetime.timedelta(minutes=15)
).timestamp()
)
jwt_token_identifier = "6e67b1cb-2b8d-4be5-91cb-757edb2ec970"
claims = SignedClaims(
{
"iss": "foo",
"iat": 1516239022,
"nbf": 1516239022,
"exp": expiration,
"aud": "pypi",
"jti": jwt_token_identifier,
}
)

service.find_publisher(claims, pending=False)
assert service.token_identifier_exists(jwt_token_identifier) is True

def test_find_publisher_jti_not_stored_if_pending(
self, monkeypatch, mockredis, metrics
):
service = services.OIDCPublisherService(
session=pretend.stub(),
publisher="fakepublisher",
issuer_url=pretend.stub(),
audience="fakeaudience",
cache_url="redis://fake.example.com",
metrics=metrics,
)

monkeypatch.setattr(services.redis, "StrictRedis", mockredis)

publisher = pretend.stub(verify_claims=pretend.call_recorder(lambda c: True))
find_publisher_by_issuer = pretend.call_recorder(lambda *a, **kw: publisher)
monkeypatch.setattr(
services, "find_publisher_by_issuer", find_publisher_by_issuer
)
jwt_token_identifier = "6e67b1cb-2b8d-4be5-91cb-757edb2ec970"
claims = SignedClaims(
{
"iss": "foo",
"iat": 1516239022,
"nbf": 1516239022,
"exp": int(datetime.datetime.now(tz=datetime.UTC).timestamp()),
"aud": "pypi",
"jti": jwt_token_identifier,
}
)

service.find_publisher(claims, pending=True)
assert service.token_identifier_exists(jwt_token_identifier) is False

def test_get_keyset_not_cached(self, monkeypatch, mockredis):
service = services.OIDCPublisherService(
session=pretend.stub(),
Expand Down Expand Up @@ -820,6 +941,7 @@ def test_find_publisher(self, monkeypatch):
"nbf": 1516239022,
"exp": 9999999999,
"aud": "pypi",
"jti": "6e67b1cb-2b8d-4be5-91cb-757edb2ec970",
}

service = services.NullOIDCPublisherService(
Expand Down
31 changes: 31 additions & 0 deletions tests/unit/oidc/test_views.py
Original file line number Diff line number Diff line change
Expand Up @@ -361,6 +361,37 @@ def test_mint_token_trusted_publisher_lookup_fails(dummy_github_oidc_jwt):
]


def test_mint_token_duplicate_token(dummy_github_oidc_jwt):
def find_publishers_mockup(_, pending: bool = False):
if pending is False:
raise errors.ReusedTokenError("some message")
else:
raise errors.InvalidPublisherError("some message")

claims = pretend.stub()
oidc_service = pretend.stub(
verify_jwt_signature=pretend.call_recorder(lambda token: claims),
find_publisher=find_publishers_mockup,
)
request = pretend.stub(
response=pretend.stub(status=None),
find_service=pretend.call_recorder(lambda cls, **kw: oidc_service),
flags=pretend.stub(disallow_oidc=lambda *a: False),
)

response = views.mint_token(oidc_service, dummy_github_oidc_jwt, request)
assert request.response.status == 422
assert response == {
"message": "Token request failed",
"errors": [
{
"code": "invalid-reuse-token",
"description": "valid token, but already used",
}
],
}


def test_mint_token_pending_publisher_project_already_exists(
db_request, dummy_github_oidc_jwt
):
Expand Down
4 changes: 4 additions & 0 deletions warehouse/oidc/errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,3 +13,7 @@

class InvalidPublisherError(Exception):
pass


class ReusedTokenError(Exception):
pass
10 changes: 9 additions & 1 deletion warehouse/oidc/models/github.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,10 +138,18 @@ class GitHubPublisherMixin:
"environment": _check_environment,
}

__preverified_claims__ = {
"iss",
"iat",
"nbf",
"exp",
"aud",
"jti",
}

__unchecked_claims__ = {
"actor",
"actor_id",
"jti",
"run_id",
"run_number",
"run_attempt",
Expand Down
10 changes: 9 additions & 1 deletion warehouse/oidc/models/gitlab.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,15 @@ class GitLabPublisherMixin:

__required_unverifiable_claims__: set[str] = {"ref_path", "sha"}

__preverified_claims__ = {
"iss",
"iat",
"nbf",
"exp",
"aud",
"jti",
}

__optional_verifiable_claims__: dict[str, CheckClaimCallable[Any]] = {
"environment": _check_environment,
}
Expand Down Expand Up @@ -149,7 +158,6 @@ class GitLabPublisherMixin:
"runner_environment",
"ci_config_sha",
"project_visibility",
"jti",
"user_access_level",
"groups_direct",
}
Expand Down
39 changes: 38 additions & 1 deletion warehouse/oidc/services.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
# limitations under the License.

import json
import typing
import warnings

import jwt
Expand All @@ -21,7 +22,7 @@
from zope.interface import implementer

from warehouse.metrics.interfaces import IMetricsService
from warehouse.oidc.errors import InvalidPublisherError
from warehouse.oidc.errors import InvalidPublisherError, ReusedTokenError
from warehouse.oidc.interfaces import IOIDCPublisherService, SignedClaims
from warehouse.oidc.models import OIDCPublisher, PendingOIDCPublisher
from warehouse.oidc.utils import find_publisher_by_issuer
Expand Down Expand Up @@ -219,6 +220,23 @@ def _get_key_for_token(self, token):
unverified_header = jwt.get_unverified_header(token)
return self._get_key(unverified_header["kid"])

def token_identifier_exists(self, jti: str) -> bool:
"""
Check if a JWT Token Identifier has already been used.
"""
with redis.StrictRedis.from_url(self.cache_url) as r:
return bool(r.exists(jti))

def store_jwt_identifier(self, jti: str, expiration: int) -> None:
"""
Store the JTI with its expiration date if the key does not exist.
"""
with redis.StrictRedis.from_url(self.cache_url) as r:
# Defensive: to prevent races, we expire the JTI slightly after
# the token expiration date. Thus, the lock will not be
# released before the token invalidation.
r.set(jti, exat=expiration + 1, value="placeholder", nx=True)

def verify_jwt_signature(self, unverified_token: str) -> SignedClaims | None:
try:
key = self._get_key_for_token(unverified_token)
Expand Down Expand Up @@ -287,11 +305,30 @@ def find_publisher(
publisher = find_publisher_by_issuer(
self.db, self.issuer_url, signed_claims, pending=pending
)

jwt_token_identifier: str | None = signed_claims.get("jti", None)
# jti is in the __preverified_claims__ set, so if it was present,
# it was already checked
if pending is False and jwt_token_identifier:
if self.token_identifier_exists(jwt_token_identifier):
self.metrics.increment(
"warehouse.oidc.reused_token",
tags=metrics_tags,
)
raise ReusedTokenError("JWT Token already used to mint a token.")

publisher.verify_claims(signed_claims)
self.metrics.increment(
"warehouse.oidc.find_publisher.ok",
tags=metrics_tags,
)

if pending is False and jwt_token_identifier:
# Of note, exp is coming from a trusted source here,
# so we don't validate it
expiration = typing.cast(int, signed_claims.get("exp"))
self.store_jwt_identifier(jwt_token_identifier, expiration)

return publisher
except InvalidPublisherError as e:
self.metrics.increment(
Expand Down
12 changes: 11 additions & 1 deletion warehouse/oidc/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
from warehouse.macaroons.interfaces import IMacaroonService
from warehouse.macaroons.services import DatabaseMacaroonService
from warehouse.metrics.interfaces import IMetricsService
from warehouse.oidc.errors import InvalidPublisherError
from warehouse.oidc.errors import InvalidPublisherError, ReusedTokenError
from warehouse.oidc.interfaces import IOIDCPublisherService
from warehouse.oidc.models import OIDCPublisher, PendingOIDCPublisher
from warehouse.oidc.services import OIDCPublisherService
Expand Down Expand Up @@ -238,6 +238,16 @@ def mint_token(
publisher = oidc_service.find_publisher(claims, pending=False)
# NOTE: assert to persuade mypy of the correct type here.
assert isinstance(publisher, OIDCPublisher)
except ReusedTokenError:
return _invalid(
errors=[
{
"code": "invalid-reuse-token",
"description": "valid token, but already used",
}
],
request=request,
)
except InvalidPublisherError as e:
return _invalid(
errors=[
Expand Down