Skip to content
This repository was archived by the owner on Nov 30, 2022. It is now read-only.

Reduce # of clients connected to the application db [#810] #944

Merged
merged 10 commits into from
Jul 27, 2022
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ The types of changes are:
### Fixed

* Correct build arg variable name [925](https://github.com/ethyca/fidesops/pull/925)
* Reduce number of clients connected to the application db [#944](https://github.com/ethyca/fidesops/pull/944)

## [1.6.3](https://github.com/ethyca/fidesops/compare/1.6.2...1.6.3)

Expand Down
2 changes: 1 addition & 1 deletion src/fidesops/api/v1/endpoints/privacy_request_endpoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -599,7 +599,7 @@ def get_request_preview_queries(
traversal: Traversal = Traversal(dataset_graph, identity_seed)
queries: Dict[CollectionAddress, str] = collect_queries(
traversal,
TaskResources(EMPTY_REQUEST, Policy(), connection_configs),
TaskResources(EMPTY_REQUEST, Policy(), connection_configs, db),
)
return [
DryRunDatasetResponse(
Expand Down
29 changes: 20 additions & 9 deletions src/fidesops/service/privacy_request/request_runner_service.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
import logging
from datetime import datetime, timedelta
from typing import Dict, List, Optional, Set
from typing import ContextManager, Dict, List, Optional, Set

from celery import Task
from celery.utils.log import get_task_logger
from fideslib.db.session import get_db_session
from pydantic import ValidationError
Expand Down Expand Up @@ -150,8 +151,22 @@ def queue_privacy_request(
return task.task_id


@celery_app.task()
class DatabaseTask(Task): # pylint: disable=W0223
_session = None

@property
def session(self) -> ContextManager[Session]:
"""Creates Session once per process"""
if self._session is None:
SessionLocal = get_db_session(config)
self._session = SessionLocal()

return self._session
Comment on lines +154 to +164
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This allows us to share the Session across a Celery process.

This also limits the number of times get_db_session is being called per process which was creating new Engines which were opening up new Connection Pools whose connections weren't being reused.



@celery_app.task(base=DatabaseTask, bind=True)
def run_privacy_request(
self: DatabaseTask,
privacy_request_id: str,
from_webhook_id: Optional[str] = None,
from_step: Optional[str] = None,
Expand All @@ -169,8 +184,7 @@ def run_privacy_request(
# can't be passed into and between tasks
from_step = PausedStep(from_step) # type: ignore

SessionLocal = get_db_session(config)
with SessionLocal() as session:
with self.session as session:

privacy_request = PrivacyRequest.get(db=session, object_id=privacy_request_id)
if privacy_request.status == PrivacyRequestStatus.canceled:
Expand All @@ -190,7 +204,6 @@ def run_privacy_request(
after_webhook_id=from_webhook_id,
)
if not proceed:
session.close()
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These session.close calls were removed because they're not doing anything, the Session is a context manager above (see with self.session as session), so it is being closed automatically.

return

policy = privacy_request.policy
Expand All @@ -217,6 +230,7 @@ def run_privacy_request(
graph=dataset_graph,
connection_configs=connection_configs,
identity=identity_data,
session=session,
)

upload_access_results(
Expand All @@ -238,19 +252,18 @@ def run_privacy_request(
access_request_data=get_cached_data_for_erasures(
privacy_request.id
),
session=session,
)

except PrivacyRequestPaused as exc:
privacy_request.pause_processing(session)
_log_warning(exc, config.dev_mode)
session.close()
return

except BaseException as exc: # pylint: disable=broad-except
privacy_request.error_processing(db=session)
# If dev mode, log traceback
_log_exception(exc, config.dev_mode)
session.close()
return

# Run post-execution webhooks
Expand All @@ -260,14 +273,12 @@ def run_privacy_request(
webhook_cls=PolicyPostWebhook, # type: ignore
)
if not proceed:
session.close()
return

privacy_request.finished_processing_at = datetime.utcnow()
privacy_request.status = PrivacyRequestStatus.complete
privacy_request.save(db=session)
logging.info(f"Privacy request {privacy_request.id} run completed.")
session.close()


def initiate_paused_privacy_request_followup(privacy_request: PrivacyRequest) -> None:
Expand Down
13 changes: 10 additions & 3 deletions src/fidesops/task/graph_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

import dask
from dask.threaded import get
from sqlalchemy.orm import Session

from fidesops.common_exceptions import CollectionDisabled, PrivacyRequestPaused
from fidesops.core.config import config
Expand Down Expand Up @@ -568,10 +569,13 @@ def run_access_request(
graph: DatasetGraph,
connection_configs: List[ConnectionConfig],
identity: Dict[str, Any],
session: Session,
) -> Dict[str, List[Row]]:
"""Run the access request"""
traversal: Traversal = Traversal(graph, identity)
with TaskResources(privacy_request, policy, connection_configs) as resources:
with TaskResources(
privacy_request, policy, connection_configs, session
) as resources:
Comment on lines +576 to +578
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We're passing in the Session here to TaskResources so we can use it to create the ExecutionLogs


def collect_tasks_fn(
tn: TraversalNode, data: Dict[CollectionAddress, GraphTask]
Expand Down Expand Up @@ -636,17 +640,20 @@ def update_erasure_mapping_from_cache(
)


def run_erasure( # pylint: disable = too-many-arguments
def run_erasure( # pylint: disable = too-many-arguments, too-many-locals
privacy_request: PrivacyRequest,
policy: Policy,
graph: DatasetGraph,
connection_configs: List[ConnectionConfig],
identity: Dict[str, Any],
access_request_data: Dict[str, List[Row]],
session: Session,
) -> Dict[str, int]:
"""Run an erasure request"""
traversal: Traversal = Traversal(graph, identity)
with TaskResources(privacy_request, policy, connection_configs) as resources:
with TaskResources(
privacy_request, policy, connection_configs, session
) as resources:

def collect_tasks_fn(
tn: TraversalNode, data: Dict[CollectionAddress, GraphTask]
Expand Down
9 changes: 4 additions & 5 deletions src/fidesops/task/task_resources.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,9 @@
import logging
from typing import Any, Dict, List, Optional

from fideslib.db.session import get_db_session
from sqlalchemy.orm import Session

from fidesops.common_exceptions import ConnectorNotFoundException
from fidesops.core.config import config
from fidesops.graph.config import CollectionAddress
from fidesops.models.connectionconfig import ConnectionConfig, ConnectionType
from fidesops.models.policy import ActionType, Policy
Expand Down Expand Up @@ -97,6 +96,7 @@ def __init__(
request: PrivacyRequest,
policy: Policy,
connection_configs: List[ConnectionConfig],
session: Session,
):
self.request = request
self.policy = policy
Expand All @@ -106,6 +106,7 @@ def __init__(
c.key: c for c in connection_configs
}
self.connections = Connections()
self.session = session

def __enter__(self) -> "TaskResources":
"""Support 'with' usage for closing resources"""
Expand Down Expand Up @@ -157,8 +158,7 @@ def write_execution_log( # pylint: disable=too-many-arguments
message: str = None,
) -> Any:
"""Store in application db. Return the created or written-to id field value."""
SessionLocal = get_db_session(config)
db = SessionLocal()
db = self.session
Comment on lines -160 to +161
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We're using the same Session that is being used for the PrivacyRequest, not creating a new one, which reduces the number of connections we're opening. It also has a side effect of causing other resources bound to the Session for the PrivacyRequest to get the most recent state when the ExecutionLog is saved, because we commit the Session.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the main implication here is with request cancellation? ie. if .cancelled_at / paused_at / .status were to change, we may be able to abort the traversal entirely?

Copy link
Contributor Author

@pattisdr pattisdr Jul 27, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I think this is a good example of where we can better make decisions mid-traversal based on the resources's current state in the future.


ExecutionLog.create(
db=db,
Expand All @@ -172,7 +172,6 @@ def write_execution_log( # pylint: disable=too-many-arguments
"message": message,
},
)
db.close()

def get_connector(self, key: FidesOpsKey) -> Any:
"""Create or return the client corresponding to the given ConnectionConfig key"""
Expand Down
2 changes: 1 addition & 1 deletion src/fidesops/tasks/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ def _create_celery(config_path: str = config.execution.celery_config_path) -> Ce

def start_worker() -> None:
logger.info("Running Celery worker...")
celery_app.worker_main(argv=["worker", "--loglevel=info"])
celery_app.worker_main(argv=["worker", "--loglevel=info", "--concurrency=2"])
Copy link
Contributor Author

@pattisdr pattisdr Jul 27, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This could be configurable in the future, but this matches our state before the move to Celery. The default is the number of cores on your machine, but limiting the number of PrivacyRequests that can be run simultaneously per worker has a lot of positive effects. This includes reducing the number of simultaneous connections on the application database, as well as the customers' owned databases, and reduces simultaneous requests against their connected API's.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, I've added a ticket to be addressed as a follow-up.



if __name__ == "__main__":
Expand Down
4 changes: 2 additions & 2 deletions tests/ops/graph/test_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ def test_node_contains_field(self) -> None:
assert node.contains_field(lambda f: f.identity == "ssn")


def test_retry_decorator(privacy_request, policy):
def test_retry_decorator(privacy_request, policy, db):
input_data = {"test": "data"}
graph: DatasetGraph = integration_db_graph("postgres_example")
traversal = Traversal(graph, {"email": "X"})
Expand All @@ -81,7 +81,7 @@ def __init__(self):
self.start_logged = 0
self.retry_logged = 0
self.end_called_with = ()
self.resources = TaskResources(privacy_request, policy, [])
self.resources = TaskResources(privacy_request, policy, [], db)

def log_end(self, action_type: ActionType, exc: Optional[str] = None):
self.end_called_with = (action_type, exc)
Expand Down
4 changes: 4 additions & 0 deletions tests/ops/integration_tests/saas/test_adobe_campaign_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ def test_adobe_campaign_access_request_task(
adobe_campaign_identity_email,
adobe_campaign_connection_config,
adobe_campaign_dataset_config,
db,
) -> None:
"""Full access request based on the Adobe Campaign SaaS config"""

Expand All @@ -46,6 +47,7 @@ def test_adobe_campaign_access_request_task(
graph,
[adobe_campaign_connection_config],
{"email": adobe_campaign_identity_email},
db,
)

assert_rows_match(
Expand Down Expand Up @@ -188,6 +190,7 @@ def test_adobe_campaign_saas_erasure_request_task(
graph,
[adobe_campaign_connection_config],
{"email": erasure_email},
db,
)

assert_rows_match(
Expand Down Expand Up @@ -297,6 +300,7 @@ def test_adobe_campaign_saas_erasure_request_task(
[adobe_campaign_connection_config],
{"email": erasure_email},
get_cached_data_for_erasures(privacy_request.id),
db,
)

# Assert erasure request made to adobe_campaign_user
Expand Down
8 changes: 3 additions & 5 deletions tests/ops/integration_tests/saas/test_hubspot_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ def test_saas_access_request_task(
graph,
[connection_config_hubspot],
{"email": hubspot_identity_email},
db,
)

assert_rows_match(
Expand Down Expand Up @@ -148,11 +149,7 @@ def test_saas_erasure_request_task(
graph = DatasetGraph(merged_graph)

v = graph_task.run_access_request(
privacy_request,
policy,
graph,
[connection_config_hubspot],
identity_kwargs,
privacy_request, policy, graph, [connection_config_hubspot], identity_kwargs, db
)

assert_rows_match(
Expand All @@ -173,6 +170,7 @@ def test_saas_erasure_request_task(
[connection_config_hubspot],
identity_kwargs,
get_cached_data_for_erasures(privacy_request.id),
db,
)

# Masking request only issued to "contacts" and "subscription_preferences" endpoints
Expand Down
3 changes: 3 additions & 0 deletions tests/ops/integration_tests/saas/test_mailchimp_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ def test_mailchimp_access_request_task(
graph,
[mailchimp_connection_config],
{"email": mailchimp_identity_email},
db,
)

assert_rows_match(
Expand Down Expand Up @@ -149,6 +150,7 @@ def test_mailchimp_erasure_request_task(
graph,
[mailchimp_connection_config],
{"email": mailchimp_identity_email},
db,
)

v = graph_task.run_erasure(
Expand All @@ -158,6 +160,7 @@ def test_mailchimp_erasure_request_task(
[mailchimp_connection_config],
{"email": mailchimp_identity_email},
get_cached_data_for_erasures(privacy_request.id),
db,
)

logs = (
Expand Down
3 changes: 3 additions & 0 deletions tests/ops/integration_tests/saas/test_outreach_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ def test_outreach_access_request_task(
graph,
[outreach_connection_config],
{"email": outreach_identity_email},
db,
)

assert_rows_match(
Expand Down Expand Up @@ -115,6 +116,7 @@ def test_outreach_erasure_request_task(
graph,
[outreach_connection_config],
{"email": outreach_erasure_identity_email},
db,
)

# verify staged data is available for erasure
Expand All @@ -136,6 +138,7 @@ def test_outreach_erasure_request_task(
[outreach_connection_config],
{"email": outreach_erasure_identity_email},
get_cached_data_for_erasures(privacy_request.id),
db,
)

# Assert erasure request made to prospects and recipients
Expand Down
4 changes: 4 additions & 0 deletions tests/ops/integration_tests/saas/test_salesforce_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ def test_salesforce_access_request_task(
salesforce_identity_email,
salesforce_connection_config,
salesforce_dataset_config,
db,
) -> None:
"""Full access request based on the Salesforce SaaS config"""

Expand All @@ -45,6 +46,7 @@ def test_salesforce_access_request_task(
graph,
[salesforce_connection_config],
{"email": salesforce_identity_email},
db,
)

assert_rows_match(
Expand Down Expand Up @@ -401,6 +403,7 @@ def test_salesforce_erasure_request_task(
graph,
[salesforce_connection_config],
{"email": salesforce_erasure_identity_email},
db,
)

# verify staged data is available for erasure
Expand Down Expand Up @@ -714,6 +717,7 @@ def test_salesforce_erasure_request_task(
[salesforce_connection_config],
{"email": salesforce_erasure_identity_email},
get_cached_data_for_erasures(privacy_request.id),
db,
)

# verify masking request was issued for endpoints with update actions
Expand Down
3 changes: 3 additions & 0 deletions tests/ops/integration_tests/saas/test_segment_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ def test_segment_saas_access_request_task(
graph,
[segment_connection_config],
{"email": segment_identity_email},
db,
)

assert_rows_match(
Expand Down Expand Up @@ -168,6 +169,7 @@ def test_segment_saas_erasure_request_task(
graph,
[segment_connection_config],
{"email": erasure_email},
db,
)

assert_rows_match(
Expand Down Expand Up @@ -214,6 +216,7 @@ def test_segment_saas_erasure_request_task(
[segment_connection_config],
{"email": erasure_email},
get_cached_data_for_erasures(privacy_request.id),
db,
)

# Assert erasure request made to segment_user - cannot verify success immediately as this can take
Expand Down
Loading