Skip to content

Commit 2bc493b

Browse files
committed
Implement InlineBinding and DeadBindingElimination
1 parent 76963a6 commit 2bc493b

24 files changed

Lines changed: 740 additions & 1378 deletions
Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
{-# LANGUAGE GADTs #-}
2+
module LambdaComp.CBPV.Optimization.DeadBindingElimination
3+
( runDeadLetElimination
4+
) where
5+
6+
import Control.Monad.Writer.CPS (MonadWriter (tell), Writer, censor, runWriter, listens)
7+
import Data.Set (Set)
8+
import Data.Set qualified as Set
9+
10+
import LambdaComp.CBPV.Syntax
11+
import Control.Applicative (liftA3)
12+
import Data.Functor ((<&>))
13+
14+
runDeadLetElimination :: Tm Val -> Tm Val
15+
runDeadLetElimination = fst . runWriter . deadLetElimination
16+
17+
type WithFreeVars = Writer (Set Ident)
18+
19+
deadLetElimination :: Tm c -> WithFreeVars (Tm c)
20+
deadLetElimination tm@(TmVar x) = tm <$ tell (Set.singleton x)
21+
deadLetElimination tm@TmUnit = pure tm
22+
deadLetElimination tm@TmTrue = pure tm
23+
deadLetElimination tm@TmFalse = pure tm
24+
deadLetElimination tm@(TmInt _) = pure tm
25+
deadLetElimination tm@(TmDouble _) = pure tm
26+
deadLetElimination (TmThunk tm) = TmThunk <$> deadLetElimination tm
27+
deadLetElimination (TmIf tm0 tm1 tm2) = liftA3 TmIf (deadLetElimination tm0) (deadLetElimination tm1) (deadLetElimination tm2)
28+
deadLetElimination (TmLam p tm) = TmLam p <$> without (paramName p) (deadLetElimination tm)
29+
deadLetElimination (tmf `TmApp` tma) = liftA2 TmApp (deadLetElimination tmf) (deadLetElimination tma)
30+
deadLetElimination (TmForce tm) = TmForce <$> deadLetElimination tm
31+
deadLetElimination (TmReturn tm) = TmReturn <$> deadLetElimination tm
32+
deadLetElimination (TmTo tm0 x tm1) = liftA3 TmTo (deadLetElimination tm0) (pure x) (without x $ deadLetElimination tm1)
33+
deadLetElimination (TmLet x tm0 tm1) = do
34+
(tm1', withX) <- without x $ listens (x `Set.member`) $ deadLetElimination tm1
35+
if withX
36+
then TmLet x <$> deadLetElimination tm0 <&> ($ tm1')
37+
else pure tm1'
38+
deadLetElimination (TmPrimBinOp op tm0 tm1) = liftA2 (TmPrimBinOp op) (deadLetElimination tm0) (deadLetElimination tm1)
39+
deadLetElimination (TmPrimUnOp op tm) = TmPrimUnOp op <$> deadLetElimination tm
40+
deadLetElimination (TmPrintInt tm0 tm1) = liftA2 TmPrintInt (deadLetElimination tm0) (deadLetElimination tm1)
41+
deadLetElimination (TmPrintDouble tm0 tm1) = liftA2 TmPrintDouble (deadLetElimination tm0) (deadLetElimination tm1)
42+
deadLetElimination (TmRec p tm) = TmRec p <$> without (paramName p) (deadLetElimination tm)
43+
44+
without :: Ident -> WithFreeVars a -> WithFreeVars a
45+
without x = censor (Set.delete x)
Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
{-# LANGUAGE GADTs #-}
2+
module LambdaComp.CBPV.Optimization.InlineBinding
3+
( runInlineSimpleLet
4+
) where
5+
6+
import Control.Applicative (liftA3)
7+
import Control.Monad.Reader (MonadReader (local), Reader, asks, runReader)
8+
import Data.Map.Strict (Map)
9+
import Data.Map.Strict qualified as Map
10+
import Data.Maybe (fromMaybe)
11+
12+
import LambdaComp.CBPV.Syntax
13+
14+
runInlineSimpleLet :: Tm c -> Tm c
15+
runInlineSimpleLet = (`runReader` Map.empty) . inlineSimpleLet
16+
17+
type WithSimpleBinding = Reader (Map Ident (Tm Val))
18+
19+
inlineSimpleLet :: Tm c -> WithSimpleBinding (Tm c)
20+
inlineSimpleLet tm@(TmVar x) = do
21+
mayTm' <- asks (Map.!? x)
22+
pure $ fromMaybe tm mayTm'
23+
inlineSimpleLet tm@TmUnit = pure tm
24+
inlineSimpleLet tm@TmTrue = pure tm
25+
inlineSimpleLet tm@TmFalse = pure tm
26+
inlineSimpleLet tm@(TmInt _) = pure tm
27+
inlineSimpleLet tm@(TmDouble _) = pure tm
28+
inlineSimpleLet (TmThunk tm) = TmThunk <$> inlineSimpleLet tm
29+
inlineSimpleLet (TmIf tm0 tm1 tm2) = liftA3 TmIf (inlineSimpleLet tm0) (inlineSimpleLet tm1) (inlineSimpleLet tm2)
30+
inlineSimpleLet (TmLam p tm) = TmLam p <$> local (Map.insert (paramName p) (TmVar (paramName p))) (inlineSimpleLet tm)
31+
inlineSimpleLet (tmf `TmApp` tma) = liftA2 TmApp (inlineSimpleLet tmf) (inlineSimpleLet tma)
32+
inlineSimpleLet (TmForce tm) = TmForce <$> inlineSimpleLet tm
33+
inlineSimpleLet (TmReturn tm) = TmReturn <$> inlineSimpleLet tm
34+
inlineSimpleLet (TmTo tm0 x tm1) = liftA3 TmTo (inlineSimpleLet tm0) (pure x) (local (Map.insert x (TmVar x)) $ inlineSimpleLet tm1)
35+
inlineSimpleLet (TmLet x tm0 tm1) = liftA2 (TmLet x) (inlineSimpleLet tm0) (local (Map.insert x xBinding) $ inlineSimpleLet tm1)
36+
where
37+
xBinding
38+
| isSimpleTm x tm0 = tm0
39+
| otherwise = TmVar x
40+
inlineSimpleLet (TmPrimBinOp op tm0 tm1) = liftA2 (TmPrimBinOp op) (inlineSimpleLet tm0) (inlineSimpleLet tm1)
41+
inlineSimpleLet (TmPrimUnOp op tm) = TmPrimUnOp op <$> inlineSimpleLet tm
42+
inlineSimpleLet (TmPrintInt tm0 tm1) = liftA2 TmPrintInt (inlineSimpleLet tm0) (inlineSimpleLet tm1)
43+
inlineSimpleLet (TmPrintDouble tm0 tm1) = liftA2 TmPrintDouble (inlineSimpleLet tm0) (inlineSimpleLet tm1)
44+
inlineSimpleLet (TmRec p tm) = TmRec p <$> local (Map.insert (paramName p) (TmVar (paramName p))) (inlineSimpleLet tm)
45+
46+
isSimpleTm :: Ident -> Tm Val -> Bool
47+
isSimpleTm x (TmVar y) = x /= y
48+
isSimpleTm _ TmUnit = True
49+
isSimpleTm _ TmTrue = True
50+
isSimpleTm _ TmFalse = True
51+
isSimpleTm _ (TmInt _) = True
52+
isSimpleTm _ (TmDouble _) = True
53+
isSimpleTm _ (TmThunk _) = False

src/LambdaComp/CBPV/Optimization/Local.hs

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,12 +2,14 @@ module LambdaComp.CBPV.Optimization.Local
22
( runLocalOptDefault
33
) where
44

5-
import LambdaComp.CBPV.Optimization.BindingConversion
6-
import LambdaComp.CBPV.Optimization.SkipReturn
5+
import LambdaComp.CBPV.Optimization.BindingConversion (runCommutingTo, runLiftingLet)
6+
import LambdaComp.CBPV.Optimization.DeadBindingElimination (runDeadLetElimination)
7+
import LambdaComp.CBPV.Optimization.InlineBinding (runInlineSimpleLet)
8+
import LambdaComp.CBPV.Optimization.SkipReturn (runSkipReturn)
79
import LambdaComp.CBPV.Syntax
810

911
runLocalOptDefault :: Program -> Program
1012
runLocalOptDefault = fmap runLocalOptDefaultTm
1113

1214
runLocalOptDefaultTm :: Tm Val -> Tm Val
13-
runLocalOptDefaultTm = runLiftingLet . runSkipReturn . runCommutingThen
15+
runLocalOptDefaultTm = runDeadLetElimination . runInlineSimpleLet . runLiftingLet . runSkipReturn . runCommutingTo

test/golden/Constant.lc.am.code.gen

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -5,13 +5,8 @@
55
, ThunkCodeSection
66
{ thunkCodeSectionName = "sys_thunk_0"
77
, thunkCode =
8-
[ IAssign
9-
( AIdent "var_c_v_0" )
8+
[ IPrintInt
109
( VaInt 5 )
11-
, IPrintInt
12-
( VaAddr
13-
( AIdent "var_c_v_0" )
14-
)
1510
, ISetReturn
1611
( VaInt 0 )
1712
, IExit

test/golden/Constant.lc.c.code.gen

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,8 @@ item var_u_main;
55

66
void sys_thunk_1(item *const _, item *const ret)
77
{
8-
/* TmLet "c_v_0" (TmInt 5) (TmPrintInt (TmVar "c_v_0") (TmReturn (TmInt 0))) */
9-
const item var_c_v_0 = {.int_item = 5};
10-
const item sys_msg_0 = var_c_v_0;
8+
/* TmPrintInt (TmInt 5) (TmReturn (TmInt 0)) */
9+
const item sys_msg_0 = {.int_item = 5};
1110
printf("%d\n", sys_msg_0.int_item);
1211
(*ret).int_item = 0;
1312
}
@@ -18,7 +17,7 @@ int main(void)
1817
item retv;
1918
{
2019
item *const ret = &retv;
21-
/* TmThunk (TmLet "c_v_0" (TmInt 5) (TmPrintInt (TmVar "c_v_0") (TmReturn (TmInt 0)))) */
20+
/* TmThunk (TmPrintInt (TmInt 5) (TmReturn (TmInt 0))) */
2221
var_u_main.thunk_item.code = sys_thunk_1;
2322
var_u_main.thunk_item.env = NULL;
2423
var_u_main.thunk_item.code(var_u_main.thunk_item.env, ret);

test/golden/Constant.lc.cbpv.opt

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2,13 +2,10 @@ fromList
22
[
33
( "u_main"
44
, TmThunk
5-
( TmLet "c_v_0"
5+
( TmPrintInt
66
( TmInt 5 )
7-
( TmPrintInt
8-
( TmVar "c_v_0" )
9-
( TmReturn
10-
( TmInt 0 )
11-
)
7+
( TmReturn
8+
( TmInt 0 )
129
)
1310
)
1411
)

0 commit comments

Comments
 (0)