Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(ingest/mlflow): add dataset lineage #12837

Open
wants to merge 11 commits into
base: master
Choose a base branch
from
102 changes: 96 additions & 6 deletions metadata-ingestion/src/datahub/ingestion/source/mlflow.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import json
import time
from dataclasses import dataclass
from typing import Any, Callable, Iterable, List, Optional, TypeVar, Union
from typing import Any, Callable, Iterable, List, Optional, Tuple, TypeVar, Union

from mlflow import MlflowClient
from mlflow.entities import Experiment, Run
Expand Down Expand Up @@ -42,6 +43,7 @@
AuditStampClass,
ContainerClass,
DataPlatformInstanceClass,
DataProcessInstanceInputClass,
DataProcessInstanceOutputClass,
DataProcessInstancePropertiesClass,
DataProcessInstanceRunEventClass,
Expand All @@ -60,16 +62,15 @@
TagAssociationClass,
TagPropertiesClass,
TimeStampClass,
UpstreamClass,
UpstreamLineageClass,
VersionPropertiesClass,
VersionTagClass,
_Aspect,
)
from datahub.metadata.urns import (
DataPlatformUrn,
MlModelUrn,
VersionSetUrn,
)
from datahub.metadata.urns import DataPlatformUrn, MlModelUrn, VersionSetUrn
from datahub.sdk.container import Container
from datahub.sdk.dataset import Dataset

T = TypeVar("T")

Expand Down Expand Up @@ -213,6 +214,7 @@ def _get_experiment_workunits(self) -> Iterable[MetadataWorkUnit]:
if runs:
for run in runs:
yield from self._get_run_workunits(experiment, run)
yield from self._get_dataset_input_workunits(run)

def _get_experiment_custom_properties(self, experiment):
experiment_custom_props = getattr(experiment, "tags", {}) or {}
Expand Down Expand Up @@ -262,6 +264,94 @@ def _convert_run_result_type(
type="SKIPPED", nativeResultType=self.platform
)

def _get_dataset_schema(self, schema: str) -> Optional[List[Tuple[str, str]]]:
try:
schema_dict = json.loads(schema)
except json.JSONDecodeError:
print("Failed to parse schema JSON")
return None

if "mlflow_colspec" in schema_dict:
try:
return [
(field["name"], field["type"])
for field in schema_dict["mlflow_colspec"]
]
except (KeyError, TypeError):
return None
# If the schema is not formatted, return None
return None

def _get_dataset_platform_from_source_type(self, source_type):
# manually map mlflow platform to datahub platform
if source_type == "gs":
return "gcs"
return source_type

def _get_dataset_input_workunits(self, run: Run) -> Iterable[MetadataWorkUnit]:
run_urn = DataProcessInstance(
id=run.info.run_id,
orchestrator=self.platform,
).urn
dataset_inputs = run.inputs.dataset_inputs
dataset_reference_urns = []
for dataset_input in dataset_inputs:
source_type = dataset_input.dataset.source_type
dataset_tags = {k[1]: v[1] for k, v in dataset_input.tags}
dataset = dataset_input.dataset
formatted_platform = self._get_dataset_platform_from_source_type(
source_type
)
custom_properties = dataset_tags
formatted_schema = self._get_dataset_schema(dataset.schema)
# If the schema is not formatted, pass the schema as a custom property
if formatted_schema is None:
custom_properties["schema"] = dataset.schema
# If the dataset is local or code, we create a local dataset reference
if source_type in ("local", "code"):
local_dataset_reference = Dataset(
platform=formatted_platform,
name=dataset.name,
schema=formatted_schema,
custom_properties=custom_properties,
)
yield from local_dataset_reference.as_workunits()
dataset_reference_urns.append(str(local_dataset_reference.urn))
# Otherwise, we create a hosted dataset reference and a hosted dataset
else:
hosted_dataset = Dataset(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

imo we should not be generating dataset entities for non mlflow platforms

here we should just do hosted_dataset_urn = DatasetUrn.... if that urn exists, lineage will show up by default. if it doesn't exist, they'll need to go into the UI and click "show hidden edges" to make them show up

platform=formatted_platform,
name=dataset.name,
schema=formatted_schema,
custom_properties=dataset_tags,
)
hosted_dataset_reference = Dataset(
platform=self.platform,
name=dataset.name,
schema=formatted_schema,
custom_properties=dataset_tags,
upstreams=UpstreamLineageClass(
upstreams=[
UpstreamClass(dataset=str(hosted_dataset.urn), type="COPY")
]
),
)
dataset_reference_urns.append(str(hosted_dataset_reference.urn))

yield from hosted_dataset.as_workunits()
yield from hosted_dataset_reference.as_workunits()

# add the dataset reference as upstream for the run
if dataset_reference_urns:
input_edges = [
EdgeClass(destinationUrn=dataset_referece_urn)
for dataset_referece_urn in dataset_reference_urns
]
yield MetadataChangeProposalWrapper(
entityUrn=str(run_urn),
aspect=DataProcessInstanceInputClass(inputs=[], inputEdges=input_edges),
).as_workunit()

def _get_run_workunits(
self, experiment: Experiment, run: Run
) -> Iterable[MetadataWorkUnit]:
Expand Down
Loading