Skip to content

Commit

Permalink
Keep aggregation in Calcite consistent with current PPL behavior
Browse files Browse the repository at this point in the history
Signed-off-by: Lantao Jin <[email protected]>
  • Loading branch information
LantaoJin committed Mar 9, 2025
1 parent dcf2057 commit 9f2911a
Show file tree
Hide file tree
Showing 38 changed files with 1,489 additions and 244 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ public enum Key {
CALCITE_ENGINE_ENABLED("plugins.calcite.enabled"),
CALCITE_FALLBACK_ALLOWED("plugins.calcite.fallback.allowed"),
CALCITE_PUSHDOWN_ENABLED("plugins.calcite.pushdown.enabled"),
CALCITE_LEGACY_ENABLED("plugins.calcite.legacy.enabled"),

/** Query Settings. */
FIELD_TYPE_TOLERANCE("plugins.query.field_type_tolerance"),
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.sql.ast.expression;

import java.util.Optional;

/** marker interface for numeric based count aggregation (specific number of returned results) */
public interface CountedAggregation {
Optional<Literal> getResults();
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.sql.ast.expression;

import java.util.Collections;
import java.util.List;
import java.util.Optional;
import lombok.EqualsAndHashCode;
import lombok.Getter;
import lombok.ToString;
import org.opensearch.sql.ast.tree.Aggregation;

/**
* Logical plan node of Rare (Aggregation) command, the interface for building aggregation actions
* in queries.
*/
@ToString
@Getter
@EqualsAndHashCode(callSuper = true)
public class RareAggregation extends Aggregation implements CountedAggregation {
private final Optional<Literal> results;

/** Aggregation Constructor without span and argument. */
public RareAggregation(
Optional<Literal> results,
List<UnresolvedExpression> aggExprList,
List<UnresolvedExpression> sortExprList,
List<UnresolvedExpression> groupExprList) {
super(aggExprList, sortExprList, groupExprList, null, Collections.emptyList());
this.results = results;
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.sql.ast.expression;

import java.util.Collections;
import java.util.List;
import java.util.Optional;
import lombok.EqualsAndHashCode;
import lombok.Getter;
import lombok.ToString;
import org.opensearch.sql.ast.tree.Aggregation;

/**
* Logical plan node of Top (Aggregation) command, the interface for building aggregation actions in
* queries.
*/
@ToString
@Getter
@EqualsAndHashCode(callSuper = true)
public class TopAggregation extends Aggregation implements CountedAggregation {
private final Optional<Literal> results;

/** Aggregation Constructor without span and argument. */
public TopAggregation(
Optional<Literal> results,
List<UnresolvedExpression> aggExprList,
List<UnresolvedExpression> sortExprList,
List<UnresolvedExpression> groupExprList) {
super(aggExprList, sortExprList, groupExprList, null, Collections.emptyList());
this.results = results;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,10 @@
import org.apache.calcite.rel.core.JoinRelType;
import org.apache.calcite.rex.RexCall;
import org.apache.calcite.rex.RexCorrelVariable;
import org.apache.calcite.rex.RexInputRef;
import org.apache.calcite.rex.RexLiteral;
import org.apache.calcite.rex.RexNode;
import org.apache.calcite.sql.fun.SqlStdOperatorTable;
import org.apache.calcite.tools.RelBuilder;
import org.apache.calcite.tools.RelBuilder.AggCall;
import org.apache.calcite.util.Holder;
Expand All @@ -35,9 +37,12 @@
import org.opensearch.sql.ast.Node;
import org.opensearch.sql.ast.expression.AllFields;
import org.opensearch.sql.ast.expression.Argument;
import org.opensearch.sql.ast.expression.CountedAggregation;
import org.opensearch.sql.ast.expression.Field;
import org.opensearch.sql.ast.expression.Let;
import org.opensearch.sql.ast.expression.Map;
import org.opensearch.sql.ast.expression.QualifiedName;
import org.opensearch.sql.ast.expression.TopAggregation;
import org.opensearch.sql.ast.expression.UnresolvedExpression;
import org.opensearch.sql.ast.expression.subquery.SubqueryExpression;
import org.opensearch.sql.ast.tree.Aggregation;
Expand All @@ -48,10 +53,12 @@
import org.opensearch.sql.ast.tree.Lookup;
import org.opensearch.sql.ast.tree.Project;
import org.opensearch.sql.ast.tree.Relation;
import org.opensearch.sql.ast.tree.Rename;
import org.opensearch.sql.ast.tree.Sort;
import org.opensearch.sql.ast.tree.SubqueryAlias;
import org.opensearch.sql.ast.tree.UnresolvedPlan;
import org.opensearch.sql.calcite.utils.JoinAndLookupUtils;
import org.opensearch.sql.exception.SemanticCheckException;

public class CalciteRelNodeVisitor extends AbstractNodeVisitor<RelNode, CalcitePlanContext> {

Expand Down Expand Up @@ -146,6 +153,30 @@ public RelNode visitProject(Project node, CalcitePlanContext context) {
return context.relBuilder.peek();
}

@Override
public RelNode visitRename(Rename node, CalcitePlanContext context) {
visitChildren(node, context);
List<String> originalNames = context.relBuilder.peek().getRowType().getFieldNames();
List<String> newNames = new ArrayList<>(originalNames);
for (Map renameMap : node.getRenameList()) {
if (renameMap.getTarget() instanceof Field t) {
String newName = t.getField().toString();
RexNode check = rexVisitor.analyze(renameMap.getOrigin(), context);
if (check instanceof RexInputRef ref) {
newNames.set(ref.getIndex(), newName);
} else {
throw new SemanticCheckException(
String.format("the original field %s cannot be resolved", renameMap.getOrigin()));
}
} else {
throw new SemanticCheckException(
String.format("the target expected to be field, but is %s", renameMap.getTarget()));
}
}
context.relBuilder.rename(newNames);
return context.relBuilder.peek();
}

@Override
public RelNode visitSort(Sort node, CalcitePlanContext context) {
visitChildren(node, context);
Expand Down Expand Up @@ -256,21 +287,91 @@ public RelNode visitAggregation(Aggregation node, CalcitePlanContext context) {
node.getAggExprList().stream()
.map(expr -> aggVisitor.analyze(expr, context))
.collect(Collectors.toList());
List<RexNode> groupByList =
node.getGroupExprList().stream()
.map(expr -> rexVisitor.analyze(expr, context))
.collect(Collectors.toList());

// The span column is always the first column in result whatever
// the order of span in query is first or last one
List<RexNode> groupByList = new ArrayList<>();
UnresolvedExpression span = node.getSpan();
if (!Objects.isNull(span)) {
RexNode spanRex = rexVisitor.analyze(span, context);
groupByList.add(spanRex);
// add span's group alias field (most recent added expression)
}
groupByList.addAll(
node.getGroupExprList().stream().map(expr -> rexVisitor.analyze(expr, context)).toList());

context.relBuilder.aggregate(context.relBuilder.groupKey(groupByList), aggList);

if (node instanceof CountedAggregation c) {
// handle top and rare through aggregate
List<RexNode> sortWithAliasList;
if (node instanceof TopAggregation) {
sortWithAliasList =
node.getSortExprList().stream()
.map(expr -> rexVisitor.analyze(expr, context))
.map(s -> context.relBuilder.call(SqlStdOperatorTable.DESC, s))
.toList();
} else {
sortWithAliasList =
node.getSortExprList().stream().map(expr -> rexVisitor.analyze(expr, context)).toList();
}
List<RexNode> sortList = removeAliasForSort(sortWithAliasList, context);

if (c.getResults().isPresent()) {
context.relBuilder.sortLimit(0, (Integer) c.getResults().get().getValue(), sortList);
} else {
context.relBuilder.sort(sortList);
}
// remove the sort list from projects in Top and Rare
context.relBuilder.projectExcept(sortList);
} else {
// handle normal aggregate
// TODO Should we keep alignment with V2 behaviour in new Calcite implementation?
// TODO how about add a legacy enable config to control behaviour in Calcite?
// Some behaviours between PPL and Databases are different.
// As an example, in command `stats count() by colA, colB`:
// 1. the sequence of output schema is different:
// In PPL v2, the sequence of output schema is "count, colA, colB".
// But in most databases, the sequence of output schema is "colA, colB, count".
// 2. the output order is different:
// In PPL v2, the order of output results is ordered by "colA + colB".
// But in most databases, the output order is random.
// User must add ORDER BY clause after GROUP BY clause to keep the results aligning.
// Following logic is to align with the PPL legacy behaviour.

// alignment for 1.sequence of output schema: adding order-by
// we use the groupByList instead of node.getSortExprList as input because
// the groupByList may include span column.
List<RexNode> groupByListWithoutAlias = removeAliasForSort(groupByList, context);
context.relBuilder.sort(groupByListWithoutAlias);

// alignment for 2.the output order: schema reordering
List<RexNode> outputFields = context.relBuilder.fields();
int numOfOutputFields = outputFields.size();
int numOfAggList = aggList.size();
List<RexNode> reordered = new ArrayList<>(numOfOutputFields);
// Add aggregation results first
List<RexNode> aggRexList =
outputFields.subList(numOfOutputFields - numOfAggList, numOfOutputFields);
reordered.addAll(aggRexList);
// Add group by columns
reordered.addAll(groupByListWithoutAlias);
context.relBuilder.project(reordered);
}
return context.relBuilder.peek();
}

/** TODO sort with aliased list will not work, so we add this converting */
private List<RexNode> removeAliasForSort(
List<RexNode> rexWithAliasList, CalcitePlanContext context) {
return rexWithAliasList.stream()
.map(context.rexBuilder::extractAlias)
.flatMap(Optional::stream)
.map(ref -> ((RexLiteral) ref).getValueAs(String.class))
.map(context.relBuilder::field)
.map(f -> (RexNode) f)
.toList();
}

@Override
public RelNode visitJoin(Join node, CalcitePlanContext context) {
List<UnresolvedPlan> children = node.getChildren();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -235,16 +235,17 @@ public RexNode visitSpan(Span node, CalcitePlanContext context) {
return context.rexBuilder.makeIntervalLiteral(new BigDecimal(millis), intervalQualifier);
} else {
// if the unit is not time base - create a math expression to bucket the span partitions
SqlTypeName type = field.getType().getSqlTypeName();
return context.rexBuilder.makeCall(
typeFactory.createSqlType(SqlTypeName.DOUBLE),
typeFactory.createSqlType(type),
SqlStdOperatorTable.MULTIPLY,
List.of(
context.rexBuilder.makeCall(
typeFactory.createSqlType(SqlTypeName.DOUBLE),
typeFactory.createSqlType(type),
SqlStdOperatorTable.FLOOR,
List.of(
context.rexBuilder.makeCall(
typeFactory.createSqlType(SqlTypeName.DOUBLE),
typeFactory.createSqlType(type),
SqlStdOperatorTable.DIVIDE,
List.of(field, value)))),
value));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,15 @@

package org.opensearch.sql.calcite;

import java.util.Optional;
import org.apache.calcite.avatica.util.TimeUnit;
import org.apache.calcite.rex.RexBuilder;
import org.apache.calcite.rex.RexCall;
import org.apache.calcite.rex.RexInputRef;
import org.apache.calcite.rex.RexLiteral;
import org.apache.calcite.rex.RexNode;
import org.apache.calcite.sql.SqlIntervalQualifier;
import org.apache.calcite.sql.SqlKind;
import org.apache.calcite.sql.fun.SqlStdOperatorTable;
import org.apache.calcite.sql.parser.SqlParserPos;
import org.opensearch.sql.ast.expression.SpanUnit;
Expand All @@ -23,6 +28,39 @@ public RexNode coalesce(RexNode... nodes) {
return this.makeCall(SqlStdOperatorTable.COALESCE, nodes);
}

/** extract the reference from the node */
public Optional<RexInputRef> extractRef(RexNode node) {
if (node == null) {
return Optional.empty();
} else if (node.getKind() == SqlKind.INPUT_REF) {
return Optional.of((RexInputRef) node);
} else if (node instanceof RexCall call) {
return call.getOperands().stream()
.map(this::extractRef)
.filter(Optional::isPresent)
.map(Optional::get)
.findFirst();
} else {
return Optional.empty();
}
}

public Optional<RexLiteral> extractAlias(RexNode node) {
if (node == null) {
return Optional.empty();
} else if (node.getKind() == SqlKind.AS) {
return Optional.of((RexLiteral) ((RexCall) node).getOperands().get(1));
} else if (node instanceof RexCall call) {
return call.getOperands().stream()
.map(this::extractAlias)
.filter(Optional::isPresent)
.map(Optional::get)
.findFirst();
} else {
return Optional.empty();
}
}

public RexNode equals(RexNode n1, RexNode n2) {
return this.makeCall(SqlStdOperatorTable.EQUALS, n1, n2);
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.sql.calcite.udf;

import org.apache.calcite.rel.type.RelDataType;
import org.apache.calcite.sql.SqlAggFunction;
import org.apache.calcite.sql.SqlCall;
import org.apache.calcite.sql.SqlFunctionCategory;
import org.apache.calcite.sql.SqlKind;
import org.apache.calcite.sql.type.OperandTypes;
import org.apache.calcite.sql.type.ReturnTypes;
import org.apache.calcite.sql.type.SqlTypeName;
import org.apache.calcite.sql.validate.SqlValidator;
import org.apache.calcite.sql.validate.SqlValidatorScope;
import org.apache.calcite.util.Optionality;

/**
* This class is used to support the legacy COUNT(*) function. It extends SqlAggFunction and
* overrides the deriveType method to return INTEGER type.
*/
public class LegacyPPLCountAggFunction extends SqlAggFunction {

public LegacyPPLCountAggFunction() {
super(
"COUNT",
null,
SqlKind.COUNT,
ReturnTypes.INTEGER, // Change return type to INTEGER
null,
OperandTypes.ONE_OR_MORE,
SqlFunctionCategory.NUMERIC,
false,
false,
Optionality.FORBIDDEN);
}

@Override
public RelDataType deriveType(SqlValidator validator, SqlValidatorScope scope, SqlCall call) {
return validator.getTypeFactory().createSqlType(SqlTypeName.INTEGER);
}
}
Loading

0 comments on commit 9f2911a

Please sign in to comment.