Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add linters to check for spark logging and configuration access #1808

Merged
merged 3 commits into from
May 30, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
93 changes: 87 additions & 6 deletions src/databricks/labs/ucx/source_code/spark_connect.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,15 +68,37 @@ class RDDApiMatcher(SharedClusterMatcher):
]

def lint(self, node: ast.AST) -> Iterator[Advice]:
if not isinstance(node, ast.Call):
yield from self._lint_sc(node)
yield from self._lint_rdd_use(node)

def _lint_rdd_use(self, node: ast.AST) -> Iterator[Advice]:
if isinstance(node, ast.Attribute):
if node.attr == 'rdd':
yield self._rdd_failure(node)
return
if isinstance(node, ast.Call) and isinstance(node.func, ast.Attribute) and node.func.attr == 'mapPartitions':
yield Failure(
code='rdd-in-shared-clusters',
message=f'RDD APIs are not supported on {self._cluster_type_str()}. '
f'Use mapInArrow() or Pandas UDFs instead',
start_line=node.lineno,
start_col=node.col_offset,
end_line=node.end_lineno or 0,
end_col=node.end_col_offset or 0,
)

def _lint_sc(self, node: ast.AST) -> Iterator[Advice]:
if not isinstance(node, ast.Call) or not isinstance(node.func, ast.Attribute):
return
assert isinstance(node.func, ast.Attribute) # Avoid linter warning
if node.func.attr not in self._SC_METHODS:
return
function_name = AstHelper.get_full_function_name(node)
if not function_name or not function_name.endswith(f"sc.{node.func.attr}"):
return
yield Failure(
yield self._rdd_failure(node)

def _rdd_failure(self, node: ast.AST) -> Advice:
return Failure(
code='rdd-in-shared-clusters',
message=f'RDD APIs are not supported on {self._cluster_type_str()}. Rewrite it using DataFrame API',
start_line=node.lineno,
Expand All @@ -88,28 +110,87 @@ def lint(self, node: ast.AST) -> Iterator[Advice]:

class SparkSqlContextMatcher(SharedClusterMatcher):
_ATTRIBUTES = ["sc", "sqlContext", "sparkContext"]
_KNOWN_REPLACEMENTS = {"getConf": "conf", "_conf": "conf"}

def lint(self, node: ast.AST) -> Iterator[Advice]:
if not isinstance(node, ast.Attribute):
return
if not isinstance(node.value, ast.Name) or node.value.id not in SparkSqlContextMatcher._ATTRIBUTES:

if isinstance(node.value, ast.Name) and node.value.id in SparkSqlContextMatcher._ATTRIBUTES:
yield self._get_advice(node, node.value.id)
# sparkContext can be an attribute as in df.sparkContext.getConf()
if isinstance(node.value, ast.Attribute) and node.value.attr == 'sparkContext':
yield self._get_advice(node, node.value.attr)

def _get_advice(self, node: ast.Attribute, name: str) -> Advice:
if node.attr in SparkSqlContextMatcher._KNOWN_REPLACEMENTS:
replacement = SparkSqlContextMatcher._KNOWN_REPLACEMENTS[node.attr]
return Failure(
code='legacy-context-in-shared-clusters',
message=f'{name} and {node.attr} are not supported on {self._cluster_type_str()}. '
f'Rewrite it using spark.{replacement}',
start_line=node.lineno,
start_col=node.col_offset,
end_line=node.end_lineno or 0,
end_col=node.end_col_offset or 0,
)
return Failure(
code='legacy-context-in-shared-clusters',
message=f'{name} is not supported on {self._cluster_type_str()}. Rewrite it using spark',
start_line=node.lineno,
start_col=node.col_offset,
end_line=node.end_lineno or 0,
end_col=node.end_col_offset or 0,
)


class LoggingMatcher(SharedClusterMatcher):
def lint(self, node: ast.AST) -> Iterator[Advice]:
yield from self._match_sc_set_log_level(node)
yield from self._match_jvm_log(node)

def _match_sc_set_log_level(self, node: ast.AST) -> Iterator[Advice]:
if not isinstance(node, ast.Call) or not isinstance(node.func, ast.Attribute):
return
if node.func.attr != 'setLogLevel':
return
function_name = AstHelper.get_full_function_name(node)
if not function_name or not function_name.endswith('sc.setLogLevel'):
return

yield Failure(
code='legacy-context-in-shared-clusters',
message=f'{node.value.id} is not supported on {self._cluster_type_str()}. Rewrite it using spark',
code='spark-logging-in-shared-clusters',
message=f'Cannot set Spark log level directly from code on {self._cluster_type_str()}. '
f'Remove the call and set the cluster spark conf \'spark.log.level\' instead',
start_line=node.lineno,
start_col=node.col_offset,
end_line=node.end_lineno or 0,
end_col=node.end_col_offset or 0,
)

def _match_jvm_log(self, node: ast.AST) -> Iterator[Advice]:
if not isinstance(node, ast.Attribute):
return
attribute_name = AstHelper.get_full_attribute_name(node)
if attribute_name and attribute_name.endswith('org.apache.log4j'):
yield Failure(
code='spark-logging-in-shared-clusters',
message=f'Cannot access Spark Driver JVM logger on {self._cluster_type_str()}. '
f'Use logging.getLogger() instead',
start_line=node.lineno,
start_col=node.col_offset,
end_line=node.end_lineno or 0,
end_col=node.end_col_offset or 0,
)


class SparkConnectLinter(Linter):
def __init__(self, is_serverless: bool = False):
self._matchers = [
JvmAccessMatcher(is_serverless=is_serverless),
RDDApiMatcher(is_serverless=is_serverless),
SparkSqlContextMatcher(is_serverless=is_serverless),
LoggingMatcher(is_serverless=is_serverless),
]

def lint(self, code: str) -> Iterator[Advice]:
Expand Down
131 changes: 130 additions & 1 deletion tests/unit/source_code/test_spark_connect.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
import ast
from itertools import chain

from databricks.labs.ucx.source_code.base import Failure
from databricks.labs.ucx.source_code.spark_connect import SparkConnectLinter
from databricks.labs.ucx.source_code.spark_connect import LoggingMatcher, SparkConnectLinter


def test_jvm_access_match_shared():
Expand Down Expand Up @@ -140,6 +143,132 @@ def test_rdd_context_match_serverless():
] == list(linter.lint(code))


def test_rdd_map_partitions():
linter = SparkConnectLinter(is_serverless=False)
code = """
df = spark.createDataFrame([])
df.rdd.mapPartitions(myUdf)
"""
assert [
Failure(
code="rdd-in-shared-clusters",
message='RDD APIs are not supported on UC Shared Clusters. Use mapInArrow() or Pandas UDFs instead',
start_line=3,
start_col=0,
end_line=3,
end_col=27,
),
Failure(
code="rdd-in-shared-clusters",
message='RDD APIs are not supported on UC Shared Clusters. Rewrite it using DataFrame API',
start_line=3,
start_col=0,
end_line=3,
end_col=6,
),
] == list(linter.lint(code))


def test_conf_shared():
linter = SparkConnectLinter(is_serverless=False)
code = """df.sparkContext.getConf().get('spark.my.conf')"""
assert [
Failure(
code='legacy-context-in-shared-clusters',
message='sparkContext and getConf are not supported on UC Shared Clusters. Rewrite it using spark.conf',
start_line=1,
start_col=0,
end_line=1,
end_col=23,
),
] == list(linter.lint(code))


def test_conf_serverless():
linter = SparkConnectLinter(is_serverless=True)
code = """sc._conf().get('spark.my.conf')"""
assert [
Failure(
code='legacy-context-in-shared-clusters',
message='sc and _conf are not supported on Serverless Compute. Rewrite it using spark.conf',
start_line=1,
start_col=0,
end_line=1,
end_col=8,
),
] == list(linter.lint(code))


def test_logging_shared():
logging_matcher = LoggingMatcher(is_serverless=False)
code = """
sc.setLogLevel("INFO")
setLogLevel("WARN")

log4jLogger = sc._jvm.org.apache.log4j
LOGGER = log4jLogger.LogManager.getLogger(__name__)
sc._jvm.org.apache.log4j.LogManager.getLogger(__name__).info("test")

"""

assert [
Failure(
code='spark-logging-in-shared-clusters',
message='Cannot set Spark log level directly from code on UC Shared Clusters. '
'Remove the call and set the cluster spark conf \'spark.log.level\' instead',
start_line=2,
start_col=0,
end_line=2,
end_col=22,
),
Failure(
code='spark-logging-in-shared-clusters',
message='Cannot access Spark Driver JVM logger on UC Shared Clusters. ' 'Use logging.getLogger() instead',
start_line=5,
start_col=14,
end_line=5,
end_col=38,
),
Failure(
code='spark-logging-in-shared-clusters',
message='Cannot access Spark Driver JVM logger on UC Shared Clusters. ' 'Use logging.getLogger() instead',
start_line=7,
start_col=0,
end_line=7,
end_col=24,
),
] == list(chain.from_iterable([logging_matcher.lint(node) for node in ast.walk(ast.parse(code))]))


def test_logging_serverless():
logging_matcher = LoggingMatcher(is_serverless=True)
code = """
sc.setLogLevel("INFO")
log4jLogger = sc._jvm.org.apache.log4j

"""

assert [
Failure(
code='spark-logging-in-shared-clusters',
message='Cannot set Spark log level directly from code on Serverless Compute. '
'Remove the call and set the cluster spark conf \'spark.log.level\' instead',
start_line=2,
start_col=0,
end_line=2,
end_col=22,
),
Failure(
code='spark-logging-in-shared-clusters',
message='Cannot access Spark Driver JVM logger on Serverless Compute. ' 'Use logging.getLogger() instead',
start_line=3,
start_col=14,
end_line=3,
end_col=38,
),
] == list(chain.from_iterable([logging_matcher.lint(node) for node in ast.walk(ast.parse(code))]))


def test_valid_code():
linter = SparkConnectLinter()
code = """
Expand Down
Loading