Skip to content

Commit 3bc9116

Browse files
committed
Add Thread.interrupted() checks
1 parent eee99ae commit 3bc9116

23 files changed

+390
-26
lines changed

lang/src/org/partiql/lang/CompilerPipeline.kt

+10-4
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ import org.partiql.lang.eval.*
2020
import org.partiql.lang.eval.builtins.*
2121
import org.partiql.lang.eval.builtins.storedprocedure.StoredProcedure
2222
import org.partiql.lang.syntax.*
23+
import org.partiql.lang.util.interruptibleFold
2324

2425
/**
2526
* Contains all of the information needed for processing steps.
@@ -180,7 +181,7 @@ interface CompilerPipeline {
180181
}
181182
}
182183

183-
private class CompilerPipelineImpl(
184+
internal class CompilerPipelineImpl(
184185
override val valueFactory: ExprValueFactory,
185186
private val parser: Parser,
186187
override val compileOptions: CompileOptions,
@@ -198,10 +199,15 @@ private class CompilerPipelineImpl(
198199
override fun compile(query: ExprNode): Expression {
199200
val context = StepContext(valueFactory, compileOptions, functions, procedures)
200201

201-
val preProcessedQuery = preProcessingSteps.fold(query) { currentExprNode, step ->
202-
step(currentExprNode, context)
203-
}
202+
val preProcessedQuery = executePreProcessingSteps(query, context)
204203

205204
return compiler.compile(preProcessedQuery)
206205
}
206+
207+
internal fun executePreProcessingSteps(
208+
query: ExprNode,
209+
context: StepContext
210+
) = preProcessingSteps.interruptibleFold(query) { currentExprNode, step ->
211+
step(currentExprNode, context)
212+
}
207213
}

lang/src/org/partiql/lang/ast/AstDeserialization.kt

+8-4
Original file line numberDiff line numberDiff line change
@@ -248,7 +248,7 @@ class AstDeserializerBuilder(val ion: IonSystem) {
248248
}
249249
}
250250

251-
private class AstDeserializerInternal(
251+
internal class AstDeserializerInternal(
252252
val astVersion: AstVersion,
253253
val ion: IonSystem,
254254
private val metaDeserializers: Map<String, MetaDeserializer>
@@ -262,7 +262,9 @@ private class AstDeserializerInternal(
262262
return deserializeExprNode(sexp)
263263
}
264264

265-
private fun validate(rootSexp: IonSexp) {
265+
internal fun validate(rootSexp: IonSexp) {
266+
checkThreadInterrupted()
267+
266268
val nodeTag = rootSexp.nodeTag // Throws if nodeTag is invalid for the current AstVersion
267269
val nodeArgs = rootSexp.args
268270

@@ -321,8 +323,9 @@ private class AstDeserializerInternal(
321323
/**
322324
* Given a serialized AST, return its [ExprNode] representation.
323325
*/
324-
private fun deserializeExprNode(metaOrTermOrExp: IonSexp): ExprNode =
325-
deserializeSexpMetaOrTerm(metaOrTermOrExp) { target, metas ->
326+
internal fun deserializeExprNode(metaOrTermOrExp: IonSexp): ExprNode {
327+
checkThreadInterrupted()
328+
return deserializeSexpMetaOrTerm(metaOrTermOrExp) { target, metas ->
326329
val nodeTag = target.nodeTag
327330
val targetArgs = target.args //args is an extension property--call it once for efficiency
328331
//.toList() forces immutability
@@ -417,6 +420,7 @@ private class AstDeserializerInternal(
417420
NodeTag.TYPE -> errInvalidContext(nodeTag)
418421
}
419422
}
423+
}
420424

421425
private fun deserializeLit(targetArgs: List<IonValue>, metas: MetaContainer) = Literal(targetArgs.first(), metas)
422426
private fun deserializeMissing(metas: MetaContainer) = LiteralMissing(metas)

lang/src/org/partiql/lang/ast/AstSerialization.kt

+2
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ import com.amazon.ion.IonSystem
2020
import org.partiql.lang.util.IonWriterContext
2121
import org.partiql.lang.util.asIonSexp
2222
import org.partiql.lang.util.case
23+
import org.partiql.lang.util.checkThreadInterrupted
2324
import kotlin.UnsupportedOperationException
2425

2526
/**
@@ -73,6 +74,7 @@ private class AstSerializerImpl(val astVersion: AstVersion, val ion: IonSystem):
7374

7475
private fun IonWriterContext.writeExprNode(expr: ExprNode): Unit =
7576
writeAsTerm(expr.metas) {
77+
checkThreadInterrupted()
7678
sexp {
7779
when (expr) {
7880
// Leaf nodes

lang/src/org/partiql/lang/ast/ExprNodeToStatement.kt

+2
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ package org.partiql.lang.ast
33
import com.amazon.ionelement.api.emptyMetaContainer
44
import com.amazon.ionelement.api.toIonElement
55
import org.partiql.lang.domains.PartiqlAst
6+
import org.partiql.lang.util.checkThreadInterrupted
67
import org.partiql.pig.runtime.SymbolPrimitive
78
import org.partiql.pig.runtime.asPrimitive
89

@@ -67,6 +68,7 @@ private fun ExprNode.toAstExec() : PartiqlAst.Statement {
6768
}
6869

6970
fun ExprNode.toAstExpr(): PartiqlAst.Expr {
71+
checkThreadInterrupted()
7072
val node = this
7173
val metas = this.metas.toIonElementMetaContainer()
7274

lang/src/org/partiql/lang/ast/StatementToExprNode.kt

+2
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ import com.amazon.ion.IonSystem
66
import com.amazon.ionelement.api.toIonValue
77
import org.partiql.lang.domains.PartiqlAst
88
import org.partiql.lang.domains.PartiqlAst.*
9+
import org.partiql.lang.util.checkThreadInterrupted
910

1011
import org.partiql.pig.runtime.SymbolPrimitive
1112
import org.partiql.lang.ast.SetQuantifier as ExprNodeSetQuantifier // Conflicts with PartiqlAst.SetQuantifier
@@ -69,6 +70,7 @@ private class StatementTransformer(val ion: IonSystem) {
6970
this.map { it.toExprNode() }
7071

7172
private fun Expr.toExprNode(): ExprNode {
73+
checkThreadInterrupted()
7274
val metas = this.metas.toPartiQlMetaContainer()
7375
return when (this) {
7476
is Expr.Missing -> LiteralMissing(metas)

lang/src/org/partiql/lang/ast/ast.kt

+18-2
Original file line numberDiff line numberDiff line change
@@ -24,19 +24,35 @@ import java.util.*
2424
sealed class AstNode : Iterable<AstNode> {
2525

2626
/**
27-
* returns all the children nodes.
27+
* Returns all the children nodes.
28+
*
29+
* This property is [deprecated](see https://github.com/partiql/partiql-lang-kotlin/issues/396). Use
30+
* one of the following PIG-generated classes to analyze AST nodes instead:
31+
*
32+
* - [org.partiql.lang.domains.PartiqlAst.Visitor]
33+
* - [org.partiql.lang.domains.PartiqlAst.VisitorFold]
2834
*/
35+
@Deprecated("DO NOT USE - see kdoc, see https://github.com/partiql/partiql-lang-kotlin/issues/396")
2936
abstract val children: List<AstNode>
3037

3138
/**
3239
* Depth first iterator over all nodes.
40+
*
41+
* While collecting child nodes, throws [InterruptedException] if the [Thread.interrupted] flag has been set.
42+
*
43+
* This property is [deprecated](see https://github.com/partiql/partiql-lang-kotlin/issues/396). Use
44+
* one of the following PIG-generated classes to analyze AST nodes instead:
45+
*
46+
* - [org.partiql.lang.domains.PartiqlAst.Visitor]
47+
* - [org.partiql.lang.domains.PartiqlAst.VisitorFold]
3348
*/
49+
@Deprecated("DO NOT USE - see kdoc for alternatives")
3450
override operator fun iterator(): Iterator<AstNode> {
3551
val allNodes = mutableListOf<AstNode>()
3652

3753
fun depthFirstSequence(node: AstNode) {
3854
allNodes.add(node)
39-
node.children.map { depthFirstSequence(it) }
55+
node.children.interruptibleMap { depthFirstSequence(it) }
4056
}
4157

4258
depthFirstSequence(this)

lang/src/org/partiql/lang/ast/passes/AstRewriterBase.kt

+6-3
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
package org.partiql.lang.ast.passes
1616

1717
import org.partiql.lang.ast.*
18+
import org.partiql.lang.util.checkThreadInterrupted
1819

1920
/**
2021
* Provides a minimal interface for an AST rewriter implementation.
@@ -28,11 +29,12 @@ interface AstRewriter {
2829
* This is the base-class for an AST rewriter which simply makes an exact copy of the original AST.
2930
* Simple rewrites can be performed by inheritors.
3031
*/
31-
@Deprecated("New rewriters should implement PIG's PartiqlAst.VisitorTransform instead")
32+
@Deprecated("New rewriters should implement PIG's VisitorTransformBase instead")
3233
open class AstRewriterBase : AstRewriter {
3334

34-
override fun rewriteExprNode(node: ExprNode): ExprNode =
35-
when (node) {
35+
override fun rewriteExprNode(node: ExprNode): ExprNode {
36+
checkThreadInterrupted()
37+
return when (node) {
3638
is Literal -> rewriteLiteral(node)
3739
is LiteralMissing -> rewriteLiteralMissing(node)
3840
is VariableReference -> rewriteVariableReference(node)
@@ -55,6 +57,7 @@ open class AstRewriterBase : AstRewriter {
5557
is DateTimeType.Date -> rewriteDate(node)
5658
is DateTimeType.Time -> rewriteTime(node)
5759
}
60+
}
5861

5962
open fun rewriteMetas(itemWithMetas: HasMetas): MetaContainer = itemWithMetas.metas
6063

lang/src/org/partiql/lang/ast/passes/AstVisitor.kt

+1-1
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ import org.partiql.lang.ast.*
2727
*
2828
* One `visit*` function is included for each base type in the AST.
2929
*/
30-
@Deprecated("Use AstNode#iterator() or AstNode#children()")
30+
@Deprecated("Use org.lang.partiql.domains.PartiqlAst.Visitor instead")
3131
interface AstVisitor {
3232
/**
3333
* Invoked by [AstWalker] for every instance of [ExprNode] encountered.

lang/src/org/partiql/lang/ast/passes/AstWalker.kt

+1
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ open class AstWalker(private val visitor: AstVisitor) {
2828

2929
protected open fun walkExprNode(vararg exprs: ExprNode?) {
3030
exprs.filterNotNull().forEach { expr: ExprNode ->
31+
checkThreadInterrupted()
3132
visitor.visitExprNode(expr)
3233

3334
when (expr) {

lang/src/org/partiql/lang/eval/EvaluatingCompiler.kt

+10-1
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,6 @@ import org.partiql.lang.eval.visitors.PartiqlAstSanityValidator
2828
import org.partiql.lang.syntax.SqlParser
2929
import org.partiql.lang.util.*
3030
import java.math.*
31-
import java.time.LocalDate
3231
import java.util.*
3332
import kotlin.collections.*
3433

@@ -207,6 +206,10 @@ internal class EvaluatingCompiler(
207206

208207
/**
209208
* Compiles an [ExprNode] tree to an [Expression].
209+
*
210+
* Checks [Thread.interrupted] before every expression and sub-expression is compiled
211+
* and throws [InterruptedException] if [Thread.interrupted] it has been set in the
212+
* hope that long running compilations may be aborted by the caller.
210213
*/
211214
fun compile(originalAst: ExprNode): Expression {
212215
val visitorTransformer = compileOptions.visitorTransformMode.createVisitorTransform()
@@ -257,7 +260,13 @@ internal class EvaluatingCompiler(
257260
*/
258261
fun eval(ast: ExprNode, session: EvaluationSession): ExprValue = compile(ast).eval(session)
259262

263+
/**
264+
* Compiles the specified [ExprNode] into a [ThunkEnv].
265+
*
266+
* This function will [InterruptedException] if [Thread.interrupted] has been set.
267+
*/
260268
private fun compileExprNode(expr: ExprNode): ThunkEnv {
269+
checkThreadInterrupted()
261270
return when (expr) {
262271
is Literal -> compileLiteral(expr)
263272
is LiteralMissing -> compileLiteralMissing(expr)

lang/src/org/partiql/lang/eval/visitors/FromSourceAliasVisitorTransform.kt

+2-2
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,9 @@ import org.partiql.pig.runtime.SymbolPrimitive
1515
*
1616
* If provided with a query that has all of the from source aliases already specified, an exact clone is returned.
1717
*/
18-
class FromSourceAliasVisitorTransform : PartiqlAst.VisitorTransform() {
18+
class FromSourceAliasVisitorTransform : VisitorTransformBase() {
1919

20-
private class InnerFromSourceAliasVisitorTransform : PartiqlAst.VisitorTransform() {
20+
private class InnerFromSourceAliasVisitorTransform : VisitorTransformBase() {
2121
private var fromSourceCounter = 0
2222

2323
override fun transformFromSourceScan_asAlias(node: PartiqlAst.FromSource.Scan): SymbolPrimitive? {

lang/src/org/partiql/lang/eval/visitors/GroupByItemAliasVisitorTransform.kt

+1-1
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ import org.partiql.pig.runtime.SymbolPrimitive
3333
*
3434
* If provided with a query with all of the group by item aliases already specified, an exact clone is returned.
3535
*/
36-
class GroupByItemAliasVisitorTransform(var nestLevel: Int = 0) : PartiqlAst.VisitorTransform() {
36+
class GroupByItemAliasVisitorTransform(var nestLevel: Int = 0) : VisitorTransformBase() {
3737

3838
override fun transformGroupBy(node: PartiqlAst.GroupBy): PartiqlAst.GroupBy {
3939
return PartiqlAst.build {

lang/src/org/partiql/lang/eval/visitors/PipelinedVisitorTransform.kt

+2-1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
package org.partiql.lang.eval.visitors
22

33
import org.partiql.lang.domains.PartiqlAst
4+
import org.partiql.lang.util.interruptibleFold
45

56
/**
67
* A simple visitor transformer that provides a pipeline of transformers to be executed in sequential order.
@@ -11,7 +12,7 @@ class PipelinedVisitorTransform(vararg transformers: PartiqlAst.VisitorTransform
1112
private val transformerList = transformers.toList()
1213

1314
override fun transformStatement(node: PartiqlAst.Statement): PartiqlAst.Statement =
14-
transformerList.fold(node) {
15+
transformerList.interruptibleFold(node) {
1516
intermediateNode, transformer ->
1617
transformer.transformStatement(intermediateNode)
1718
}

lang/src/org/partiql/lang/eval/visitors/SelectListItemAliasVisitorTransform.kt

+1-1
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ import org.partiql.pig.runtime.SymbolPrimitive
3232
*
3333
* ```
3434
*/
35-
class SelectListItemAliasVisitorTransform : PartiqlAst.VisitorTransform() {
35+
class SelectListItemAliasVisitorTransform : VisitorTransformBase() {
3636

3737
override fun transformProjectionProjectList(node: PartiqlAst.Projection.ProjectList): PartiqlAst.Projection {
3838
return PartiqlAst.build {

lang/src/org/partiql/lang/eval/visitors/SelectStarVisitorTransform.kt

+1-1
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ import org.partiql.lang.ast.UniqueNameMeta
55
import org.partiql.lang.domains.PartiqlAst
66
import org.partiql.lang.eval.errNoContext
77

8-
class SelectStarVisitorTransform : PartiqlAst.VisitorTransform() {
8+
class SelectStarVisitorTransform : VisitorTransformBase() {
99

1010
/**
1111
* Copies all parts of [PartiqlAst.Expr.Select] except [newProjection] for [PartiqlAst.Projection].

lang/src/org/partiql/lang/eval/visitors/StaticTypeVisitorTransform.kt

+1-1
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ enum class StaticTypeVisitorTransformConstraints {
6060
*/
6161
class StaticTypeVisitorTransform(private val ion: IonSystem,
6262
globalBindings: Bindings<StaticType>,
63-
constraints: Set<StaticTypeVisitorTransformConstraints> = setOf()) : PartiqlAst.VisitorTransform() {
63+
constraints: Set<StaticTypeVisitorTransformConstraints> = setOf()) : VisitorTransformBase() {
6464

6565
/** Used to allow certain binding lookups to occur directly in the global scope. */
6666
private val globalEnv = wrapBindings(globalBindings, 0)

lang/src/org/partiql/lang/eval/visitors/SubstitutionVisitorTransform.kt

+2-2
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ data class SubstitutionPair(val target: PartiqlAst.Expr, val replacement: Partiq
3535
*
3636
* This class is `open` to allow subclasses to restrict the nodes to which the substitution should occur.
3737
*/
38-
open class SubstitutionVisitorTransform(protected val substitutions: Map<PartiqlAst.Expr, SubstitutionPair>): PartiqlAst.VisitorTransform() {
38+
open class SubstitutionVisitorTransform(protected val substitutions: Map<PartiqlAst.Expr, SubstitutionPair>): VisitorTransformBase() {
3939

4040
/**
4141
* If [node] matches any of the target nodes in [substitutions], replaces the node with the replacement.
@@ -59,7 +59,7 @@ open class SubstitutionVisitorTransform(protected val substitutions: Map<Partiql
5959
* After .copy() and copying metas is added to PIG (https://github.com/partiql/partiql-ir-generator/pull/53) change
6060
* this and its usages to use .copy().
6161
*/
62-
inner class MetaVisitorTransform(private val newMetas: MetaContainer) : PartiqlAst.VisitorTransform() {
62+
inner class MetaVisitorTransform(private val newMetas: MetaContainer) : VisitorTransformBase() {
6363
override fun transformMetas(metas: MetaContainer): MetaContainer = newMetas
6464
}
6565

lang/src/org/partiql/lang/eval/visitors/VisitorTransformBase.kt

+13-1
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,23 @@
11
package org.partiql.lang.eval.visitors
22

33
import org.partiql.lang.domains.PartiqlAst
4+
import org.partiql.lang.util.checkThreadInterrupted
45

56
/**
6-
* Base-class for visitor transforms that provides additional functions outside of [PartiqlAst.VisitorTransform].
7+
* Base-class for visitor transforms that provides additional `transform*` functions that outside of
8+
* the PIG-generated [PartiqlAst.VisitorTransform] class and adds a [Thread.interrupted] check
9+
* to [transformExpr].
10+
*
11+
* All transforms should derive from this class instead of [PartiqlAst.VisitorTransform] so that they can
12+
* be interrupted of they take a long time to process large ASTs.
713
*/
814
abstract class VisitorTransformBase : PartiqlAst.VisitorTransform() {
15+
16+
override fun transformExpr(node: PartiqlAst.Expr): PartiqlAst.Expr {
17+
checkThreadInterrupted()
18+
return super.transformExpr(node)
19+
}
20+
921
/**
1022
* Transforms the [PartiqlAst.Expr.Select] expression following the PartiQL evaluation order. That is:
1123
*

lang/src/org/partiql/lang/eval/visitors/VisitorTransforms.kt

+1-1
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,6 @@ fun basicVisitorTransforms() = PipelinedVisitorTransform(
2424

2525
/** A stateless visitor transform that returns the input. */
2626
@JvmField
27-
internal val IDENTITY_VISITOR_TRANSFORM: PartiqlAst.VisitorTransform = object : PartiqlAst.VisitorTransform() {
27+
internal val IDENTITY_VISITOR_TRANSFORM: PartiqlAst.VisitorTransform = object : VisitorTransformBase() {
2828
override fun transformStatement(node: PartiqlAst.Statement): PartiqlAst.Statement = node
2929
}

lang/src/org/partiql/lang/syntax/SqlParser.kt

+6
Original file line numberDiff line numberDiff line change
@@ -1039,12 +1039,17 @@ class SqlParser(private val ion: IonSystem) : Parser {
10391039
/**
10401040
* Parses the given token list.
10411041
*
1042+
* Throws [InterruptedException] if [Thread.interrupted] is set. This is the best place to do
1043+
* that for the parser because this is the main function called to parse an expression and so
1044+
* is called quite frequently during parsing by many parts of the parser.
1045+
*
10421046
* @param precedence The precedence of the current expression parsing.
10431047
* A negative value represents the "top-level" parsing.
10441048
*
10451049
* @return The parse tree for the given expression.
10461050
*/
10471051
internal fun List<Token>.parseExpression(precedence: Int = -1): ParseNode {
1052+
checkThreadInterrupted()
10481053
var expr = parseUnaryTerm()
10491054
var rem = expr.remaining
10501055

@@ -2815,6 +2820,7 @@ class SqlParser(private val ion: IonSystem) : Parser {
28152820
* If [dmlListTokenSeen] is true, it means it has been encountered at least once before while traversing the parse tree.
28162821
*/
28172822
private fun validateTopLevelNodes(node: ParseNode, level: Int, topLevelTokenSeen: Boolean, dmlListTokenSeen: Boolean) {
2823+
checkThreadInterrupted()
28182824
val isTopLevelType = when (node.type.isDml) {
28192825
// DML_LIST token type allows multiple DML keywords to be used in the same statement.
28202826
// Hence, DML keyword tokens are not treated as top level tokens if present with the DML_LIST token type

0 commit comments

Comments
 (0)