diff --git a/src/databricks/labs/ucx/source_code/dependencies.py b/src/databricks/labs/ucx/source_code/dependencies.py new file mode 100644 index 0000000000..0a615ca7f0 --- /dev/null +++ b/src/databricks/labs/ucx/source_code/dependencies.py @@ -0,0 +1,164 @@ +from __future__ import annotations + +import abc +from collections.abc import Callable + +from databricks.sdk.service.workspace import ObjectType, ObjectInfo, ExportFormat +from databricks.sdk import WorkspaceClient + + +class Dependency: + + @staticmethod + def from_object_info(object_info: ObjectInfo): + assert object_info.path is not None + return Dependency(object_info.object_type, object_info.path) + + def __init__(self, object_type: ObjectType | None, path: str): + self._type = object_type + self._path = path + + @property + def type(self) -> ObjectType | None: + return self._type + + @property + def path(self) -> str: + return self._path + + def __hash__(self): + return hash(self.path) + + def __eq__(self, other): + return isinstance(other, Dependency) and self.path == other.path + + +class SourceContainer(abc.ABC): + + @property + @abc.abstractmethod + def object_type(self) -> ObjectType: + raise NotImplementedError() + + @abc.abstractmethod + def build_dependency_graph(self, graph) -> None: + raise NotImplementedError() + + +class DependencyLoader: + + def __init__(self, ws: WorkspaceClient): + self._ws = ws + + def load_dependency(self, dependency: Dependency) -> SourceContainer | None: + object_info = self._load_object(dependency) + if object_info.object_type is ObjectType.NOTEBOOK: + return self._load_notebook(object_info) + raise NotImplementedError(str(object_info.object_type)) + + def _load_object(self, dependency: Dependency) -> ObjectInfo: + result = self._ws.workspace.list(dependency.path) + object_info = next((oi for oi in result), None) + if object_info is None: + raise ValueError(f"Could not locate object at '{dependency.path}'") + if dependency.type is not None and object_info.object_type is not dependency.type: + raise ValueError( + f"Invalid object at '{dependency.path}', expected a {str(dependency.type)}, got a {str(object_info.object_type)}" + ) + return object_info + + def _load_notebook(self, object_info: ObjectInfo) -> SourceContainer: + # local import to avoid circular dependency + # pylint: disable=import-outside-toplevel + from databricks.labs.ucx.source_code.notebook import Notebook + + assert object_info.path is not None + assert object_info.language is not None + source = self._load_source(object_info) + return Notebook.parse(object_info.path, source, object_info.language) + + def _load_source(self, object_info: ObjectInfo) -> str: + if not object_info.language or not object_info.path: + raise ValueError(f"Invalid ObjectInfo: {object_info}") + with self._ws.workspace.download(object_info.path, format=ExportFormat.SOURCE) as f: + return f.read().decode("utf-8") + + +class DependencyGraph: + + def __init__(self, dependency: Dependency, parent: DependencyGraph | None, loader: DependencyLoader): + self._dependency = dependency + self._parent = parent + self._loader = loader + self._dependencies: dict[Dependency, DependencyGraph] = {} + + @property + def dependency(self): + return self._dependency + + @property + def path(self): + return self._dependency.path + + def register_dependency(self, dependency: Dependency) -> DependencyGraph | None: + # already registered ? + child_graph = self.locate_dependency(dependency) + if child_graph is not None: + self._dependencies[dependency] = child_graph + return child_graph + # nay, create the child graph and populate it + child_graph = DependencyGraph(dependency, self, self._loader) + self._dependencies[dependency] = child_graph + container = self._loader.load_dependency(dependency) + if not container: + return None + container.build_dependency_graph(child_graph) + return child_graph + + def locate_dependency(self, dependency: Dependency) -> DependencyGraph | None: + # need a list since unlike JS, Python won't let you assign closure variables + found: list[DependencyGraph] = [] + path = dependency.path + # TODO https://github.com/databrickslabs/ucx/issues/1287 + path = path[2:] if path.startswith('./') else path + + def check_registered_dependency(graph): + # TODO https://github.com/databrickslabs/ucx/issues/1287 + graph_path = graph.path[2:] if graph.path.startswith('./') else graph.path + if graph_path == path: + found.append(graph) + return True + return False + + self.root.visit(check_registered_dependency) + return found[0] if len(found) > 0 else None + + @property + def root(self): + return self if self._parent is None else self._parent.root + + @property + def dependencies(self) -> set[Dependency]: + dependencies: set[Dependency] = set() + + def add_to_dependencies(graph: DependencyGraph) -> bool: + if graph.dependency in dependencies: + return True + dependencies.add(graph.dependency) + return False + + self.visit(add_to_dependencies) + return dependencies + + @property + def paths(self) -> set[str]: + return {d.path for d in self.dependencies} + + # when visit_node returns True it interrupts the visit + def visit(self, visit_node: Callable[[DependencyGraph], bool | None]) -> bool | None: + if visit_node(self): + return True + for dependency in self._dependencies.values(): + if dependency.visit(visit_node): + return True + return False diff --git a/src/databricks/labs/ucx/source_code/notebook.py b/src/databricks/labs/ucx/source_code/notebook.py index 7173d6cce6..d93606bb93 100644 --- a/src/databricks/labs/ucx/source_code/notebook.py +++ b/src/databricks/labs/ucx/source_code/notebook.py @@ -4,13 +4,13 @@ import logging from abc import ABC, abstractmethod from ast import parse as parse_python -from collections.abc import Callable from enum import Enum from sqlglot import ParseError as SQLParseError from sqlglot import parse as parse_sql -from databricks.sdk.service.workspace import Language +from databricks.sdk.service.workspace import Language, ObjectType +from databricks.labs.ucx.source_code.dependencies import DependencyGraph, Dependency, SourceContainer from databricks.labs.ucx.source_code.python_linter import ASTLinter, PythonLinter @@ -79,7 +79,8 @@ def build_dependency_graph(self, parent: DependencyGraph): assert isinstance(node, ast.Call) path = PythonLinter.get_dbutils_notebook_run_path_arg(node) if isinstance(path, ast.Constant): - parent.register_dependency(path.value.strip("'").strip('"')) + dependency = Dependency(ObjectType.NOTEBOOK, path.value.strip("'").strip('"')) + parent.register_dependency(dependency) class RCell(Cell): @@ -155,7 +156,7 @@ def build_dependency_graph(self, parent: DependencyGraph): start = line.index(command) if start >= 0: path = line[start + len(command) :].strip() - parent.register_dependency(path.strip('"')) + parent.register_dependency(Dependency(ObjectType.NOTEBOOK, path.strip('"'))) return raise ValueError("Missing notebook path in %run command") @@ -297,76 +298,7 @@ def wrap_with_magic(self, code: str, cell_language: CellLanguage) -> str: return "\n".join(lines) -class DependencyGraph: - - def __init__(self, path: str, parent: DependencyGraph | None, locator: Callable[[str], Notebook]): - self._path = path - self._parent = parent - self._locator = locator - self._dependencies: dict[str, DependencyGraph] = {} - - @property - def path(self): - return self._path - - def register_dependency(self, path: str) -> DependencyGraph | None: - # already registered ? - child_graph = self.locate_dependency(path) - if child_graph is not None: - self._dependencies[path] = child_graph - return child_graph - # nay, create the child graph and populate it - child_graph = DependencyGraph(path, self, self._locator) - self._dependencies[path] = child_graph - notebook = self._locator(path) - if not notebook: - return None - notebook.build_dependency_graph(child_graph) - return child_graph - - def locate_dependency(self, path: str) -> DependencyGraph | None: - # need a list since unlike JS, Python won't let you assign closure variables - found: list[DependencyGraph] = [] - path = path[2:] if path.startswith('./') else path - - def check_registered_dependency(graph): - graph_path = graph.path[2:] if graph.path.startswith('./') else graph.path - if graph_path == path: - found.append(graph) - return True - return False - - self.root.visit(check_registered_dependency) - return found[0] if len(found) > 0 else None - - @property - def root(self): - return self if self._parent is None else self._parent.root - - @property - def paths(self) -> set[str]: - paths: set[str] = set() - - def add_to_paths(graph: DependencyGraph) -> bool: - if graph.path in paths: - return True - paths.add(graph.path) - return False - - self.visit(add_to_paths) - return paths - - # when visit_node returns True it interrupts the visit - def visit(self, visit_node: Callable[[DependencyGraph], bool | None]) -> bool | None: - if visit_node(self): - return True - for dependency in self._dependencies.values(): - if dependency.visit(visit_node): - return True - return False - - -class Notebook: +class Notebook(SourceContainer): @staticmethod def parse(path: str, source: str, default_language: Language) -> Notebook: @@ -383,6 +315,10 @@ def __init__(self, path: str, source: str, language: Language, cells: list[Cell] self._cells = cells self._ends_with_lf = ends_with_lf + @property + def object_type(self) -> ObjectType: + return ObjectType.NOTEBOOK + @property def path(self) -> str: return self._path diff --git a/src/databricks/labs/ucx/source_code/notebook_migrator.py b/src/databricks/labs/ucx/source_code/notebook_migrator.py index 4f765c0bb5..70454cb2a6 100644 --- a/src/databricks/labs/ucx/source_code/notebook_migrator.py +++ b/src/databricks/labs/ucx/source_code/notebook_migrator.py @@ -2,13 +2,25 @@ from databricks.sdk.service.workspace import ExportFormat, ObjectInfo, ObjectType from databricks.labs.ucx.source_code.languages import Languages -from databricks.labs.ucx.source_code.notebook import DependencyGraph, Notebook, RunCell +from databricks.labs.ucx.source_code.notebook import Notebook, RunCell +from databricks.labs.ucx.source_code.dependencies import DependencyGraph, Dependency, DependencyLoader class NotebookMigrator: - def __init__(self, ws: WorkspaceClient, languages: Languages): + def __init__(self, ws: WorkspaceClient, languages: Languages, loader: DependencyLoader): self._ws = ws self._languages = languages + self._loader = loader + + def build_dependency_graph(self, object_info: ObjectInfo) -> DependencyGraph: + if not object_info.path or not object_info.language or object_info.object_type is not ObjectType.NOTEBOOK: + raise ValueError("Not a valid Notebook") + dependency = Dependency.from_object_info(object_info) + graph = DependencyGraph(dependency, None, self._loader) + container = self._loader.load_dependency(dependency) + if container is not None: + container.build_dependency_graph(graph) + return graph def revert(self, object_info: ObjectInfo): if not object_info.path: @@ -21,23 +33,16 @@ def revert(self, object_info: ObjectInfo): def apply(self, object_info: ObjectInfo) -> bool: if not object_info.path or not object_info.language or object_info.object_type is not ObjectType.NOTEBOOK: return False - notebook = self._load_notebook(object_info) + notebook = self._loader.load_dependency(Dependency.from_object_info(object_info)) + assert isinstance(notebook, Notebook) return self._apply(notebook) - def build_dependency_graph(self, object_info: ObjectInfo) -> DependencyGraph: - if not object_info.path or not object_info.language or object_info.object_type is not ObjectType.NOTEBOOK: - raise ValueError("Not a valid Notebook") - notebook = self._load_notebook(object_info) - dependencies = DependencyGraph(object_info.path, None, self._load_notebook_from_path) - notebook.build_dependency_graph(dependencies) - return dependencies - def _apply(self, notebook: Notebook) -> bool: changed = False for cell in notebook.cells: # %run is not a supported language, so this needs to come first if isinstance(cell, RunCell): - # TODO data on what to change to ? + # TODO migration data ? if cell.migrate_notebook_path(): changed = True continue @@ -52,27 +57,3 @@ def _apply(self, notebook: Notebook) -> bool: self._ws.workspace.upload(notebook.path, notebook.to_migrated_code().encode("utf-8")) # TODO https://github.com/databrickslabs/ucx/issues/1327 store 'migrated' status return changed - - def _load_notebook_from_path(self, path: str) -> Notebook: - object_info = self._load_object(path) - if object_info.object_type is not ObjectType.NOTEBOOK: - raise ValueError(f"Not a Notebook: {path}") - return self._load_notebook(object_info) - - def _load_object(self, path: str) -> ObjectInfo: - result = self._ws.workspace.list(path) - object_info = next((oi for oi in result), None) - if object_info is None: - raise ValueError(f"Could not locate object at '{path}'") - return object_info - - def _load_notebook(self, object_info: ObjectInfo) -> Notebook: - assert object_info is not None and object_info.path is not None and object_info.language is not None - source = self._load_source(object_info) - return Notebook.parse(object_info.path, source, object_info.language) - - def _load_source(self, object_info: ObjectInfo) -> str: - if not object_info.language or not object_info.path: - raise ValueError(f"Invalid ObjectInfo: {object_info}") - with self._ws.workspace.download(object_info.path, format=ExportFormat.SOURCE) as f: - return f.read().decode("utf-8") diff --git a/tests/unit/source_code/test_notebook.py b/tests/unit/source_code/test_notebook.py index f3ef36df0e..c116542e8c 100644 --- a/tests/unit/source_code/test_notebook.py +++ b/tests/unit/source_code/test_notebook.py @@ -1,10 +1,11 @@ -from collections.abc import Callable +from unittest.mock import create_autospec import pytest -from databricks.sdk.service.workspace import Language +from databricks.sdk.service.workspace import Language, ObjectType from databricks.labs.ucx.source_code.base import Advisory -from databricks.labs.ucx.source_code.notebook import Notebook, DependencyGraph +from databricks.labs.ucx.source_code.notebook import Notebook +from databricks.labs.ucx.source_code.dependencies import DependencyGraph, DependencyLoader, Dependency from databricks.labs.ucx.source_code.python_linter import PythonLinter from tests.unit import _load_sources @@ -106,27 +107,29 @@ def test_notebook_generates_runnable_cells(source: tuple[str, Language, list[str assert cell.is_runnable() -def notebook_locator( - paths: list[str], sources: list[str], languages: list[Language] -) -> Callable[[str], Notebook | None]: - def locator(path: str) -> Notebook | None: - local_path = path[2:] if path.startswith('./') else path +def mock_dependency_loader(paths: list[str], sources: list[str], languages: list[Language]) -> DependencyLoader: + def load_dependency_side_effect(*args): + dependency = args[0] + local_path = dependency.path[2:] if dependency.path.startswith('./') else dependency.path index = paths.index(local_path) if index < 0: - raise ValueError(f"Can't locate notebook {path}") + raise ValueError(f"Can't locate: {dependency.path}") return Notebook.parse(paths[index], sources[index], languages[index]) - return locator + loader = create_autospec(DependencyLoader) + loader.load_dependency.side_effect = load_dependency_side_effect + return loader def test_notebook_builds_leaf_dependency_graph(): paths = ["leaf1.py.txt"] sources: list[str] = _load_sources(Notebook, *paths) languages = [Language.PYTHON] * len(paths) - locator = notebook_locator(paths, sources, languages) - notebook = locator(paths[0]) - graph = DependencyGraph(paths[0], None, locator) - notebook.build_dependency_graph(graph) + loader = mock_dependency_loader(paths, sources, languages) + dependency = Dependency(ObjectType.NOTEBOOK, paths[0]) + graph = DependencyGraph(dependency, None, loader) + container = loader.load_dependency(dependency) + container.build_dependency_graph(graph) assert graph.paths == {"leaf1.py.txt"} @@ -134,10 +137,11 @@ def test_notebook_builds_depth1_dependency_graph(): paths = ["root1.run.py.txt", "leaf1.py.txt", "leaf2.py.txt"] sources: list[str] = _load_sources(Notebook, *paths) languages = [Language.PYTHON] * len(paths) - locator = notebook_locator(paths, sources, languages) - notebook = locator(paths[0]) - graph = DependencyGraph(paths[0], None, locator) - notebook.build_dependency_graph(graph) + loader = mock_dependency_loader(paths, sources, languages) + dependency = Dependency(ObjectType.NOTEBOOK, paths[0]) + graph = DependencyGraph(dependency, None, loader) + container = loader.load_dependency(dependency) + container.build_dependency_graph(graph) actual = {path[2:] if path.startswith('./') else path for path in graph.paths} assert actual == set(paths) @@ -146,10 +150,11 @@ def test_notebook_builds_depth2_dependency_graph(): paths = ["root2.run.py.txt", "root1.run.py.txt", "leaf1.py.txt", "leaf2.py.txt"] sources: list[str] = _load_sources(Notebook, *paths) languages = [Language.PYTHON] * len(paths) - locator = notebook_locator(paths, sources, languages) - notebook = locator(paths[0]) - graph = DependencyGraph(paths[0], None, locator) - notebook.build_dependency_graph(graph) + loader = mock_dependency_loader(paths, sources, languages) + dependency = Dependency(ObjectType.NOTEBOOK, paths[0]) + graph = DependencyGraph(dependency, None, loader) + container = loader.load_dependency(dependency) + container.build_dependency_graph(graph) actual = {path[2:] if path.startswith('./') else path for path in graph.paths} assert actual == set(paths) @@ -158,16 +163,20 @@ def test_notebook_builds_dependency_graph_avoiding_duplicates(): paths = ["root3.run.py.txt", "root1.run.py.txt", "leaf1.py.txt", "leaf2.py.txt"] sources: list[str] = _load_sources(Notebook, *paths) languages = [Language.PYTHON] * len(paths) - locator = notebook_locator(paths, sources, languages) - notebook = locator(paths[0]) + loader = mock_dependency_loader(paths, sources, languages) + old_load_dependency_side_effect = loader.load_dependency.side_effect + dependency = Dependency(ObjectType.NOTEBOOK, paths[0]) + graph = DependencyGraph(dependency, None, loader) visited: list[str] = [] - def registering_locator(path: str): - visited.append(path) - return locator(path) + def load_dependency_side_effect(*args): + dep = args[0] + visited.append(dep.path) + return old_load_dependency_side_effect(*args) - graph = DependencyGraph(paths[0], None, registering_locator) - notebook.build_dependency_graph(graph) + loader.load_dependency.side_effect = load_dependency_side_effect + container = loader.load_dependency(dependency) + container.build_dependency_graph(graph) # if visited once only, set and list will have same len assert len(set(visited)) == len(visited) @@ -176,10 +185,11 @@ def test_notebook_builds_cyclical_dependency_graph(): paths = ["cyclical1.run.py.txt", "cyclical2.run.py.txt"] sources: list[str] = _load_sources(Notebook, *paths) languages = [Language.PYTHON] * len(paths) - locator = notebook_locator(paths, sources, languages) - notebook = locator(paths[0]) - graph = DependencyGraph(paths[0], None, locator) - notebook.build_dependency_graph(graph) + loader = mock_dependency_loader(paths, sources, languages) + dependency = Dependency(ObjectType.NOTEBOOK, paths[0]) + graph = DependencyGraph(dependency, None, loader) + container = loader.load_dependency(dependency) + container.build_dependency_graph(graph) actual = {path[2:] if path.startswith('./') else path for path in graph.paths} assert actual == set(paths) @@ -188,10 +198,11 @@ def test_notebook_builds_python_dependency_graph(): paths = ["root4.py.txt", "leaf3.py.txt"] sources: list[str] = _load_sources(Notebook, *paths) languages = [Language.PYTHON] * len(paths) - locator = notebook_locator(paths, sources, languages) - notebook = locator(paths[0]) - graph = DependencyGraph(paths[0], None, locator) - notebook.build_dependency_graph(graph) + loader = mock_dependency_loader(paths, sources, languages) + dependency = Dependency(ObjectType.NOTEBOOK, paths[0]) + graph = DependencyGraph(dependency, None, loader) + container = loader.load_dependency(dependency) + container.build_dependency_graph(graph) actual = {path[2:] if path.startswith('./') else path for path in graph.paths} assert actual == set(paths) diff --git a/tests/unit/source_code/test_notebook_migrator.py b/tests/unit/source_code/test_notebook_migrator.py index 32112fc5d1..865e65144a 100644 --- a/tests/unit/source_code/test_notebook_migrator.py +++ b/tests/unit/source_code/test_notebook_migrator.py @@ -6,6 +6,7 @@ from databricks.sdk.service.workspace import ExportFormat, Language, ObjectInfo, ObjectType from databricks.labs.ucx.hive_metastore.table_migrate import MigrationIndex +from databricks.labs.ucx.source_code.dependencies import DependencyLoader from databricks.labs.ucx.source_code.languages import Languages from databricks.labs.ucx.source_code.notebook import Notebook from databricks.labs.ucx.source_code.notebook_migrator import NotebookMigrator @@ -15,7 +16,7 @@ def test_apply_invalid_object_fails(): ws = create_autospec(WorkspaceClient) languages = create_autospec(Languages) - migrator = NotebookMigrator(ws, languages) + migrator = NotebookMigrator(ws, languages, DependencyLoader(ws)) object_info = ObjectInfo(language=Language.PYTHON) assert not migrator.apply(object_info) @@ -23,7 +24,7 @@ def test_apply_invalid_object_fails(): def test_revert_invalid_object_fails(): ws = create_autospec(WorkspaceClient) languages = create_autospec(Languages) - migrator = NotebookMigrator(ws, languages) + migrator = NotebookMigrator(ws, languages, DependencyLoader(ws)) object_info = ObjectInfo(language=Language.PYTHON) assert not migrator.revert(object_info) @@ -32,7 +33,7 @@ def test_revert_restores_original_code(): ws = create_autospec(WorkspaceClient) ws.workspace.download.return_value.__enter__.return_value.read.return_value = b'original_code' languages = create_autospec(Languages) - migrator = NotebookMigrator(ws, languages) + migrator = NotebookMigrator(ws, languages, DependencyLoader(ws)) object_info = ObjectInfo(path='path', language=Language.PYTHON) migrator.revert(object_info) ws.workspace.download.assert_called_with('path.bak', format=ExportFormat.SOURCE) @@ -48,7 +49,9 @@ def test_apply_returns_false_when_language_not_supported(): ws.workspace.download.return_value.__enter__.return_value.read.return_value = notebook_code.encode("utf-8") languages = create_autospec(Languages) languages.is_supported.return_value = False - migrator = NotebookMigrator(ws, languages) + loader = create_autospec(DependencyLoader) + loader.load_dependency.return_value = Notebook.parse('path', notebook_code, Language.R) + migrator = NotebookMigrator(ws, languages, loader) object_info = ObjectInfo(path='path', language=Language.R, object_type=ObjectType.NOTEBOOK) result = migrator.apply(object_info) assert not result @@ -63,7 +66,9 @@ def test_apply_returns_false_when_no_fixes_applied(): languages = create_autospec(Languages) languages.is_supported.return_value = True languages.apply_fixes.return_value = "# original code" # cell code - migrator = NotebookMigrator(ws, languages) + loader = create_autospec(DependencyLoader) + loader.load_dependency.return_value = Notebook.parse('path', notebook_code, Language.R) + migrator = NotebookMigrator(ws, languages, loader) object_info = ObjectInfo(path='path', language=Language.PYTHON, object_type=ObjectType.NOTEBOOK) assert not migrator.apply(object_info) @@ -81,7 +86,9 @@ def test_apply_returns_true_and_changes_code_when_fixes_applied(): languages = create_autospec(Languages) languages.is_supported.return_value = True languages.apply_fixes.return_value = migrated_cell_code - migrator = NotebookMigrator(ws, languages) + loader = create_autospec(DependencyLoader) + loader.load_dependency.return_value = Notebook.parse('path', original_code, Language.R) + migrator = NotebookMigrator(ws, languages, loader) object_info = ObjectInfo(path='path', language=Language.PYTHON, object_type=ObjectType.NOTEBOOK) assert migrator.apply(object_info) ws.workspace.upload.assert_any_call('path.bak', original_code.encode("utf-8")) @@ -111,7 +118,7 @@ def list_side_effect(*args): ws = create_autospec(WorkspaceClient) ws.workspace.download.side_effect = download_side_effect ws.workspace.list.side_effect = list_side_effect - migrator = NotebookMigrator(ws, Languages(create_autospec(MigrationIndex))) + migrator = NotebookMigrator(ws, Languages(create_autospec(MigrationIndex)), DependencyLoader(ws)) object_info = ObjectInfo(path="root3.run.py.txt", language=Language.PYTHON, object_type=ObjectType.NOTEBOOK) migrator.build_dependency_graph(object_info) assert len(visited) == len(paths) @@ -134,7 +141,7 @@ def download_side_effect(*args, **kwargs): ws = create_autospec(WorkspaceClient) ws.workspace.download.side_effect = download_side_effect ws.workspace.list.return_value = [] - migrator = NotebookMigrator(ws, Languages(create_autospec(MigrationIndex))) + migrator = NotebookMigrator(ws, Languages(create_autospec(MigrationIndex)), DependencyLoader(ws)) object_info = ObjectInfo(path="root1.run.py.txt", language=Language.PYTHON, object_type=ObjectType.NOTEBOOK) with pytest.raises(ValueError): migrator.build_dependency_graph(object_info)