Traverse.hs 46.7 KB
Newer Older
1 2 3 4 5 6 7 8 9
{- sv2v
 - Author: Zachary Snow <zach@zachjs.com>
 -
 - Utilities for traversing AST transformations.
 -}

module Convert.Traverse
( MapperM
, Mapper
10 11
, CollectorM
, TFStrategy (..)
12
, unmonad
13
, collectify
14 15
, traverseDescriptionsM
, traverseDescriptions
16
, collectDescriptionsM
17 18
, traverseModuleItemsM
, traverseModuleItems
19
, collectModuleItemsM
20 21
, traverseStmtsM
, traverseStmts
22
, collectStmtsM
23 24 25
, traverseStmtsM'
, traverseStmts'
, collectStmtsM'
26 27 28
, traverseStmtLHSsM
, traverseStmtLHSs
, collectStmtLHSsM
29 30 31
, traverseExprsM
, traverseExprs
, collectExprsM
32 33 34
, traverseExprsM'
, traverseExprs'
, collectExprsM'
35 36 37
, traverseStmtExprsM
, traverseStmtExprs
, collectStmtExprsM
38 39 40
, traverseLHSsM
, traverseLHSs
, collectLHSsM
41 42 43
, traverseLHSsM'
, traverseLHSs'
, collectLHSsM'
44 45 46
, traverseDeclsM
, traverseDecls
, collectDeclsM
47 48 49
, traverseDeclsM'
, traverseDecls'
, collectDeclsM'
50 51 52
, traverseNestedTypesM
, traverseNestedTypes
, collectNestedTypesM
53 54 55
, traverseTypesM
, traverseTypes
, collectTypesM
56 57 58
, traverseGenItemsM
, traverseGenItems
, collectGenItemsM
59 60 61
, traverseAsgnsM
, traverseAsgns
, collectAsgnsM
62 63 64
, traverseAsgnsM'
, traverseAsgns'
, collectAsgnsM'
65 66 67
, traverseStmtAsgnsM
, traverseStmtAsgns
, collectStmtAsgnsM
68 69 70
, traverseNestedModuleItemsM
, traverseNestedModuleItems
, collectNestedModuleItemsM
71
, traverseNestedStmts
72
, collectNestedStmtsM
73
, traverseNestedExprsM
74 75 76 77 78
, traverseNestedExprs
, collectNestedExprsM
, traverseNestedLHSsM
, traverseNestedLHSs
, collectNestedLHSsM
79
, traverseScopesM
80
, scopedConversion
81
, scopedConversionM
82
, stately
83
, traverseFilesM
84
, traverseFiles
85 86
) where

87
import Data.Functor.Identity (runIdentity)
88
import Control.Monad.State
89
import Control.Monad.Writer
90 91
import Language.SystemVerilog.AST

92
type MapperM m t = t -> m t
93
type Mapper t = t -> t
94
type CollectorM m t = t -> m ()
95

96 97 98 99 100
data TFStrategy
    = IncludeTFs
    | ExcludeTFs
    deriving Eq

101
unmonad :: (MapperM (State ()) a -> MapperM (State ()) b) -> Mapper a -> Mapper b
102 103 104
unmonad traverser mapper thing =
    evalState (traverser (return . mapper) thing) ()

105 106 107 108 109 110
collectify :: Monad m => (MapperM m a -> MapperM m b) -> CollectorM m a -> CollectorM m b
collectify traverser collector thing =
    traverser mapper thing >>= \_ -> return ()
    where mapper x = collector x >>= \() -> return x

traverseDescriptionsM :: Monad m => MapperM m Description -> MapperM m AST
111 112 113 114 115
traverseDescriptionsM mapper descriptions =
    mapM mapper descriptions

traverseDescriptions :: Mapper Description -> Mapper AST
traverseDescriptions = unmonad traverseDescriptionsM
116 117
collectDescriptionsM :: Monad m => CollectorM m Description -> CollectorM m AST
collectDescriptionsM = collectify traverseDescriptionsM
118

Zachary Snow committed
119 120 121 122
maybeDo :: Monad m => (a -> m b) -> Maybe a -> m (Maybe b)
maybeDo _ Nothing = return Nothing
maybeDo fun (Just val) = fun val >>= return . Just

123
traverseModuleItemsM :: Monad m => MapperM m ModuleItem -> MapperM m Description
124
traverseModuleItemsM mapper (Part attrs extern kw lifetime name ports items) = do
125 126
    items' <- mapM fullMapper items
    let items'' = concatMap breakGenerate items'
127
    return $ Part attrs extern kw lifetime name ports items''
128
    where
129
        fullMapper (Generate [GenBlock "" genItems]) =
130
            mapM fullGenItemMapper genItems >>= mapper . Generate
131 132 133
        fullMapper (Generate genItems) = do
            let genItems' = filter (/= GenNull) genItems
            mapM fullGenItemMapper genItems' >>= mapper . Generate
134 135
        fullMapper (MIAttr attr mi) =
            fullMapper mi >>= return . MIAttr attr
136
        fullMapper other = mapper other
137
        fullGenItemMapper = traverseNestedGenItemsM genItemMapper
138 139 140
        genItemMapper (GenModuleItem moduleItem) = do
            moduleItem' <- fullMapper moduleItem
            return $ case moduleItem' of
141
                Generate subItems -> GenBlock "" subItems
142
                _ -> GenModuleItem moduleItem'
143 144 145
        genItemMapper (GenIf (Number "1") s _) = return s
        genItemMapper (GenIf (Number "0") _ s) = return s
        genItemMapper (GenBlock _ []) = return GenNull
146
        genItemMapper other = return other
147 148 149 150 151 152 153 154 155 156
        breakGenerate :: ModuleItem -> [ModuleItem]
        breakGenerate (Generate genItems) =
            if all isGenModuleItem genItems
                then map (\(GenModuleItem item) -> item) genItems
                else [Generate genItems]
            where
                isGenModuleItem :: GenItem -> Bool
                isGenModuleItem (GenModuleItem _) = True
                isGenModuleItem _ = False
        breakGenerate other = [other]
157 158
traverseModuleItemsM mapper (PackageItem packageItem) = do
    let item = MIPackageItem packageItem
159
    converted <-
160
        traverseModuleItemsM mapper (Part [] False Module Nothing "DNE" [] [item])
161
    let item' = case converted of
162
            Part [] False Module Nothing "DNE" [] [newItem] -> newItem
163 164
            _ -> error $ "redirected PackageItem traverse failed: "
                    ++ show converted
165 166 167
    return $ case item' of
        MIPackageItem packageItem' -> PackageItem packageItem'
        other -> error $ "encountered bad package module item: " ++ show other
168 169
traverseModuleItemsM mapper (Package lifetime name packageItems) = do
    let items = map MIPackageItem packageItems
170
    converted <-
171
        traverseModuleItemsM mapper (Part [] False Module Nothing "DNE" [] items)
172
    let items' = case converted of
173
            Part [] False Module Nothing "DNE" [] newItems -> newItems
174 175
            _ -> error $ "redirected Package traverse failed: "
                    ++ show converted
176
    return $ Package lifetime name $ map (\(MIPackageItem item) -> item) items'
177
traverseModuleItemsM _ (Directive str) = return $ Directive str
178 179 180

traverseModuleItems :: Mapper ModuleItem -> Mapper Description
traverseModuleItems = unmonad traverseModuleItemsM
181 182
collectModuleItemsM :: Monad m => CollectorM m ModuleItem -> CollectorM m Description
collectModuleItemsM = collectify traverseModuleItemsM
183

184 185
traverseStmtsM' :: Monad m => TFStrategy -> MapperM m Stmt -> MapperM m ModuleItem
traverseStmtsM' strat mapper = moduleItemMapper
186 187 188
    where
        moduleItemMapper (AlwaysC kw stmt) =
            fullMapper stmt >>= return . AlwaysC kw
189
        moduleItemMapper (MIPackageItem (Function lifetime ret name decls stmts)) = do
190 191 192 193
            stmts' <-
                if strat == IncludeTFs
                    then mapM fullMapper stmts
                    else return stmts
194
            return $ MIPackageItem $ Function lifetime ret name decls stmts'
195
        moduleItemMapper (MIPackageItem (Task lifetime name decls stmts)) = do
196 197 198 199
            stmts' <-
                if strat == IncludeTFs
                    then mapM fullMapper stmts
                    else return stmts
200
            return $ MIPackageItem $ Task lifetime name decls stmts'
201 202
        moduleItemMapper (Initial stmt) =
            fullMapper stmt >>= return . Initial
203
        moduleItemMapper other = return $ other
204 205
        fullMapper = traverseNestedStmtsM mapper

206 207 208 209 210 211 212
traverseStmts' :: TFStrategy -> Mapper Stmt -> Mapper ModuleItem
traverseStmts' strat = unmonad $ traverseStmtsM' strat
collectStmtsM' :: Monad m => TFStrategy -> CollectorM m Stmt -> CollectorM m ModuleItem
collectStmtsM' strat = collectify $ traverseStmtsM' strat

traverseStmtsM :: Monad m => MapperM m Stmt -> MapperM m ModuleItem
traverseStmtsM = traverseStmtsM' IncludeTFs
213
traverseStmts :: Mapper Stmt -> Mapper ModuleItem
214
traverseStmts = traverseStmts' IncludeTFs
215
collectStmtsM :: Monad m => CollectorM m Stmt -> CollectorM m ModuleItem
216
collectStmtsM = collectStmtsM' IncludeTFs
217 218 219 220 221 222 223

-- private utility for turning a thing which maps over a single lever of
-- statements into one that maps over the nested statements first, then the
-- higher levels up
traverseNestedStmtsM :: Monad m => MapperM m Stmt -> MapperM m Stmt
traverseNestedStmtsM mapper = fullMapper
    where
224 225
        fullMapper stmt = mapper stmt >>= traverseSinglyNestedStmtsM fullMapper

226
-- variant of the above which only traverses one level down
227 228 229
traverseSinglyNestedStmtsM :: Monad m => MapperM m Stmt -> MapperM m Stmt
traverseSinglyNestedStmtsM fullMapper = cs
    where
230
        cs (StmtAttr a stmt) = fullMapper stmt >>= return . StmtAttr a
231 232 233
        cs (Block _ "" [] []) = return Null
        cs (Block kw name decls stmts) =
            mapM fullMapper stmts >>= return . Block kw name decls
234
        cs (Case u kw expr cases def) = do
235 236
            caseStmts <- mapM fullMapper $ map snd cases
            let cases' = zip (map fst cases) caseStmts
Zachary Snow committed
237
            def' <- maybeDo fullMapper def
238
            return $ Case u kw expr cases' def'
239
        cs (AsgnBlk op lhs expr) = return $ AsgnBlk op lhs expr
240
        cs (Asgn    mt lhs expr) = return $ Asgn    mt lhs expr
241
        cs (For a b c stmt) = fullMapper stmt >>= return . For a b c
242 243 244 245
        cs (While   e stmt) = fullMapper stmt >>= return . While   e
        cs (RepeatL e stmt) = fullMapper stmt >>= return . RepeatL e
        cs (DoWhile e stmt) = fullMapper stmt >>= return . DoWhile e
        cs (Forever   stmt) = fullMapper stmt >>= return . Forever
246
        cs (Foreach x vars stmt) = fullMapper stmt >>= return . Foreach x vars
247
        cs (If u e s1 s2) = do
248 249
            s1' <- fullMapper s1
            s2' <- fullMapper s2
250
            return $ If u e s1' s2'
251
        cs (Timing event stmt) = fullMapper stmt >>= return . Timing event
252
        cs (Return expr) = return $ Return expr
253
        cs (Subroutine ps f exprs) = return $ Subroutine ps f exprs
254
        cs (Trigger blocks x) = return $ Trigger blocks x
255 256
        cs (Assertion a) =
            traverseAssertionStmtsM fullMapper a >>= return . Assertion
257 258
        cs (Null) = return Null

259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361
traverseAssertionStmtsM :: Monad m => MapperM m Stmt -> MapperM m Assertion
traverseAssertionStmtsM mapper = assertionMapper
    where
        actionBlockMapper (ActionBlockIf stmt) =
            mapper stmt >>= return . ActionBlockIf
        actionBlockMapper (ActionBlockElse Nothing stmt) =
            mapper stmt >>= return . ActionBlockElse Nothing
        actionBlockMapper (ActionBlockElse (Just s1) s2) = do
            s1' <- mapper s1
            s2' <- mapper s2
            return $ ActionBlockElse (Just s1') s2'
        assertionMapper (Assert e ab) =
            actionBlockMapper ab >>= return . Assert e
        assertionMapper (Assume e ab) =
            actionBlockMapper ab >>= return . Assume e
        assertionMapper (Cover e stmt) =
            mapper stmt >>= return . Cover e

-- Note that this does not include the expressions without the statements of the
-- actions associated with the assertions.
traverseAssertionExprsM :: Monad m => MapperM m Expr -> MapperM m Assertion
traverseAssertionExprsM mapper = assertionMapper
    where
        seqExprMapper (SeqExpr e) =
            mapper e >>= return . SeqExpr
        seqExprMapper (SeqExprAnd        s1 s2) =
            ssMapper   SeqExprAnd        s1 s2
        seqExprMapper (SeqExprOr         s1 s2) =
            ssMapper   SeqExprOr         s1 s2
        seqExprMapper (SeqExprIntersect  s1 s2) =
            ssMapper   SeqExprIntersect  s1 s2
        seqExprMapper (SeqExprWithin     s1 s2) =
            ssMapper   SeqExprWithin     s1 s2
        seqExprMapper (SeqExprThroughout e s) = do
            e' <- mapper e
            s' <- seqExprMapper s
            return $ SeqExprThroughout e' s'
        seqExprMapper (SeqExprDelay ms e s) = do
            ms' <- case ms of
                Nothing -> return Nothing
                Just x -> seqExprMapper x >>= return . Just
            e' <- mapper e
            s' <- seqExprMapper s
            return $ SeqExprDelay ms' e' s'
        seqExprMapper (SeqExprFirstMatch s items) = do
            s' <- seqExprMapper s
            items' <- mapM seqMatchItemMapper items
            return $ SeqExprFirstMatch s' items'
        seqMatchItemMapper (Left (a, b, c)) = do
            c' <- mapper c
            return $ Left (a, b, c')
        seqMatchItemMapper (Right (x, (Args l p))) = do
            l' <- mapM maybeExprMapper l
            pes <- mapM maybeExprMapper $ map snd p
            let p' = zip (map fst p) pes
            return $ Right (x, Args l' p')
        maybeExprMapper Nothing = return Nothing
        maybeExprMapper (Just e) =
            mapper e >>= return . Just
        ppMapper constructor p1 p2 = do
            p1' <- propExprMapper p1
            p2' <- propExprMapper p2
            return $ constructor p1' p2'
        ssMapper constructor s1 s2 = do
            s1' <- seqExprMapper s1
            s2' <- seqExprMapper s2
            return $ constructor s1' s2'
        spMapper constructor se pe = do
            se' <- seqExprMapper se
            pe' <- propExprMapper pe
            return $ constructor se' pe'
        propExprMapper (PropExpr se) =
            seqExprMapper se >>= return . PropExpr
        propExprMapper (PropExprImpliesO se pe) =
            spMapper PropExprImpliesO se pe
        propExprMapper (PropExprImpliesNO se pe) =
            spMapper PropExprImpliesNO se pe
        propExprMapper (PropExprFollowsO se pe) =
            spMapper PropExprFollowsO se pe
        propExprMapper (PropExprFollowsNO se pe) =
            spMapper PropExprFollowsNO se pe
        propExprMapper (PropExprIff p1 p2) =
            ppMapper PropExprIff p1 p2
        propSpecMapper (PropertySpec ms me pe) = do
            me' <- case me of
                Nothing -> return Nothing
                Just e -> mapper e >>= return . Just
            pe' <- propExprMapper pe
            return $ PropertySpec ms me' pe'
        assertionExprMapper (Left e) =
            propSpecMapper e >>= return . Left
        assertionExprMapper (Right e) =
            mapper e >>= return . Right
        assertionMapper (Assert e ab) = do
            e' <- assertionExprMapper e
            return $ Assert e' ab
        assertionMapper (Assume e ab) = do
            e' <- assertionExprMapper e
            return $ Assume e' ab
        assertionMapper (Cover e stmt) = do
            e' <- assertionExprMapper e
            return $ Cover e' stmt

362
traverseStmtLHSsM :: Monad m => MapperM m LHS -> MapperM m Stmt
363
traverseStmtLHSsM mapper = stmtMapper
364
    where
365
        fullMapper = mapper
366 367 368
        stmtMapper (Timing (Event sense) stmt) = do
            sense' <- senseMapper sense
            return $ Timing (Event sense') stmt
369 370 371 372
        stmtMapper (Asgn (Just (Event sense)) lhs expr) = do
            lhs' <- fullMapper lhs
            sense' <- senseMapper sense
            return $ Asgn (Just $ Event sense') lhs' expr
373
        stmtMapper (AsgnBlk op lhs expr) = fullMapper lhs >>= \lhs' -> return $ AsgnBlk op lhs' expr
374
        stmtMapper (Asgn    mt lhs expr) = fullMapper lhs >>= \lhs' -> return $ Asgn    mt lhs' expr
375
        stmtMapper (For inits me incrs stmt) = do
376
            inits' <- mapInits inits
377 378 379 380 381
            let (lhss, asgnOps, exprs) = unzip3 incrs
            lhss' <- mapM fullMapper lhss
            let incrs' = zip3 lhss' asgnOps exprs
            return $ For inits' me incrs' stmt
            where
382 383 384 385 386
                mapInits (Left decls) = return $ Left decls
                mapInits (Right asgns) = do
                    let (lhss, exprs) = unzip asgns
                    lhss' <- mapM fullMapper lhss
                    return $ Right $ zip lhss' exprs
387 388
        stmtMapper (Assertion a) =
            assertionMapper a >>= return . Assertion
389
        stmtMapper other = return other
390 391 392 393 394 395 396 397
        senseMapper (Sense        lhs) = fullMapper lhs >>= return . Sense
        senseMapper (SensePosedge lhs) = fullMapper lhs >>= return . SensePosedge
        senseMapper (SenseNegedge lhs) = fullMapper lhs >>= return . SenseNegedge
        senseMapper (SenseOr    s1 s2) = do
            s1' <- senseMapper s1
            s2' <- senseMapper s2
            return $ SenseOr s1' s2'
        senseMapper (SenseStar       ) = return SenseStar
398 399 400 401 402 403 404 405 406 407 408 409 410
        assertionExprMapper (Left (PropertySpec (Just sense) me pe)) = do
            sense' <- senseMapper sense
            return $ Left $ PropertySpec (Just sense') me pe
        assertionExprMapper other = return $ other
        assertionMapper (Assert e ab) = do
            e' <- assertionExprMapper e
            return $ Assert e' ab
        assertionMapper (Assume e ab) = do
            e' <- assertionExprMapper e
            return $ Assume e' ab
        assertionMapper (Cover e stmt) = do
            e' <- assertionExprMapper e
            return $ Cover e' stmt
411 412 413 414 415

traverseStmtLHSs :: Mapper LHS -> Mapper Stmt
traverseStmtLHSs = unmonad traverseStmtLHSsM
collectStmtLHSsM :: Monad m => CollectorM m LHS -> CollectorM m Stmt
collectStmtLHSsM = collectify traverseStmtLHSsM
416 417 418 419 420

traverseNestedExprsM :: Monad m => MapperM m Expr -> MapperM m Expr
traverseNestedExprsM mapper = exprMapper
    where
        exprMapper e = mapper e >>= em
421 422 423
        maybeExprMapper Nothing = return Nothing
        maybeExprMapper (Just e) =
            exprMapper e >>= return . Just
424 425 426
        typeOrExprMapper (Left t) = return $ Left t
        typeOrExprMapper (Right e) =
            exprMapper e >>= return . Right
427 428 429
        em (String s) = return $ String s
        em (Number s) = return $ Number s
        em (Ident  i) = return $ Ident  i
430
        em (PSIdent x y) = return $ PSIdent x y
431
        em (Range e m (e1, e2)) = do
432 433 434
            e' <- exprMapper e
            e1' <- exprMapper e1
            e2' <- exprMapper e2
435
            return $ Range e' m (e1', e2')
436
        em (Bit   e1 e2) = do
437 438
            e1' <- exprMapper e1
            e2' <- exprMapper e2
439
            return $ Bit e1' e2'
440 441 442 443 444 445
        em (Repeat     e l) = do
            e' <- exprMapper e
            l' <- mapM exprMapper l
            return $ Repeat e' l'
        em (Concat     l) =
            mapM exprMapper l >>= return . Concat
446 447 448 449
        em (Stream o e l) = do
            e' <- exprMapper e
            l' <- mapM exprMapper l
            return $ Stream o e' l'
450
        em (Call    ps f (Args l p)) = do
451 452 453
            l' <- mapM maybeExprMapper l
            pes <- mapM maybeExprMapper $ map snd p
            let p' = zip (map fst p) pes
454
            return $ Call ps f (Args l' p')
455 456 457 458 459 460 461 462 463 464 465
        em (UniOp      o e) =
            exprMapper e >>= return . UniOp o
        em (BinOp      o e1 e2) = do
            e1' <- exprMapper e1
            e2' <- exprMapper e2
            return $ BinOp o e1' e2'
        em (Mux        e1 e2 e3) = do
            e1' <- exprMapper e1
            e2' <- exprMapper e2
            e3' <- exprMapper e3
            return $ Mux e1' e2' e3'
466 467 468 469 470 471
        em (Cast (Left t) e) =
            exprMapper e >>= return . Cast (Left t)
        em (Cast (Right e1) e2) = do
            e1' <- exprMapper e1
            e2' <- exprMapper e2
            return $ Cast (Right e1') e2'
472 473 474 475 476 477
        em (DimsFn f tore) =
            typeOrExprMapper tore >>= return . DimsFn f
        em (DimFn f tore e) = do
            tore' <- typeOrExprMapper tore
            e' <- exprMapper e
            return $ DimFn f tore' e'
478 479
        em (Dot e x) =
            exprMapper e >>= \e' -> return $ Dot e' x
480
        em (Pattern l) = do
481 482
            let names = map fst l
            exprs <- mapM exprMapper $ map snd l
483
            return $ Pattern $ zip names exprs
484
        em (Nil) = return Nil
485

486
exprMapperHelpers :: Monad m => MapperM m Expr ->
487
    (MapperM m Range, MapperM m (Maybe Expr), MapperM m Decl, MapperM m LHS, MapperM m Type)
488
exprMapperHelpers exprMapper =
489
    (rangeMapper, maybeExprMapper, declMapper, traverseNestedLHSsM lhsMapper, typeMapper)
490 491 492 493 494 495 496 497 498 499 500
    where

    rangeMapper (a, b) = do
        a' <- exprMapper a
        b' <- exprMapper b
        return (a', b')

    maybeExprMapper Nothing = return Nothing
    maybeExprMapper (Just e) =
        exprMapper e >>= return . Just

501
    typeMapper' t = do
502 503
        let (tf, rs) = typeRanges t
        rs' <- mapM rangeMapper rs
504 505 506
        return $ tf rs'
    typeMapper = traverseNestedTypesM typeMapper'

507 508 509 510 511
    maybeTypeMapper Nothing = return Nothing
    maybeTypeMapper (Just t) =
        typeMapper t >>= return . Just

    declMapper (Param s t x e) = do
512 513
        t' <- typeMapper t
        e' <- exprMapper e
514 515 516 517
        return $ Param s t' x e'
    declMapper (ParamType s x mt) = do
        mt' <- maybeTypeMapper mt
        return $ ParamType s x mt'
518 519
    declMapper (Variable d t x a me) = do
        t' <- typeMapper t
520 521
        a' <- mapM rangeMapper a
        me' <- maybeExprMapper me
522
        return $ Variable d t' x a' me'
523

524 525 526 527
    lhsMapper (LHSRange l m r) =
        rangeMapper r >>= return . LHSRange l m
    lhsMapper (LHSBit l e) =
        exprMapper e >>= return . LHSBit l
528 529 530
    lhsMapper (LHSStream o e ls) = do
        e' <- exprMapper e
        return $ LHSStream o e' ls
531 532
    lhsMapper other = return other

533 534 535
traverseExprsM' :: Monad m => TFStrategy -> MapperM m Expr -> MapperM m ModuleItem
traverseExprsM' strat exprMapper = moduleItemMapper
    where
536

537
    (rangeMapper, maybeExprMapper, declMapper, lhsMapper, typeMapper)
538
        = exprMapperHelpers exprMapper
539

540
    stmtMapper = traverseNestedStmtsM (traverseStmtExprsM exprMapper)
541

542 543 544
    portBindingMapper (p, me) =
        maybeExprMapper me >>= \me' -> return (p, me')

545 546 547 548 549
    paramBindingMapper (p, Left t) =
        typeMapper t >>= \t' -> return (p, Left t')
    paramBindingMapper (p, Right e) =
        exprMapper e >>= \e' -> return (p, Right e')

550 551 552
    moduleItemMapper (MIAttr attr mi) =
        -- note: we exclude expressions in attributes from conversion
        return $ MIAttr attr mi
553 554 555
    moduleItemMapper (MIPackageItem (Typedef t x)) = do
        t' <- typeMapper t
        return $ MIPackageItem $ Typedef t' x
556 557
    moduleItemMapper (MIPackageItem (Decl decl)) =
        declMapper decl >>= return . MIPackageItem . Decl
558 559 560 561
    moduleItemMapper (Defparam lhs expr) = do
        lhs' <- lhsMapper lhs
        expr' <- exprMapper expr
        return $ Defparam lhs' expr'
562 563
    moduleItemMapper (AlwaysC kw stmt) =
        stmtMapper stmt >>= return . AlwaysC kw
564 565
    moduleItemMapper (Initial stmt) =
        stmtMapper stmt >>= return . Initial
566 567
    moduleItemMapper (Assign delay lhs expr) = do
        delay' <- maybeExprMapper delay
568
        lhs' <- lhsMapper lhs
569
        expr' <- exprMapper expr
570
        return $ Assign delay' lhs' expr'
571
    moduleItemMapper (MIPackageItem (Function lifetime ret f decls stmts)) = do
572
        ret' <- typeMapper ret
573 574 575 576 577 578 579 580
        decls' <-
            if strat == IncludeTFs
                then mapM declMapper decls
                else return decls
        stmts' <-
            if strat == IncludeTFs
                then mapM stmtMapper stmts
                else return stmts
581
        return $ MIPackageItem $ Function lifetime ret' f decls' stmts'
582
    moduleItemMapper (MIPackageItem (Task lifetime f decls stmts)) = do
583 584 585 586 587 588 589 590
        decls' <-
            if strat == IncludeTFs
                then mapM declMapper decls
                else return decls
        stmts' <-
            if strat == IncludeTFs
                then mapM stmtMapper stmts
                else return stmts
591
        return $ MIPackageItem $ Task lifetime f decls' stmts'
592
    moduleItemMapper (Instance m p x r l) = do
593
        p' <- mapM paramBindingMapper p
594
        l' <- mapM portBindingMapper l
595
        r' <- mapM rangeMapper r
596
        return $ Instance m p' x r' l'
597 598
    moduleItemMapper (Modport x l) =
        mapM modportDeclMapper l >>= return . Modport x
599 600
    moduleItemMapper (NInputGate  kw x lhs exprs) = do
        exprs' <- mapM exprMapper exprs
601 602 603 604 605 606
        lhs' <- lhsMapper lhs
        return $ NInputGate kw x lhs' exprs'
    moduleItemMapper (NOutputGate kw x lhss expr) = do
        lhss' <- mapM lhsMapper lhss
        expr' <- exprMapper expr
        return $ NOutputGate kw x lhss' expr'
607
    moduleItemMapper (Genvar   x) = return $ Genvar   x
608 609 610
    moduleItemMapper (Generate items) = do
        items' <- mapM (traverseNestedGenItemsM genItemMapper) items
        return $ Generate items'
611 612
    moduleItemMapper (MIPackageItem (Comment c)) =
        return $ MIPackageItem $ Comment c
613 614
    moduleItemMapper (MIPackageItem (Import x y)) =
        return $ MIPackageItem $ Import x y
615 616
    moduleItemMapper (MIPackageItem (Export x)) =
        return $ MIPackageItem $ Export x
617 618 619 620
    moduleItemMapper (AssertionItem (mx, a)) = do
        a' <- traverseAssertionStmtsM stmtMapper a
        a'' <- traverseAssertionExprsM exprMapper a'
        return $ AssertionItem (mx, a'')
621

622
    genItemMapper (GenFor (n1, x1, e1) cc (x2, op2, e2) mn subItems) = do
623 624 625
        e1' <- exprMapper e1
        e2' <- exprMapper e2
        cc' <- exprMapper cc
626
        return $ GenFor (n1, x1, e1') cc' (x2, op2, e2') mn subItems
627 628 629 630 631 632 633 634 635 636
    genItemMapper (GenIf e i1 i2) = do
        e' <- exprMapper e
        return $ GenIf e' i1 i2
    genItemMapper (GenCase e cases def) = do
        e' <- exprMapper e
        caseExprs <- mapM (mapM exprMapper . fst) cases
        let cases' = zip caseExprs (map snd cases)
        return $ GenCase e' cases' def
    genItemMapper other = return other

637 638 639 640
    modportDeclMapper (dir, ident, Just e) = do
        e' <- exprMapper e
        return (dir, ident, Just e')
    modportDeclMapper other = return other
641

642 643 644 645 646 647 648
traverseExprs' :: TFStrategy -> Mapper Expr -> Mapper ModuleItem
traverseExprs' strat = unmonad $ traverseExprsM' strat
collectExprsM' :: Monad m => TFStrategy -> CollectorM m Expr -> CollectorM m ModuleItem
collectExprsM' strat = collectify $ traverseExprsM' strat

traverseExprsM :: Monad m => MapperM m Expr -> MapperM m ModuleItem
traverseExprsM = traverseExprsM' IncludeTFs
649
traverseExprs :: Mapper Expr -> Mapper ModuleItem
650
traverseExprs = traverseExprs' IncludeTFs
651
collectExprsM :: Monad m => CollectorM m Expr -> CollectorM m ModuleItem
652
collectExprsM = collectExprsM' IncludeTFs
653

654 655 656 657
traverseStmtExprsM :: Monad m => MapperM m Expr -> MapperM m Stmt
traverseStmtExprsM exprMapper = flatStmtMapper
    where

658
    (_, maybeExprMapper, declMapper, lhsMapper, _)
659 660 661 662 663 664 665 666 667
        = exprMapperHelpers exprMapper

    caseMapper (exprs, stmt) = do
        exprs' <- mapM exprMapper exprs
        return (exprs', stmt)
    stmtMapper = traverseNestedStmtsM flatStmtMapper
    flatStmtMapper (StmtAttr attr stmt) =
        -- note: we exclude expressions in attributes from conversion
        return $ StmtAttr attr stmt
668
    flatStmtMapper (Block kw name decls stmts) = do
669
        decls' <- mapM declMapper decls
670
        return $ Block kw name decls' stmts
671 672 673 674
    flatStmtMapper (Case u kw e cases def) = do
        e' <- exprMapper e
        cases' <- mapM caseMapper cases
        return $ Case u kw e' cases' def
675 676 677 678 679 680 681 682
    flatStmtMapper (AsgnBlk op lhs expr) = do
        lhs' <- lhsMapper lhs
        expr' <- exprMapper expr
        return $ AsgnBlk op lhs' expr'
    flatStmtMapper (Asgn    mt lhs expr) = do
        lhs' <- lhsMapper lhs
        expr' <- exprMapper expr
        return $ Asgn    mt lhs' expr'
683
    flatStmtMapper (For inits cc asgns stmt) = do
684 685
        inits' <- initsMapper inits
        cc' <- exprMapper cc
686 687 688 689 690 691 692 693 694
        asgns' <- mapM asgnMapper asgns
        return $ For inits' cc' asgns' stmt
    flatStmtMapper (While   e stmt) =
        exprMapper e >>= \e' -> return $ While   e' stmt
    flatStmtMapper (RepeatL e stmt) =
        exprMapper e >>= \e' -> return $ RepeatL e' stmt
    flatStmtMapper (DoWhile e stmt) =
        exprMapper e >>= \e' -> return $ DoWhile e' stmt
    flatStmtMapper (Forever   stmt) = return $ Forever stmt
695
    flatStmtMapper (Foreach x vars stmt) = return $ Foreach x vars stmt
696 697 698
    flatStmtMapper (If u cc s1 s2) =
        exprMapper cc >>= \cc' -> return $ If u cc' s1 s2
    flatStmtMapper (Timing event stmt) = return $ Timing event stmt
699
    flatStmtMapper (Subroutine ps f (Args l p)) = do
700 701 702
        l' <- mapM maybeExprMapper l
        pes <- mapM maybeExprMapper $ map snd p
        let p' = zip (map fst p) pes
703
        return $ Subroutine ps f (Args l' p')
704 705
    flatStmtMapper (Return expr) =
        exprMapper expr >>= return . Return
706
    flatStmtMapper (Trigger blocks x) = return $ Trigger blocks x
707 708 709 710 711 712
    flatStmtMapper (Assertion a) = do
        a' <- traverseAssertionStmtsM stmtMapper a
        a'' <- traverseAssertionExprsM exprMapper a'
        return $ Assertion a''
    flatStmtMapper (Null) = return Null

713 714 715
    initsMapper (Left decls) = mapM declMapper decls >>= return . Left
    initsMapper (Right asgns) = mapM mapper asgns >>= return . Right
        where mapper (l, e) = exprMapper e >>= return . (,) l
716 717 718 719 720 721 722 723

    asgnMapper (l, op, e) = exprMapper e >>= \e' -> return $ (l, op, e')

traverseStmtExprs :: Mapper Expr -> Mapper Stmt
traverseStmtExprs = unmonad traverseStmtExprsM
collectStmtExprsM :: Monad m => CollectorM m Expr -> CollectorM m Stmt
collectStmtExprsM = collectify traverseStmtExprsM

724 725 726
traverseLHSsM' :: Monad m => TFStrategy -> MapperM m LHS -> MapperM m ModuleItem
traverseLHSsM' strat mapper item =
    traverseStmtsM' strat (traverseStmtLHSsM mapper) item >>= traverseModuleItemLHSsM
727
    where
728
        traverseModuleItemLHSsM (Assign delay lhs expr) = do
729
            lhs' <- mapper lhs
730
            return $ Assign delay lhs' expr
731
        traverseModuleItemLHSsM (Defparam lhs expr) = do
732
            lhs' <- mapper lhs
733
            return $ Defparam lhs' expr
734 735 736 737 738 739
        traverseModuleItemLHSsM (NOutputGate kw x lhss expr) = do
            lhss' <- mapM mapper lhss
            return $ NOutputGate kw x lhss' expr
        traverseModuleItemLHSsM (NInputGate  kw x lhs exprs) = do
            lhs' <- mapper lhs
            return $ NInputGate kw x lhs' exprs
740
        traverseModuleItemLHSsM (AssertionItem (mx, a)) = do
741 742 743 744 745 746
            converted <-
                traverseNestedStmtsM (traverseStmtLHSsM mapper) (Assertion a)
            return $ case converted of
                Assertion a' -> AssertionItem (mx, a')
                _ -> error $ "redirected AssertionItem traverse failed: "
                        ++ show converted
747 748 749
        traverseModuleItemLHSsM (Generate items) = do
            items' <- mapM (traverseNestedGenItemsM traverGenItemLHSsM) items
            return $ Generate items'
750
        traverseModuleItemLHSsM other = return other
751 752 753 754 755 756 757
        traverGenItemLHSsM (GenFor (n1, x1, e1) cc (x2, op2, e2) mn subItems) = do
            wrapped_x1' <- (if n1 then return else mapper) $ LHSIdent x1
            wrapped_x2' <- mapper $ LHSIdent x2
            let LHSIdent x1' = wrapped_x1'
            let LHSIdent x2' = wrapped_x2'
            return $ GenFor (n1, x1', e1) cc (x2', op2, e2) mn subItems
        traverGenItemLHSsM other = return other
758

759 760 761 762 763 764 765
traverseLHSs' :: TFStrategy -> Mapper LHS -> Mapper ModuleItem
traverseLHSs' strat = unmonad $ traverseLHSsM' strat
collectLHSsM' :: Monad m => TFStrategy -> CollectorM m LHS -> CollectorM m ModuleItem
collectLHSsM' strat = collectify $ traverseLHSsM' strat

traverseLHSsM :: Monad m => MapperM m LHS -> MapperM m ModuleItem
traverseLHSsM = traverseLHSsM' IncludeTFs
766
traverseLHSs :: Mapper LHS -> Mapper ModuleItem
767
traverseLHSs = traverseLHSs' IncludeTFs
768
collectLHSsM :: Monad m => CollectorM m LHS -> CollectorM m ModuleItem
769
collectLHSsM = collectLHSsM' IncludeTFs
770

771 772 773
traverseNestedLHSsM :: Monad m => MapperM m LHS -> MapperM m LHS
traverseNestedLHSsM mapper = fullMapper
    where
774
        fullMapper lhs = mapper lhs >>= tl
775 776 777 778 779 780
        tl (LHSIdent  x       ) = return $ LHSIdent x
        tl (LHSBit    l e     ) = fullMapper l >>= \l' -> return $ LHSBit    l' e
        tl (LHSRange  l m r   ) = fullMapper l >>= \l' -> return $ LHSRange  l' m r
        tl (LHSDot    l x     ) = fullMapper l >>= \l' -> return $ LHSDot    l' x
        tl (LHSConcat     lhss) = mapM fullMapper lhss >>= return . LHSConcat
        tl (LHSStream o e lhss) = mapM fullMapper lhss >>= return . LHSStream o e
781

782 783 784 785 786
traverseNestedLHSs :: Mapper LHS -> Mapper LHS
traverseNestedLHSs = unmonad traverseNestedLHSsM
collectNestedLHSsM :: Monad m => CollectorM m LHS -> CollectorM m LHS
collectNestedLHSsM = collectify traverseNestedLHSsM

787 788
traverseDeclsM' :: Monad m => TFStrategy -> MapperM m Decl -> MapperM m ModuleItem
traverseDeclsM' strat mapper item = do
789 790
    item' <- miMapper item
    traverseStmtsM' strat stmtMapper item'
791
    where
792 793
        miMapper (MIPackageItem (Decl decl)) =
            mapper decl >>= return . MIPackageItem . Decl
794
        miMapper (MIPackageItem (Function l t x decls stmts)) = do
795 796 797 798
            decls' <-
                if strat == IncludeTFs
                    then mapM mapper decls
                    else return decls
799 800
            return $ MIPackageItem $ Function l t x decls' stmts
        miMapper (MIPackageItem (Task l x decls stmts)) = do
801 802 803 804
            decls' <-
                if strat == IncludeTFs
                    then mapM mapper decls
                    else return decls
805 806
            return $ MIPackageItem $ Task l x decls' stmts
        miMapper other = return other
807
        stmtMapper (Block kw name decls stmts) = do
808
            decls' <- mapM mapper decls
809
            return $ Block kw name decls' stmts
810
        stmtMapper other = return other
811

812 813 814 815 816 817 818
traverseDecls' :: TFStrategy -> Mapper Decl -> Mapper ModuleItem
traverseDecls' strat = unmonad $ traverseDeclsM' strat
collectDeclsM' :: Monad m => TFStrategy -> CollectorM m Decl -> CollectorM m ModuleItem
collectDeclsM' strat = collectify $ traverseDeclsM' strat

traverseDeclsM :: Monad m => MapperM m Decl -> MapperM m ModuleItem
traverseDeclsM = traverseDeclsM' IncludeTFs
819
traverseDecls :: Mapper Decl -> Mapper ModuleItem
820
traverseDecls = traverseDecls' IncludeTFs
821
collectDeclsM :: Monad m => CollectorM m Decl -> CollectorM m ModuleItem
822
collectDeclsM = collectDeclsM' IncludeTFs
823

824 825
traverseNestedTypesM :: Monad m => MapperM m Type -> MapperM m Type
traverseNestedTypesM mapper = fullMapper
826
    where
827
        fullMapper t = tm t >>= mapper
828
        tm (Alias      ps xx    rs) = return $ Alias      ps xx    rs
829
        tm (Net           kw sg rs) = return $ Net           kw sg rs
830 831 832 833
        tm (Implicit         sg rs) = return $ Implicit         sg rs
        tm (IntegerVector kw sg rs) = return $ IntegerVector kw sg rs
        tm (IntegerAtom   kw sg   ) = return $ IntegerAtom   kw sg
        tm (NonInteger    kw      ) = return $ NonInteger    kw
834 835 836 837 838 839 840 841 842 843
        tm (InterfaceT x my r) = return $ InterfaceT x my r
        tm (Enum Nothing vals r) =
            return $ Enum Nothing vals r
        tm (Enum (Just t) vals r) = do
            t' <- fullMapper t
            return $ Enum (Just t') vals r
        tm (Struct p fields r) = do
            types <- mapM fullMapper $ map fst fields
            let idents = map snd fields
            return $ Struct p (zip types idents) r
844 845 846 847
        tm (Union p fields r) = do
            types <- mapM fullMapper $ map fst fields
            let idents = map snd fields
            return $ Union p (zip types idents) r
848

849 850 851 852 853
traverseNestedTypes :: Mapper Type -> Mapper Type
traverseNestedTypes = unmonad traverseNestedTypesM
collectNestedTypesM :: Monad m => CollectorM m Type -> CollectorM m Type
collectNestedTypesM = collectify traverseNestedTypesM

854 855 856 857 858 859 860
traverseTypesM :: Monad m => MapperM m Type -> MapperM m ModuleItem
traverseTypesM mapper item =
    miMapper item >>=
    traverseDeclsM declMapper >>=
    traverseExprsM (traverseNestedExprsM exprMapper)
    where
        fullMapper = traverseNestedTypesM mapper
861 862
        maybeMapper Nothing = return Nothing
        maybeMapper (Just t) = fullMapper t >>= return . Just
863 864 865
        typeOrExprMapper (Right e) = return $ Right e
        typeOrExprMapper (Left t) =
            fullMapper t >>= return . Left
866 867
        exprMapper (Cast (Left t) e) =
            fullMapper t >>= \t' -> return $ Cast (Left t') e
868 869 870 871 872
        exprMapper (DimsFn f tore) =
            typeOrExprMapper tore >>= return . DimsFn f
        exprMapper (DimFn f tore e) = do
            tore' <- typeOrExprMapper tore
            return $ DimFn f tore' e
873
        exprMapper other = return other
874 875 876 877
        declMapper (Param s t x e) =
            fullMapper t >>= \t' -> return $ Param s t' x e
        declMapper (ParamType s x mt) =
            maybeMapper mt >>= \mt' -> return $ ParamType s x mt'
878
        declMapper (Variable d t x a me) =
879
            fullMapper t >>= \t' -> return $ Variable d t' x a me
880 881
        miMapper (MIPackageItem (Typedef t x)) =
            fullMapper t >>= \t' -> return $ MIPackageItem $ Typedef t' x
882 883
        miMapper (MIPackageItem (Function l t x d s)) =
            fullMapper t >>= \t' -> return $ MIPackageItem $ Function l t' x d s
884 885
        miMapper (MIPackageItem (other @ (Task _ _ _ _))) =
            return $ MIPackageItem other
886 887 888 889 890 891 892
        miMapper (Instance m params x r p) = do
            params' <- mapM mapParam params
            return $ Instance m params' x r p
            where
                mapParam (i, Left t) =
                    fullMapper t >>= \t' -> return (i, Left t')
                mapParam (i, Right e) = return $ (i, Right e)
893
        miMapper other = return other
894 895 896 897 898

traverseTypes :: Mapper Type -> Mapper ModuleItem
traverseTypes = unmonad traverseTypesM
collectTypesM :: Monad m => CollectorM m Type -> CollectorM m ModuleItem
collectTypesM = collectify traverseTypesM
899 900 901 902 903 904 905 906 907 908 909 910 911 912 913 914 915 916 917

traverseGenItemsM :: Monad m => MapperM m GenItem -> MapperM m ModuleItem
traverseGenItemsM mapper = moduleItemMapper
    where
        fullMapper = traverseNestedGenItemsM mapper
        moduleItemMapper (Generate genItems) =
            mapM fullMapper genItems >>= return . Generate
        moduleItemMapper other = return other

traverseGenItems :: Mapper GenItem -> Mapper ModuleItem
traverseGenItems = unmonad traverseGenItemsM
collectGenItemsM :: Monad m => CollectorM m GenItem -> CollectorM m ModuleItem
collectGenItemsM = collectify traverseGenItemsM

-- traverses all GenItems within a given GenItem, but doesn't inspect within
-- GenModuleItems
traverseNestedGenItemsM :: Monad m => MapperM m GenItem -> MapperM m GenItem
traverseNestedGenItemsM mapper = fullMapper
    where
918 919 920 921 922 923
        fullMapper stmt =
            mapper stmt >>= traverseSinglyNestedGenItemsM fullMapper

traverseSinglyNestedGenItemsM :: Monad m => MapperM m GenItem -> MapperM m GenItem
traverseSinglyNestedGenItemsM fullMapper = gim
    where
924 925 926 927 928 929
        gim (GenBlock x subItems) = do
            subItems' <- mapM fullMapper subItems
            return $ GenBlock x (concatMap flattenBlocks subItems')
        gim (GenFor a b c d subItems) = do
            subItems' <- mapM fullMapper subItems
            return $ GenFor a b c d (concatMap flattenBlocks subItems')
930 931 932 933 934 935 936 937 938 939 940 941
        gim (GenIf e i1 i2) = do
            i1' <- fullMapper i1
            i2' <- fullMapper i2
            return $ GenIf e i1' i2'
        gim (GenCase e cases def) = do
            caseItems <- mapM (fullMapper . snd) cases
            let cases' = zip (map fst cases) caseItems
            def' <- maybeDo fullMapper def
            return $ GenCase e cases' def'
        gim (GenModuleItem moduleItem) =
            return $ GenModuleItem moduleItem
        gim (GenNull) = return GenNull
942
        flattenBlocks :: GenItem -> [GenItem]
943
        flattenBlocks (GenBlock "" items) = items
944
        flattenBlocks other = [other]
945

946 947
traverseAsgnsM' :: Monad m => TFStrategy -> MapperM m (LHS, Expr) -> MapperM m ModuleItem
traverseAsgnsM' strat mapper = moduleItemMapper
948 949 950
    where
        moduleItemMapper item = miMapperA item >>= miMapperB

951
        miMapperA (Assign delay lhs expr) = do
952
            (lhs', expr') <- mapper (lhs, expr)
953
            return $ Assign delay lhs' expr'
954 955 956
        miMapperA (Defparam lhs expr) = do
            (lhs', expr') <- mapper (lhs, expr)
            return $ Defparam lhs' expr'
957 958
        miMapperA other = return other

959
        miMapperB = traverseStmtsM' strat stmtMapper
960
        stmtMapper = traverseStmtAsgnsM mapper
961

962 963 964 965 966 967 968
traverseAsgns' :: TFStrategy -> Mapper (LHS, Expr) -> Mapper ModuleItem
traverseAsgns' strat = unmonad $ traverseAsgnsM' strat
collectAsgnsM' :: Monad m => TFStrategy -> CollectorM m (LHS, Expr) -> CollectorM m ModuleItem
collectAsgnsM' strat = collectify $ traverseAsgnsM' strat

traverseAsgnsM :: Monad m => MapperM m (LHS, Expr) -> MapperM m ModuleItem
traverseAsgnsM = traverseAsgnsM' IncludeTFs
969
traverseAsgns :: Mapper (LHS, Expr) -> Mapper ModuleItem
970
traverseAsgns = traverseAsgns' IncludeTFs
971
collectAsgnsM :: Monad m => CollectorM m (LHS, Expr) -> CollectorM m ModuleItem
972
collectAsgnsM = collectAsgnsM' IncludeTFs
973

974 975 976 977 978 979 980 981 982 983 984 985 986 987 988 989
traverseStmtAsgnsM :: Monad m => MapperM m (LHS, Expr) -> MapperM m Stmt
traverseStmtAsgnsM mapper = stmtMapper
    where
        stmtMapper (AsgnBlk op lhs expr) = do
            (lhs', expr') <- mapper (lhs, expr)
            return $ AsgnBlk op lhs' expr'
        stmtMapper (Asgn    mt lhs expr) = do
            (lhs', expr') <- mapper (lhs, expr)
            return $ Asgn    mt lhs' expr'
        stmtMapper other = return other

traverseStmtAsgns :: Mapper (LHS, Expr) -> Mapper Stmt
traverseStmtAsgns = unmonad traverseStmtAsgnsM
collectStmtAsgnsM :: Monad m => CollectorM m (LHS, Expr) -> CollectorM m Stmt
collectStmtAsgnsM = collectify traverseStmtAsgnsM

990 991
traverseNestedModuleItemsM :: Monad m => MapperM m ModuleItem -> MapperM m ModuleItem
traverseNestedModuleItemsM mapper item = do
992
    converted <-
993
        traverseModuleItemsM mapper (Part [] False Module Nothing "DNE" [] [item])
994
    let items' = case converted of
995
            Part [] False Module Nothing "DNE" [] newItems -> newItems
996 997
            _ -> error $ "redirected NestedModuleItems traverse failed: "
                    ++ show converted
998 999 1000
    return $ case items' of
        [item'] -> item'
        _ -> Generate $ map GenModuleItem items'
1001 1002 1003 1004 1005 1006

traverseNestedModuleItems :: Mapper ModuleItem -> Mapper ModuleItem
traverseNestedModuleItems = unmonad traverseNestedModuleItemsM
collectNestedModuleItemsM :: Monad m => CollectorM m ModuleItem -> CollectorM m ModuleItem
collectNestedModuleItemsM = collectify traverseNestedModuleItemsM

1007 1008
traverseNestedStmts :: Mapper Stmt -> Mapper Stmt
traverseNestedStmts = unmonad traverseNestedStmtsM
1009 1010
collectNestedStmtsM :: Monad m => CollectorM m Stmt -> CollectorM m Stmt
collectNestedStmtsM = collectify traverseNestedStmtsM
1011 1012 1013 1014 1015

traverseNestedExprs :: Mapper Expr -> Mapper Expr
traverseNestedExprs = unmonad traverseNestedExprsM
collectNestedExprsM :: Monad m => CollectorM m Expr -> CollectorM m Expr
collectNestedExprsM = collectify traverseNestedExprsM
1016 1017 1018

-- Traverse all the declaration scopes within a ModuleItem. Note that Functions,
-- Tasks, Always and Initial blocks are all NOT passed through ModuleItem
1019
-- mapper, and Decl ModuleItems are NOT passed through the Decl mapper. The
1020 1021 1022 1023 1024 1025
-- state is restored to its previous value after each scope is exited. Only the
-- Decl mapper may modify the state, as we maintain the invariant that all other
-- functions restore the state on exit. The Stmt mapper must not traverse
-- statements recursively, as we add a recursive wrapper here.
traverseScopesM
    :: (Eq s, Show s)
1026 1027 1028 1029 1030
    => Monad m
    => MapperM (StateT s m) Decl
    -> MapperM (StateT s m) ModuleItem
    -> MapperM (StateT s m) Stmt
    -> MapperM (StateT s m) ModuleItem
1031 1032 1033 1034
traverseScopesM declMapper moduleItemMapper stmtMapper =
    fullModuleItemMapper
    where

1035 1036
        nestedStmtMapper stmt =
            stmtMapper stmt >>= traverseSinglyNestedStmtsM fullStmtMapper
1037
        fullStmtMapper (Block kw name decls stmts) = do
1038 1039
            prevState <- get
            decls' <- mapM declMapper decls
1040
            block <- nestedStmtMapper $ Block kw name decls' stmts
1041 1042
            put prevState
            return block
1043
        fullStmtMapper other = nestedStmtMapper other
1044 1045 1046 1047 1048 1049 1050 1051 1052 1053 1054 1055 1056 1057 1058 1059 1060 1061 1062 1063 1064 1065 1066 1067 1068

        redirectModuleItem (MIPackageItem (Function ml t x decls stmts)) = do
            prevState <- get
            t' <- do
                res <- declMapper $ Variable Local t x [] Nothing
                case res of
                    Variable Local newType _ [] Nothing -> return newType
                    _ -> error $ "redirected func ret traverse failed: " ++ show res
            decls' <- mapM declMapper decls
            stmts' <- mapM fullStmtMapper stmts
            put prevState
            return $ MIPackageItem $ Function ml t' x decls' stmts'
        redirectModuleItem (MIPackageItem (Task     ml   x decls stmts)) = do
            prevState <- get
            decls' <- mapM declMapper decls
            stmts' <- mapM fullStmtMapper stmts
            put prevState
            return $ MIPackageItem $ Task     ml    x decls' stmts'
        redirectModuleItem (AlwaysC kw stmt) =
            fullStmtMapper stmt >>= return . AlwaysC kw
        redirectModuleItem (Initial stmt) =
            fullStmtMapper stmt >>= return . Initial
        redirectModuleItem item =
            moduleItemMapper item

1069 1070 1071 1072
        -- This previously checked the invariant that the module item mappers
        -- should not modify the state. Now we simply "enforce" it but resetting
        -- the state to its previous value. Comparing the state, as we did
        -- previously, incurs a noticeable performance hit.
1073 1074 1075
        fullModuleItemMapper item = do
            prevState <- get
            item' <- redirectModuleItem item
1076 1077
            put prevState
            return item'
1078

1079 1080 1081 1082 1083 1084 1085 1086 1087 1088 1089
-- applies the given decl conversion across the description, and then performs a
-- scoped traversal for each ModuleItem in the description
scopedConversion
    :: (Eq s, Show s)
    => MapperM (State s) Decl
    -> MapperM (State s) ModuleItem
    -> MapperM (State s) Stmt
    -> s
    -> Description
    -> Description
scopedConversion traverseDeclM traverseModuleItemM traverseStmtM s description =
1090 1091 1092 1093 1094 1095 1096 1097 1098 1099 1100 1101 1102
    runIdentity $ scopedConversionM traverseDeclM traverseModuleItemM traverseStmtM s description

scopedConversionM
    :: (Eq s, Show s)
    => Monad m
    => MapperM (StateT s m) Decl
    -> MapperM (StateT s m) ModuleItem
    -> MapperM (StateT s m) Stmt
    -> s
    -> Description
    -> m Description
scopedConversionM traverseDeclM traverseModuleItemM traverseStmtM s description =
    evalStateT (initialTraverse description >>= scopedTraverse) s
1103
    where
1104
        initialTraverse = traverseModuleItemsM traverseMIPackageItemDecl
1105 1106
        scopedTraverse = traverseModuleItemsM $
            traverseScopesM traverseDeclM traverseModuleItemM traverseStmtM
1107 1108 1109
        traverseMIPackageItemDecl (MIPackageItem (Decl decl)) =
            traverseDeclM decl >>= return . MIPackageItem . Decl
        traverseMIPackageItemDecl other = return other
1110

1111 1112 1113 1114 1115
-- convert a basic mapper with an initial argument to a stateful mapper
stately :: (Eq s, Show s) => (s -> Mapper a) -> MapperM (State s) a
stately mapper thing = do
    s <- get
    return $ mapper s thing
1116 1117 1118 1119 1120 1121 1122 1123 1124 1125

-- In many conversions, we want to resolve items locally first, and then fall
-- back to looking at other source files, if necessary. This helper captures
-- this behavior, allowing a conversion to fall back to arbitrary global
-- collected item, if one exists. While this isn't foolproof (we could
-- inadvertently resolve a name that doesn't exist in the given file), many
-- projects rely on their toolchain to locate their modules, interfaces,
-- packages, or typenames in other files. Global resolution of modules and
-- interfaces is more commonly expected than global resolution of typenames and
-- packages.
1126 1127
traverseFilesM
    :: (Monoid w, Monad m)
1128
    => CollectorM (Writer w) AST
1129 1130 1131 1132
    -> (w -> MapperM m AST)
    -> MapperM m [AST]
traverseFilesM fileCollectorM fileMapperM files =
    mapM traverseFileM files
1133 1134
    where
        globalNotes = execWriter $ mapM fileCollectorM files
1135 1136
        traverseFileM file =
            fileMapperM notes file
1137 1138 1139
            where
                localNotes = execWriter $ fileCollectorM file
                notes = localNotes <> globalNotes
1140 1141 1142 1143 1144 1145 1146 1147
traverseFiles
    :: Monoid w
    => CollectorM (Writer w) AST
    -> (w -> Mapper AST)
    -> Mapper [AST]
traverseFiles fileCollectorM fileMapper files =
    evalState (traverseFilesM fileCollectorM fileMapperM  files) ()
    where fileMapperM = (\w -> return . fileMapper w)