Skip to content

Commit 9d06a9e

Browse files
mpetruskasrowen
authored andcommitted
[SPARK-22393][SPARK-SHELL] spark-shell can't find imported types in class constructors, extends clause
## What changes were proposed in this pull request? [SPARK-22393](https://issues.apache.org/jira/browse/SPARK-22393) ## How was this patch tested? With a new test case in `RepSuite` ---- This code is a retrofit of the Scala [SI-9881](scala/bug#9881) bug fix, which never made it into the Scala 2.11 branches. Pushing these changes directly to the Scala repo is not practical (see: scala/scala#6195). Author: Mark Petruska <[email protected]> Closes #19846 from mpetruska/SPARK-22393.
1 parent 16adaf6 commit 9d06a9e

File tree

4 files changed

+191
-0
lines changed

4 files changed

+191
-0
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,74 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.repl
19+
20+
import scala.tools.nsc.interpreter.{ExprTyper, IR}
21+
22+
trait SparkExprTyper extends ExprTyper {
23+
24+
import repl._
25+
import global.{reporter => _, Import => _, _}
26+
import naming.freshInternalVarName
27+
28+
def doInterpret(code: String): IR.Result = {
29+
// interpret/interpretSynthetic may change the phase,
30+
// which would have unintended effects on types.
31+
val savedPhase = phase
32+
try interpretSynthetic(code) finally phase = savedPhase
33+
}
34+
35+
override def symbolOfLine(code: String): Symbol = {
36+
def asExpr(): Symbol = {
37+
val name = freshInternalVarName()
38+
// Typing it with a lazy val would give us the right type, but runs
39+
// into compiler bugs with things like existentials, so we compile it
40+
// behind a def and strip the NullaryMethodType which wraps the expr.
41+
val line = "def " + name + " = " + code
42+
43+
doInterpret(line) match {
44+
case IR.Success =>
45+
val sym0 = symbolOfTerm(name)
46+
// drop NullaryMethodType
47+
sym0.cloneSymbol setInfo exitingTyper(sym0.tpe_*.finalResultType)
48+
case _ => NoSymbol
49+
}
50+
}
51+
52+
def asDefn(): Symbol = {
53+
val old = repl.definedSymbolList.toSet
54+
55+
doInterpret(code) match {
56+
case IR.Success =>
57+
repl.definedSymbolList filterNot old match {
58+
case Nil => NoSymbol
59+
case sym :: Nil => sym
60+
case syms => NoSymbol.newOverloaded(NoPrefix, syms)
61+
}
62+
case _ => NoSymbol
63+
}
64+
}
65+
66+
def asError(): Symbol = {
67+
doInterpret(code)
68+
NoSymbol
69+
}
70+
71+
beSilentDuring(asExpr()) orElse beSilentDuring(asDefn()) orElse asError()
72+
}
73+
74+
}

repl/scala-2.11/src/main/scala/org/apache/spark/repl/SparkILoop.scala

+4
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,10 @@ class SparkILoop(in0: Option[BufferedReader], out: JPrintWriter)
3535
def this(in0: BufferedReader, out: JPrintWriter) = this(Some(in0), out)
3636
def this() = this(None, new JPrintWriter(Console.out, true))
3737

38+
override def createInterpreter(): Unit = {
39+
intp = new SparkILoopInterpreter(settings, out)
40+
}
41+
3842
val initializationCommands: Seq[String] = Seq(
3943
"""
4044
@transient val spark = if (org.apache.spark.repl.Main.sparkSession != null) {
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,103 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.repl
19+
20+
import scala.tools.nsc.Settings
21+
import scala.tools.nsc.interpreter._
22+
23+
class SparkILoopInterpreter(settings: Settings, out: JPrintWriter) extends IMain(settings, out) {
24+
self =>
25+
26+
override lazy val memberHandlers = new {
27+
val intp: self.type = self
28+
} with MemberHandlers {
29+
import intp.global._
30+
31+
override def chooseHandler(member: intp.global.Tree): MemberHandler = member match {
32+
case member: Import => new SparkImportHandler(member)
33+
case _ => super.chooseHandler (member)
34+
}
35+
36+
class SparkImportHandler(imp: Import) extends ImportHandler(imp: Import) {
37+
38+
override def targetType: Type = intp.global.rootMirror.getModuleIfDefined("" + expr) match {
39+
case NoSymbol => intp.typeOfExpression("" + expr)
40+
case sym => sym.tpe
41+
}
42+
43+
private def safeIndexOf(name: Name, s: String): Int = fixIndexOf(name, pos(name, s))
44+
private def fixIndexOf(name: Name, idx: Int): Int = if (idx == name.length) -1 else idx
45+
private def pos(name: Name, s: String): Int = {
46+
var i = name.pos(s.charAt(0), 0)
47+
val sLen = s.length()
48+
if (sLen == 1) return i
49+
while (i + sLen <= name.length) {
50+
var j = 1
51+
while (s.charAt(j) == name.charAt(i + j)) {
52+
j += 1
53+
if (j == sLen) return i
54+
}
55+
i = name.pos(s.charAt(0), i + 1)
56+
}
57+
name.length
58+
}
59+
60+
private def isFlattenedSymbol(sym: Symbol): Boolean =
61+
sym.owner.isPackageClass &&
62+
sym.name.containsName(nme.NAME_JOIN_STRING) &&
63+
sym.owner.info.member(sym.name.take(
64+
safeIndexOf(sym.name, nme.NAME_JOIN_STRING))) != NoSymbol
65+
66+
private def importableTargetMembers =
67+
importableMembers(exitingTyper(targetType)).filterNot(isFlattenedSymbol).toList
68+
69+
def isIndividualImport(s: ImportSelector): Boolean =
70+
s.name != nme.WILDCARD && s.rename != nme.WILDCARD
71+
def isWildcardImport(s: ImportSelector): Boolean =
72+
s.name == nme.WILDCARD
73+
74+
// non-wildcard imports
75+
private def individualSelectors = selectors filter isIndividualImport
76+
77+
override val importsWildcard: Boolean = selectors exists isWildcardImport
78+
79+
lazy val importableSymbolsWithRenames: List[(Symbol, Name)] = {
80+
val selectorRenameMap =
81+
individualSelectors.flatMap(x => x.name.bothNames zip x.rename.bothNames).toMap
82+
importableTargetMembers flatMap (m => selectorRenameMap.get(m.name) map (m -> _))
83+
}
84+
85+
override lazy val individualSymbols: List[Symbol] = importableSymbolsWithRenames map (_._1)
86+
override lazy val wildcardSymbols: List[Symbol] =
87+
if (importsWildcard) importableTargetMembers else Nil
88+
89+
}
90+
91+
}
92+
93+
object expressionTyper extends {
94+
val repl: SparkILoopInterpreter.this.type = self
95+
} with SparkExprTyper { }
96+
97+
override def symbolOfLine(code: String): global.Symbol =
98+
expressionTyper.symbolOfLine(code)
99+
100+
override def typeOfExpression(expr: String, silent: Boolean): global.Type =
101+
expressionTyper.typeOfExpression(expr, silent)
102+
103+
}

repl/src/test/scala/org/apache/spark/repl/ReplSuite.scala

+10
Original file line numberDiff line numberDiff line change
@@ -227,4 +227,14 @@ class ReplSuite extends SparkFunSuite {
227227
assertDoesNotContain("error: not found: value sc", output)
228228
}
229229

230+
test("spark-shell should find imported types in class constructors and extends clause") {
231+
val output = runInterpreter("local",
232+
"""
233+
|import org.apache.spark.Partition
234+
|class P(p: Partition)
235+
|class P(val index: Int) extends Partition
236+
""".stripMargin)
237+
assertDoesNotContain("error: not found: type Partition", output)
238+
}
239+
230240
}

0 commit comments

Comments
 (0)