Skip to content

Commit 1a8065e

Browse files
committed
Fixes aggregations of attribute references to values of union types
1 parent 5121093 commit 1a8065e

File tree

9 files changed

+242
-37
lines changed

9 files changed

+242
-37
lines changed

CHANGELOG.md

+3-1
Original file line numberDiff line numberDiff line change
@@ -32,18 +32,20 @@ Thank you to all who have contributed!
3232

3333
### Changed
3434
- Function resolution logic: Now the function resolver would match all possible candidate(based on if the argument can be coerced to the Signature parameter type). If there are multiple match it will first attempt to pick the one requires the least cast, then pick the function with the highest precedence.
35+
- **Behavioral change**: The COUNT aggregate function now returns INT64.
3536

3637
### Deprecated
3738

3839
### Fixed
40+
- Fixes aggregations of attribute references to values of union types. This fix also allows for proper error handling by passing the UnknownAggregateFunction problem to the ProblemCallback. Please note that, with this change, the planner will no longer immediately throw an IllegalStateException for this exact scenario.
3941

4042
### Removed
4143

4244
### Security
4345

4446
### Contributors
4547
Thank you to all who have contributed!
46-
- @<your-username>
48+
- @johnedquinn
4749

4850
## [0.14.3] - 2024-02-14
4951

partiql-planner/src/main/kotlin/org/partiql/planner/Errors.kt

+8
Original file line numberDiff line numberDiff line change
@@ -94,6 +94,14 @@ public sealed class PlanningProblemDetails(
9494
"Unknown function `$identifier($types)"
9595
})
9696

97+
public data class UnknownAggregateFunction(
98+
val identifier: String,
99+
val args: List<StaticType>,
100+
) : PlanningProblemDetails(ProblemSeverity.ERROR, {
101+
val types = args.joinToString { "<${it.toString().lowercase()}>" }
102+
"Unknown aggregate function `$identifier($types)"
103+
})
104+
97105
public object ExpressionAlwaysReturnsNullOrMissing : PlanningProblemDetails(
98106
severity = ProblemSeverity.ERROR,
99107
messageFormatter = { "Expression always returns null or missing." }

partiql-planner/src/main/kotlin/org/partiql/planner/internal/PartiQLHeader.kt

+11-2
Original file line numberDiff line numberDiff line change
@@ -702,13 +702,13 @@ internal object PartiQLHeader : Header() {
702702
private fun count() = listOf(
703703
FunctionSignature.Aggregation(
704704
name = "count",
705-
returns = INT32,
705+
returns = INT64,
706706
parameters = listOf(FunctionParameter("value", ANY)),
707707
isNullable = false,
708708
),
709709
FunctionSignature.Aggregation(
710710
name = "count_star",
711-
returns = INT32,
711+
returns = INT64,
712712
parameters = listOf(),
713713
isNullable = false,
714714
),
@@ -741,6 +741,15 @@ internal object PartiQLHeader : Header() {
741741
)
742742
}
743743

744+
/**
745+
* According to SQL:1999 Section 6.16 Syntax Rule 14.c and Rule 14.d:
746+
* > If AVG is specified and DT is exact numeric, then the declared type of the result is exact
747+
* numeric with implementation-defined precision not less than the precision of DT and
748+
* implementation-defined scale not less than the scale of DT.
749+
*
750+
* > If DT is approximate numeric, then the declared type of the result is approximate numeric
751+
* with implementation-defined precision not less than the precision of DT.
752+
*/
744753
private fun avg() = types.numeric.map {
745754
FunctionSignature.Aggregation(
746755
name = "avg",

partiql-planner/src/main/kotlin/org/partiql/planner/internal/transforms/PlanTransform.kt

+54-4
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,11 @@
11
package org.partiql.planner.internal.transforms
22

3+
import org.partiql.errors.Problem
34
import org.partiql.errors.ProblemCallback
5+
import org.partiql.errors.UNKNOWN_PROBLEM_LOCATION
46
import org.partiql.plan.PlanNode
57
import org.partiql.plan.partiQLPlan
8+
import org.partiql.planner.PlanningProblemDetails
69
import org.partiql.planner.internal.ir.Agg
710
import org.partiql.planner.internal.ir.Catalog
811
import org.partiql.planner.internal.ir.Fn
@@ -12,7 +15,9 @@ import org.partiql.planner.internal.ir.Rel
1215
import org.partiql.planner.internal.ir.Rex
1316
import org.partiql.planner.internal.ir.Statement
1417
import org.partiql.planner.internal.ir.visitor.PlanBaseVisitor
18+
import org.partiql.types.function.FunctionSignature
1519
import org.partiql.value.PartiQLValueExperimental
20+
import org.partiql.value.PartiQLValueType
1621

1722
/**
1823
* This is an internal utility to translate from the internal unresolved plan used for typing to the public plan IR.
@@ -58,7 +63,7 @@ internal object PlanTransform : PlanBaseVisitor<PlanNode, ProblemCallback>() {
5863
override fun visitAggResolved(node: Agg.Resolved, ctx: ProblemCallback) = org.partiql.plan.Agg(node.signature)
5964

6065
override fun visitAggUnresolved(node: Agg.Unresolved, ctx: ProblemCallback): org.partiql.plan.Rex.Op {
61-
error("Unresolved aggregation ${node.identifier}")
66+
error("Internal error: This should have been handled somewhere else. Cause: Unresolved aggregation ${node.identifier}.")
6267
}
6368

6469
override fun visitStatement(node: Statement, ctx: ProblemCallback) =
@@ -331,11 +336,56 @@ internal object PlanTransform : PlanBaseVisitor<PlanNode, ProblemCallback>() {
331336
groups = node.groups.map { visitRex(it, ctx) },
332337
)
333338

334-
override fun visitRelOpAggregateCall(node: Rel.Op.Aggregate.Call, ctx: ProblemCallback) =
335-
org.partiql.plan.Rel.Op.Aggregate.Call(
336-
agg = visitAgg(node.agg, ctx),
339+
@OptIn(PartiQLValueExperimental::class)
340+
override fun visitRelOpAggregateCall(node: Rel.Op.Aggregate.Call, ctx: ProblemCallback): org.partiql.plan.Rel.Op.Aggregate.Call {
341+
val agg = when (val agg = node.agg) {
342+
is Agg.Unresolved -> {
343+
val name = agg.identifier.toNormalizedString()
344+
ctx.invoke(
345+
Problem(
346+
UNKNOWN_PROBLEM_LOCATION,
347+
PlanningProblemDetails.UnknownAggregateFunction(
348+
agg.identifier.toString(),
349+
node.args.map { it.type }
350+
)
351+
)
352+
)
353+
org.partiql.plan.Agg(
354+
FunctionSignature.Aggregation(
355+
"UNKNOWN_AGG::$name",
356+
returns = PartiQLValueType.MISSING,
357+
parameters = emptyList()
358+
)
359+
)
360+
}
361+
is Agg.Resolved -> {
362+
visitAggResolved(agg, ctx)
363+
}
364+
}
365+
return org.partiql.plan.Rel.Op.Aggregate.Call(
366+
agg = agg,
337367
args = node.args.map { visitRex(it, ctx) },
338368
)
369+
}
370+
371+
private fun Identifier.toNormalizedString(): String {
372+
return when (this) {
373+
is Identifier.Symbol -> this.toNormalizedString()
374+
is Identifier.Qualified -> {
375+
val toJoin = listOf(this.root) + this.steps
376+
toJoin.joinToString(separator = ".") { ident ->
377+
ident.toNormalizedString()
378+
}
379+
}
380+
}
381+
}
382+
383+
private fun Identifier.Symbol.toNormalizedString(): String {
384+
return when (this.caseSensitivity) {
385+
Identifier.CaseSensitivity.SENSITIVE -> "\"${this.symbol}\""
386+
Identifier.CaseSensitivity.INSENSITIVE -> this.symbol
387+
}
388+
}
339389

340390
override fun visitRelOpExclude(node: Rel.Op.Exclude, ctx: ProblemCallback) = org.partiql.plan.Rel.Op.Exclude(
341391
input = visitRel(node.input, ctx),

partiql-planner/src/main/kotlin/org/partiql/planner/internal/typer/PlanTyper.kt

+1-1
Original file line numberDiff line numberDiff line change
@@ -1168,7 +1168,7 @@ internal class PlanTyper(
11681168
fun resolveAgg(agg: Agg.Unresolved, arguments: List<Rex>): Pair<Rel.Op.Aggregate.Call, StaticType> {
11691169
var missingArg = false
11701170
val args = arguments.map {
1171-
val arg = visitRex(it, null)
1171+
val arg = visitRex(it, it.type)
11721172
if (arg.type.isMissable()) missingArg = true
11731173
arg
11741174
}

partiql-planner/src/main/kotlin/org/partiql/planner/internal/typer/TypeEnv.kt

+49-20
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,8 @@ import org.partiql.planner.internal.ir.rexOpVarResolved
1010
import org.partiql.spi.BindingCase
1111
import org.partiql.spi.BindingName
1212
import org.partiql.spi.BindingPath
13+
import org.partiql.types.AnyOfType
14+
import org.partiql.types.AnyType
1315
import org.partiql.types.StaticType
1416
import org.partiql.types.StructType
1517
import org.partiql.types.TupleConstraint
@@ -85,30 +87,28 @@ internal class TypeEnv(public val schema: List<Rel.Binding>) {
8587
for (i in schema.indices) {
8688
val local = schema[i]
8789
val type = local.type
88-
if (type is StructType) {
89-
when (type.containsKey(name)) {
90-
true -> {
91-
if (c != null && known) {
92-
// TODO root was already definitively matched, emit ambiguous error.
93-
return null
94-
}
95-
c = rex(type, rexOpVarResolved(i))
96-
known = true
90+
when (type.containsKey(name)) {
91+
true -> {
92+
if (c != null && known) {
93+
// TODO root was already definitively matched, emit ambiguous error.
94+
return null
9795
}
98-
null -> {
99-
if (c != null) {
100-
if (known) {
101-
continue
102-
} else {
103-
// TODO we have more than one possible match, emit ambiguous error.
104-
return null
105-
}
96+
c = rex(type, rexOpVarResolved(i))
97+
known = true
98+
}
99+
null -> {
100+
if (c != null) {
101+
if (known) {
102+
continue
103+
} else {
104+
// TODO we have more than one possible match, emit ambiguous error.
105+
return null
106106
}
107-
c = rex(type, rexOpVarResolved(i))
108-
known = false
109107
}
110-
false -> continue
108+
c = rex(type, rexOpVarResolved(i))
109+
known = false
111110
}
111+
false -> continue
112112
}
113113
}
114114
return c
@@ -152,4 +152,33 @@ internal class TypeEnv(public val schema: List<Rel.Binding>) {
152152
val closed = constraints.contains(TupleConstraint.Open(false))
153153
return if (closed) false else null
154154
}
155+
156+
/**
157+
* Searches for the [BindingName] within the given [StaticType].
158+
*
159+
* Returns
160+
* - true iff known to contain key
161+
* - false iff known to NOT contain key
162+
* - null iff NOT known to contain key
163+
*
164+
* @param name
165+
* @return
166+
*/
167+
private fun StaticType.containsKey(name: BindingName): Boolean? {
168+
return when (val type = this.flatten()) {
169+
is StructType -> type.containsKey(name)
170+
is AnyOfType -> {
171+
val anyKnownToContainKey = type.allTypes.any { it.containsKey(name) == true }
172+
val anyKnownToNotContainKey = type.allTypes.any { it.containsKey(name) == false }
173+
val anyNotKnownToContainKey = type.allTypes.any { it.containsKey(name) == null }
174+
when {
175+
anyKnownToNotContainKey.not() && anyNotKnownToContainKey.not() -> true
176+
anyKnownToContainKey.not() && anyNotKnownToContainKey -> false
177+
else -> null
178+
}
179+
}
180+
is AnyType -> null
181+
else -> false
182+
}
183+
}
155184
}

partiql-planner/src/main/kotlin/org/partiql/planner/internal/typer/TypeUtils.kt

+1-1
Original file line numberDiff line numberDiff line change
@@ -93,7 +93,7 @@ internal fun StaticType.toRuntimeType(): PartiQLValueType {
9393
// handle anyOf(null, T) cases
9494
val t = types.filter { it !is NullType && it !is MissingType }
9595
return if (t.size != 1) {
96-
error("Cannot have a UNION runtime type: $this")
96+
PartiQLValueType.ANY
9797
} else {
9898
t.first().asRuntimeType()
9999
}

0 commit comments

Comments
 (0)