Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix Python API calling into Scala code #132

Merged
merged 11 commits into from
Dec 8, 2022
1 change: 1 addition & 0 deletions .github/actions/build/action.yml
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ runs:
run: |
mvn --batch-mode --update-snapshots clean compile test-compile
mvn --batch-mode package -DskipTests -Dmaven.test.skip=true
mvn --batch-mode install -DskipTests -Dmaven.test.skip=true -Dgpg.skip
shell: bash

- name: Upload Binaries
Expand Down
2 changes: 1 addition & 1 deletion .github/actions/test-jvm/action.yml
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ runs:
if: always()
uses: actions/upload-artifact@v3
with:
name: Unit Test Results (Spark ${{ inputs.spark-version }} Scala ${{ inputs.scala-version }})
name: JVM Test Results (Spark ${{ inputs.spark-version }} Scala ${{ inputs.scala-version }})
path: |
target/surefire-reports/*.xml
!target/surefire-reports/TEST-org.scalatest*.xml
Expand Down
46 changes: 39 additions & 7 deletions .github/actions/test-python/action.yml
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,9 @@ runs:
run: |
./set-version.sh ${{ inputs.spark-version }} ${{ inputs.scala-version }}
git diff

SPARK_EXTENSION_VERSION=$(grep --max-count=1 "<version>.*</version>" pom.xml | sed -E -e "s/\s*<[^>]+>//g")
echo "SPARK_EXTENSION_VERSION=$SPARK_EXTENSION_VERSION" | tee -a "$GITHUB_ENV"
shell: bash

- name: Fetch Binaries Artifact
Expand All @@ -54,7 +57,7 @@ runs:
uses: actions/cache@v3
with:
path: ~/.cache/pip
key: ${{ runner.os }}-pip-test-${{ inputs.python-version }}-${{ hashFiles('requirements.txt') }}
key: ${{ runner.os }}-pip-test-${{ inputs.python-version }}-${{ hashFiles(format('python/requirements-{0}_{1}.txt', inputs.spark-compat-version, inputs.scala-compat-version)) }}
restore-keys: ${{ runner.os }}-pip-test-${{ inputs.python-version }}-

- name: Setup Python
Expand All @@ -67,22 +70,51 @@ runs:
python -m pip install --upgrade pip
pip install pypandoc
pip install -r python/requirements-${{ inputs.spark-compat-version }}_${{ inputs.scala-compat-version }}.txt
pip install pytest
pip install pytest unittest-xml-reporting

SPARK_HOME=$(python -c "import pyspark; import os; print(os.path.dirname(pyspark.__file__))")
echo "SPARK_HOME=$SPARK_HOME" | tee -a "$GITHUB_ENV"
shell: bash

- name: Python Tests
- name: Python Unit Tests
env:
PYTHONPATH: python:python/test
run: |
python -m pytest python/test --junit-xml pytest.xml
python -m pytest python/test --junit-xml test-results/pytest.xml
shell: bash

- name: Install Spark Extension
run: mvn --batch-mode install -DskipTests -Dmaven.test.skip=true -Dgpg.skip
shell: bash

- name: Python Integration Tests
run: |
find python/test -name 'test*.py' | while read test
do
if ! $SPARK_HOME/bin/spark-submit --packages uk.co.gresearch.spark:spark-extension_${{ inputs.scala-compat-version }}:$SPARK_EXTENSION_VERSION "$test" test-results
then
state="fail"
fi
done
if [[ "$state" == "fail" ]]; then exit 1; fi
shell: bash

- name: Python Release Test
run: |
$SPARK_HOME/bin/spark-submit --packages uk.co.gresearch.spark:spark-extension_${{ inputs.scala-compat-version }}:$SPARK_EXTENSION_VERSION test-release.py
shell: bash

- name: Scala Release Test
run: |
$SPARK_HOME/bin/spark-shell --packages uk.co.gresearch.spark:spark-extension_${{ inputs.scala-compat-version }}:$SPARK_EXTENSION_VERSION < test-release.scala
shell: bash

- name: Upload Unit Test Results
- name: Upload Test Results
if: always()
uses: actions/upload-artifact@v3
with:
name: Unit Test Results (Spark ${{ inputs.spark-version }} Scala ${{ inputs.scala-version }} Python ${{ inputs.python-version }})
path: pytest.xml
name: Python Test Results (Spark ${{ inputs.spark-version }} Scala ${{ inputs.scala-version }} Python ${{ inputs.python-version }})
path: test-results/*.xml

branding:
icon: 'check-circle'
Expand Down
4 changes: 2 additions & 2 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ jobs:
with:
spark-version: ${{ matrix.spark-version }}
scala-version: ${{ matrix.scala-version }}
spark-compat-version: ${{ matrix.spark-compat-version }}-snapshot
spark-compat-version: ${{ matrix.spark-compat-version }}-SNAPSHOT

test-jvm:
name: Test (Spark ${{ matrix.spark-compat-version }}.${{ matrix.spark-patch-version }} Scala ${{ matrix.scala-version }})
Expand Down Expand Up @@ -303,7 +303,7 @@ jobs:
with:
spark-version: ${{ matrix.spark-version }}
scala-version: ${{ matrix.scala-version }}
spark-compat-version: ${{ matrix.spark-compat-version }}-snapshot
spark-compat-version: ${{ matrix.spark-compat-version }}-SNAPSHOT
scala-compat-version: ${{ matrix.scala-compat-version }}

event_file:
Expand Down
16 changes: 5 additions & 11 deletions python/gresearch/spark/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,13 +28,7 @@ def _to_seq(jvm: JVMView, list: List[Any]) -> JavaObject:


def _to_map(jvm: JVMView, map: Mapping[Any, Any]) -> JavaObject:
return _get_scala_object(jvm, "scala.collection.JavaConverters").mapAsScalaMap(map)


def _get_scala_object(jvm: JVMView, name: str) -> JavaObject:
clazz = jvm.java.lang.Class.forName('{}$'.format(name))
ff = clazz.getDeclaredField("MODULE$")
return ff.get(None)
return jvm.scala.collection.JavaConverters.mapAsScalaMap(map)


def histogram(self: DataFrame,
Expand All @@ -58,7 +52,7 @@ def histogram(self: DataFrame,
value_column = col(value_column)
aggregate_columns = [col(column) for column in aggregate_columns]

hist = _get_scala_object(jvm, 'uk.co.gresearch.spark.Histogram')
hist = jvm.uk.co.gresearch.spark.Histogram
jdf = hist.of(self._jdf, _to_seq(jvm, thresholds), value_column, _to_seq(jvm, aggregate_columns))
return DataFrame(jdf, self.session_or_ctx())

Expand Down Expand Up @@ -95,12 +89,12 @@ def with_row_numbers(self: DataFrame,
ascending: Union[bool, List[bool]] = True) -> DataFrame:
jvm = self._sc._jvm
jsl = self._sc._getJavaStorageLevel(storage_level)
juho = _get_scala_object(jvm, 'uk.co.gresearch.spark.UnpersistHandle')
juho = jvm.uk.co.gresearch.spark.UnpersistHandle
juh = unpersist_handle._handle if unpersist_handle else juho.Noop()
jcols = self._sort_cols([order], {'ascending': ascending}) if not isinstance(order, list) or order else jvm.PythonUtils.toSeq([])

row_numners = _get_scala_object(jvm, 'uk.co.gresearch.spark.RowNumbers')
jdf = row_numners \
row_numbers = jvm.uk.co.gresearch.spark.RowNumbers
jdf = row_numbers \
.withRowNumberColumnName(row_number_column_name) \
.withStorageLevel(jsl) \
.withUnpersistHandle(juh) \
Expand Down
7 changes: 3 additions & 4 deletions python/gresearch/spark/diff/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,8 @@
from pyspark.sql import DataFrame
from pyspark.sql.types import DataType

from gresearch.spark import _to_seq, _to_map, _get_scala_object
from gresearch.spark.diff.comparator import DiffComparator, DefaultDiffComparator
from gresearch.spark import _to_seq, _to_map
from gresearch.spark.diff.comparator import DiffComparator, DiffComparators, DefaultDiffComparator


class DiffMode(Enum):
Expand Down Expand Up @@ -253,8 +253,7 @@ def _to_java_map(self, jvm: JVMView, map: Mapping[Any, DiffComparator], key_to_j
return _to_map(jvm, {key_to_java(jvm, key): cmp._to_java(jvm) for key, cmp in map.items()})

def _to_java_data_type(self, jvm: JVMView, dt: DataType) -> JavaObject:
jdt = _get_scala_object(jvm, "org.apache.spark.sql.types.DataType").fromJson(dt.json())
return jdt
return jvm.org.apache.spark.sql.types.DataType.fromJson(dt.json())


class Differ:
Expand Down
8 changes: 4 additions & 4 deletions python/gresearch/spark/diff/comparator/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,14 @@

from py4j.java_gateway import JVMView, JavaObject

from gresearch.spark import _get_scala_object


class DiffComparator(abc.ABC):
@abc.abstractmethod
def _to_java(self, jvm: JVMView) -> JavaObject:
pass


class DiffComparators:
@staticmethod
def default() -> 'DefaultDiffComparator':
return DefaultDiffComparator()
Expand All @@ -31,12 +31,12 @@ def duration(duration: str) -> 'DurationDiffComparator':

class DefaultDiffComparator(DiffComparator):
def _to_java(self, jvm: JVMView) -> JavaObject:
return _get_scala_object(jvm, "uk.co.gresearch.spark.diff.comparator.DefaultDiffComparator")
return jvm.uk.co.gresearch.spark.diff.DiffComparators.default()


class NullSafeEqualDiffComparator(DiffComparator):
def _to_java(self, jvm: JVMView) -> JavaObject:
return _get_scala_object(jvm, "uk.co.gresearch.spark.diff.comparator.NullSafeEqualDiffComparator")
return jvm.uk.co.gresearch.spark.diff.DiffComparators.nullSafeEqual()


@dataclass(frozen=True)
Expand Down
58 changes: 40 additions & 18 deletions python/test/spark_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,21 +16,19 @@
import logging
import os
import subprocess
import sys
import unittest

from pyspark import SparkConf
from pyspark.sql import SparkSession

logger = logging.getLogger()
logger.level = logging.INFO


@contextlib.contextmanager
def spark_session():
from pyspark.sql import SparkSession

session = SparkSession \
.builder \
.config(conf=SparkTest.conf) \
.getOrCreate()

session = SparkTest.get_spark_session()
try:
yield session
finally:
Expand All @@ -39,6 +37,22 @@ def spark_session():

class SparkTest(unittest.TestCase):

@staticmethod
def main():
if len(sys.argv) == 2:
# location to store test results provided, this requires package unittest-xml-reporting
import xmlrunner

unittest.main(
testRunner=xmlrunner.XMLTestRunner(output=sys.argv[1]),
argv=sys.argv[:1],
# these make sure that some options that are not applicable
# remain hidden from the help menu.
failfast=False, buffer=False, catchbreak=False
)
else:
unittest.main()

@staticmethod
def get_pom_path() -> str:
paths = ['.', '..', os.path.join('..', '..')]
Expand Down Expand Up @@ -73,24 +87,32 @@ def get_spark_config(path, dependencies) -> SparkConf:
]))),
])

path = get_pom_path.__func__()
dependencies = get_dependencies_from_mvn.__func__(path)
logging.info('found {} JVM dependencies'.format(len(dependencies.split(':'))))
conf = get_spark_config.__func__(path, dependencies)
@classmethod
def get_spark_session(cls) -> SparkSession:
builder = SparkSession.builder

if 'PYSPARK_GATEWAY_PORT' in os.environ:
logging.info('Running inside existing Spark environment')
else:
logging.info('Setting up Spark environment')
path = cls.get_pom_path()
dependencies = cls.get_dependencies_from_mvn(path)
logging.info('found {} JVM dependencies'.format(len(dependencies.split(':'))))
conf = cls.get_spark_config(path, dependencies)
builder.config(conf=conf)

return builder.getOrCreate()

spark: SparkSession = None

@classmethod
def setUpClass(cls):
super(SparkTest, cls).setUpClass()
logging.info('launching Spark')

cls.spark = SparkSession \
.builder \
.config(conf=cls.conf) \
.getOrCreate()
logging.info('launching Spark session')
cls.spark = cls.get_spark_session()

@classmethod
def tearDownClass(cls):
logging.info('stopping Spark')
logging.info('stopping Spark session')
cls.spark.stop()
super(SparkTest, cls).tearDownClass()
46 changes: 25 additions & 21 deletions python/test/test_diff.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,25 +12,21 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import logging

logger = logging.getLogger()
logger.level = logging.INFO

import unittest
import re

from py4j.java_gateway import JavaObject
from pyspark.sql import Row
from pyspark.sql.functions import col, when
from pyspark.sql.types import IntegerType, DateType
from py4j.java_gateway import JavaObject

from gresearch.spark.diff import Differ, DiffOptions, DiffMode, DiffComparators
from spark_common import SparkTest
from gresearch.spark.diff import Differ, DiffOptions, DiffMode, DiffComparator


class DiffTest(SparkTest):

expected_diff = None

@classmethod
def setUpClass(cls):
super(DiffTest, cls).setUpClass()
Expand Down Expand Up @@ -278,9 +274,9 @@ def test_diff_mode_consts(self):
self.assertIsNotNone(DiffMode.Default.name, jmodes.Default().toString())

def test_diff_fluent_setters(self):
cmp1 = DiffComparator.default()
cmp2 = DiffComparator.epsilon(0.01)
cmp3 = DiffComparator.duration('PT24H')
cmp1 = DiffComparators.default()
cmp2 = DiffComparators.epsilon(0.01)
cmp3 = DiffComparators.duration('PT24H')

default = DiffOptions()
options = default \
Expand Down Expand Up @@ -337,7 +333,7 @@ def test_diff_fluent_setters(self):

def test_diff_with_comparators(self):
options = DiffOptions() \
.with_column_name_comparator(DiffComparator.epsilon(0.1).as_relative(), 'val')
.with_column_name_comparator(DiffComparators.epsilon(0.1).as_relative(), 'val')

diff = self.left_df.diff_with_options(self.right_df, options, 'id').orderBy('id').collect()
expected = self.spark.createDataFrame(self.expected_diff) \
Expand All @@ -348,27 +344,35 @@ def test_diff_with_comparators(self):

def test_diff_options_with_duplicate_comparators(self):
options = DiffOptions() \
.with_data_type_comparator(DiffComparator.default(), DateType(), IntegerType()) \
.with_column_name_comparator(DiffComparator.default(), 'col1', 'col2')
.with_data_type_comparator(DiffComparators.default(), DateType(), IntegerType()) \
.with_column_name_comparator(DiffComparators.default(), 'col1', 'col2')

with self.assertRaisesRegex(ValueError, "A comparator for data type date exists already."):
options.with_data_type_comparator(DiffComparator.default(), DateType())
options.with_data_type_comparator(DiffComparators.default(), DateType())

with self.assertRaisesRegex(ValueError, "A comparator for data type int exists already."):
options.with_data_type_comparator(DiffComparator.default(), IntegerType())
options.with_data_type_comparator(DiffComparators.default(), IntegerType())

with self.assertRaisesRegex(ValueError, "A comparator for data types date, int exists already."):
options.with_data_type_comparator(DiffComparator.default(), DateType(), IntegerType())
options.with_data_type_comparator(DiffComparators.default(), DateType(), IntegerType())

with self.assertRaisesRegex(ValueError, "A comparator for column name col1 exists already."):
options.with_column_name_comparator(DiffComparator.default(), 'col1')
options.with_column_name_comparator(DiffComparators.default(), 'col1')

with self.assertRaisesRegex(ValueError, "A comparator for column name col2 exists already."):
options.with_column_name_comparator(DiffComparator.default(), 'col2')
options.with_column_name_comparator(DiffComparators.default(), 'col2')

with self.assertRaisesRegex(ValueError, "A comparator for column names col1, col2 exists already."):
options.with_column_name_comparator(DiffComparator.default(), 'col1', 'col2')
options.with_column_name_comparator(DiffComparators.default(), 'col1', 'col2')

def test_diff_comparators(self):
jvm = self.spark.sparkContext._jvm
self.assertIsNotNone(DiffComparators.default()._to_java(jvm))
self.assertIsNotNone(DiffComparators.nullSafeEqual()._to_java(jvm))
self.assertIsNotNone(DiffComparators.epsilon(0.01)._to_java(jvm))
if jvm.uk.co.gresearch.spark.diff.comparator.DurationDiffComparator.isSupportedBySpark():
self.assertIsNotNone(DiffComparators.duration('PT24H')._to_java(jvm))


if __name__ == '__main__':
unittest.main()
SparkTest.main()
Loading