From ffa537c69611339b54e14fe85d4da44e44d44a87 Mon Sep 17 00:00:00 2001 From: John Ed Quinn Date: Mon, 1 Apr 2024 14:21:14 -0700 Subject: [PATCH 1/3] Fixes aggregations of attribute references to values of union types --- CHANGELOG.md | 5 +- .../main/kotlin/org/partiql/planner/Errors.kt | 8 + .../partiql/planner/internal/PartiQLHeader.kt | 13 +- .../internal/transforms/PlanTransform.kt | 58 +++++- .../planner/internal/typer/PlanTyper.kt | 4 +- .../partiql/planner/internal/typer/TypeEnv.kt | 75 +++++-- .../planner/internal/typer/TypeUtils.kt | 2 +- .../internal/typer/PlanTyperTestsPorted.kt | 194 +++++++++++++++++- .../catalogs/default/aggregations/T1.ion | 17 ++ .../catalogs/default/aggregations/T2.ion | 17 ++ .../catalogs/default/aggregations/T3.ion | 17 ++ .../kotlin/org/partiql/types/StaticType.kt | 5 +- 12 files changed, 377 insertions(+), 38 deletions(-) create mode 100644 partiql-planner/src/testFixtures/resources/catalogs/default/aggregations/T1.ion create mode 100644 partiql-planner/src/testFixtures/resources/catalogs/default/aggregations/T2.ion create mode 100644 partiql-planner/src/testFixtures/resources/catalogs/default/aggregations/T3.ion diff --git a/CHANGELOG.md b/CHANGELOG.md index 89d4d7814f..c17373049b 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -33,6 +33,8 @@ Thank you to all who have contributed! ### Changed - Change `StaticType.AnyOfType`'s `.toString` to not perform `.flatten()` - Change modeling of `COALESCE` and `NULLIF` to dedicated nodes in logical plan +- Function resolution logic: Now the function resolver would match all possible candidate (based on if the argument can be coerced to the Signature parameter type). If there are multiple match it will first attempt to pick the one requires the least cast, then pick the function with the highest precedence. +- **Behavioral change**: The COUNT aggregate function now returns INT64. ### Deprecated - The current SqlBlock, SqlDialect, and SqlLayout are marked as deprecated and will be slightly changed in the next release. @@ -40,7 +42,7 @@ Thank you to all who have contributed! ### Fixed - `StaticType.flatten()` on an `AnyOfType` with `AnyType` will return `AnyType` - Updates the default `.sql()` method to use a more efficient (internal) printer implementation. - +- Fixes aggregations of attribute references to values of union types. This fix also allows for proper error handling by passing the UnknownAggregateFunction problem to the ProblemCallback. Please note that, with this change, the planner will no longer immediately throw an IllegalStateException for this exact scenario. ### Removed @@ -51,6 +53,7 @@ Thank you to all who have contributed! - @ - @rchowell - @alancai98 +- @johnedquinn ## [0.14.4] diff --git a/partiql-planner/src/main/kotlin/org/partiql/planner/Errors.kt b/partiql-planner/src/main/kotlin/org/partiql/planner/Errors.kt index 4866c350a5..5e21d9e45d 100644 --- a/partiql-planner/src/main/kotlin/org/partiql/planner/Errors.kt +++ b/partiql-planner/src/main/kotlin/org/partiql/planner/Errors.kt @@ -94,6 +94,14 @@ public sealed class PlanningProblemDetails( "Unknown function `$identifier($types)" }) + public data class UnknownAggregateFunction( + val identifier: String, + val args: List, + ) : PlanningProblemDetails(ProblemSeverity.ERROR, { + val types = args.joinToString { "<${it.toString().lowercase()}>" } + "Unknown aggregate function `$identifier($types)" + }) + public object ExpressionAlwaysReturnsNullOrMissing : PlanningProblemDetails( severity = ProblemSeverity.ERROR, messageFormatter = { "Expression always returns null or missing." } diff --git a/partiql-planner/src/main/kotlin/org/partiql/planner/internal/PartiQLHeader.kt b/partiql-planner/src/main/kotlin/org/partiql/planner/internal/PartiQLHeader.kt index 24e5e22bff..ab50c6ac09 100644 --- a/partiql-planner/src/main/kotlin/org/partiql/planner/internal/PartiQLHeader.kt +++ b/partiql-planner/src/main/kotlin/org/partiql/planner/internal/PartiQLHeader.kt @@ -702,13 +702,13 @@ internal object PartiQLHeader : Header() { private fun count() = listOf( FunctionSignature.Aggregation( name = "count", - returns = INT32, + returns = INT64, parameters = listOf(FunctionParameter("value", ANY)), isNullable = false, ), FunctionSignature.Aggregation( name = "count_star", - returns = INT32, + returns = INT64, parameters = listOf(), isNullable = false, ), @@ -741,6 +741,15 @@ internal object PartiQLHeader : Header() { ) } + /** + * According to SQL:1999 Section 6.16 Syntax Rule 14.c and Rule 14.d: + * > If AVG is specified and DT is exact numeric, then the declared type of the result is exact + * numeric with implementation-defined precision not less than the precision of DT and + * implementation-defined scale not less than the scale of DT. + * + * > If DT is approximate numeric, then the declared type of the result is approximate numeric + * with implementation-defined precision not less than the precision of DT. + */ private fun avg() = types.numeric.map { FunctionSignature.Aggregation( name = "avg", diff --git a/partiql-planner/src/main/kotlin/org/partiql/planner/internal/transforms/PlanTransform.kt b/partiql-planner/src/main/kotlin/org/partiql/planner/internal/transforms/PlanTransform.kt index 531eb1c093..6692e18c13 100644 --- a/partiql-planner/src/main/kotlin/org/partiql/planner/internal/transforms/PlanTransform.kt +++ b/partiql-planner/src/main/kotlin/org/partiql/planner/internal/transforms/PlanTransform.kt @@ -1,8 +1,11 @@ package org.partiql.planner.internal.transforms +import org.partiql.errors.Problem import org.partiql.errors.ProblemCallback +import org.partiql.errors.UNKNOWN_PROBLEM_LOCATION import org.partiql.plan.PlanNode import org.partiql.plan.partiQLPlan +import org.partiql.planner.PlanningProblemDetails import org.partiql.planner.internal.ir.Agg import org.partiql.planner.internal.ir.Catalog import org.partiql.planner.internal.ir.Fn @@ -12,7 +15,9 @@ import org.partiql.planner.internal.ir.Rel import org.partiql.planner.internal.ir.Rex import org.partiql.planner.internal.ir.Statement import org.partiql.planner.internal.ir.visitor.PlanBaseVisitor +import org.partiql.types.function.FunctionSignature import org.partiql.value.PartiQLValueExperimental +import org.partiql.value.PartiQLValueType /** * This is an internal utility to translate from the internal unresolved plan used for typing to the public plan IR. @@ -58,7 +63,7 @@ internal object PlanTransform : PlanBaseVisitor() { override fun visitAggResolved(node: Agg.Resolved, ctx: ProblemCallback) = org.partiql.plan.Agg(node.signature) override fun visitAggUnresolved(node: Agg.Unresolved, ctx: ProblemCallback): org.partiql.plan.Rex.Op { - error("Unresolved aggregation ${node.identifier}") + error("Internal error: This should have been handled somewhere else. Cause: Unresolved aggregation ${node.identifier}.") } override fun visitStatement(node: Statement, ctx: ProblemCallback) = @@ -342,11 +347,56 @@ internal object PlanTransform : PlanBaseVisitor() { groups = node.groups.map { visitRex(it, ctx) }, ) - override fun visitRelOpAggregateCall(node: Rel.Op.Aggregate.Call, ctx: ProblemCallback) = - org.partiql.plan.Rel.Op.Aggregate.Call( - agg = visitAgg(node.agg, ctx), + @OptIn(PartiQLValueExperimental::class) + override fun visitRelOpAggregateCall(node: Rel.Op.Aggregate.Call, ctx: ProblemCallback): org.partiql.plan.Rel.Op.Aggregate.Call { + val agg = when (val agg = node.agg) { + is Agg.Unresolved -> { + val name = agg.identifier.toNormalizedString() + ctx.invoke( + Problem( + UNKNOWN_PROBLEM_LOCATION, + PlanningProblemDetails.UnknownAggregateFunction( + agg.identifier.toString(), + node.args.map { it.type } + ) + ) + ) + org.partiql.plan.Agg( + FunctionSignature.Aggregation( + "UNKNOWN_AGG::$name", + returns = PartiQLValueType.MISSING, + parameters = emptyList() + ) + ) + } + is Agg.Resolved -> { + visitAggResolved(agg, ctx) + } + } + return org.partiql.plan.Rel.Op.Aggregate.Call( + agg = agg, args = node.args.map { visitRex(it, ctx) }, ) + } + + private fun Identifier.toNormalizedString(): String { + return when (this) { + is Identifier.Symbol -> this.toNormalizedString() + is Identifier.Qualified -> { + val toJoin = listOf(this.root) + this.steps + toJoin.joinToString(separator = ".") { ident -> + ident.toNormalizedString() + } + } + } + } + + private fun Identifier.Symbol.toNormalizedString(): String { + return when (this.caseSensitivity) { + Identifier.CaseSensitivity.SENSITIVE -> "\"${this.symbol}\"" + Identifier.CaseSensitivity.INSENSITIVE -> this.symbol + } + } override fun visitRelOpExclude(node: Rel.Op.Exclude, ctx: ProblemCallback) = org.partiql.plan.Rel.Op.Exclude( input = visitRel(node.input, ctx), diff --git a/partiql-planner/src/main/kotlin/org/partiql/planner/internal/typer/PlanTyper.kt b/partiql-planner/src/main/kotlin/org/partiql/planner/internal/typer/PlanTyper.kt index bf7845fb3c..cdd4b1ed9a 100644 --- a/partiql-planner/src/main/kotlin/org/partiql/planner/internal/typer/PlanTyper.kt +++ b/partiql-planner/src/main/kotlin/org/partiql/planner/internal/typer/PlanTyper.kt @@ -1268,8 +1268,8 @@ internal class PlanTyper( fun resolveAgg(agg: Agg.Unresolved, arguments: List): Pair { var missingArg = false val args = arguments.map { - val arg = visitRex(it, null) - if (arg.type.isMissable()) missingArg = true + val arg = visitRex(it, it.type) + if (arg.type is MissingType) missingArg = true arg } diff --git a/partiql-planner/src/main/kotlin/org/partiql/planner/internal/typer/TypeEnv.kt b/partiql-planner/src/main/kotlin/org/partiql/planner/internal/typer/TypeEnv.kt index d413abde04..44e01808fd 100644 --- a/partiql-planner/src/main/kotlin/org/partiql/planner/internal/typer/TypeEnv.kt +++ b/partiql-planner/src/main/kotlin/org/partiql/planner/internal/typer/TypeEnv.kt @@ -10,6 +10,8 @@ import org.partiql.planner.internal.ir.rexOpVarResolved import org.partiql.spi.BindingCase import org.partiql.spi.BindingName import org.partiql.spi.BindingPath +import org.partiql.types.AnyOfType +import org.partiql.types.AnyType import org.partiql.types.StaticType import org.partiql.types.StructType import org.partiql.types.TupleConstraint @@ -85,30 +87,28 @@ internal class TypeEnv(public val schema: List) { for (i in schema.indices) { val local = schema[i] val type = local.type - if (type is StructType) { - when (type.containsKey(name)) { - true -> { - if (c != null && known) { - // TODO root was already definitively matched, emit ambiguous error. - return null - } - c = rex(type, rexOpVarResolved(i)) - known = true + when (type.containsKey(name)) { + true -> { + if (c != null && known) { + // TODO root was already definitively matched, emit ambiguous error. + return null } - null -> { - if (c != null) { - if (known) { - continue - } else { - // TODO we have more than one possible match, emit ambiguous error. - return null - } + c = rex(type, rexOpVarResolved(i)) + known = true + } + null -> { + if (c != null) { + if (known) { + continue + } else { + // TODO we have more than one possible match, emit ambiguous error. + return null } - c = rex(type, rexOpVarResolved(i)) - known = false } - false -> continue + c = rex(type, rexOpVarResolved(i)) + known = false } + false -> continue } } return c @@ -152,4 +152,39 @@ internal class TypeEnv(public val schema: List) { val closed = constraints.contains(TupleConstraint.Open(false)) return if (closed) false else null } + + /** + * Searches for the [BindingName] within the given [StaticType]. + * + * Returns + * - true iff known to contain key + * - false iff known to NOT contain key + * - null iff NOT known to contain key + * + * @param name + * @return + */ + private fun StaticType.containsKey(name: BindingName): Boolean? { + return when (val type = this.flatten()) { + is StructType -> type.containsKey(name) + is AnyOfType -> { + val anyKnownToContainKey = type.allTypes.any { it.containsKey(name) == true } + val anyKnownToNotContainKey = type.allTypes.any { it.containsKey(name) == false } + val anyNotKnownToContainKey = type.allTypes.any { it.containsKey(name) == null } + when { + // There are: + // - No subtypes that are known to not contain the key + // - No subtypes that are not known to contain the key + anyKnownToNotContainKey.not() && anyNotKnownToContainKey.not() -> true + // There are: + // - No subtypes that are known to contain the key + // - No subtypes that are not known to contain the key + anyKnownToContainKey.not() && anyNotKnownToContainKey.not() -> false + else -> null + } + } + is AnyType -> null + else -> false + } + } } diff --git a/partiql-planner/src/main/kotlin/org/partiql/planner/internal/typer/TypeUtils.kt b/partiql-planner/src/main/kotlin/org/partiql/planner/internal/typer/TypeUtils.kt index bccd22c451..d83c45de5f 100644 --- a/partiql-planner/src/main/kotlin/org/partiql/planner/internal/typer/TypeUtils.kt +++ b/partiql-planner/src/main/kotlin/org/partiql/planner/internal/typer/TypeUtils.kt @@ -93,7 +93,7 @@ internal fun StaticType.toRuntimeType(): PartiQLValueType { // handle anyOf(null, T) cases val t = types.filter { it !is NullType && it !is MissingType } return if (t.size != 1) { - error("Cannot have a UNION runtime type: $this") + PartiQLValueType.ANY } else { t.first().asRuntimeType() } diff --git a/partiql-planner/src/test/kotlin/org/partiql/planner/internal/typer/PlanTyperTestsPorted.kt b/partiql-planner/src/test/kotlin/org/partiql/planner/internal/typer/PlanTyperTestsPorted.kt index db22b36dc4..3f20e25465 100644 --- a/partiql-planner/src/test/kotlin/org/partiql/planner/internal/typer/PlanTyperTestsPorted.kt +++ b/partiql-planner/src/test/kotlin/org/partiql/planner/internal/typer/PlanTyperTestsPorted.kt @@ -1,6 +1,8 @@ package org.partiql.planner.internal.typer import com.amazon.ionelement.api.loadSingleElement +import org.junit.jupiter.api.Disabled +import org.junit.jupiter.api.Test import org.junit.jupiter.api.assertThrows import org.junit.jupiter.api.extension.ExtensionContext import org.junit.jupiter.api.parallel.Execution @@ -129,6 +131,7 @@ class PlanTyperTestsPorted { } } map.entries.map { + println("Map Entry: ${it.key} to ${it.value}") it.key to MemoryConnector.Metadata.of(*it.value.toTypedArray()) } } @@ -3063,14 +3066,16 @@ class PlanTyperTestsPorted { fun aggregationCases() = listOf( SuccessTestCase( name = "AGGREGATE over INTS, without alias", - query = "SELECT a, COUNT(*), SUM(a), MIN(b) FROM << {'a': 1, 'b': 2} >> GROUP BY a", + query = "SELECT a, COUNT(*), COUNT(a), SUM(a), MIN(b), MAX(a) FROM << {'a': 1, 'b': 2} >> GROUP BY a", expected = BagType( StructType( fields = mapOf( "a" to StaticType.INT4, - "_1" to StaticType.INT4, - "_2" to StaticType.INT4.asNullable(), + "_1" to StaticType.INT8, + "_2" to StaticType.INT8, "_3" to StaticType.INT4.asNullable(), + "_4" to StaticType.INT4.asNullable(), + "_5" to StaticType.INT4.asNullable(), ), contentClosed = true, constraints = setOf( @@ -3083,12 +3088,13 @@ class PlanTyperTestsPorted { ), SuccessTestCase( name = "AGGREGATE over INTS, with alias", - query = "SELECT a, COUNT(*) AS c, SUM(a) AS s, MIN(b) AS m FROM << {'a': 1, 'b': 2} >> GROUP BY a", + query = "SELECT a, COUNT(*) AS c_s, COUNT(a) AS c, SUM(a) AS s, MIN(b) AS m FROM << {'a': 1, 'b': 2} >> GROUP BY a", expected = BagType( StructType( fields = mapOf( "a" to StaticType.INT4, - "c" to StaticType.INT4, + "c_s" to StaticType.INT8, + "c" to StaticType.INT8, "s" to StaticType.INT4.asNullable(), "m" to StaticType.INT4.asNullable(), ), @@ -3108,7 +3114,7 @@ class PlanTyperTestsPorted { StructType( fields = mapOf( "a" to StaticType.DECIMAL, - "c" to StaticType.INT4, + "c" to StaticType.INT8, "s" to StaticType.DECIMAL.asNullable(), "m" to StaticType.DECIMAL.asNullable(), ), @@ -3121,6 +3127,87 @@ class PlanTyperTestsPorted { ) ) ), + SuccessTestCase( + name = "AGGREGATE over nullable integers", + query = """ + SELECT + a AS a, + COUNT(*) AS count_star, + COUNT(a) AS count_a, + COUNT(b) AS count_b, + SUM(a) AS sum_a, + SUM(b) AS sum_b, + MIN(a) AS min_a, + MIN(b) AS min_b, + MAX(a) AS max_a, + MAX(b) AS max_b, + AVG(a) AS avg_a, + AVG(b) AS avg_b + FROM << + { 'a': 1, 'b': 2 }, + { 'a': 3, 'b': 4 }, + { 'a': 5, 'b': NULL } + >> GROUP BY a + """.trimIndent(), + expected = BagType( + StructType( + fields = mapOf( + "a" to StaticType.INT4, + "count_star" to StaticType.INT8, + "count_a" to StaticType.INT8, + "count_b" to StaticType.INT8, + "sum_a" to StaticType.INT4.asNullable(), + "sum_b" to StaticType.INT4.asNullable(), + "min_a" to StaticType.INT4.asNullable(), + "min_b" to StaticType.INT4.asNullable(), + "max_a" to StaticType.INT4.asNullable(), + "max_b" to StaticType.INT4.asNullable(), + "avg_a" to StaticType.INT4.asNullable(), + "avg_b" to StaticType.INT4.asNullable(), + ), + contentClosed = true, + constraints = setOf( + TupleConstraint.Open(false), + TupleConstraint.UniqueAttrs(true), + TupleConstraint.Ordered + ) + ) + ) + ), + SuccessTestCase( + name = "AGGREGATE over nullable integers", + query = """ + SELECT + COUNT(a) AS count_a + FROM << + { 'a': 1, 'b': 2 } + >> t1 INNER JOIN << + { 'c': 1, 'd': 3 } + >> t2 + ON t1.a = t1.c + AND ( + 1 = ( + SELECT COUNT(e) AS count_e + FROM << + { 'e': 10 } + >> t3 + ) + ); + """.trimIndent(), + expected = BagType( + StructType( + fields = mapOf( + "count_a" to StaticType.INT8 + ), + contentClosed = true, + constraints = setOf( + TupleConstraint.Open(false), + TupleConstraint.UniqueAttrs(true), + TupleConstraint.Ordered + ) + ) + ) + ), ) @JvmStatic @@ -3371,6 +3458,101 @@ class PlanTyperTestsPorted { // // Parameterized Tests // + + @Test + fun failingAggTest() { + val tc = SuccessTestCase( + name = "AGGREGATE over nullable integers", + query = """ + SELECT T1.a + FROM T1 + LEFT JOIN T2 AS T2_1 + ON T2_1.d = + ( + SELECT + CASE WHEN COUNT(f) = 1 THEN MAX(f) ELSE 0 END AS e + FROM T3 AS T3_mapping + ) + LEFT JOIN T2 AS T2_2 + ON T2_2.d = + ( + SELECT + CASE WHEN COUNT(f) = 1 THEN MAX(f) ELSE 0 END AS e + FROM T3 AS T3_mapping + ) + ; + """.trimIndent(), + expected = BagType( + StructType( + fields = mapOf( + "a" to StaticType.BOOL + ), + contentClosed = true, + constraints = setOf( + TupleConstraint.Open(false), + TupleConstraint.UniqueAttrs(true), + TupleConstraint.Ordered + ) + ) + ), + catalog = "aggregations" + ) + runTest(tc) + } + + @Test + @Disabled("The planner doesn't support heterogeneous input to aggregation functions (yet?).") + fun failingTest() { + val tc = SuccessTestCase( + name = "AGGREGATE over heterogeneous data", + query = """ + SELECT + a AS a, + COUNT(*) AS count_star, + COUNT(a) AS count_a, + COUNT(b) AS count_b, + SUM(a) AS sum_a, + SUM(b) AS sum_b, + MIN(a) AS min_a, + MIN(b) AS min_b, + MAX(a) AS max_a, + MAX(b) AS max_b, + AVG(a) AS avg_a, + AVG(b) AS avg_b + FROM << + { 'a': 1.0, 'b': 2.0 }, + { 'a': 3, 'b': 4 }, + { 'a': 5, 'b': NULL } + >> GROUP BY a + """.trimIndent(), + expected = BagType( + StructType( + fields = mapOf( + "a" to StaticType.DECIMAL, + "count_star" to StaticType.INT8, + "count_a" to StaticType.INT8, + "count_b" to StaticType.INT8, + "sum_a" to StaticType.DECIMAL.asNullable(), + "sum_b" to StaticType.DECIMAL.asNullable(), + "min_a" to StaticType.DECIMAL.asNullable(), + "min_b" to StaticType.DECIMAL.asNullable(), + "max_a" to StaticType.DECIMAL.asNullable(), + "max_b" to StaticType.DECIMAL.asNullable(), + "avg_a" to StaticType.DECIMAL.asNullable(), + "avg_b" to StaticType.DECIMAL.asNullable(), + ), + contentClosed = true, + constraints = setOf( + TupleConstraint.Open(false), + TupleConstraint.UniqueAttrs(true), + TupleConstraint.Ordered + ) + ) + ) + ) + runTest(tc) + } + @ParameterizedTest @ArgumentsSource(TestProvider::class) fun test(tc: TestCase) = runTest(tc) diff --git a/partiql-planner/src/testFixtures/resources/catalogs/default/aggregations/T1.ion b/partiql-planner/src/testFixtures/resources/catalogs/default/aggregations/T1.ion new file mode 100644 index 0000000000..f0defe828f --- /dev/null +++ b/partiql-planner/src/testFixtures/resources/catalogs/default/aggregations/T1.ion @@ -0,0 +1,17 @@ +{ + type: "bag", + items: { + type: "struct", + constraints: [ closed ], + fields: [ + { + name: "a", + type: "bool", + }, + { + name: "b", + type: "int32", + }, + ] + } +} diff --git a/partiql-planner/src/testFixtures/resources/catalogs/default/aggregations/T2.ion b/partiql-planner/src/testFixtures/resources/catalogs/default/aggregations/T2.ion new file mode 100644 index 0000000000..9f51c844e0 --- /dev/null +++ b/partiql-planner/src/testFixtures/resources/catalogs/default/aggregations/T2.ion @@ -0,0 +1,17 @@ +{ + type: "bag", + items: { + type: "struct", + constraints: [ closed ], + fields: [ + { + name: "c", + type: "bool", + }, + { + name: "d", + type: "int32", + }, + ] + } +} diff --git a/partiql-planner/src/testFixtures/resources/catalogs/default/aggregations/T3.ion b/partiql-planner/src/testFixtures/resources/catalogs/default/aggregations/T3.ion new file mode 100644 index 0000000000..40e7812425 --- /dev/null +++ b/partiql-planner/src/testFixtures/resources/catalogs/default/aggregations/T3.ion @@ -0,0 +1,17 @@ +{ + type: "bag", + items: { + type: "struct", + constraints: [ closed ], + fields: [ + { + name: "e", + type: "bool", + }, + { + name: "f", + type: "int32", + }, + ] + } +} diff --git a/partiql-types/src/main/kotlin/org/partiql/types/StaticType.kt b/partiql-types/src/main/kotlin/org/partiql/types/StaticType.kt index 66749a4c50..5eeba1b397 100644 --- a/partiql-types/src/main/kotlin/org/partiql/types/StaticType.kt +++ b/partiql-types/src/main/kotlin/org/partiql/types/StaticType.kt @@ -599,9 +599,10 @@ public data class StructType( get() = listOf(this) override fun toString(): String { - val firstSeveral = fields.take(3).joinToString { "${it.key}: ${it.value}" } + val firstFieldsSize = 15 + val firstSeveral = fields.take(firstFieldsSize).joinToString { "${it.key}: ${it.value}" } return when { - fields.size <= 3 -> "struct($firstSeveral, $constraints)" + fields.size <= firstFieldsSize -> "struct($firstSeveral, $constraints)" else -> "struct($firstSeveral, ... and ${fields.size - 3} other field(s), $constraints)" } } From 234518c251770b6ebfc8338182eaa2b6f5b15ae4 Mon Sep 17 00:00:00 2001 From: John Ed Quinn Date: Mon, 1 Apr 2024 14:22:57 -0700 Subject: [PATCH 2/3] Updates UndefinedVariable class, deprecates methods, and updates tests --- CHANGELOG.md | 2 + .../main/kotlin/org/partiql/planner/Errors.kt | 67 ++++++++++- .../planner/internal/typer/PlanTyper.kt | 35 +++++- .../internal/typer/PlanTyperTestsPorted.kt | 111 +++++++----------- 4 files changed, 136 insertions(+), 79 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index c17373049b..02df8c217c 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -38,6 +38,8 @@ Thank you to all who have contributed! ### Deprecated - The current SqlBlock, SqlDialect, and SqlLayout are marked as deprecated and will be slightly changed in the next release. +- Deprecates constructor and properties `variableName` and `caseSensitive` of `org.partiql.planner.PlanningProblemDetails.UndefinedVariable` + in favor of newly added constructor and properties `name` and `inScopeVariables`. ### Fixed - `StaticType.flatten()` on an `AnyOfType` with `AnyType` will return `AnyType` diff --git a/partiql-planner/src/main/kotlin/org/partiql/planner/Errors.kt b/partiql-planner/src/main/kotlin/org/partiql/planner/Errors.kt index 5e21d9e45d..6527441d46 100644 --- a/partiql-planner/src/main/kotlin/org/partiql/planner/Errors.kt +++ b/partiql-planner/src/main/kotlin/org/partiql/planner/Errors.kt @@ -2,6 +2,7 @@ package org.partiql.planner import org.partiql.errors.ProblemDetails import org.partiql.errors.ProblemSeverity +import org.partiql.plan.Identifier import org.partiql.types.StaticType /** @@ -24,15 +25,69 @@ public sealed class PlanningProblemDetails( public data class CompileError(val errorMessage: String) : PlanningProblemDetails(ProblemSeverity.ERROR, { errorMessage }) - public data class UndefinedVariable(val variableName: String, val caseSensitive: Boolean) : - PlanningProblemDetails( - ProblemSeverity.ERROR, - { - "Undefined variable '$variableName'." + - quotationHint(caseSensitive) + public data class UndefinedVariable( + val name: Identifier, + val inScopeVariables: Set + ) : PlanningProblemDetails( + ProblemSeverity.ERROR, + { + "Variable ${pretty(name)} does not exist in the database environment and is not an attribute of the following in-scope variables $inScopeVariables." + + quotationHint(isSymbolAndCaseSensitive(name)) + } + ) { + + @Deprecated("This will be removed in a future major version release.", replaceWith = ReplaceWith("name")) + val variableName: String = when (name) { + is Identifier.Symbol -> name.symbol + is Identifier.Qualified -> when (name.steps.size) { + 0 -> name.root.symbol + else -> name.steps.last().symbol + } + } + + @Deprecated("This will be removed in a future major version release.", replaceWith = ReplaceWith("name")) + val caseSensitive: Boolean = when (name) { + is Identifier.Symbol -> name.caseSensitivity == Identifier.CaseSensitivity.SENSITIVE + is Identifier.Qualified -> when (name.steps.size) { + 0 -> name.root.caseSensitivity == Identifier.CaseSensitivity.SENSITIVE + else -> name.steps.last().caseSensitivity == Identifier.CaseSensitivity.SENSITIVE } + } + + @Deprecated("This will be removed in a future major version release.", replaceWith = ReplaceWith("UndefinedVariable(Identifier, Set)")) + public constructor(variableName: String, caseSensitive: Boolean) : this( + Identifier.Symbol( + variableName, + when (caseSensitive) { + true -> Identifier.CaseSensitivity.SENSITIVE + false -> Identifier.CaseSensitivity.INSENSITIVE + } + ), + emptySet() ) + private companion object { + /** + * Used to check whether the [id] is an [Identifier.Symbol] and whether it is case-sensitive. This is helpful + * for giving the [quotationHint] to the user. + */ + private fun isSymbolAndCaseSensitive(id: Identifier): Boolean = when (id) { + is Identifier.Symbol -> id.caseSensitivity == Identifier.CaseSensitivity.SENSITIVE + is Identifier.Qualified -> false + } + + private fun pretty(id: Identifier): String = when (id) { + is Identifier.Symbol -> pretty(id) + is Identifier.Qualified -> (listOf(id.root) + id.steps).joinToString(".") { pretty(it) } + } + + private fun pretty(id: Identifier.Symbol): String = when (id.caseSensitivity) { + Identifier.CaseSensitivity.INSENSITIVE -> id.symbol + Identifier.CaseSensitivity.SENSITIVE -> "\"${id.symbol}\"" + } + } + } + public data class UndefinedDmlTarget(val variableName: String, val caseSensitive: Boolean) : PlanningProblemDetails( ProblemSeverity.ERROR, diff --git a/partiql-planner/src/main/kotlin/org/partiql/planner/internal/typer/PlanTyper.kt b/partiql-planner/src/main/kotlin/org/partiql/planner/internal/typer/PlanTyper.kt index cdd4b1ed9a..64198eb280 100644 --- a/partiql-planner/src/main/kotlin/org/partiql/planner/internal/typer/PlanTyper.kt +++ b/partiql-planner/src/main/kotlin/org/partiql/planner/internal/typer/PlanTyper.kt @@ -451,8 +451,8 @@ internal class PlanTyper( } val resolvedVar = env.resolve(path, locals, strategy) if (resolvedVar == null) { - handleUndefinedVariable(path.steps.last()) - return rex(ANY, rexOpErr("Undefined variable ${node.identifier}")) + val details = handleUndefinedVariable(node.identifier, locals.schema.map { it.name }.toSet()) + return rex(ANY, rexOpErr(details.message)) } return visitRex(resolvedVar, null) } @@ -1399,15 +1399,42 @@ internal class PlanTyper( // ERRORS - private fun handleUndefinedVariable(name: BindingName) { + /** + * Invokes [onProblem] with a newly created [PlanningProblemDetails.UndefinedVariable] and returns the + * [PlanningProblemDetails.UndefinedVariable]. + */ + private fun handleUndefinedVariable(name: Identifier, locals: Set): PlanningProblemDetails.UndefinedVariable { + val planName = name.toPlan() + val details = PlanningProblemDetails.UndefinedVariable(planName, locals) onProblem( Problem( sourceLocation = UNKNOWN_PROBLEM_LOCATION, - details = PlanningProblemDetails.UndefinedVariable(name.name, name.bindingCase == BindingCase.SENSITIVE) + details = details ) ) + return details + } + + private fun Identifier.CaseSensitivity.toPlan(): org.partiql.plan.Identifier.CaseSensitivity = when (this) { + Identifier.CaseSensitivity.SENSITIVE -> org.partiql.plan.Identifier.CaseSensitivity.SENSITIVE + Identifier.CaseSensitivity.INSENSITIVE -> org.partiql.plan.Identifier.CaseSensitivity.INSENSITIVE + } + + private fun Identifier.toPlan(): org.partiql.plan.Identifier = when (this) { + is Identifier.Symbol -> this.toPlan() + is Identifier.Qualified -> this.toPlan() } + private fun Identifier.Symbol.toPlan(): org.partiql.plan.Identifier.Symbol = org.partiql.plan.Identifier.Symbol( + this.symbol, + this.caseSensitivity.toPlan() + ) + + private fun Identifier.Qualified.toPlan(): org.partiql.plan.Identifier.Qualified = org.partiql.plan.Identifier.Qualified( + this.root.toPlan(), + this.steps.map { it.toPlan() } + ) + private fun handleUnexpectedType(actual: StaticType, expected: Set) { onProblem( Problem( diff --git a/partiql-planner/src/test/kotlin/org/partiql/planner/internal/typer/PlanTyperTestsPorted.kt b/partiql-planner/src/test/kotlin/org/partiql/planner/internal/typer/PlanTyperTestsPorted.kt index 3f20e25465..88cf544510 100644 --- a/partiql-planner/src/test/kotlin/org/partiql/planner/internal/typer/PlanTyperTestsPorted.kt +++ b/partiql-planner/src/test/kotlin/org/partiql/planner/internal/typer/PlanTyperTestsPorted.kt @@ -15,6 +15,7 @@ import org.junit.jupiter.params.provider.MethodSource import org.partiql.errors.Problem import org.partiql.errors.UNKNOWN_PROBLEM_LOCATION import org.partiql.parser.PartiQLParser +import org.partiql.plan.Identifier import org.partiql.plan.PartiQLPlan import org.partiql.plan.Statement import org.partiql.plan.debug.PlanPrinter @@ -104,6 +105,18 @@ class PlanTyperTestsPorted { } } + private fun id(vararg parts: Identifier.Symbol): Identifier { + return when (parts.size) { + 0 -> error("Identifier requires more than one part.") + 1 -> parts.first() + else -> Identifier.Qualified(parts.first(), parts.drop(1)) + } + } + + private fun sensitive(part: String): Identifier.Symbol = Identifier.Symbol(part, Identifier.CaseSensitivity.SENSITIVE) + + private fun insensitive(part: String): Identifier.Symbol = Identifier.Symbol(part, Identifier.CaseSensitivity.INSENSITIVE) + /** * MemoryConnector.Factory from reading the resources in /resource_path.txt for Github CI/CD. */ @@ -131,7 +144,6 @@ class PlanTyperTestsPorted { } } map.entries.map { - println("Map Entry: ${it.key} to ${it.value}") it.key to MemoryConnector.Metadata.of(*it.value.toTypedArray()) } } @@ -805,7 +817,7 @@ class PlanTyperTestsPorted { problemHandler = assertProblemExists { Problem( UNKNOWN_PROBLEM_LOCATION, - PlanningProblemDetails.UndefinedVariable("a", false) + PlanningProblemDetails.UndefinedVariable(insensitive("a"), setOf("t1", "t2")) ) } ), @@ -2022,7 +2034,7 @@ class PlanTyperTestsPorted { problemHandler = assertProblemExists { Problem( UNKNOWN_PROBLEM_LOCATION, - PlanningProblemDetails.UndefinedVariable("unknown_col", false) + PlanningProblemDetails.UndefinedVariable(insensitive("unknown_col"), setOf("pets")) ) } ), @@ -2920,7 +2932,7 @@ class PlanTyperTestsPorted { problemHandler = assertProblemExists { Problem( UNKNOWN_PROBLEM_LOCATION, - PlanningProblemDetails.UndefinedVariable("main", true) + PlanningProblemDetails.UndefinedVariable(id(sensitive("pql"), sensitive("main")), setOf()) ) } ), @@ -2933,7 +2945,7 @@ class PlanTyperTestsPorted { problemHandler = assertProblemExists { Problem( UNKNOWN_PROBLEM_LOCATION, - PlanningProblemDetails.UndefinedVariable("pql", true) + PlanningProblemDetails.UndefinedVariable(sensitive("pql"), setOf()) ) } ), @@ -3177,27 +3189,28 @@ class PlanTyperTestsPorted { SuccessTestCase( name = "AGGREGATE over nullable integers", query = """ - SELECT - COUNT(a) AS count_a - FROM << - { 'a': 1, 'b': 2 } - >> t1 INNER JOIN << - { 'c': 1, 'd': 3 } - >> t2 - ON t1.a = t1.c - AND ( - 1 = ( - SELECT COUNT(e) AS count_e - FROM << - { 'e': 10 } - >> t3 + SELECT T1.a + FROM T1 + LEFT JOIN T2 AS T2_1 + ON T2_1.d = + ( + SELECT + CASE WHEN COUNT(f) = 1 THEN MAX(f) ELSE 0 END AS e + FROM T3 AS T3_mapping ) - ); + LEFT JOIN T2 AS T2_2 + ON T2_2.d = + ( + SELECT + CASE WHEN COUNT(f) = 1 THEN MAX(f) ELSE 0 END AS e + FROM T3 AS T3_mapping + ) + ; """.trimIndent(), expected = BagType( StructType( fields = mapOf( - "count_a" to StaticType.INT8 + "a" to StaticType.BOOL ), contentClosed = true, constraints = setOf( @@ -3206,8 +3219,9 @@ class PlanTyperTestsPorted { TupleConstraint.Ordered ) ) - ) - ), + ), + catalog = "aggregations" + ) ) @JvmStatic @@ -3459,47 +3473,6 @@ class PlanTyperTestsPorted { // Parameterized Tests // - @Test - fun failingAggTest() { - val tc = SuccessTestCase( - name = "AGGREGATE over nullable integers", - query = """ - SELECT T1.a - FROM T1 - LEFT JOIN T2 AS T2_1 - ON T2_1.d = - ( - SELECT - CASE WHEN COUNT(f) = 1 THEN MAX(f) ELSE 0 END AS e - FROM T3 AS T3_mapping - ) - LEFT JOIN T2 AS T2_2 - ON T2_2.d = - ( - SELECT - CASE WHEN COUNT(f) = 1 THEN MAX(f) ELSE 0 END AS e - FROM T3 AS T3_mapping - ) - ; - """.trimIndent(), - expected = BagType( - StructType( - fields = mapOf( - "a" to StaticType.BOOL - ), - contentClosed = true, - constraints = setOf( - TupleConstraint.Open(false), - TupleConstraint.UniqueAttrs(true), - TupleConstraint.Ordered - ) - ) - ), - catalog = "aggregations" - ) - runTest(tc) - } - @Test @Disabled("The planner doesn't support heterogeneous input to aggregation functions (yet?).") fun failingTest() { @@ -3824,7 +3797,7 @@ class PlanTyperTestsPorted { problemHandler = assertProblemExists { Problem( UNKNOWN_PROBLEM_LOCATION, - PlanningProblemDetails.UndefinedVariable("pets", false) + PlanningProblemDetails.UndefinedVariable(insensitive("pets"), emptySet()) ) } ), @@ -3857,7 +3830,7 @@ class PlanTyperTestsPorted { problemHandler = assertProblemExists { Problem( UNKNOWN_PROBLEM_LOCATION, - PlanningProblemDetails.UndefinedVariable("pets", false) + PlanningProblemDetails.UndefinedVariable(insensitive("pets"), emptySet()) ) } ), @@ -3924,7 +3897,7 @@ class PlanTyperTestsPorted { problemHandler = assertProblemExists { Problem( UNKNOWN_PROBLEM_LOCATION, - PlanningProblemDetails.UndefinedVariable("pets", false) + PlanningProblemDetails.UndefinedVariable(id(insensitive("ddb"), insensitive("pets")), emptySet()) ) } ), @@ -4194,7 +4167,7 @@ class PlanTyperTestsPorted { problemHandler = assertProblemExists { Problem( UNKNOWN_PROBLEM_LOCATION, - PlanningProblemDetails.UndefinedVariable("non_existing_column", false) + PlanningProblemDetails.UndefinedVariable(insensitive("non_existing_column"), emptySet()) ) } ), @@ -4249,7 +4222,7 @@ class PlanTyperTestsPorted { problemHandler = assertProblemExists { Problem( UNKNOWN_PROBLEM_LOCATION, - PlanningProblemDetails.UndefinedVariable("unknown_col", false) + PlanningProblemDetails.UndefinedVariable(insensitive("unknown_col"), setOf("orders")) ) } ), From d767d086cbad285b691fc288536dda23f0b3805f Mon Sep 17 00:00:00 2001 From: John Ed Quinn Date: Fri, 29 Mar 2024 15:30:30 -0700 Subject: [PATCH 3/3] Addresses PR feedback --- .../main/kotlin/org/partiql/planner/Errors.kt | 16 +++------- .../internal/transforms/PlanTransform.kt | 24 ++------------ .../planner/internal/typer/PlanTyper.kt | 31 ++----------------- .../partiql/planner/internal/typer/TypeEnv.kt | 12 +++++-- .../planner/internal/utils/PlanUtils.kt | 26 ++++++++++++++++ 5 files changed, 44 insertions(+), 65 deletions(-) create mode 100644 partiql-planner/src/main/kotlin/org/partiql/planner/internal/utils/PlanUtils.kt diff --git a/partiql-planner/src/main/kotlin/org/partiql/planner/Errors.kt b/partiql-planner/src/main/kotlin/org/partiql/planner/Errors.kt index 6527441d46..73058ede26 100644 --- a/partiql-planner/src/main/kotlin/org/partiql/planner/Errors.kt +++ b/partiql-planner/src/main/kotlin/org/partiql/planner/Errors.kt @@ -3,6 +3,7 @@ package org.partiql.planner import org.partiql.errors.ProblemDetails import org.partiql.errors.ProblemSeverity import org.partiql.plan.Identifier +import org.partiql.planner.internal.utils.PlanUtils import org.partiql.types.StaticType /** @@ -31,7 +32,8 @@ public sealed class PlanningProblemDetails( ) : PlanningProblemDetails( ProblemSeverity.ERROR, { - "Variable ${pretty(name)} does not exist in the database environment and is not an attribute of the following in-scope variables $inScopeVariables." + + val humanReadableName = PlanUtils.identifierToString(name) + "Variable $humanReadableName does not exist in the database environment and is not an attribute of the following in-scope variables $inScopeVariables." + quotationHint(isSymbolAndCaseSensitive(name)) } ) { @@ -75,16 +77,6 @@ public sealed class PlanningProblemDetails( is Identifier.Symbol -> id.caseSensitivity == Identifier.CaseSensitivity.SENSITIVE is Identifier.Qualified -> false } - - private fun pretty(id: Identifier): String = when (id) { - is Identifier.Symbol -> pretty(id) - is Identifier.Qualified -> (listOf(id.root) + id.steps).joinToString(".") { pretty(it) } - } - - private fun pretty(id: Identifier.Symbol): String = when (id.caseSensitivity) { - Identifier.CaseSensitivity.INSENSITIVE -> id.symbol - Identifier.CaseSensitivity.SENSITIVE -> "\"${id.symbol}\"" - } } } @@ -150,7 +142,7 @@ public sealed class PlanningProblemDetails( }) public data class UnknownAggregateFunction( - val identifier: String, + val identifier: Identifier, val args: List, ) : PlanningProblemDetails(ProblemSeverity.ERROR, { val types = args.joinToString { "<${it.toString().lowercase()}>" } diff --git a/partiql-planner/src/main/kotlin/org/partiql/planner/internal/transforms/PlanTransform.kt b/partiql-planner/src/main/kotlin/org/partiql/planner/internal/transforms/PlanTransform.kt index 6692e18c13..a814cc7a35 100644 --- a/partiql-planner/src/main/kotlin/org/partiql/planner/internal/transforms/PlanTransform.kt +++ b/partiql-planner/src/main/kotlin/org/partiql/planner/internal/transforms/PlanTransform.kt @@ -15,6 +15,7 @@ import org.partiql.planner.internal.ir.Rel import org.partiql.planner.internal.ir.Rex import org.partiql.planner.internal.ir.Statement import org.partiql.planner.internal.ir.visitor.PlanBaseVisitor +import org.partiql.planner.internal.utils.PlanUtils import org.partiql.types.function.FunctionSignature import org.partiql.value.PartiQLValueExperimental import org.partiql.value.PartiQLValueType @@ -351,12 +352,12 @@ internal object PlanTransform : PlanBaseVisitor() { override fun visitRelOpAggregateCall(node: Rel.Op.Aggregate.Call, ctx: ProblemCallback): org.partiql.plan.Rel.Op.Aggregate.Call { val agg = when (val agg = node.agg) { is Agg.Unresolved -> { - val name = agg.identifier.toNormalizedString() + val name = PlanUtils.identifierToString(visitIdentifier(agg.identifier, ctx)) ctx.invoke( Problem( UNKNOWN_PROBLEM_LOCATION, PlanningProblemDetails.UnknownAggregateFunction( - agg.identifier.toString(), + visitIdentifier(agg.identifier, ctx), node.args.map { it.type } ) ) @@ -379,25 +380,6 @@ internal object PlanTransform : PlanBaseVisitor() { ) } - private fun Identifier.toNormalizedString(): String { - return when (this) { - is Identifier.Symbol -> this.toNormalizedString() - is Identifier.Qualified -> { - val toJoin = listOf(this.root) + this.steps - toJoin.joinToString(separator = ".") { ident -> - ident.toNormalizedString() - } - } - } - } - - private fun Identifier.Symbol.toNormalizedString(): String { - return when (this.caseSensitivity) { - Identifier.CaseSensitivity.SENSITIVE -> "\"${this.symbol}\"" - Identifier.CaseSensitivity.INSENSITIVE -> this.symbol - } - } - override fun visitRelOpExclude(node: Rel.Op.Exclude, ctx: ProblemCallback) = org.partiql.plan.Rel.Op.Exclude( input = visitRel(node.input, ctx), items = node.items.map { visitRelOpExcludeItem(it, ctx) }, diff --git a/partiql-planner/src/main/kotlin/org/partiql/planner/internal/typer/PlanTyper.kt b/partiql-planner/src/main/kotlin/org/partiql/planner/internal/typer/PlanTyper.kt index 64198eb280..3b275d9d39 100644 --- a/partiql-planner/src/main/kotlin/org/partiql/planner/internal/typer/PlanTyper.kt +++ b/partiql-planner/src/main/kotlin/org/partiql/planner/internal/typer/PlanTyper.kt @@ -70,6 +70,7 @@ import org.partiql.planner.internal.ir.rexOpStructField import org.partiql.planner.internal.ir.rexOpTupleUnion import org.partiql.planner.internal.ir.statementQuery import org.partiql.planner.internal.ir.util.PlanRewriter +import org.partiql.planner.internal.transforms.PlanTransform import org.partiql.spi.BindingCase import org.partiql.spi.BindingName import org.partiql.spi.BindingPath @@ -1266,19 +1267,11 @@ internal class PlanTyper( * to each row of T and eliminating null values <--- all NULL values are eliminated as inputs */ fun resolveAgg(agg: Agg.Unresolved, arguments: List): Pair { - var missingArg = false val args = arguments.map { val arg = visitRex(it, it.type) - if (arg.type is MissingType) missingArg = true arg } - // - if (missingArg) { - handleAlwaysMissing() - return relOpAggregateCall(agg, listOf(rexErr("MISSING"))) to MissingType - } - // Try to match the arguments to functions defined in the catalog return when (val match = env.resolveAgg(agg, args)) { is FnMatch.Ok -> { @@ -1404,7 +1397,7 @@ internal class PlanTyper( * [PlanningProblemDetails.UndefinedVariable]. */ private fun handleUndefinedVariable(name: Identifier, locals: Set): PlanningProblemDetails.UndefinedVariable { - val planName = name.toPlan() + val planName = PlanTransform.visitIdentifier(name, onProblem) val details = PlanningProblemDetails.UndefinedVariable(planName, locals) onProblem( Problem( @@ -1415,26 +1408,6 @@ internal class PlanTyper( return details } - private fun Identifier.CaseSensitivity.toPlan(): org.partiql.plan.Identifier.CaseSensitivity = when (this) { - Identifier.CaseSensitivity.SENSITIVE -> org.partiql.plan.Identifier.CaseSensitivity.SENSITIVE - Identifier.CaseSensitivity.INSENSITIVE -> org.partiql.plan.Identifier.CaseSensitivity.INSENSITIVE - } - - private fun Identifier.toPlan(): org.partiql.plan.Identifier = when (this) { - is Identifier.Symbol -> this.toPlan() - is Identifier.Qualified -> this.toPlan() - } - - private fun Identifier.Symbol.toPlan(): org.partiql.plan.Identifier.Symbol = org.partiql.plan.Identifier.Symbol( - this.symbol, - this.caseSensitivity.toPlan() - ) - - private fun Identifier.Qualified.toPlan(): org.partiql.plan.Identifier.Qualified = org.partiql.plan.Identifier.Qualified( - this.root.toPlan(), - this.steps.map { it.toPlan() } - ) - private fun handleUnexpectedType(actual: StaticType, expected: Set) { onProblem( Problem( diff --git a/partiql-planner/src/main/kotlin/org/partiql/planner/internal/typer/TypeEnv.kt b/partiql-planner/src/main/kotlin/org/partiql/planner/internal/typer/TypeEnv.kt index 44e01808fd..1ed9c80db1 100644 --- a/partiql-planner/src/main/kotlin/org/partiql/planner/internal/typer/TypeEnv.kt +++ b/partiql-planner/src/main/kotlin/org/partiql/planner/internal/typer/TypeEnv.kt @@ -168,9 +168,15 @@ internal class TypeEnv(public val schema: List) { return when (val type = this.flatten()) { is StructType -> type.containsKey(name) is AnyOfType -> { - val anyKnownToContainKey = type.allTypes.any { it.containsKey(name) == true } - val anyKnownToNotContainKey = type.allTypes.any { it.containsKey(name) == false } - val anyNotKnownToContainKey = type.allTypes.any { it.containsKey(name) == null } + var anyKnownToContainKey = false + var anyKnownToNotContainKey = false + var anyNotKnownToContainKey = false + for (t in type.allTypes) { + val containsKey = t.containsKey(name) + anyKnownToContainKey = anyKnownToContainKey || (containsKey == true) + anyKnownToNotContainKey = anyKnownToNotContainKey || (containsKey == false) + anyNotKnownToContainKey = anyNotKnownToContainKey || (containsKey == null) + } when { // There are: // - No subtypes that are known to not contain the key diff --git a/partiql-planner/src/main/kotlin/org/partiql/planner/internal/utils/PlanUtils.kt b/partiql-planner/src/main/kotlin/org/partiql/planner/internal/utils/PlanUtils.kt new file mode 100644 index 0000000000..760abda76b --- /dev/null +++ b/partiql-planner/src/main/kotlin/org/partiql/planner/internal/utils/PlanUtils.kt @@ -0,0 +1,26 @@ +package org.partiql.planner.internal.utils + +import org.partiql.plan.Identifier + +internal object PlanUtils { + + /** + * Transforms an identifier to a human-readable string. + * + * Example output: aCaseInsensitiveCatalog."aCaseSensitiveSchema".aCaseInsensitiveTable + */ + fun identifierToString(node: Identifier): String = when (node) { + is Identifier.Symbol -> identifierSymbolToString(node) + is Identifier.Qualified -> { + val toJoin = listOf(node.root) + node.steps + toJoin.joinToString(separator = ".") { ident -> + identifierSymbolToString(ident) + } + } + } + + private fun identifierSymbolToString(node: Identifier.Symbol) = when (node.caseSensitivity) { + Identifier.CaseSensitivity.SENSITIVE -> "\"${node.symbol}\"" + Identifier.CaseSensitivity.INSENSITIVE -> node.symbol + } +}