From 17385e6dec34da357f25748c6a2b3469338d50eb Mon Sep 17 00:00:00 2001 From: Vsevolod Stepanov Date: Thu, 30 May 2024 17:02:50 +0200 Subject: [PATCH 1/3] add checks for spark logging --- .../labs/ucx/source_code/spark_connect.py | 44 ++++++++++- tests/unit/source_code/test_spark_connect.py | 75 ++++++++++++++++++- 2 files changed, 116 insertions(+), 3 deletions(-) diff --git a/src/databricks/labs/ucx/source_code/spark_connect.py b/src/databricks/labs/ucx/source_code/spark_connect.py index 4416c8e604..a468547bf5 100644 --- a/src/databricks/labs/ucx/source_code/spark_connect.py +++ b/src/databricks/labs/ucx/source_code/spark_connect.py @@ -68,9 +68,8 @@ class RDDApiMatcher(SharedClusterMatcher): ] def lint(self, node: ast.AST) -> Iterator[Advice]: - if not isinstance(node, ast.Call): + if not isinstance(node, ast.Call) or not isinstance(node.func, ast.Attribute): return - assert isinstance(node.func, ast.Attribute) # Avoid linter warning if node.func.attr not in self._SC_METHODS: return function_name = AstHelper.get_full_function_name(node) @@ -104,12 +103,53 @@ def lint(self, node: ast.AST) -> Iterator[Advice]: ) +class LoggingMatcher(SharedClusterMatcher): + def lint(self, node: ast.AST) -> Iterator[Advice]: + yield from self._match_sc_set_log_level(node) + yield from self._match_jvm_log(node) + + def _match_sc_set_log_level(self, node: ast.AST) -> Iterator[Advice]: + if not isinstance(node, ast.Call) or not isinstance(node.func, ast.Attribute): + return + if node.func.attr != 'setLogLevel': + return + function_name = AstHelper.get_full_function_name(node) + if not function_name or not function_name.endswith('sc.setLogLevel'): + return + + yield Failure( + code='spark-logging-in-shared-clusters', + message=f'Cannot set Spark log level directly from code on {self._cluster_type_str()}. ' + f'Remove the call and set the cluster spark conf \'spark.log.level\' instead', + start_line=node.lineno, + start_col=node.col_offset, + end_line=node.end_lineno or 0, + end_col=node.end_col_offset or 0, + ) + + def _match_jvm_log(self, node: ast.AST) -> Iterator[Advice]: + if not isinstance(node, ast.Attribute): + return + attribute_name = AstHelper.get_full_attribute_name(node) + if attribute_name and attribute_name.endswith('org.apache.log4j'): + yield Failure( + code='spark-logging-in-shared-clusters', + message=f'Cannot access Spark Driver JVM logger on {self._cluster_type_str()}. ' + f'Use logging.getLogger() instead', + start_line=node.lineno, + start_col=node.col_offset, + end_line=node.end_lineno or 0, + end_col=node.end_col_offset or 0, + ) + + class SparkConnectLinter(Linter): def __init__(self, is_serverless: bool = False): self._matchers = [ JvmAccessMatcher(is_serverless=is_serverless), RDDApiMatcher(is_serverless=is_serverless), SparkSqlContextMatcher(is_serverless=is_serverless), + LoggingMatcher(is_serverless=is_serverless), ] def lint(self, code: str) -> Iterator[Advice]: diff --git a/tests/unit/source_code/test_spark_connect.py b/tests/unit/source_code/test_spark_connect.py index 40ee85b8a3..fc7383be9b 100644 --- a/tests/unit/source_code/test_spark_connect.py +++ b/tests/unit/source_code/test_spark_connect.py @@ -1,5 +1,8 @@ +import ast +from itertools import chain + from databricks.labs.ucx.source_code.base import Failure -from databricks.labs.ucx.source_code.spark_connect import SparkConnectLinter +from databricks.labs.ucx.source_code.spark_connect import LoggingMatcher, SparkConnectLinter def test_jvm_access_match_shared(): @@ -140,6 +143,76 @@ def test_rdd_context_match_serverless(): ] == list(linter.lint(code)) +def test_logging_shared(): + logging_matcher = LoggingMatcher(is_serverless=False) + code = """ +sc.setLogLevel("INFO") +setLogLevel("WARN") + +log4jLogger = sc._jvm.org.apache.log4j +LOGGER = log4jLogger.LogManager.getLogger(__name__) +sc._jvm.org.apache.log4j.LogManager.getLogger(__name__).info("test") + + """ + + assert [ + Failure( + code='spark-logging-in-shared-clusters', + message='Cannot set Spark log level directly from code on UC Shared Clusters. ' + 'Remove the call and set the cluster spark conf \'spark.log.level\' instead', + start_line=2, + start_col=0, + end_line=2, + end_col=22, + ), + Failure( + code='spark-logging-in-shared-clusters', + message='Cannot access Spark Driver JVM logger on UC Shared Clusters. ' 'Use logging.getLogger() instead', + start_line=5, + start_col=14, + end_line=5, + end_col=38, + ), + Failure( + code='spark-logging-in-shared-clusters', + message='Cannot access Spark Driver JVM logger on UC Shared Clusters. ' 'Use logging.getLogger() instead', + start_line=7, + start_col=0, + end_line=7, + end_col=24, + ), + ] == list(chain.from_iterable([logging_matcher.lint(node) for node in ast.walk(ast.parse(code))])) + + +def test_logging_serverless(): + logging_matcher = LoggingMatcher(is_serverless=True) + code = """ +sc.setLogLevel("INFO") +log4jLogger = sc._jvm.org.apache.log4j + + """ + + assert [ + Failure( + code='spark-logging-in-shared-clusters', + message='Cannot set Spark log level directly from code on Serverless Compute. ' + 'Remove the call and set the cluster spark conf \'spark.log.level\' instead', + start_line=2, + start_col=0, + end_line=2, + end_col=22, + ), + Failure( + code='spark-logging-in-shared-clusters', + message='Cannot access Spark Driver JVM logger on Serverless Compute. ' 'Use logging.getLogger() instead', + start_line=3, + start_col=14, + end_line=3, + end_col=38, + ), + ] == list(chain.from_iterable([logging_matcher.lint(node) for node in ast.walk(ast.parse(code))])) + + def test_valid_code(): linter = SparkConnectLinter() code = """ From bb3027bcfbf32bb677f6b8c892ed7316cb690639 Mon Sep 17 00:00:00 2001 From: Vsevolod Stepanov Date: Thu, 30 May 2024 18:04:25 +0200 Subject: [PATCH 2/3] detect sc.conf uses --- .../labs/ucx/source_code/spark_connect.py | 26 +++++++++++++--- tests/unit/source_code/test_spark_connect.py | 30 +++++++++++++++++++ 2 files changed, 52 insertions(+), 4 deletions(-) diff --git a/src/databricks/labs/ucx/source_code/spark_connect.py b/src/databricks/labs/ucx/source_code/spark_connect.py index a468547bf5..93ba40b627 100644 --- a/src/databricks/labs/ucx/source_code/spark_connect.py +++ b/src/databricks/labs/ucx/source_code/spark_connect.py @@ -87,15 +87,33 @@ def lint(self, node: ast.AST) -> Iterator[Advice]: class SparkSqlContextMatcher(SharedClusterMatcher): _ATTRIBUTES = ["sc", "sqlContext", "sparkContext"] + _KNOWN_REPLACEMENTS = {"getConf": "conf", "_conf": "conf"} def lint(self, node: ast.AST) -> Iterator[Advice]: if not isinstance(node, ast.Attribute): return - if not isinstance(node.value, ast.Name) or node.value.id not in SparkSqlContextMatcher._ATTRIBUTES: - return - yield Failure( + + if isinstance(node.value, ast.Name) and node.value.id in SparkSqlContextMatcher._ATTRIBUTES: + yield self._get_advice(node, node.value.id) + # sparkContext can be an attribute as in df.sparkContext.getConf() + if isinstance(node.value, ast.Attribute) and node.value.attr == 'sparkContext': + yield self._get_advice(node, node.value.attr) + + def _get_advice(self, node: ast.Attribute, name: str) -> Advice: + if node.attr in SparkSqlContextMatcher._KNOWN_REPLACEMENTS: + replacement = SparkSqlContextMatcher._KNOWN_REPLACEMENTS[node.attr] + return Failure( + code='legacy-context-in-shared-clusters', + message=f'{name} and {node.attr} are not supported on {self._cluster_type_str()}. ' + f'Rewrite it using spark.{replacement}', + start_line=node.lineno, + start_col=node.col_offset, + end_line=node.end_lineno or 0, + end_col=node.end_col_offset or 0, + ) + return Failure( code='legacy-context-in-shared-clusters', - message=f'{node.value.id} is not supported on {self._cluster_type_str()}. Rewrite it using spark', + message=f'{name} is not supported on {self._cluster_type_str()}. Rewrite it using spark', start_line=node.lineno, start_col=node.col_offset, end_line=node.end_lineno or 0, diff --git a/tests/unit/source_code/test_spark_connect.py b/tests/unit/source_code/test_spark_connect.py index fc7383be9b..9bece047bb 100644 --- a/tests/unit/source_code/test_spark_connect.py +++ b/tests/unit/source_code/test_spark_connect.py @@ -143,6 +143,36 @@ def test_rdd_context_match_serverless(): ] == list(linter.lint(code)) +def test_conf_shared(): + linter = SparkConnectLinter(is_serverless=False) + code = """df.sparkContext.getConf().get('spark.my.conf')""" + assert [ + Failure( + code='legacy-context-in-shared-clusters', + message='sparkContext and getConf are not supported on UC Shared Clusters. Rewrite it using spark.conf', + start_line=1, + start_col=0, + end_line=1, + end_col=23, + ), + ] == list(linter.lint(code)) + + +def test_conf_serverless(): + linter = SparkConnectLinter(is_serverless=True) + code = """sc._conf().get('spark.my.conf')""" + assert [ + Failure( + code='legacy-context-in-shared-clusters', + message='sc and _conf are not supported on Serverless Compute. Rewrite it using spark.conf', + start_line=1, + start_col=0, + end_line=1, + end_col=8, + ), + ] == list(linter.lint(code)) + + def test_logging_shared(): logging_matcher = LoggingMatcher(is_serverless=False) code = """ From 28baf8b1ec91c8b59fd48c90bdad2a68cba0b637 Mon Sep 17 00:00:00 2001 From: Vsevolod Stepanov Date: Thu, 30 May 2024 18:20:49 +0200 Subject: [PATCH 3/3] detect mapPartition uses --- .../labs/ucx/source_code/spark_connect.py | 25 +++++++++++++++++- tests/unit/source_code/test_spark_connect.py | 26 +++++++++++++++++++ 2 files changed, 50 insertions(+), 1 deletion(-) diff --git a/src/databricks/labs/ucx/source_code/spark_connect.py b/src/databricks/labs/ucx/source_code/spark_connect.py index 93ba40b627..f3f386f0d3 100644 --- a/src/databricks/labs/ucx/source_code/spark_connect.py +++ b/src/databricks/labs/ucx/source_code/spark_connect.py @@ -68,6 +68,26 @@ class RDDApiMatcher(SharedClusterMatcher): ] def lint(self, node: ast.AST) -> Iterator[Advice]: + yield from self._lint_sc(node) + yield from self._lint_rdd_use(node) + + def _lint_rdd_use(self, node: ast.AST) -> Iterator[Advice]: + if isinstance(node, ast.Attribute): + if node.attr == 'rdd': + yield self._rdd_failure(node) + return + if isinstance(node, ast.Call) and isinstance(node.func, ast.Attribute) and node.func.attr == 'mapPartitions': + yield Failure( + code='rdd-in-shared-clusters', + message=f'RDD APIs are not supported on {self._cluster_type_str()}. ' + f'Use mapInArrow() or Pandas UDFs instead', + start_line=node.lineno, + start_col=node.col_offset, + end_line=node.end_lineno or 0, + end_col=node.end_col_offset or 0, + ) + + def _lint_sc(self, node: ast.AST) -> Iterator[Advice]: if not isinstance(node, ast.Call) or not isinstance(node.func, ast.Attribute): return if node.func.attr not in self._SC_METHODS: @@ -75,7 +95,10 @@ def lint(self, node: ast.AST) -> Iterator[Advice]: function_name = AstHelper.get_full_function_name(node) if not function_name or not function_name.endswith(f"sc.{node.func.attr}"): return - yield Failure( + yield self._rdd_failure(node) + + def _rdd_failure(self, node: ast.AST) -> Advice: + return Failure( code='rdd-in-shared-clusters', message=f'RDD APIs are not supported on {self._cluster_type_str()}. Rewrite it using DataFrame API', start_line=node.lineno, diff --git a/tests/unit/source_code/test_spark_connect.py b/tests/unit/source_code/test_spark_connect.py index 9bece047bb..bd9fee3bb4 100644 --- a/tests/unit/source_code/test_spark_connect.py +++ b/tests/unit/source_code/test_spark_connect.py @@ -143,6 +143,32 @@ def test_rdd_context_match_serverless(): ] == list(linter.lint(code)) +def test_rdd_map_partitions(): + linter = SparkConnectLinter(is_serverless=False) + code = """ +df = spark.createDataFrame([]) +df.rdd.mapPartitions(myUdf) + """ + assert [ + Failure( + code="rdd-in-shared-clusters", + message='RDD APIs are not supported on UC Shared Clusters. Use mapInArrow() or Pandas UDFs instead', + start_line=3, + start_col=0, + end_line=3, + end_col=27, + ), + Failure( + code="rdd-in-shared-clusters", + message='RDD APIs are not supported on UC Shared Clusters. Rewrite it using DataFrame API', + start_line=3, + start_col=0, + end_line=3, + end_col=6, + ), + ] == list(linter.lint(code)) + + def test_conf_shared(): linter = SparkConnectLinter(is_serverless=False) code = """df.sparkContext.getConf().get('spark.my.conf')"""