Skip to content

Commit 6add804

Browse files
committed
Splits: handle braces within if/for/while better
Previously, we'd assume that a brace within `if` or `while` contains the entire condition and let the brace rules handle it; however, this brace could enclose only a part of the condition and hence needs to be handled differently then. Additionally, we ignored the danglingParentheses.ctrlSite parameter in these cases as well.
1 parent f311e38 commit 6add804

File tree

10 files changed

+48
-39
lines changed

10 files changed

+48
-39
lines changed

scalafmt-core/shared/src/main/scala/org/scalafmt/internal/Splits.scala

+13-11
Original file line numberDiff line numberDiff line change
@@ -1977,23 +1977,18 @@ object SplitsAfterLeftParen extends Splits {
19771977
}
19781978

19791979
private def get2(implicit ft: FT, fo: FormatOps, cfg: ScalafmtConfig) =
1980-
ft.right match {
1980+
getIfForWhile.getOrElse(ft.right match {
19811981
case _: T.LeftBrace => Seq(Split(NoSplit, 0))
1982-
case _ => getIfForWhile.getOrElse(getRest)
1983-
}
1982+
case _ => getRest
1983+
})
19841984

19851985
private def getIfForWhile(implicit
19861986
ft: FT,
19871987
fo: FormatOps,
19881988
cfg: ScalafmtConfig,
19891989
) = {
19901990
import fo._, tokens._, ft._
1991-
val ifForWhile = leftOwner match { // If/For/While/For with (
1992-
case t: Term.EnumeratorsBlock => getHeadOpt(t).contains(ft)
1993-
case _: Term.If | _: Term.While => !isTokenHeadOrBefore(left, leftOwner)
1994-
case _ => false
1995-
}
1996-
if (ifForWhile) Some {
1991+
def impl(enclosedInBraces: => Boolean) = Some {
19971992
val close = matchingLeft(ft)
19981993
val indentLen = cfg.indent.ctrlSite.getOrElse(cfg.indent.callSite)
19991994
def indents =
@@ -2008,7 +2003,7 @@ object SplitsAfterLeftParen extends Splits {
20082003
if (cfg.align.openParenCtrlSite) baseNoSplit().withIndents(indents)
20092004
.withPolicy(penalizeNewlines)
20102005
.andPolicy(decideNewlinesOnlyBeforeCloseOnBreak(close))
2011-
else baseNoSplit().withSingleLine(close)
2006+
else baseNoSplit().withSingleLine(close, ignore = enclosedInBraces)
20122007
Seq(
20132008
noSplit,
20142009
Split(Newline, 1).withIndent(indentLen, close, Before)
@@ -2019,7 +2014,14 @@ object SplitsAfterLeftParen extends Splits {
20192014
baseNoSplit(Newline).withIndents(indents).withPolicy(penalizeNewlines),
20202015
)
20212016
}
2022-
else None
2017+
def withCond(t: Tree.WithCond) =
2018+
if (isTokenHeadOrBefore(left, t)) None else impl(isEnclosedInBraces(t.cond))
2019+
leftOwner match { // If/For/While/For with (
2020+
case t: Term.EnumeratorsBlock if getHeadOpt(t).contains(ft) => impl(false)
2021+
case t: Term.If => withCond(t)
2022+
case t: Term.While => withCond(t)
2023+
case _ => None
2024+
}
20232025
}
20242026

20252027
private def getRest(implicit ft: FT, fo: FormatOps, cfg: ScalafmtConfig) = {

scalafmt-tests-community/intellij/shared/src/test/scala/org/scalafmt/community/intellij/CommunityIntellijScalaSuite.scala

+2-2
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ abstract class CommunityIntellijScalaSuite(name: String)
1313
class CommunityIntellijScala_2024_2_Suite
1414
extends CommunityIntellijScalaSuite("intellij-scala-2024.2") {
1515

16-
override protected def totalStatesVisited: Option[Int] = Some(59395995)
16+
override protected def totalStatesVisited: Option[Int] = Some(59396121)
1717

1818
override protected def builds = Seq {
1919
getBuild(
@@ -54,7 +54,7 @@ class CommunityIntellijScala_2024_2_Suite
5454
class CommunityIntellijScala_2024_3_Suite
5555
extends CommunityIntellijScalaSuite("intellij-scala-2024.3") {
5656

57-
override protected def totalStatesVisited: Option[Int] = Some(59610714)
57+
override protected def totalStatesVisited: Option[Int] = Some(59610840)
5858

5959
override protected def builds = Seq {
6060
getBuild(

scalafmt-tests-community/scala2/shared/src/test/scala/org/scalafmt/community/scala2/CommunityScala2Suite.scala

+1-1
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ class CommunityScala2_12Suite extends CommunityScala2Suite("scala-2.12") {
1818

1919
class CommunityScala2_13Suite extends CommunityScala2Suite("scala-2.13") {
2020

21-
override protected def totalStatesVisited: Option[Int] = Some(52987012)
21+
override protected def totalStatesVisited: Option[Int] = Some(52987016)
2222

2323
override protected def builds =
2424
Seq(getBuild("v2.13.14", dialects.Scala213, 1287))

scalafmt-tests-community/scala3/shared/src/test/scala/org/scalafmt/community/scala3/CommunityScala3Suite.scala

+2-2
Original file line numberDiff line numberDiff line change
@@ -9,15 +9,15 @@ abstract class CommunityScala3Suite(name: String)
99

1010
class CommunityScala3_2Suite extends CommunityScala3Suite("scala-3.2") {
1111

12-
override protected def totalStatesVisited: Option[Int] = Some(39188529)
12+
override protected def totalStatesVisited: Option[Int] = Some(39188577)
1313

1414
override protected def builds = Seq(getBuild("3.2.2", dialects.Scala32, 791))
1515

1616
}
1717

1818
class CommunityScala3_3Suite extends CommunityScala3Suite("scala-3.3") {
1919

20-
override protected def totalStatesVisited: Option[Int] = Some(42292843)
20+
override protected def totalStatesVisited: Option[Int] = Some(42292891)
2121

2222
override protected def builds = Seq(getBuild("3.3.3", dialects.Scala33, 861))
2323

scalafmt-tests-community/spark/shared/src/test/scala/org/scalafmt/community/spark/CommunitySparkSuite.scala

+2-2
Original file line numberDiff line numberDiff line change
@@ -9,15 +9,15 @@ abstract class CommunitySparkSuite(name: String)
99

1010
class CommunitySpark3_4Suite extends CommunitySparkSuite("spark-3.4") {
1111

12-
override protected def totalStatesVisited: Option[Int] = Some(85958083)
12+
override protected def totalStatesVisited: Option[Int] = Some(85958174)
1313

1414
override protected def builds = Seq(getBuild("v3.4.1", dialects.Scala213, 2585))
1515

1616
}
1717

1818
class CommunitySpark3_5Suite extends CommunitySparkSuite("spark-3.5") {
1919

20-
override protected def totalStatesVisited: Option[Int] = Some(90961547)
20+
override protected def totalStatesVisited: Option[Int] = Some(90961638)
2121

2222
override protected def builds = Seq(getBuild("v3.5.3", dialects.Scala213, 2756))
2323

scalafmt-tests/shared/src/test/resources/newlines/source_classic.stat

+5-3
Original file line numberDiff line numberDiff line change
@@ -9258,9 +9258,11 @@ object a {
92589258
}
92599259
>>>
92609260
object a {
9261-
if ({ rs = counterCells; rs } != null && {
9262-
m = rs.length; m
9263-
} > 0 && rs({ j = (m - 1) & h; j }) == null) {
9261+
if (
9262+
{ rs = counterCells; rs } != null && {
9263+
m = rs.length; m
9264+
} > 0 && rs({ j = (m - 1) & h; j }) == null
9265+
) {
92649266
rs(j) = r
92659267
}
92669268
}

scalafmt-tests/shared/src/test/resources/newlines/source_fold.stat

+4-3
Original file line numberDiff line numberDiff line change
@@ -8675,9 +8675,10 @@ object a {
86758675
}
86768676
>>>
86778677
object a {
8678-
if ({ rs = counterCells; rs } != null && { m = rs.length; m } > 0 && rs({
8679-
j = (m - 1) & h; j
8680-
}) == null) { rs(j) = r }
8678+
if (
8679+
{ rs = counterCells; rs } != null && { m = rs.length; m } > 0 &&
8680+
rs({ j = (m - 1) & h; j }) == null
8681+
) { rs(j) = r }
86818682
}
86828683
<<< #4133 case body enclosed, no break before lparen, no break after lparen
86838684
object a {

scalafmt-tests/shared/src/test/resources/newlines/source_keep.stat

+5-3
Original file line numberDiff line numberDiff line change
@@ -9063,9 +9063,11 @@ object a {
90639063
}
90649064
>>>
90659065
object a {
9066-
if ({ rs = counterCells; rs } != null && { m = rs.length; m } > 0 && rs({
9067-
j = (m - 1) & h; j
9068-
}) == null) {
9066+
if (
9067+
{ rs = counterCells; rs } != null && { m = rs.length; m } > 0 && rs({
9068+
j = (m - 1) & h; j
9069+
}) == null
9070+
) {
90699071
rs(j) = r
90709072
}
90719073
}

scalafmt-tests/shared/src/test/resources/newlines/source_unfold.stat

+13-11
Original file line numberDiff line numberDiff line change
@@ -9375,17 +9375,19 @@ object a {
93759375
}
93769376
>>>
93779377
object a {
9378-
if ({
9379-
rs = counterCells;
9380-
rs
9381-
} != null && {
9382-
m = rs.length;
9383-
m
9384-
} > 0 &&
9385-
rs({
9386-
j = (m - 1) & h;
9387-
j
9388-
}) == null) {
9378+
if (
9379+
{
9380+
rs = counterCells;
9381+
rs
9382+
} != null && {
9383+
m = rs.length;
9384+
m
9385+
} > 0 &&
9386+
rs({
9387+
j = (m - 1) & h;
9388+
j
9389+
}) == null
9390+
) {
93899391
rs(j) = r
93909392
}
93919393
}

scalafmt-tests/shared/src/test/scala/org/scalafmt/FormatTests.scala

+1-1
Original file line numberDiff line numberDiff line change
@@ -148,7 +148,7 @@ class FormatTests extends FunSuite with CanRunTests with FormatAssertions {
148148
val explored = Debug.explored.get()
149149
logger.debug(s"Total explored: $explored")
150150
if (!onlyUnit && !onlyManual)
151-
assertEquals(explored, 2540816, "total explored")
151+
assertEquals(explored, 2541354, "total explored")
152152
// TODO(olafur) don't block printing out test results.
153153
TestPlatformCompat.executeAndWait(PlatformFileOps.writeFile(
154154
FileOps.getPath("target", "index.html"),

0 commit comments

Comments
 (0)