11package scala .meta .internal .metals .codeactions
22
3+ import org .eclipse .lsp4j as l
4+ import org .eclipse .lsp4j .CodeActionParams
5+
36import scala .annotation .tailrec
4- import scala .concurrent .ExecutionContext
5- import scala .concurrent .Future
6-
7- import scala .meta .Case
8- import scala .meta .Enumerator
9- import scala .meta .Lit
10- import scala .meta .Name
11- import scala .meta .Pat
12- import scala .meta .Term
13- import scala .meta .Tree
14- import scala .meta .Type
15- import scala .meta .XtensionClassifiable
16- import scala .meta .XtensionSyntax
7+ import scala .concurrent .{ExecutionContext , Future }
178import scala .meta .inputs .Position
189import scala .meta .internal .metals .Buffers
19- import scala .meta .internal .metals .MetalsEnrichments ._
20- import scala .meta .internal .metals .codeactions .CodeAction
21- import scala .meta .internal .metals .codeactions .CodeActionBuilder
10+ import scala .meta .internal .metals .MetalsEnrichments .*
2211import scala .meta .internal .parsing .Trees
2312import scala .meta .io .AbsolutePath
2413import scala .meta .pc .CancelToken
25-
26- import org .eclipse .lsp4j .CodeActionParams
27- import org .eclipse .{lsp4j => l }
14+ import scala .meta .{
15+ Case ,
16+ Enumerator ,
17+ Init ,
18+ Lit ,
19+ Name ,
20+ Pat ,
21+ Stat ,
22+ Term ,
23+ Tree ,
24+ Type ,
25+ XtensionClassifiable ,
26+ XtensionSyntax ,
27+ }
28+ import scala .util .Try
2829
2930class FlatMapToForComprehensionCodeAction (
3031 trees : Trees ,
@@ -138,14 +139,15 @@ class FlatMapToForComprehensionCodeAction(
138139 * .filter(_ > 3)
139140 * }}}
140141 * <p>Now when the cursor is on `map`, we want to start the conversion
141- * on `filter`` instead, which is the parentMost or `outerMost` apply.
142+ * on `filter` instead, which is the parentMost or `outerMost` apply.
143+ *
142144 * @param currentTermApply the termApply on which the cursor is
143145 * when invoking the code action
144146 * @param lastValidTermApply the last inner [[Term.Apply ]] from the previous
145147 * iteration which had one of the functions of
146148 * `map`, `flatMap`, `filter`, `filterNot`, or `withFilter`
147149 * in its [[Term.Select ]]
148- * @return the `lastValidTermApply`` if the `currenTermApply` does not have
150+ * @return the `lastValidTermApply` if the `currenTermApply` does not have
149151 * an interesting function. Otherwise, the currentTermApply.
150152 */
151153 @ tailrec
@@ -273,8 +275,150 @@ class FlatMapToForComprehensionCodeAction(
273275 else None
274276 }
275277
278+ private def replaceNameInTermWithNewName (
279+ term : Term ,
280+ nameGenerator : MetalsNames ,
281+ nameToReplace : Name ,
282+ ): Option [(Pat , Term )] = {
283+ def replaceName (
284+ tree : Term ,
285+ newName : Term .Name ,
286+ ): Term = {
287+ def handleCase (caseTree : Case ) = caseTree match {
288+ case Case (Pat .Var (Term .Name (patVarName)), _, _)
289+ if patVarName == nameToReplace.value =>
290+ caseTree // New scope, outer nameToReplace is unreachable
291+ case Case (Pat .Typed (Pat .Var (Term .Name (patVarName)), _), _, _)
292+ if patVarName == nameToReplace.value =>
293+ caseTree // New scope, outer nameToReplace is unreachable
294+ case Case (Pat .Extract (_, argClause), _, _)
295+ if ! argClause.forall(_.isInstanceOf [Pat .Var ]) =>
296+ throw new IllegalStateException (" Too complex to handle" )
297+ case Case (Pat .Extract (_, argClause), _, _) if argClause.exists {
298+ case Pat .Var (Term .Name (patVarName)) =>
299+ patVarName == nameToReplace.value
300+ } =>
301+ caseTree // New scope, outer nameToReplace is unreachable
302+ case Case (pat, perhapsGuard, body) =>
303+ Case (
304+ pat,
305+ perhapsGuard.map(replaceName(_, newName)),
306+ replaceName(body, newName),
307+ )
308+ }
309+
310+ tree match {
311+ case apply @ Term .Apply (fun, args) =>
312+ val newFun = replaceName(fun, newName)
313+ val newArgs = args.map(replaceName(_, newName))
314+ Term .Apply (newFun, Term .ArgClause (newArgs, apply.argClause.mod))
315+
316+ case apply @ Term .ApplyInfix (lhs, op, targs, args) =>
317+ val newLHS = replaceName(lhs, newName)
318+ val newArgs = args.map(replaceName(_, newName))
319+ Term .ApplyInfix (
320+ newLHS,
321+ op,
322+ Type .ArgClause (targs),
323+ Term .ArgClause (newArgs, apply.argClause.mod),
324+ )
325+
326+ case Term .Select (qual, name) =>
327+ Term .Select (replaceName(qual, newName), name)
328+
329+ case Term .Block (stats) =>
330+ Term .Block (stats.map {
331+ case term : Term => replaceName(term, newName)
332+ case _ : Stat =>
333+ throw new IllegalStateException (" Too complex to handle" )
334+ })
335+
336+ case termFor @ Term .For (enums, yieldExpr) =>
337+ val anyNameOverlaps = enums.exists {
338+ case Enumerator .Assign (Pat .Var (Term .Name (name)), _) =>
339+ name == nameToReplace.value
340+ case _ => false
341+ }
342+ if (anyNameOverlaps)
343+ termFor // New scope, outer nameToReplace is unreachable
344+ else
345+ Term .For (
346+ enums.map {
347+ case Enumerator .Val (pat, rhs) =>
348+ Enumerator .Val (pat, replaceName(rhs, newName))
349+ case Enumerator .Guard (term) =>
350+ Enumerator .Guard (replaceName(term, newName))
351+ case Enumerator .Generator (pat, rhs) =>
352+ Enumerator .Generator (pat, replaceName(rhs, newName))
353+ case Enumerator .CaseGenerator (pat, rhs) =>
354+ Enumerator .CaseGenerator (pat, replaceName(rhs, newName))
355+ case other => other
356+ },
357+ replaceName(yieldExpr, newName),
358+ )
359+
360+ case Term .If (cond, thenTerm, elseTerm) =>
361+ Term .If (
362+ replaceName(cond, newName),
363+ replaceName(thenTerm, newName),
364+ replaceName(elseTerm, newName),
365+ )
366+
367+ case Term .Assign (lhs @ Term .Name (_), rhs) =>
368+ Term .Assign (
369+ replaceName(lhs, newName),
370+ replaceName(rhs, newName),
371+ )
372+
373+ case Term .AnonymousFunction (term) =>
374+ Term .AnonymousFunction (replaceName(term, newName))
375+
376+ case Term .PartialFunction (cases) =>
377+ Term .PartialFunction (cases.map(handleCase))
378+
379+ case Term .Match (term, cases) =>
380+ Term .Match (
381+ replaceName(term, newName),
382+ cases.map(handleCase),
383+ )
384+
385+ case Term .Try (expr, catchClause, finallyClause) =>
386+ Term .Try (
387+ replaceName(expr, newName),
388+ catchClause.map(handleCase),
389+ finallyClause.map(replaceName(_, newName)),
390+ )
391+
392+ case Term .Throw (term) =>
393+ Term .Throw (replaceName(term, newName))
394+
395+ case Term .New (Init (tpe, name, termsNested)) =>
396+ Term .New (
397+ Init (tpe, name, termsNested.map(_.map(replaceName(_, newName))))
398+ )
399+
400+ case Term .Interpolate (name, literals, terms) =>
401+ Term .Interpolate (
402+ name,
403+ literals,
404+ terms.map(replaceName(_, newName)),
405+ )
406+
407+ case Term .Name (name) if nameToReplace.value == name =>
408+ newName
409+
410+ case other => other
411+ }
412+ }
413+
414+ val newName = nameGenerator.createNewName()
415+ Try (replaceName(term, Term .Name (newName)))
416+ .map(newTerm => (Pat .Var (Term .Name (newName)), newTerm))
417+ .toOption
418+ }
419+
276420 private def isSimple (pat : Pat ): Boolean = { // this is to decide whether to
277- // put pat in the left side of an Enumerator
421+ // put pat on the left side of an Enumerator
278422 pat match {
279423 case _ : Pat .Extract | _ : Pat .ExtractInfix | _ : Pat .Interpolate | _ : Lit |
280424 _ : Term .Name | _ : Pat .Typed | _ : Pat .Var =>
@@ -286,6 +430,7 @@ class FlatMapToForComprehensionCodeAction(
286430 }
287431 }
288432
433+ @ tailrec
289434 private def processPatAndNextQual (
290435 tree : Tree ,
291436 nameGenerator : MetalsNames ,
@@ -296,7 +441,12 @@ class FlatMapToForComprehensionCodeAction(
296441 val newName = nameGenerator.createNewName()
297442 Some (Pat .Var (Term .Name (newName)), term)
298443 case Term .Function (List (param), term) =>
299- Some (Pat .Var (Term .Name (param.name.value)), term)
444+ if (nameGenerator.isNameEncountered(param.name.value)) {
445+ replaceNameInTermWithNewName(term, nameGenerator, param.name)
446+ } else {
447+ nameGenerator.recordNameEncountered(param.name.value)
448+ Some (Pat .Var (Term .Name (param.name.value)), term)
449+ }
300450 case Term .AnonymousFunction (term) =>
301451 replacePlaceHolderInTermWithNewName(term, nameGenerator)
302452 case term : Term .Eta =>
@@ -314,10 +464,6 @@ class FlatMapToForComprehensionCodeAction(
314464 Pat .Var (Term .Name (newName)),
315465 Term .Apply (term, List (Term .Name (newName))),
316466 )
317- Some (
318- Pat .Var (Term .Name (newName)),
319- Term .Apply (term, List (Term .Name (newName))),
320- )
321467 case _ => None
322468 }
323469 }
@@ -329,14 +475,14 @@ class FlatMapToForComprehensionCodeAction(
329475 *
330476 * @param nameGenerator the stateful mutable name generator object for
331477 * creating a new Metals generated name in each call.
332- * @param perhapsLastName paramName from previous iteration
478+ * @param perhapsLastPat pat from previous iteration
333479 * in `list.map(x => x + 1).flatMap(b => Some(b - 1))`,
334- * if we are now processing `map`, it would be `b` `
480+ * if we are now processing `map`, it would be `b`
335481 * @param shouldFlat is it map or flatMap
336482 * @param existingForElements list of enumerators obtained from previous iterations
337483 * @param maybeCurrentYieldTerm the yield term from previous iterations if they
338- * existed or `None` `
339- * @param nextQual in `list.map(x => x + 1)`, it is `x + 1``
484+ * existed or `None`
485+ * @param nextQual in `list.map(x => x + 1)`, it is `x + 1`
340486 * @return (the list of deducted enumerators, maybe the deducted yield term)
341487 */
342488 private def obtainNextYieldAndElemsForMap (
@@ -358,7 +504,7 @@ class FlatMapToForComprehensionCodeAction(
358504 nextQual,
359505 )
360506 } else
361- Enumerator .Val ( // when it is map
507+ Enumerator .Val ( // when it is map,
362508 // it is lastName = nextQual
363509 lastPat,
364510 nextQual,
@@ -565,7 +711,7 @@ class FlatMapToForComprehensionCodeAction(
565711 * .filter(s => s > 7)
566712 * }}}
567713 * <p>if it had traversed `filter` in the previous iteration, the value of
568- * `perhapseLastName ` would be `s` which would be passed as argument when
714+ * `perhapsLastPat ` would be `s` which would be passed as argument when
569715 * we are passing `map` as the termApply. Also, `s` itself would be the value
570716 * of the so far extracted `maybeCurrentYieldTerm`, because of `filter`.
571717 *
@@ -579,11 +725,10 @@ class FlatMapToForComprehensionCodeAction(
579725 * case `List(1, 2, 3)`, or it is to be paired with the qual extracted from the
580726 * next iteration, in case, there is a `map`/`flatMap` before it.
581727 *
582- * @param perhapsLastName the param name extracted from the termApply
583- * in the last iteration
584- * @param currentYieldTerm the so far extraxcted yield term from the previous iterations
585- * @param existingForElements
586- * @param termApply the termApply to be traveresed in this iteration
728+ * @param perhapsLastPat the pat extracted from the termApply in the last iteration
729+ * @param currentYieldTerm the so far extracted yield term from the previous iterations
730+ * @param existingForElements Tail of already processed enumerators
731+ * @param termApply the termApply to be traversed in this iteration
587732 * @param nameGenerator a stateful mutable object which is used for creating
588733 * non-overlapping
589734 * names for the anonymous parameters/placeholders of
@@ -599,7 +744,7 @@ class FlatMapToForComprehensionCodeAction(
599744 termApply : Term .Apply ,
600745 nameGenerator : MetalsNames ,
601746 ): (List [Enumerator ], Option [Term ]) = {
602- val perhapsValueNameAndNextQual = termApply.args.headOption.flatMap {
747+ def perhapsValueNameAndNextQual = termApply.args.headOption.flatMap {
603748 processPatAndNextQual(
604749 _,
605750 nameGenerator,
0 commit comments