Skip to content

Commit 7c68f7b

Browse files
committed
Code action for-comprehension: support overlapping param names
A chain of map/flatMap can reuse the same name, e.g. using subsequent `.map(x => x + 1)`. After converting to for-comprehension, those duplicate names need to be renamed. This PR tries to handle some common usages (like pattern matching, anonymous functions), but some are complex enough to justify to not support them, and the code action just gives up. Resolves #4069
1 parent 131b395 commit 7c68f7b

4 files changed

Lines changed: 549 additions & 46 deletions

File tree

metals/src/main/scala/scala/meta/internal/metals/codeactions/FlatMapToForComprehensionCodeAction.scala

Lines changed: 184 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -1,30 +1,31 @@
11
package scala.meta.internal.metals.codeactions
22

3+
import org.eclipse.lsp4j as l
4+
import org.eclipse.lsp4j.CodeActionParams
5+
36
import scala.annotation.tailrec
4-
import scala.concurrent.ExecutionContext
5-
import scala.concurrent.Future
6-
7-
import scala.meta.Case
8-
import scala.meta.Enumerator
9-
import scala.meta.Lit
10-
import scala.meta.Name
11-
import scala.meta.Pat
12-
import scala.meta.Term
13-
import scala.meta.Tree
14-
import scala.meta.Type
15-
import scala.meta.XtensionClassifiable
16-
import scala.meta.XtensionSyntax
7+
import scala.concurrent.{ExecutionContext, Future}
178
import scala.meta.inputs.Position
189
import scala.meta.internal.metals.Buffers
19-
import scala.meta.internal.metals.MetalsEnrichments._
20-
import scala.meta.internal.metals.codeactions.CodeAction
21-
import scala.meta.internal.metals.codeactions.CodeActionBuilder
10+
import scala.meta.internal.metals.MetalsEnrichments.*
2211
import scala.meta.internal.parsing.Trees
2312
import scala.meta.io.AbsolutePath
2413
import scala.meta.pc.CancelToken
25-
26-
import org.eclipse.lsp4j.CodeActionParams
27-
import org.eclipse.{lsp4j => l}
14+
import scala.meta.{
15+
Case,
16+
Enumerator,
17+
Init,
18+
Lit,
19+
Name,
20+
Pat,
21+
Stat,
22+
Term,
23+
Tree,
24+
Type,
25+
XtensionClassifiable,
26+
XtensionSyntax,
27+
}
28+
import scala.util.Try
2829

2930
class FlatMapToForComprehensionCodeAction(
3031
trees: Trees,
@@ -138,14 +139,15 @@ class FlatMapToForComprehensionCodeAction(
138139
* .filter(_ > 3)
139140
* }}}
140141
* <p>Now when the cursor is on `map`, we want to start the conversion
141-
* on `filter`` instead, which is the parentMost or `outerMost` apply.
142+
* on `filter` instead, which is the parentMost or `outerMost` apply.
143+
*
142144
* @param currentTermApply the termApply on which the cursor is
143145
* when invoking the code action
144146
* @param lastValidTermApply the last inner [[Term.Apply]] from the previous
145147
* iteration which had one of the functions of
146148
* `map`, `flatMap`, `filter`, `filterNot`, or `withFilter`
147149
* in its [[Term.Select]]
148-
* @return the `lastValidTermApply`` if the `currenTermApply` does not have
150+
* @return the `lastValidTermApply` if the `currenTermApply` does not have
149151
* an interesting function. Otherwise, the currentTermApply.
150152
*/
151153
@tailrec
@@ -273,8 +275,150 @@ class FlatMapToForComprehensionCodeAction(
273275
else None
274276
}
275277

278+
private def replaceNameInTermWithNewName(
279+
term: Term,
280+
nameGenerator: MetalsNames,
281+
nameToReplace: Name,
282+
): Option[(Pat, Term)] = {
283+
def replaceName(
284+
tree: Term,
285+
newName: Term.Name,
286+
): Term = {
287+
def handleCase(caseTree: Case) = caseTree match {
288+
case Case(Pat.Var(Term.Name(patVarName)), _, _)
289+
if patVarName == nameToReplace.value =>
290+
caseTree // New scope, outer nameToReplace is unreachable
291+
case Case(Pat.Typed(Pat.Var(Term.Name(patVarName)), _), _, _)
292+
if patVarName == nameToReplace.value =>
293+
caseTree // New scope, outer nameToReplace is unreachable
294+
case Case(Pat.Extract(_, argClause), _, _)
295+
if !argClause.forall(_.isInstanceOf[Pat.Var]) =>
296+
throw new IllegalStateException("Too complex to handle")
297+
case Case(Pat.Extract(_, argClause), _, _) if argClause.exists {
298+
case Pat.Var(Term.Name(patVarName)) =>
299+
patVarName == nameToReplace.value
300+
} =>
301+
caseTree // New scope, outer nameToReplace is unreachable
302+
case Case(pat, perhapsGuard, body) =>
303+
Case(
304+
pat,
305+
perhapsGuard.map(replaceName(_, newName)),
306+
replaceName(body, newName),
307+
)
308+
}
309+
310+
tree match {
311+
case apply @ Term.Apply(fun, args) =>
312+
val newFun = replaceName(fun, newName)
313+
val newArgs = args.map(replaceName(_, newName))
314+
Term.Apply(newFun, Term.ArgClause(newArgs, apply.argClause.mod))
315+
316+
case apply @ Term.ApplyInfix(lhs, op, targs, args) =>
317+
val newLHS = replaceName(lhs, newName)
318+
val newArgs = args.map(replaceName(_, newName))
319+
Term.ApplyInfix(
320+
newLHS,
321+
op,
322+
Type.ArgClause(targs),
323+
Term.ArgClause(newArgs, apply.argClause.mod),
324+
)
325+
326+
case Term.Select(qual, name) =>
327+
Term.Select(replaceName(qual, newName), name)
328+
329+
case Term.Block(stats) =>
330+
Term.Block(stats.map {
331+
case term: Term => replaceName(term, newName)
332+
case _: Stat =>
333+
throw new IllegalStateException("Too complex to handle")
334+
})
335+
336+
case termFor @ Term.For(enums, yieldExpr) =>
337+
val anyNameOverlaps = enums.exists {
338+
case Enumerator.Assign(Pat.Var(Term.Name(name)), _) =>
339+
name == nameToReplace.value
340+
case _ => false
341+
}
342+
if (anyNameOverlaps)
343+
termFor // New scope, outer nameToReplace is unreachable
344+
else
345+
Term.For(
346+
enums.map {
347+
case Enumerator.Val(pat, rhs) =>
348+
Enumerator.Val(pat, replaceName(rhs, newName))
349+
case Enumerator.Guard(term) =>
350+
Enumerator.Guard(replaceName(term, newName))
351+
case Enumerator.Generator(pat, rhs) =>
352+
Enumerator.Generator(pat, replaceName(rhs, newName))
353+
case Enumerator.CaseGenerator(pat, rhs) =>
354+
Enumerator.CaseGenerator(pat, replaceName(rhs, newName))
355+
case other => other
356+
},
357+
replaceName(yieldExpr, newName),
358+
)
359+
360+
case Term.If(cond, thenTerm, elseTerm) =>
361+
Term.If(
362+
replaceName(cond, newName),
363+
replaceName(thenTerm, newName),
364+
replaceName(elseTerm, newName),
365+
)
366+
367+
case Term.Assign(lhs @ Term.Name(_), rhs) =>
368+
Term.Assign(
369+
replaceName(lhs, newName),
370+
replaceName(rhs, newName),
371+
)
372+
373+
case Term.AnonymousFunction(term) =>
374+
Term.AnonymousFunction(replaceName(term, newName))
375+
376+
case Term.PartialFunction(cases) =>
377+
Term.PartialFunction(cases.map(handleCase))
378+
379+
case Term.Match(term, cases) =>
380+
Term.Match(
381+
replaceName(term, newName),
382+
cases.map(handleCase),
383+
)
384+
385+
case Term.Try(expr, catchClause, finallyClause) =>
386+
Term.Try(
387+
replaceName(expr, newName),
388+
catchClause.map(handleCase),
389+
finallyClause.map(replaceName(_, newName)),
390+
)
391+
392+
case Term.Throw(term) =>
393+
Term.Throw(replaceName(term, newName))
394+
395+
case Term.New(Init(tpe, name, termsNested)) =>
396+
Term.New(
397+
Init(tpe, name, termsNested.map(_.map(replaceName(_, newName))))
398+
)
399+
400+
case Term.Interpolate(name, literals, terms) =>
401+
Term.Interpolate(
402+
name,
403+
literals,
404+
terms.map(replaceName(_, newName)),
405+
)
406+
407+
case Term.Name(name) if nameToReplace.value == name =>
408+
newName
409+
410+
case other => other
411+
}
412+
}
413+
414+
val newName = nameGenerator.createNewName()
415+
Try(replaceName(term, Term.Name(newName)))
416+
.map(newTerm => (Pat.Var(Term.Name(newName)), newTerm))
417+
.toOption
418+
}
419+
276420
private def isSimple(pat: Pat): Boolean = { // this is to decide whether to
277-
// put pat in the left side of an Enumerator
421+
// put pat on the left side of an Enumerator
278422
pat match {
279423
case _: Pat.Extract | _: Pat.ExtractInfix | _: Pat.Interpolate | _: Lit |
280424
_: Term.Name | _: Pat.Typed | _: Pat.Var =>
@@ -286,6 +430,7 @@ class FlatMapToForComprehensionCodeAction(
286430
}
287431
}
288432

433+
@tailrec
289434
private def processPatAndNextQual(
290435
tree: Tree,
291436
nameGenerator: MetalsNames,
@@ -296,7 +441,12 @@ class FlatMapToForComprehensionCodeAction(
296441
val newName = nameGenerator.createNewName()
297442
Some(Pat.Var(Term.Name(newName)), term)
298443
case Term.Function(List(param), term) =>
299-
Some(Pat.Var(Term.Name(param.name.value)), term)
444+
if (nameGenerator.isNameEncountered(param.name.value)) {
445+
replaceNameInTermWithNewName(term, nameGenerator, param.name)
446+
} else {
447+
nameGenerator.recordNameEncountered(param.name.value)
448+
Some(Pat.Var(Term.Name(param.name.value)), term)
449+
}
300450
case Term.AnonymousFunction(term) =>
301451
replacePlaceHolderInTermWithNewName(term, nameGenerator)
302452
case term: Term.Eta =>
@@ -314,10 +464,6 @@ class FlatMapToForComprehensionCodeAction(
314464
Pat.Var(Term.Name(newName)),
315465
Term.Apply(term, List(Term.Name(newName))),
316466
)
317-
Some(
318-
Pat.Var(Term.Name(newName)),
319-
Term.Apply(term, List(Term.Name(newName))),
320-
)
321467
case _ => None
322468
}
323469
}
@@ -329,14 +475,14 @@ class FlatMapToForComprehensionCodeAction(
329475
*
330476
* @param nameGenerator the stateful mutable name generator object for
331477
* creating a new Metals generated name in each call.
332-
* @param perhapsLastName paramName from previous iteration
478+
* @param perhapsLastPat pat from previous iteration
333479
* in `list.map(x => x + 1).flatMap(b => Some(b - 1))`,
334-
* if we are now processing `map`, it would be `b``
480+
* if we are now processing `map`, it would be `b`
335481
* @param shouldFlat is it map or flatMap
336482
* @param existingForElements list of enumerators obtained from previous iterations
337483
* @param maybeCurrentYieldTerm the yield term from previous iterations if they
338-
* existed or `None``
339-
* @param nextQual in `list.map(x => x + 1)`, it is `x + 1``
484+
* existed or `None`
485+
* @param nextQual in `list.map(x => x + 1)`, it is `x + 1`
340486
* @return (the list of deducted enumerators, maybe the deducted yield term)
341487
*/
342488
private def obtainNextYieldAndElemsForMap(
@@ -358,7 +504,7 @@ class FlatMapToForComprehensionCodeAction(
358504
nextQual,
359505
)
360506
} else
361-
Enumerator.Val( // when it is map
507+
Enumerator.Val( // when it is map,
362508
// it is lastName = nextQual
363509
lastPat,
364510
nextQual,
@@ -565,7 +711,7 @@ class FlatMapToForComprehensionCodeAction(
565711
* .filter(s => s > 7)
566712
* }}}
567713
* <p>if it had traversed `filter` in the previous iteration, the value of
568-
* `perhapseLastName` would be `s` which would be passed as argument when
714+
* `perhapsLastPat` would be `s` which would be passed as argument when
569715
* we are passing `map` as the termApply. Also, `s` itself would be the value
570716
* of the so far extracted `maybeCurrentYieldTerm`, because of `filter`.
571717
*
@@ -579,11 +725,10 @@ class FlatMapToForComprehensionCodeAction(
579725
* case `List(1, 2, 3)`, or it is to be paired with the qual extracted from the
580726
* next iteration, in case, there is a `map`/`flatMap` before it.
581727
*
582-
* @param perhapsLastName the param name extracted from the termApply
583-
* in the last iteration
584-
* @param currentYieldTerm the so far extraxcted yield term from the previous iterations
585-
* @param existingForElements
586-
* @param termApply the termApply to be traveresed in this iteration
728+
* @param perhapsLastPat the pat extracted from the termApply in the last iteration
729+
* @param currentYieldTerm the so far extracted yield term from the previous iterations
730+
* @param existingForElements Tail of already processed enumerators
731+
* @param termApply the termApply to be traversed in this iteration
587732
* @param nameGenerator a stateful mutable object which is used for creating
588733
* non-overlapping
589734
* names for the anonymous parameters/placeholders of
@@ -599,7 +744,7 @@ class FlatMapToForComprehensionCodeAction(
599744
termApply: Term.Apply,
600745
nameGenerator: MetalsNames,
601746
): (List[Enumerator], Option[Term]) = {
602-
val perhapsValueNameAndNextQual = termApply.args.headOption.flatMap {
747+
def perhapsValueNameAndNextQual = termApply.args.headOption.flatMap {
603748
processPatAndNextQual(
604749
_,
605750
nameGenerator,

metals/src/main/scala/scala/meta/internal/metals/codeactions/MetalsNames.scala

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,15 @@
11
package scala.meta.internal.metals.codeactions
22

33
import scala.annotation.tailrec
4+
import scala.collection.mutable
45
import scala.collection.mutable.ListBuffer
5-
6-
import scala.meta._
7-
import scala.meta.internal.mtags.MtagsEnrichments._
6+
import scala.meta.*
7+
import scala.meta.internal.mtags.MtagsEnrichments.*
88

99
case class MetalsNames(tree: Tree, prefix: String) {
1010

11+
private val seenNames = mutable.Set.empty[String]
12+
1113
private lazy val allNames = {
1214
val top = lastEnclosingStatsTree(tree)
1315
findAllNames(top)
@@ -50,4 +52,12 @@ case class MetalsNames(tree: Tree, prefix: String) {
5052
if (allNames(name)) createNewName()
5153
else name
5254
}
55+
56+
def recordNameEncountered(name: String): Unit = {
57+
seenNames += name
58+
}
59+
60+
def isNameEncountered(name: String): Boolean = {
61+
seenNames(name)
62+
}
5363
}

tests/unit/src/test/scala/tests/codeactions/FilterMapToCollectCodeActionSuite.scala

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -236,7 +236,6 @@ class FilterMapToCollectCodeActionSuite
236236
|}
237237
|""".stripMargin,
238238
s"""|${RewriteBracesParensCodeAction.toParens("map")}
239-
|${FlatMapToForComprehensionCodeAction.flatMapToForComprehension}
240239
|${FilterMapToCollectCodeAction.title}
241240
|""".stripMargin,
242241
"""|object Main {
@@ -250,7 +249,7 @@ class FilterMapToCollectCodeActionSuite
250249
| }
251250
|}
252251
|""".stripMargin,
253-
selectedActionIndex = 2,
252+
selectedActionIndex = 1,
254253
)
255254

256255
check(

0 commit comments

Comments
 (0)