Skip to content

Commit

Permalink
Provide groupByKey shortcuts for groupBy.as
Browse files Browse the repository at this point in the history
  • Loading branch information
EnricoMi committed Dec 8, 2023
1 parent 6290b1a commit 41c5c69
Show file tree
Hide file tree
Showing 2 changed files with 68 additions and 4 deletions.
38 changes: 38 additions & 0 deletions src/main/scala/uk/co/gresearch/spark/package.scala
Original file line number Diff line number Diff line change
Expand Up @@ -570,6 +570,8 @@ package object spark extends Logging with SparkVersion with BuildVersion {
* @tparam V inner type of dataset
*/
implicit class ExtendedDatasetV2[V](ds: Dataset[V]) {
private implicit val encoder: Encoder[V] = ds.encoder

/**
* Compute the histogram of a column when aggregated by aggregate columns.
* Thresholds are expected to be provided in ascending order.
Expand Down Expand Up @@ -675,6 +677,42 @@ package object spark extends Logging with SparkVersion with BuildVersion {
.partitionBy(partitionColumnsMap.keys.toSeq: _*)
}

/**
* (Scala-specific)
* Returns a [[KeyValueGroupedDataset]] where the data is grouped by the given key columns.
*
* @see `org.apache.spark.sql.Dataset.groupByKey(T => K)`
*
* @note Calling this method should be preferred to `groupByKey(T => K)` because the
* Catalyst query planner cannot exploit existing partitioning and ordering of
* this Dataset with that function.
*
* {{{
* ds.groupByKey[Int]($"age").flatMapGroups(...)
* ds.groupByKey[(String, String)]($"department", $"gender").flatMapGroups(...)
* }}}
*/
def groupByKey[K: Encoder](column: Column, columns: Column*): KeyValueGroupedDataset[K, V] =
ds.groupBy(column +: columns: _*).as[K, V]

/**
* (Scala-specific)
* Returns a [[KeyValueGroupedDataset]] where the data is grouped by the given key columns.
*
* @see `org.apache.spark.sql.Dataset.groupByKey(T => K)`
*
* @note Calling this method should be preferred to `groupByKey(T => K)` because the
* Catalyst query planner cannot exploit existing partitioning and ordering of
* this Dataset with that function.
*
* {{{
* ds.groupByKey[Int]($"age").flatMapGroups(...)
* ds.groupByKey[(String, String)]($"department", $"gender").flatMapGroups(...)
* }}}
*/
def groupByKey[K: Encoder](column: String, columns: String*): KeyValueGroupedDataset[K, V] =
ds.groupBy(column, columns: _*).as[K, V]

/**
* Groups the Dataset and sorts the groups using the specified columns, so we can run
* further process the sorted groups. See [[uk.co.gresearch.spark.group.SortedGroupByDataset]] for all the available
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

package uk.co.gresearch.spark

import org.apache.spark.sql.{DataFrame, Dataset, Row}
import org.apache.spark.sql.{DataFrame, Dataset, KeyValueGroupedDataset, Row}
import org.scalatest.funspec.AnyFunSpec
import uk.co.gresearch.spark.GroupBySortedSuite.{valueRowToTuple, valueToTuple}
import uk.co.gresearch.spark.group.SortedGroupByDataset
Expand All @@ -33,7 +33,7 @@ case class State(init: Int) {
}
}

class GroupBySortedSuite extends AnyFunSpec with SparkTestSession {
class GroupBySuite extends AnyFunSpec with SparkTestSession {

import spark.implicits._

Expand All @@ -50,6 +50,34 @@ class GroupBySortedSuite extends AnyFunSpec with SparkTestSession {
Val(3, 1, 3.1),
).reverse.toDS().repartition(3).cache()

val df: DataFrame = ds.toDF()

it("should ds.groupByKey") {
testGroupBy(ds.groupByKey($"id"))
testGroupBy(ds.groupByKey("id"))
}

it("should df.groupByKey") {
testGroupBy(df.groupByKey($"id"))
testGroupBy(df.groupByKey("id"))
}

def testGroupBy[T](ds: KeyValueGroupedDataset[Int, T]): Unit = {
val actual = ds
.mapGroups { (key, it) => (key, it.length) }
.collect()
.sortBy(v => v._1)

val expected = Seq(
// (key, group length)
(1, 4),
(2, 3),
(3, 1),
)

assert(actual === expected)
}

describe("ds.groupBySorted") {
testGroupByIdSortBySeq(ds.groupBySorted($"id")($"seq", $"value"))
testGroupByIdSortBySeqDesc(ds.groupBySorted($"id")($"seq".desc, $"value".desc))
Expand All @@ -66,8 +94,6 @@ class GroupBySortedSuite extends AnyFunSpec with SparkTestSession {
testGroupByIdSeqSortByValue(ds.groupByKeySorted(v => (v.id, v.seq))(v => v.value))
}

val df: DataFrame = ds.toDF()

describe("df.groupBySorted") {
testGroupByIdSortBySeq(df.groupBySorted($"id")($"seq", $"value"))
testGroupByIdSortBySeqDesc(df.groupBySorted($"id")($"seq".desc, $"value".desc))
Expand Down

0 comments on commit 41c5c69

Please sign in to comment.