Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fixes aggregations of attribute references to values of union types #1383

Merged
merged 3 commits into from
Apr 2, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -33,14 +33,18 @@ Thank you to all who have contributed!
### Changed
- Change `StaticType.AnyOfType`'s `.toString` to not perform `.flatten()`
- Change modeling of `COALESCE` and `NULLIF` to dedicated nodes in logical plan
- Function resolution logic: Now the function resolver would match all possible candidate (based on if the argument can be coerced to the Signature parameter type). If there are multiple match it will first attempt to pick the one requires the least cast, then pick the function with the highest precedence.
- **Behavioral change**: The COUNT aggregate function now returns INT64.

### Deprecated
- The current SqlBlock, SqlDialect, and SqlLayout are marked as deprecated and will be slightly changed in the next release.
- Deprecates constructor and properties `variableName` and `caseSensitive` of `org.partiql.planner.PlanningProblemDetails.UndefinedVariable`
in favor of newly added constructor and properties `name` and `inScopeVariables`.

### Fixed
- `StaticType.flatten()` on an `AnyOfType` with `AnyType` will return `AnyType`
- Updates the default `.sql()` method to use a more efficient (internal) printer implementation.

- Fixes aggregations of attribute references to values of union types. This fix also allows for proper error handling by passing the UnknownAggregateFunction problem to the ProblemCallback. Please note that, with this change, the planner will no longer immediately throw an IllegalStateException for this exact scenario.

### Removed

Expand All @@ -51,6 +55,7 @@ Thank you to all who have contributed!
- @<your-username>
- @rchowell
- @alancai98
- @johnedquinn

## [0.14.4]

Expand Down
67 changes: 61 additions & 6 deletions partiql-planner/src/main/kotlin/org/partiql/planner/Errors.kt
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@ package org.partiql.planner

import org.partiql.errors.ProblemDetails
import org.partiql.errors.ProblemSeverity
import org.partiql.plan.Identifier
import org.partiql.planner.internal.utils.PlanUtils
import org.partiql.types.StaticType

/**
Expand All @@ -24,15 +26,60 @@ public sealed class PlanningProblemDetails(
public data class CompileError(val errorMessage: String) :
PlanningProblemDetails(ProblemSeverity.ERROR, { errorMessage })

public data class UndefinedVariable(val variableName: String, val caseSensitive: Boolean) :
PlanningProblemDetails(
ProblemSeverity.ERROR,
{
"Undefined variable '$variableName'." +
quotationHint(caseSensitive)
public data class UndefinedVariable(
val name: Identifier,
val inScopeVariables: Set<String>
) : PlanningProblemDetails(
ProblemSeverity.ERROR,
{
val humanReadableName = PlanUtils.identifierToString(name)
"Variable $humanReadableName does not exist in the database environment and is not an attribute of the following in-scope variables $inScopeVariables." +
quotationHint(isSymbolAndCaseSensitive(name))
}
) {

@Deprecated("This will be removed in a future major version release.", replaceWith = ReplaceWith("name"))
val variableName: String = when (name) {
is Identifier.Symbol -> name.symbol
is Identifier.Qualified -> when (name.steps.size) {
0 -> name.root.symbol
else -> name.steps.last().symbol
}
}

@Deprecated("This will be removed in a future major version release.", replaceWith = ReplaceWith("name"))
val caseSensitive: Boolean = when (name) {
is Identifier.Symbol -> name.caseSensitivity == Identifier.CaseSensitivity.SENSITIVE
is Identifier.Qualified -> when (name.steps.size) {
0 -> name.root.caseSensitivity == Identifier.CaseSensitivity.SENSITIVE
else -> name.steps.last().caseSensitivity == Identifier.CaseSensitivity.SENSITIVE
}
}

@Deprecated("This will be removed in a future major version release.", replaceWith = ReplaceWith("UndefinedVariable(Identifier, Set<String>)"))
public constructor(variableName: String, caseSensitive: Boolean) : this(
Identifier.Symbol(
variableName,
when (caseSensitive) {
true -> Identifier.CaseSensitivity.SENSITIVE
false -> Identifier.CaseSensitivity.INSENSITIVE
}
),
emptySet()
)

private companion object {
/**
* Used to check whether the [id] is an [Identifier.Symbol] and whether it is case-sensitive. This is helpful
* for giving the [quotationHint] to the user.
*/
private fun isSymbolAndCaseSensitive(id: Identifier): Boolean = when (id) {
is Identifier.Symbol -> id.caseSensitivity == Identifier.CaseSensitivity.SENSITIVE
is Identifier.Qualified -> false
}
}
}

public data class UndefinedDmlTarget(val variableName: String, val caseSensitive: Boolean) :
PlanningProblemDetails(
ProblemSeverity.ERROR,
Expand Down Expand Up @@ -94,6 +141,14 @@ public sealed class PlanningProblemDetails(
"Unknown function `$identifier($types)"
})

public data class UnknownAggregateFunction(
val identifier: Identifier,
val args: List<StaticType>,
) : PlanningProblemDetails(ProblemSeverity.ERROR, {
val types = args.joinToString { "<${it.toString().lowercase()}>" }
"Unknown aggregate function `$identifier($types)"
})

public object ExpressionAlwaysReturnsNullOrMissing : PlanningProblemDetails(
severity = ProblemSeverity.ERROR,
messageFormatter = { "Expression always returns null or missing." }
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -702,13 +702,13 @@ internal object PartiQLHeader : Header() {
private fun count() = listOf(
FunctionSignature.Aggregation(
name = "count",
returns = INT32,
returns = INT64,
parameters = listOf(FunctionParameter("value", ANY)),
isNullable = false,
),
FunctionSignature.Aggregation(
name = "count_star",
returns = INT32,
returns = INT64,
parameters = listOf(),
isNullable = false,
),
Expand Down Expand Up @@ -741,6 +741,15 @@ internal object PartiQLHeader : Header() {
)
}

/**
* According to SQL:1999 Section 6.16 Syntax Rule 14.c and Rule 14.d:
* > If AVG is specified and DT is exact numeric, then the declared type of the result is exact
* numeric with implementation-defined precision not less than the precision of DT and
* implementation-defined scale not less than the scale of DT.
*
* > If DT is approximate numeric, then the declared type of the result is approximate numeric
* with implementation-defined precision not less than the precision of DT.
*/
private fun avg() = types.numeric.map {
FunctionSignature.Aggregation(
name = "avg",
Expand Down
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
package org.partiql.planner.internal.transforms

import org.partiql.errors.Problem
import org.partiql.errors.ProblemCallback
import org.partiql.errors.UNKNOWN_PROBLEM_LOCATION
import org.partiql.plan.PlanNode
import org.partiql.plan.partiQLPlan
import org.partiql.planner.PlanningProblemDetails
import org.partiql.planner.internal.ir.Agg
import org.partiql.planner.internal.ir.Catalog
import org.partiql.planner.internal.ir.Fn
Expand All @@ -12,7 +15,10 @@ import org.partiql.planner.internal.ir.Rel
import org.partiql.planner.internal.ir.Rex
import org.partiql.planner.internal.ir.Statement
import org.partiql.planner.internal.ir.visitor.PlanBaseVisitor
import org.partiql.planner.internal.utils.PlanUtils
import org.partiql.types.function.FunctionSignature
import org.partiql.value.PartiQLValueExperimental
import org.partiql.value.PartiQLValueType

/**
* This is an internal utility to translate from the internal unresolved plan used for typing to the public plan IR.
Expand Down Expand Up @@ -58,7 +64,7 @@ internal object PlanTransform : PlanBaseVisitor<PlanNode, ProblemCallback>() {
override fun visitAggResolved(node: Agg.Resolved, ctx: ProblemCallback) = org.partiql.plan.Agg(node.signature)

override fun visitAggUnresolved(node: Agg.Unresolved, ctx: ProblemCallback): org.partiql.plan.Rex.Op {
error("Unresolved aggregation ${node.identifier}")
error("Internal error: This should have been handled somewhere else. Cause: Unresolved aggregation ${node.identifier}.")
}

override fun visitStatement(node: Statement, ctx: ProblemCallback) =
Expand Down Expand Up @@ -342,11 +348,37 @@ internal object PlanTransform : PlanBaseVisitor<PlanNode, ProblemCallback>() {
groups = node.groups.map { visitRex(it, ctx) },
)

override fun visitRelOpAggregateCall(node: Rel.Op.Aggregate.Call, ctx: ProblemCallback) =
org.partiql.plan.Rel.Op.Aggregate.Call(
agg = visitAgg(node.agg, ctx),
@OptIn(PartiQLValueExperimental::class)
override fun visitRelOpAggregateCall(node: Rel.Op.Aggregate.Call, ctx: ProblemCallback): org.partiql.plan.Rel.Op.Aggregate.Call {
val agg = when (val agg = node.agg) {
is Agg.Unresolved -> {
val name = PlanUtils.identifierToString(visitIdentifier(agg.identifier, ctx))
ctx.invoke(
Problem(
UNKNOWN_PROBLEM_LOCATION,
PlanningProblemDetails.UnknownAggregateFunction(
visitIdentifier(agg.identifier, ctx),
node.args.map { it.type }
)
)
)
org.partiql.plan.Agg(
FunctionSignature.Aggregation(
"UNKNOWN_AGG::$name",
returns = PartiQLValueType.MISSING,
parameters = emptyList()
)
)
}
is Agg.Resolved -> {
visitAggResolved(agg, ctx)
}
}
return org.partiql.plan.Rel.Op.Aggregate.Call(
agg = agg,
args = node.args.map { visitRex(it, ctx) },
)
}

override fun visitRelOpExclude(node: Rel.Op.Exclude, ctx: ProblemCallback) = org.partiql.plan.Rel.Op.Exclude(
input = visitRel(node.input, ctx),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ import org.partiql.planner.internal.ir.rexOpStructField
import org.partiql.planner.internal.ir.rexOpTupleUnion
import org.partiql.planner.internal.ir.statementQuery
import org.partiql.planner.internal.ir.util.PlanRewriter
import org.partiql.planner.internal.transforms.PlanTransform
import org.partiql.spi.BindingCase
import org.partiql.spi.BindingName
import org.partiql.spi.BindingPath
Expand Down Expand Up @@ -451,8 +452,8 @@ internal class PlanTyper(
}
val resolvedVar = env.resolve(path, locals, strategy)
if (resolvedVar == null) {
handleUndefinedVariable(path.steps.last())
return rex(ANY, rexOpErr("Undefined variable ${node.identifier}"))
val details = handleUndefinedVariable(node.identifier, locals.schema.map { it.name }.toSet())
return rex(ANY, rexOpErr(details.message))
}
return visitRex(resolvedVar, null)
}
Expand Down Expand Up @@ -1266,19 +1267,11 @@ internal class PlanTyper(
* to each row of T and eliminating null values <--- all NULL values are eliminated as inputs
*/
fun resolveAgg(agg: Agg.Unresolved, arguments: List<Rex>): Pair<Rel.Op.Aggregate.Call, StaticType> {
var missingArg = false
val args = arguments.map {
val arg = visitRex(it, null)
if (arg.type.isMissable()) missingArg = true
val arg = visitRex(it, it.type)
arg
}

//
if (missingArg) {
handleAlwaysMissing()
return relOpAggregateCall(agg, listOf(rexErr("MISSING"))) to MissingType
}

// Try to match the arguments to functions defined in the catalog
return when (val match = env.resolveAgg(agg, args)) {
is FnMatch.Ok -> {
Expand Down Expand Up @@ -1399,13 +1392,20 @@ internal class PlanTyper(

// ERRORS

private fun handleUndefinedVariable(name: BindingName) {
/**
* Invokes [onProblem] with a newly created [PlanningProblemDetails.UndefinedVariable] and returns the
* [PlanningProblemDetails.UndefinedVariable].
*/
private fun handleUndefinedVariable(name: Identifier, locals: Set<String>): PlanningProblemDetails.UndefinedVariable {
val planName = PlanTransform.visitIdentifier(name, onProblem)
val details = PlanningProblemDetails.UndefinedVariable(planName, locals)
onProblem(
Problem(
sourceLocation = UNKNOWN_PROBLEM_LOCATION,
details = PlanningProblemDetails.UndefinedVariable(name.name, name.bindingCase == BindingCase.SENSITIVE)
details = details
)
)
return details
}

private fun handleUnexpectedType(actual: StaticType, expected: Set<StaticType>) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@ import org.partiql.planner.internal.ir.rexOpVarResolved
import org.partiql.spi.BindingCase
import org.partiql.spi.BindingName
import org.partiql.spi.BindingPath
import org.partiql.types.AnyOfType
import org.partiql.types.AnyType
import org.partiql.types.StaticType
import org.partiql.types.StructType
import org.partiql.types.TupleConstraint
Expand Down Expand Up @@ -85,30 +87,28 @@ internal class TypeEnv(public val schema: List<Rel.Binding>) {
for (i in schema.indices) {
val local = schema[i]
val type = local.type
if (type is StructType) {
when (type.containsKey(name)) {
true -> {
if (c != null && known) {
// TODO root was already definitively matched, emit ambiguous error.
return null
}
c = rex(type, rexOpVarResolved(i))
known = true
when (type.containsKey(name)) {
true -> {
if (c != null && known) {
// TODO root was already definitively matched, emit ambiguous error.
return null
}
null -> {
if (c != null) {
if (known) {
continue
} else {
// TODO we have more than one possible match, emit ambiguous error.
return null
}
c = rex(type, rexOpVarResolved(i))
known = true
}
null -> {
if (c != null) {
if (known) {
continue
} else {
// TODO we have more than one possible match, emit ambiguous error.
return null
}
c = rex(type, rexOpVarResolved(i))
known = false
}
false -> continue
c = rex(type, rexOpVarResolved(i))
known = false
}
false -> continue
}
}
return c
Expand Down Expand Up @@ -152,4 +152,45 @@ internal class TypeEnv(public val schema: List<Rel.Binding>) {
val closed = constraints.contains(TupleConstraint.Open(false))
return if (closed) false else null
}

/**
* Searches for the [BindingName] within the given [StaticType].
*
* Returns
* - true iff known to contain key
* - false iff known to NOT contain key
* - null iff NOT known to contain key
*
* @param name
* @return
*/
private fun StaticType.containsKey(name: BindingName): Boolean? {
return when (val type = this.flatten()) {
is StructType -> type.containsKey(name)
is AnyOfType -> {
var anyKnownToContainKey = false
var anyKnownToNotContainKey = false
var anyNotKnownToContainKey = false
for (t in type.allTypes) {
val containsKey = t.containsKey(name)
anyKnownToContainKey = anyKnownToContainKey || (containsKey == true)
anyKnownToNotContainKey = anyKnownToNotContainKey || (containsKey == false)
anyNotKnownToContainKey = anyNotKnownToContainKey || (containsKey == null)
}
when {
// There are:
// - No subtypes that are known to not contain the key
// - No subtypes that are not known to contain the key
anyKnownToNotContainKey.not() && anyNotKnownToContainKey.not() -> true
// There are:
// - No subtypes that are known to contain the key
// - No subtypes that are not known to contain the key
anyKnownToContainKey.not() && anyNotKnownToContainKey.not() -> false
else -> null
}
}
is AnyType -> null
else -> false
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ internal fun StaticType.toRuntimeType(): PartiQLValueType {
// handle anyOf(null, T) cases
val t = types.filter { it !is NullType && it !is MissingType }
return if (t.size != 1) {
error("Cannot have a UNION runtime type: $this")
PartiQLValueType.ANY
} else {
t.first().asRuntimeType()
}
Expand Down
Loading
Loading