Traverse.hs 47.3 KB
Newer Older
1 2 3 4 5 6 7 8 9
{- sv2v
 - Author: Zachary Snow <zach@zachjs.com>
 -
 - Utilities for traversing AST transformations.
 -}

module Convert.Traverse
( MapperM
, Mapper
10 11
, CollectorM
, TFStrategy (..)
12
, unmonad
13
, collectify
14 15
, traverseDescriptionsM
, traverseDescriptions
16
, collectDescriptionsM
17 18
, traverseModuleItemsM
, traverseModuleItems
19
, collectModuleItemsM
20 21
, traverseStmtsM
, traverseStmts
22
, collectStmtsM
23 24 25
, traverseStmtsM'
, traverseStmts'
, collectStmtsM'
26 27 28
, traverseStmtLHSsM
, traverseStmtLHSs
, collectStmtLHSsM
29 30 31
, traverseExprsM
, traverseExprs
, collectExprsM
32 33 34
, traverseExprsM'
, traverseExprs'
, collectExprsM'
35 36 37
, traverseStmtExprsM
, traverseStmtExprs
, collectStmtExprsM
38 39 40
, traverseLHSsM
, traverseLHSs
, collectLHSsM
41 42 43
, traverseLHSsM'
, traverseLHSs'
, collectLHSsM'
44 45 46
, traverseDeclsM
, traverseDecls
, collectDeclsM
47 48 49
, traverseDeclsM'
, traverseDecls'
, collectDeclsM'
50 51 52
, traverseNestedTypesM
, traverseNestedTypes
, collectNestedTypesM
53 54 55
, traverseTypesM
, traverseTypes
, collectTypesM
56 57 58
, traverseGenItemsM
, traverseGenItems
, collectGenItemsM
59 60 61
, traverseAsgnsM
, traverseAsgns
, collectAsgnsM
62 63 64
, traverseAsgnsM'
, traverseAsgns'
, collectAsgnsM'
65 66 67
, traverseStmtAsgnsM
, traverseStmtAsgns
, collectStmtAsgnsM
68 69 70
, traverseNestedModuleItemsM
, traverseNestedModuleItems
, collectNestedModuleItemsM
71
, traverseNestedStmts
72
, collectNestedStmtsM
73
, traverseNestedExprsM
74 75 76 77 78
, traverseNestedExprs
, collectNestedExprsM
, traverseNestedLHSsM
, traverseNestedLHSs
, collectNestedLHSsM
79
, traverseScopesM
80
, scopedConversion
81
, scopedConversionM
82
, stately
83
, traverseFilesM
84
, traverseFiles
85 86
) where

87
import Data.Functor.Identity (runIdentity)
88
import Control.Monad.State
89
import Control.Monad.Writer
90 91
import Language.SystemVerilog.AST

92
type MapperM m t = t -> m t
93
type Mapper t = t -> t
94
type CollectorM m t = t -> m ()
95

96 97 98 99 100
data TFStrategy
    = IncludeTFs
    | ExcludeTFs
    deriving Eq

101
unmonad :: (MapperM (State ()) a -> MapperM (State ()) b) -> Mapper a -> Mapper b
102 103 104
unmonad traverser mapper thing =
    evalState (traverser (return . mapper) thing) ()

105 106 107 108 109 110
collectify :: Monad m => (MapperM m a -> MapperM m b) -> CollectorM m a -> CollectorM m b
collectify traverser collector thing =
    traverser mapper thing >>= \_ -> return ()
    where mapper x = collector x >>= \() -> return x

traverseDescriptionsM :: Monad m => MapperM m Description -> MapperM m AST
111 112 113 114 115
traverseDescriptionsM mapper descriptions =
    mapM mapper descriptions

traverseDescriptions :: Mapper Description -> Mapper AST
traverseDescriptions = unmonad traverseDescriptionsM
116 117
collectDescriptionsM :: Monad m => CollectorM m Description -> CollectorM m AST
collectDescriptionsM = collectify traverseDescriptionsM
118

Zachary Snow committed
119 120 121 122
maybeDo :: Monad m => (a -> m b) -> Maybe a -> m (Maybe b)
maybeDo _ Nothing = return Nothing
maybeDo fun (Just val) = fun val >>= return . Just

123
traverseModuleItemsM :: Monad m => MapperM m ModuleItem -> MapperM m Description
124
traverseModuleItemsM mapper (Part attrs extern kw lifetime name ports items) = do
125 126
    items' <- mapM fullMapper items
    let items'' = concatMap breakGenerate items'
127
    return $ Part attrs extern kw lifetime name ports items''
128
    where
129
        fullMapper (Generate [GenBlock "" genItems]) =
130
            mapM fullGenItemMapper genItems >>= mapper . Generate
131 132 133
        fullMapper (Generate genItems) = do
            let genItems' = filter (/= GenNull) genItems
            mapM fullGenItemMapper genItems' >>= mapper . Generate
134 135
        fullMapper (MIAttr attr mi) =
            fullMapper mi >>= return . MIAttr attr
136
        fullMapper other = mapper other
137
        fullGenItemMapper = traverseNestedGenItemsM genItemMapper
138 139 140
        genItemMapper (GenModuleItem moduleItem) = do
            moduleItem' <- fullMapper moduleItem
            return $ case moduleItem' of
141
                Generate subItems -> GenBlock "" subItems
142
                _ -> GenModuleItem moduleItem'
143 144 145
        genItemMapper (GenIf (Number "1") s _) = return s
        genItemMapper (GenIf (Number "0") _ s) = return s
        genItemMapper (GenBlock _ []) = return GenNull
146
        genItemMapper other = return other
147 148 149 150 151 152 153 154 155 156
        breakGenerate :: ModuleItem -> [ModuleItem]
        breakGenerate (Generate genItems) =
            if all isGenModuleItem genItems
                then map (\(GenModuleItem item) -> item) genItems
                else [Generate genItems]
            where
                isGenModuleItem :: GenItem -> Bool
                isGenModuleItem (GenModuleItem _) = True
                isGenModuleItem _ = False
        breakGenerate other = [other]
157 158
traverseModuleItemsM mapper (PackageItem packageItem) = do
    let item = MIPackageItem packageItem
159
    converted <-
160
        traverseModuleItemsM mapper (Part [] False Module Nothing "DNE" [] [item])
161
    let item' = case converted of
162
            Part [] False Module Nothing "DNE" [] [newItem] -> newItem
163 164
            _ -> error $ "redirected PackageItem traverse failed: "
                    ++ show converted
165 166 167
    return $ case item' of
        MIPackageItem packageItem' -> PackageItem packageItem'
        other -> error $ "encountered bad package module item: " ++ show other
168 169
traverseModuleItemsM mapper (Package lifetime name packageItems) = do
    let items = map MIPackageItem packageItems
170
    converted <-
171
        traverseModuleItemsM mapper (Part [] False Module Nothing "DNE" [] items)
172
    let items' = case converted of
173
            Part [] False Module Nothing "DNE" [] newItems -> newItems
174 175
            _ -> error $ "redirected Package traverse failed: "
                    ++ show converted
176
    return $ Package lifetime name $ map (\(MIPackageItem item) -> item) items'
177 178 179

traverseModuleItems :: Mapper ModuleItem -> Mapper Description
traverseModuleItems = unmonad traverseModuleItemsM
180 181
collectModuleItemsM :: Monad m => CollectorM m ModuleItem -> CollectorM m Description
collectModuleItemsM = collectify traverseModuleItemsM
182

183 184
traverseStmtsM' :: Monad m => TFStrategy -> MapperM m Stmt -> MapperM m ModuleItem
traverseStmtsM' strat mapper = moduleItemMapper
185 186 187
    where
        moduleItemMapper (AlwaysC kw stmt) =
            fullMapper stmt >>= return . AlwaysC kw
188
        moduleItemMapper (MIPackageItem (Function lifetime ret name decls stmts)) = do
189 190 191 192
            stmts' <-
                if strat == IncludeTFs
                    then mapM fullMapper stmts
                    else return stmts
193
            return $ MIPackageItem $ Function lifetime ret name decls stmts'
194
        moduleItemMapper (MIPackageItem (Task lifetime name decls stmts)) = do
195 196 197 198
            stmts' <-
                if strat == IncludeTFs
                    then mapM fullMapper stmts
                    else return stmts
199
            return $ MIPackageItem $ Task lifetime name decls stmts'
200 201
        moduleItemMapper (Initial stmt) =
            fullMapper stmt >>= return . Initial
202
        moduleItemMapper other = return $ other
203 204
        fullMapper = traverseNestedStmtsM mapper

205 206 207 208 209 210 211
traverseStmts' :: TFStrategy -> Mapper Stmt -> Mapper ModuleItem
traverseStmts' strat = unmonad $ traverseStmtsM' strat
collectStmtsM' :: Monad m => TFStrategy -> CollectorM m Stmt -> CollectorM m ModuleItem
collectStmtsM' strat = collectify $ traverseStmtsM' strat

traverseStmtsM :: Monad m => MapperM m Stmt -> MapperM m ModuleItem
traverseStmtsM = traverseStmtsM' IncludeTFs
212
traverseStmts :: Mapper Stmt -> Mapper ModuleItem
213
traverseStmts = traverseStmts' IncludeTFs
214
collectStmtsM :: Monad m => CollectorM m Stmt -> CollectorM m ModuleItem
215
collectStmtsM = collectStmtsM' IncludeTFs
216 217 218 219 220 221 222

-- private utility for turning a thing which maps over a single lever of
-- statements into one that maps over the nested statements first, then the
-- higher levels up
traverseNestedStmtsM :: Monad m => MapperM m Stmt -> MapperM m Stmt
traverseNestedStmtsM mapper = fullMapper
    where
223 224
        fullMapper stmt = mapper stmt >>= traverseSinglyNestedStmtsM fullMapper

225
-- variant of the above which only traverses one level down
226 227 228
traverseSinglyNestedStmtsM :: Monad m => MapperM m Stmt -> MapperM m Stmt
traverseSinglyNestedStmtsM fullMapper = cs
    where
229
        cs (StmtAttr a stmt) = fullMapper stmt >>= return . StmtAttr a
230
        cs (Block _ "" [] []) = return Null
Zachary Snow committed
231
        cs (Block _ "" [] [stmt]) = fullMapper stmt
232 233 234 235 236 237 238
        cs (Block Seq name decls stmts) = do
            stmts' <- mapM fullMapper stmts
            return $ Block Seq name decls $ concatMap explode stmts'
            where
                explode :: Stmt -> [Stmt]
                explode (Block Seq "" [] ss) = ss
                explode other = [other]
239 240
        cs (Block kw name decls stmts) =
            mapM fullMapper stmts >>= return . Block kw name decls
241
        cs (Case u kw expr cases def) = do
242 243
            caseStmts <- mapM fullMapper $ map snd cases
            let cases' = zip (map fst cases) caseStmts
Zachary Snow committed
244
            def' <- maybeDo fullMapper def
245
            return $ Case u kw expr cases' def'
246
        cs (AsgnBlk op lhs expr) = return $ AsgnBlk op lhs expr
247
        cs (Asgn    mt lhs expr) = return $ Asgn    mt lhs expr
248
        cs (For a b c stmt) = fullMapper stmt >>= return . For a b c
249 250 251 252
        cs (While   e stmt) = fullMapper stmt >>= return . While   e
        cs (RepeatL e stmt) = fullMapper stmt >>= return . RepeatL e
        cs (DoWhile e stmt) = fullMapper stmt >>= return . DoWhile e
        cs (Forever   stmt) = fullMapper stmt >>= return . Forever
253
        cs (Foreach x vars stmt) = fullMapper stmt >>= return . Foreach x vars
254
        cs (If u e s1 s2) = do
255 256
            s1' <- fullMapper s1
            s2' <- fullMapper s2
257
            return $ If u e s1' s2'
258
        cs (Timing event stmt) = fullMapper stmt >>= return . Timing event
259
        cs (Return expr) = return $ Return expr
260
        cs (Subroutine expr exprs) = return $ Subroutine expr exprs
261
        cs (Trigger blocks x) = return $ Trigger blocks x
262 263
        cs (Assertion a) =
            traverseAssertionStmtsM fullMapper a >>= return . Assertion
264 265
        cs (Continue) = return Continue
        cs (Break) = return Break
266 267
        cs (Null) = return Null

268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351
traverseAssertionStmtsM :: Monad m => MapperM m Stmt -> MapperM m Assertion
traverseAssertionStmtsM mapper = assertionMapper
    where
        actionBlockMapper (ActionBlockIf stmt) =
            mapper stmt >>= return . ActionBlockIf
        actionBlockMapper (ActionBlockElse Nothing stmt) =
            mapper stmt >>= return . ActionBlockElse Nothing
        actionBlockMapper (ActionBlockElse (Just s1) s2) = do
            s1' <- mapper s1
            s2' <- mapper s2
            return $ ActionBlockElse (Just s1') s2'
        assertionMapper (Assert e ab) =
            actionBlockMapper ab >>= return . Assert e
        assertionMapper (Assume e ab) =
            actionBlockMapper ab >>= return . Assume e
        assertionMapper (Cover e stmt) =
            mapper stmt >>= return . Cover e

-- Note that this does not include the expressions without the statements of the
-- actions associated with the assertions.
traverseAssertionExprsM :: Monad m => MapperM m Expr -> MapperM m Assertion
traverseAssertionExprsM mapper = assertionMapper
    where
        seqExprMapper (SeqExpr e) =
            mapper e >>= return . SeqExpr
        seqExprMapper (SeqExprAnd        s1 s2) =
            ssMapper   SeqExprAnd        s1 s2
        seqExprMapper (SeqExprOr         s1 s2) =
            ssMapper   SeqExprOr         s1 s2
        seqExprMapper (SeqExprIntersect  s1 s2) =
            ssMapper   SeqExprIntersect  s1 s2
        seqExprMapper (SeqExprWithin     s1 s2) =
            ssMapper   SeqExprWithin     s1 s2
        seqExprMapper (SeqExprThroughout e s) = do
            e' <- mapper e
            s' <- seqExprMapper s
            return $ SeqExprThroughout e' s'
        seqExprMapper (SeqExprDelay ms e s) = do
            ms' <- case ms of
                Nothing -> return Nothing
                Just x -> seqExprMapper x >>= return . Just
            e' <- mapper e
            s' <- seqExprMapper s
            return $ SeqExprDelay ms' e' s'
        seqExprMapper (SeqExprFirstMatch s items) = do
            s' <- seqExprMapper s
            items' <- mapM seqMatchItemMapper items
            return $ SeqExprFirstMatch s' items'
        seqMatchItemMapper (Left (a, b, c)) = do
            c' <- mapper c
            return $ Left (a, b, c')
        seqMatchItemMapper (Right (x, (Args l p))) = do
            l' <- mapM maybeExprMapper l
            pes <- mapM maybeExprMapper $ map snd p
            let p' = zip (map fst p) pes
            return $ Right (x, Args l' p')
        maybeExprMapper Nothing = return Nothing
        maybeExprMapper (Just e) =
            mapper e >>= return . Just
        ppMapper constructor p1 p2 = do
            p1' <- propExprMapper p1
            p2' <- propExprMapper p2
            return $ constructor p1' p2'
        ssMapper constructor s1 s2 = do
            s1' <- seqExprMapper s1
            s2' <- seqExprMapper s2
            return $ constructor s1' s2'
        spMapper constructor se pe = do
            se' <- seqExprMapper se
            pe' <- propExprMapper pe
            return $ constructor se' pe'
        propExprMapper (PropExpr se) =
            seqExprMapper se >>= return . PropExpr
        propExprMapper (PropExprImpliesO se pe) =
            spMapper PropExprImpliesO se pe
        propExprMapper (PropExprImpliesNO se pe) =
            spMapper PropExprImpliesNO se pe
        propExprMapper (PropExprFollowsO se pe) =
            spMapper PropExprFollowsO se pe
        propExprMapper (PropExprFollowsNO se pe) =
            spMapper PropExprFollowsNO se pe
        propExprMapper (PropExprIff p1 p2) =
            ppMapper PropExprIff p1 p2
        propSpecMapper (PropertySpec ms me pe) = do
Zachary Snow committed
352
            me' <- maybeExprMapper me
353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368
            pe' <- propExprMapper pe
            return $ PropertySpec ms me' pe'
        assertionExprMapper (Left e) =
            propSpecMapper e >>= return . Left
        assertionExprMapper (Right e) =
            mapper e >>= return . Right
        assertionMapper (Assert e ab) = do
            e' <- assertionExprMapper e
            return $ Assert e' ab
        assertionMapper (Assume e ab) = do
            e' <- assertionExprMapper e
            return $ Assume e' ab
        assertionMapper (Cover e stmt) = do
            e' <- assertionExprMapper e
            return $ Cover e' stmt

369
traverseStmtLHSsM :: Monad m => MapperM m LHS -> MapperM m Stmt
370
traverseStmtLHSsM mapper = stmtMapper
371
    where
372
        fullMapper = mapper
373 374 375
        stmtMapper (Timing (Event sense) stmt) = do
            sense' <- senseMapper sense
            return $ Timing (Event sense') stmt
376 377 378 379
        stmtMapper (Asgn (Just (Event sense)) lhs expr) = do
            lhs' <- fullMapper lhs
            sense' <- senseMapper sense
            return $ Asgn (Just $ Event sense') lhs' expr
380
        stmtMapper (AsgnBlk op lhs expr) = fullMapper lhs >>= \lhs' -> return $ AsgnBlk op lhs' expr
381
        stmtMapper (Asgn    mt lhs expr) = fullMapper lhs >>= \lhs' -> return $ Asgn    mt lhs' expr
382
        stmtMapper (For inits me incrs stmt) = do
383
            inits' <- mapInits inits
384 385 386 387 388
            let (lhss, asgnOps, exprs) = unzip3 incrs
            lhss' <- mapM fullMapper lhss
            let incrs' = zip3 lhss' asgnOps exprs
            return $ For inits' me incrs' stmt
            where
389 390 391 392 393
                mapInits (Left decls) = return $ Left decls
                mapInits (Right asgns) = do
                    let (lhss, exprs) = unzip asgns
                    lhss' <- mapM fullMapper lhss
                    return $ Right $ zip lhss' exprs
394 395
        stmtMapper (Assertion a) =
            assertionMapper a >>= return . Assertion
396
        stmtMapper other = return other
397 398 399 400 401 402 403 404
        senseMapper (Sense        lhs) = fullMapper lhs >>= return . Sense
        senseMapper (SensePosedge lhs) = fullMapper lhs >>= return . SensePosedge
        senseMapper (SenseNegedge lhs) = fullMapper lhs >>= return . SenseNegedge
        senseMapper (SenseOr    s1 s2) = do
            s1' <- senseMapper s1
            s2' <- senseMapper s2
            return $ SenseOr s1' s2'
        senseMapper (SenseStar       ) = return SenseStar
405 406 407 408 409 410 411 412 413 414 415 416 417
        assertionExprMapper (Left (PropertySpec (Just sense) me pe)) = do
            sense' <- senseMapper sense
            return $ Left $ PropertySpec (Just sense') me pe
        assertionExprMapper other = return $ other
        assertionMapper (Assert e ab) = do
            e' <- assertionExprMapper e
            return $ Assert e' ab
        assertionMapper (Assume e ab) = do
            e' <- assertionExprMapper e
            return $ Assume e' ab
        assertionMapper (Cover e stmt) = do
            e' <- assertionExprMapper e
            return $ Cover e' stmt
418 419 420 421 422

traverseStmtLHSs :: Mapper LHS -> Mapper Stmt
traverseStmtLHSs = unmonad traverseStmtLHSsM
collectStmtLHSsM :: Monad m => CollectorM m LHS -> CollectorM m Stmt
collectStmtLHSsM = collectify traverseStmtLHSsM
423 424 425 426 427

traverseNestedExprsM :: Monad m => MapperM m Expr -> MapperM m Expr
traverseNestedExprsM mapper = exprMapper
    where
        exprMapper e = mapper e >>= em
428 429 430
        maybeExprMapper Nothing = return Nothing
        maybeExprMapper (Just e) =
            exprMapper e >>= return . Just
431 432 433
        typeOrExprMapper (Left t) = return $ Left t
        typeOrExprMapper (Right e) =
            exprMapper e >>= return . Right
434 435
        em (String s) = return $ String s
        em (Number s) = return $ Number s
436
        em (Time   s) = return $ Time   s
437
        em (Ident  i) = return $ Ident  i
438
        em (PSIdent x y) = return $ PSIdent x y
439
        em (Range e m (e1, e2)) = do
440 441 442
            e' <- exprMapper e
            e1' <- exprMapper e1
            e2' <- exprMapper e2
443
            return $ Range e' m (e1', e2')
444
        em (Bit   e1 e2) = do
445 446
            e1' <- exprMapper e1
            e2' <- exprMapper e2
447
            return $ Bit e1' e2'
448 449 450 451 452 453
        em (Repeat     e l) = do
            e' <- exprMapper e
            l' <- mapM exprMapper l
            return $ Repeat e' l'
        em (Concat     l) =
            mapM exprMapper l >>= return . Concat
454 455 456 457
        em (Stream o e l) = do
            e' <- exprMapper e
            l' <- mapM exprMapper l
            return $ Stream o e' l'
458 459
        em (Call  e (Args l p)) = do
            e' <- exprMapper e
460 461 462
            l' <- mapM maybeExprMapper l
            pes <- mapM maybeExprMapper $ map snd p
            let p' = zip (map fst p) pes
463
            return $ Call e' (Args l' p')
464 465 466 467 468 469 470 471 472 473 474
        em (UniOp      o e) =
            exprMapper e >>= return . UniOp o
        em (BinOp      o e1 e2) = do
            e1' <- exprMapper e1
            e2' <- exprMapper e2
            return $ BinOp o e1' e2'
        em (Mux        e1 e2 e3) = do
            e1' <- exprMapper e1
            e2' <- exprMapper e2
            e3' <- exprMapper e3
            return $ Mux e1' e2' e3'
475 476 477 478 479 480
        em (Cast (Left t) e) =
            exprMapper e >>= return . Cast (Left t)
        em (Cast (Right e1) e2) = do
            e1' <- exprMapper e1
            e2' <- exprMapper e2
            return $ Cast (Right e1') e2'
481 482 483 484 485 486
        em (DimsFn f tore) =
            typeOrExprMapper tore >>= return . DimsFn f
        em (DimFn f tore e) = do
            tore' <- typeOrExprMapper tore
            e' <- exprMapper e
            return $ DimFn f tore' e'
487 488
        em (Dot e x) =
            exprMapper e >>= \e' -> return $ Dot e' x
489
        em (Pattern l) = do
490 491
            let names = map fst l
            exprs <- mapM exprMapper $ map snd l
492
            return $ Pattern $ zip names exprs
493 494 495 496 497
        em (MinTypMax e1 e2 e3) = do
            e1' <- exprMapper e1
            e2' <- exprMapper e2
            e3' <- exprMapper e3
            return $ MinTypMax e1' e2' e3'
498
        em (Nil) = return Nil
499

500
exprMapperHelpers :: Monad m => MapperM m Expr ->
501
    (MapperM m Range, MapperM m (Maybe Expr), MapperM m Decl, MapperM m LHS, MapperM m Type)
502
exprMapperHelpers exprMapper =
503
    (rangeMapper, maybeExprMapper, declMapper, traverseNestedLHSsM lhsMapper, typeMapper)
504 505 506 507 508 509 510 511 512 513 514
    where

    rangeMapper (a, b) = do
        a' <- exprMapper a
        b' <- exprMapper b
        return (a', b')

    maybeExprMapper Nothing = return Nothing
    maybeExprMapper (Just e) =
        exprMapper e >>= return . Just

515
    typeMapper' t = do
516 517
        let (tf, rs) = typeRanges t
        rs' <- mapM rangeMapper rs
518 519 520
        return $ tf rs'
    typeMapper = traverseNestedTypesM typeMapper'

521 522 523 524 525
    maybeTypeMapper Nothing = return Nothing
    maybeTypeMapper (Just t) =
        typeMapper t >>= return . Just

    declMapper (Param s t x e) = do
526 527
        t' <- typeMapper t
        e' <- exprMapper e
528 529 530 531
        return $ Param s t' x e'
    declMapper (ParamType s x mt) = do
        mt' <- maybeTypeMapper mt
        return $ ParamType s x mt'
532 533
    declMapper (Variable d t x a me) = do
        t' <- typeMapper t
534 535
        a' <- mapM rangeMapper a
        me' <- maybeExprMapper me
536
        return $ Variable d t' x a' me'
537

538 539 540 541
    lhsMapper (LHSRange l m r) =
        rangeMapper r >>= return . LHSRange l m
    lhsMapper (LHSBit l e) =
        exprMapper e >>= return . LHSBit l
542 543 544
    lhsMapper (LHSStream o e ls) = do
        e' <- exprMapper e
        return $ LHSStream o e' ls
545 546
    lhsMapper other = return other

547 548 549
traverseExprsM' :: Monad m => TFStrategy -> MapperM m Expr -> MapperM m ModuleItem
traverseExprsM' strat exprMapper = moduleItemMapper
    where
550

551
    (rangeMapper, maybeExprMapper, declMapper, lhsMapper, typeMapper)
552
        = exprMapperHelpers exprMapper
553

554
    stmtMapper = traverseNestedStmtsM (traverseStmtExprsM exprMapper)
555

556 557 558
    portBindingMapper (p, me) =
        maybeExprMapper me >>= \me' -> return (p, me')

559 560 561 562 563
    paramBindingMapper (p, Left t) =
        typeMapper t >>= \t' -> return (p, Left t')
    paramBindingMapper (p, Right e) =
        exprMapper e >>= \e' -> return (p, Right e')

564 565 566
    moduleItemMapper (MIAttr attr mi) =
        -- note: we exclude expressions in attributes from conversion
        return $ MIAttr attr mi
567 568 569
    moduleItemMapper (MIPackageItem (Typedef t x)) = do
        t' <- typeMapper t
        return $ MIPackageItem $ Typedef t' x
570 571
    moduleItemMapper (MIPackageItem (Decl decl)) =
        declMapper decl >>= return . MIPackageItem . Decl
572 573 574 575
    moduleItemMapper (Defparam lhs expr) = do
        lhs' <- lhsMapper lhs
        expr' <- exprMapper expr
        return $ Defparam lhs' expr'
576 577
    moduleItemMapper (AlwaysC kw stmt) =
        stmtMapper stmt >>= return . AlwaysC kw
578 579
    moduleItemMapper (Initial stmt) =
        stmtMapper stmt >>= return . Initial
580 581
    moduleItemMapper (Assign delay lhs expr) = do
        delay' <- maybeExprMapper delay
582
        lhs' <- lhsMapper lhs
583
        expr' <- exprMapper expr
584
        return $ Assign delay' lhs' expr'
585
    moduleItemMapper (MIPackageItem (Function lifetime ret f decls stmts)) = do
586
        ret' <- typeMapper ret
587 588 589 590 591 592 593 594
        decls' <-
            if strat == IncludeTFs
                then mapM declMapper decls
                else return decls
        stmts' <-
            if strat == IncludeTFs
                then mapM stmtMapper stmts
                else return stmts
595
        return $ MIPackageItem $ Function lifetime ret' f decls' stmts'
596
    moduleItemMapper (MIPackageItem (Task lifetime f decls stmts)) = do
597 598 599 600 601 602 603 604
        decls' <-
            if strat == IncludeTFs
                then mapM declMapper decls
                else return decls
        stmts' <-
            if strat == IncludeTFs
                then mapM stmtMapper stmts
                else return stmts
605
        return $ MIPackageItem $ Task lifetime f decls' stmts'
606
    moduleItemMapper (Instance m p x r l) = do
607
        p' <- mapM paramBindingMapper p
608
        l' <- mapM portBindingMapper l
609
        r' <- mapM rangeMapper r
610
        return $ Instance m p' x r' l'
611 612
    moduleItemMapper (Modport x l) =
        mapM modportDeclMapper l >>= return . Modport x
613 614
    moduleItemMapper (NInputGate  kw x lhs exprs) = do
        exprs' <- mapM exprMapper exprs
615 616 617 618 619 620
        lhs' <- lhsMapper lhs
        return $ NInputGate kw x lhs' exprs'
    moduleItemMapper (NOutputGate kw x lhss expr) = do
        lhss' <- mapM lhsMapper lhss
        expr' <- exprMapper expr
        return $ NOutputGate kw x lhss' expr'
621
    moduleItemMapper (Genvar   x) = return $ Genvar   x
622 623 624
    moduleItemMapper (Generate items) = do
        items' <- mapM (traverseNestedGenItemsM genItemMapper) items
        return $ Generate items'
625 626
    moduleItemMapper (MIPackageItem (Directive c)) =
        return $ MIPackageItem $ Directive c
627 628
    moduleItemMapper (MIPackageItem (Comment c)) =
        return $ MIPackageItem $ Comment c
629 630
    moduleItemMapper (MIPackageItem (Import x y)) =
        return $ MIPackageItem $ Import x y
631 632
    moduleItemMapper (MIPackageItem (Export x)) =
        return $ MIPackageItem $ Export x
633 634 635 636
    moduleItemMapper (AssertionItem (mx, a)) = do
        a' <- traverseAssertionStmtsM stmtMapper a
        a'' <- traverseAssertionExprsM exprMapper a'
        return $ AssertionItem (mx, a'')
637

638
    genItemMapper (GenFor (n1, x1, e1) cc (x2, op2, e2) subItem) = do
639 640 641
        e1' <- exprMapper e1
        e2' <- exprMapper e2
        cc' <- exprMapper cc
642
        return $ GenFor (n1, x1, e1') cc' (x2, op2, e2') subItem
643 644 645 646 647 648 649 650 651 652
    genItemMapper (GenIf e i1 i2) = do
        e' <- exprMapper e
        return $ GenIf e' i1 i2
    genItemMapper (GenCase e cases def) = do
        e' <- exprMapper e
        caseExprs <- mapM (mapM exprMapper . fst) cases
        let cases' = zip caseExprs (map snd cases)
        return $ GenCase e' cases' def
    genItemMapper other = return other

653 654 655 656
    modportDeclMapper (dir, ident, Just e) = do
        e' <- exprMapper e
        return (dir, ident, Just e')
    modportDeclMapper other = return other
657

658 659 660 661 662 663 664
traverseExprs' :: TFStrategy -> Mapper Expr -> Mapper ModuleItem
traverseExprs' strat = unmonad $ traverseExprsM' strat
collectExprsM' :: Monad m => TFStrategy -> CollectorM m Expr -> CollectorM m ModuleItem
collectExprsM' strat = collectify $ traverseExprsM' strat

traverseExprsM :: Monad m => MapperM m Expr -> MapperM m ModuleItem
traverseExprsM = traverseExprsM' IncludeTFs
665
traverseExprs :: Mapper Expr -> Mapper ModuleItem
666
traverseExprs = traverseExprs' IncludeTFs
667
collectExprsM :: Monad m => CollectorM m Expr -> CollectorM m ModuleItem
668
collectExprsM = collectExprsM' IncludeTFs
669

670 671 672 673
traverseStmtExprsM :: Monad m => MapperM m Expr -> MapperM m Stmt
traverseStmtExprsM exprMapper = flatStmtMapper
    where

674
    (_, maybeExprMapper, declMapper, lhsMapper, _)
675 676 677 678 679 680 681 682 683
        = exprMapperHelpers exprMapper

    caseMapper (exprs, stmt) = do
        exprs' <- mapM exprMapper exprs
        return (exprs', stmt)
    stmtMapper = traverseNestedStmtsM flatStmtMapper
    flatStmtMapper (StmtAttr attr stmt) =
        -- note: we exclude expressions in attributes from conversion
        return $ StmtAttr attr stmt
684
    flatStmtMapper (Block kw name decls stmts) = do
685
        decls' <- mapM declMapper decls
686
        return $ Block kw name decls' stmts
687 688 689 690
    flatStmtMapper (Case u kw e cases def) = do
        e' <- exprMapper e
        cases' <- mapM caseMapper cases
        return $ Case u kw e' cases' def
691 692 693 694 695 696 697 698
    flatStmtMapper (AsgnBlk op lhs expr) = do
        lhs' <- lhsMapper lhs
        expr' <- exprMapper expr
        return $ AsgnBlk op lhs' expr'
    flatStmtMapper (Asgn    mt lhs expr) = do
        lhs' <- lhsMapper lhs
        expr' <- exprMapper expr
        return $ Asgn    mt lhs' expr'
699
    flatStmtMapper (For inits cc asgns stmt) = do
700 701
        inits' <- initsMapper inits
        cc' <- exprMapper cc
702 703 704 705 706 707 708 709 710
        asgns' <- mapM asgnMapper asgns
        return $ For inits' cc' asgns' stmt
    flatStmtMapper (While   e stmt) =
        exprMapper e >>= \e' -> return $ While   e' stmt
    flatStmtMapper (RepeatL e stmt) =
        exprMapper e >>= \e' -> return $ RepeatL e' stmt
    flatStmtMapper (DoWhile e stmt) =
        exprMapper e >>= \e' -> return $ DoWhile e' stmt
    flatStmtMapper (Forever   stmt) = return $ Forever stmt
711
    flatStmtMapper (Foreach x vars stmt) = return $ Foreach x vars stmt
712 713 714
    flatStmtMapper (If u cc s1 s2) =
        exprMapper cc >>= \cc' -> return $ If u cc' s1 s2
    flatStmtMapper (Timing event stmt) = return $ Timing event stmt
715 716
    flatStmtMapper (Subroutine e (Args l p)) = do
        e' <- exprMapper e
717 718 719
        l' <- mapM maybeExprMapper l
        pes <- mapM maybeExprMapper $ map snd p
        let p' = zip (map fst p) pes
720
        return $ Subroutine e' (Args l' p')
721 722
    flatStmtMapper (Return expr) =
        exprMapper expr >>= return . Return
723
    flatStmtMapper (Trigger blocks x) = return $ Trigger blocks x
724 725 726 727
    flatStmtMapper (Assertion a) = do
        a' <- traverseAssertionStmtsM stmtMapper a
        a'' <- traverseAssertionExprsM exprMapper a'
        return $ Assertion a''
728 729
    flatStmtMapper (Continue) = return Continue
    flatStmtMapper (Break) = return Break
730 731
    flatStmtMapper (Null) = return Null

732 733 734
    initsMapper (Left decls) = mapM declMapper decls >>= return . Left
    initsMapper (Right asgns) = mapM mapper asgns >>= return . Right
        where mapper (l, e) = exprMapper e >>= return . (,) l
735 736 737 738 739 740 741 742

    asgnMapper (l, op, e) = exprMapper e >>= \e' -> return $ (l, op, e')

traverseStmtExprs :: Mapper Expr -> Mapper Stmt
traverseStmtExprs = unmonad traverseStmtExprsM
collectStmtExprsM :: Monad m => CollectorM m Expr -> CollectorM m Stmt
collectStmtExprsM = collectify traverseStmtExprsM

743 744 745
traverseLHSsM' :: Monad m => TFStrategy -> MapperM m LHS -> MapperM m ModuleItem
traverseLHSsM' strat mapper item =
    traverseStmtsM' strat (traverseStmtLHSsM mapper) item >>= traverseModuleItemLHSsM
746
    where
747
        traverseModuleItemLHSsM (Assign delay lhs expr) = do
748
            lhs' <- mapper lhs
749
            return $ Assign delay lhs' expr
750
        traverseModuleItemLHSsM (Defparam lhs expr) = do
751
            lhs' <- mapper lhs
752
            return $ Defparam lhs' expr
753 754 755 756 757 758
        traverseModuleItemLHSsM (NOutputGate kw x lhss expr) = do
            lhss' <- mapM mapper lhss
            return $ NOutputGate kw x lhss' expr
        traverseModuleItemLHSsM (NInputGate  kw x lhs exprs) = do
            lhs' <- mapper lhs
            return $ NInputGate kw x lhs' exprs
759
        traverseModuleItemLHSsM (AssertionItem (mx, a)) = do
760 761 762 763 764 765
            converted <-
                traverseNestedStmtsM (traverseStmtLHSsM mapper) (Assertion a)
            return $ case converted of
                Assertion a' -> AssertionItem (mx, a')
                _ -> error $ "redirected AssertionItem traverse failed: "
                        ++ show converted
766 767 768
        traverseModuleItemLHSsM (Generate items) = do
            items' <- mapM (traverseNestedGenItemsM traverGenItemLHSsM) items
            return $ Generate items'
769
        traverseModuleItemLHSsM other = return other
770
        traverGenItemLHSsM (GenFor (n1, x1, e1) cc (x2, op2, e2) subItem) = do
771 772 773 774
            wrapped_x1' <- (if n1 then return else mapper) $ LHSIdent x1
            wrapped_x2' <- mapper $ LHSIdent x2
            let LHSIdent x1' = wrapped_x1'
            let LHSIdent x2' = wrapped_x2'
775
            return $ GenFor (n1, x1', e1) cc (x2', op2, e2) subItem
776
        traverGenItemLHSsM other = return other
777

778 779 780 781 782 783 784
traverseLHSs' :: TFStrategy -> Mapper LHS -> Mapper ModuleItem
traverseLHSs' strat = unmonad $ traverseLHSsM' strat
collectLHSsM' :: Monad m => TFStrategy -> CollectorM m LHS -> CollectorM m ModuleItem
collectLHSsM' strat = collectify $ traverseLHSsM' strat

traverseLHSsM :: Monad m => MapperM m LHS -> MapperM m ModuleItem
traverseLHSsM = traverseLHSsM' IncludeTFs
785
traverseLHSs :: Mapper LHS -> Mapper ModuleItem
786
traverseLHSs = traverseLHSs' IncludeTFs
787
collectLHSsM :: Monad m => CollectorM m LHS -> CollectorM m ModuleItem
788
collectLHSsM = collectLHSsM' IncludeTFs
789

790 791 792
traverseNestedLHSsM :: Monad m => MapperM m LHS -> MapperM m LHS
traverseNestedLHSsM mapper = fullMapper
    where
793
        fullMapper lhs = mapper lhs >>= tl
794 795 796 797 798 799
        tl (LHSIdent  x       ) = return $ LHSIdent x
        tl (LHSBit    l e     ) = fullMapper l >>= \l' -> return $ LHSBit    l' e
        tl (LHSRange  l m r   ) = fullMapper l >>= \l' -> return $ LHSRange  l' m r
        tl (LHSDot    l x     ) = fullMapper l >>= \l' -> return $ LHSDot    l' x
        tl (LHSConcat     lhss) = mapM fullMapper lhss >>= return . LHSConcat
        tl (LHSStream o e lhss) = mapM fullMapper lhss >>= return . LHSStream o e
800

801 802 803 804 805
traverseNestedLHSs :: Mapper LHS -> Mapper LHS
traverseNestedLHSs = unmonad traverseNestedLHSsM
collectNestedLHSsM :: Monad m => CollectorM m LHS -> CollectorM m LHS
collectNestedLHSsM = collectify traverseNestedLHSsM

806 807
traverseDeclsM' :: Monad m => TFStrategy -> MapperM m Decl -> MapperM m ModuleItem
traverseDeclsM' strat mapper item = do
808 809
    item' <- miMapper item
    traverseStmtsM' strat stmtMapper item'
810
    where
811 812
        miMapper (MIPackageItem (Decl decl)) =
            mapper decl >>= return . MIPackageItem . Decl
813
        miMapper (MIPackageItem (Function l t x decls stmts)) = do
814 815 816 817
            decls' <-
                if strat == IncludeTFs
                    then mapM mapper decls
                    else return decls
818 819
            return $ MIPackageItem $ Function l t x decls' stmts
        miMapper (MIPackageItem (Task l x decls stmts)) = do
820 821 822 823
            decls' <-
                if strat == IncludeTFs
                    then mapM mapper decls
                    else return decls
824 825
            return $ MIPackageItem $ Task l x decls' stmts
        miMapper other = return other
826
        stmtMapper (Block kw name decls stmts) = do
827
            decls' <- mapM mapper decls
828
            return $ Block kw name decls' stmts
829
        stmtMapper other = return other
830

831 832 833 834 835 836 837
traverseDecls' :: TFStrategy -> Mapper Decl -> Mapper ModuleItem
traverseDecls' strat = unmonad $ traverseDeclsM' strat
collectDeclsM' :: Monad m => TFStrategy -> CollectorM m Decl -> CollectorM m ModuleItem
collectDeclsM' strat = collectify $ traverseDeclsM' strat

traverseDeclsM :: Monad m => MapperM m Decl -> MapperM m ModuleItem
traverseDeclsM = traverseDeclsM' IncludeTFs
838
traverseDecls :: Mapper Decl -> Mapper ModuleItem
839
traverseDecls = traverseDecls' IncludeTFs
840
collectDeclsM :: Monad m => CollectorM m Decl -> CollectorM m ModuleItem
841
collectDeclsM = collectDeclsM' IncludeTFs
842

843 844
traverseNestedTypesM :: Monad m => MapperM m Type -> MapperM m Type
traverseNestedTypesM mapper = fullMapper
845
    where
846
        fullMapper t = tm t >>= mapper
847
        tm (Alias      ps xx    rs) = return $ Alias      ps xx    rs
848
        tm (Net           kw sg rs) = return $ Net           kw sg rs
849 850 851 852
        tm (Implicit         sg rs) = return $ Implicit         sg rs
        tm (IntegerVector kw sg rs) = return $ IntegerVector kw sg rs
        tm (IntegerAtom   kw sg   ) = return $ IntegerAtom   kw sg
        tm (NonInteger    kw      ) = return $ NonInteger    kw
853 854 855 856 857 858 859 860 861 862
        tm (InterfaceT x my r) = return $ InterfaceT x my r
        tm (Enum Nothing vals r) =
            return $ Enum Nothing vals r
        tm (Enum (Just t) vals r) = do
            t' <- fullMapper t
            return $ Enum (Just t') vals r
        tm (Struct p fields r) = do
            types <- mapM fullMapper $ map fst fields
            let idents = map snd fields
            return $ Struct p (zip types idents) r
863 864 865 866
        tm (Union p fields r) = do
            types <- mapM fullMapper $ map fst fields
            let idents = map snd fields
            return $ Union p (zip types idents) r
867

868 869 870 871 872
traverseNestedTypes :: Mapper Type -> Mapper Type
traverseNestedTypes = unmonad traverseNestedTypesM
collectNestedTypesM :: Monad m => CollectorM m Type -> CollectorM m Type
collectNestedTypesM = collectify traverseNestedTypesM

873 874 875 876 877 878 879
traverseTypesM :: Monad m => MapperM m Type -> MapperM m ModuleItem
traverseTypesM mapper item =
    miMapper item >>=
    traverseDeclsM declMapper >>=
    traverseExprsM (traverseNestedExprsM exprMapper)
    where
        fullMapper = traverseNestedTypesM mapper
880 881
        maybeMapper Nothing = return Nothing
        maybeMapper (Just t) = fullMapper t >>= return . Just
882 883 884
        typeOrExprMapper (Right e) = return $ Right e
        typeOrExprMapper (Left t) =
            fullMapper t >>= return . Left
885 886
        exprMapper (Cast (Left t) e) =
            fullMapper t >>= \t' -> return $ Cast (Left t') e
887 888 889 890 891
        exprMapper (DimsFn f tore) =
            typeOrExprMapper tore >>= return . DimsFn f
        exprMapper (DimFn f tore e) = do
            tore' <- typeOrExprMapper tore
            return $ DimFn f tore' e
892
        exprMapper other = return other
893 894 895 896
        declMapper (Param s t x e) =
            fullMapper t >>= \t' -> return $ Param s t' x e
        declMapper (ParamType s x mt) =
            maybeMapper mt >>= \mt' -> return $ ParamType s x mt'
897
        declMapper (Variable d t x a me) =
898
            fullMapper t >>= \t' -> return $ Variable d t' x a me
899 900
        miMapper (MIPackageItem (Typedef t x)) =
            fullMapper t >>= \t' -> return $ MIPackageItem $ Typedef t' x
901 902
        miMapper (MIPackageItem (Function l t x d s)) =
            fullMapper t >>= \t' -> return $ MIPackageItem $ Function l t' x d s
903 904
        miMapper (MIPackageItem (other @ (Task _ _ _ _))) =
            return $ MIPackageItem other
905 906 907 908 909 910 911
        miMapper (Instance m params x r p) = do
            params' <- mapM mapParam params
            return $ Instance m params' x r p
            where
                mapParam (i, Left t) =
                    fullMapper t >>= \t' -> return (i, Left t')
                mapParam (i, Right e) = return $ (i, Right e)
912
        miMapper other = return other
913 914 915 916 917

traverseTypes :: Mapper Type -> Mapper ModuleItem
traverseTypes = unmonad traverseTypesM
collectTypesM :: Monad m => CollectorM m Type -> CollectorM m ModuleItem
collectTypesM = collectify traverseTypesM
918 919 920 921 922 923 924 925 926 927 928 929 930 931 932 933 934 935 936

traverseGenItemsM :: Monad m => MapperM m GenItem -> MapperM m ModuleItem
traverseGenItemsM mapper = moduleItemMapper
    where
        fullMapper = traverseNestedGenItemsM mapper
        moduleItemMapper (Generate genItems) =
            mapM fullMapper genItems >>= return . Generate
        moduleItemMapper other = return other

traverseGenItems :: Mapper GenItem -> Mapper ModuleItem
traverseGenItems = unmonad traverseGenItemsM
collectGenItemsM :: Monad m => CollectorM m GenItem -> CollectorM m ModuleItem
collectGenItemsM = collectify traverseGenItemsM

-- traverses all GenItems within a given GenItem, but doesn't inspect within
-- GenModuleItems
traverseNestedGenItemsM :: Monad m => MapperM m GenItem -> MapperM m GenItem
traverseNestedGenItemsM mapper = fullMapper
    where
937 938 939 940 941 942
        fullMapper stmt =
            mapper stmt >>= traverseSinglyNestedGenItemsM fullMapper

traverseSinglyNestedGenItemsM :: Monad m => MapperM m GenItem -> MapperM m GenItem
traverseSinglyNestedGenItemsM fullMapper = gim
    where
943 944 945
        gim (GenBlock x subItems) = do
            subItems' <- mapM fullMapper subItems
            return $ GenBlock x (concatMap flattenBlocks subItems')
946 947 948
        gim (GenFor a b c subItem) = do
            subItem' <- fullMapper subItem
            return $ GenFor a b c subItem'
949 950 951 952 953 954 955 956 957 958 959 960
        gim (GenIf e i1 i2) = do
            i1' <- fullMapper i1
            i2' <- fullMapper i2
            return $ GenIf e i1' i2'
        gim (GenCase e cases def) = do
            caseItems <- mapM (fullMapper . snd) cases
            let cases' = zip (map fst cases) caseItems
            def' <- maybeDo fullMapper def
            return $ GenCase e cases' def'
        gim (GenModuleItem moduleItem) =
            return $ GenModuleItem moduleItem
        gim (GenNull) = return GenNull
961
        flattenBlocks :: GenItem -> [GenItem]
962
        flattenBlocks (GenBlock "" items) = items
963
        flattenBlocks other = [other]
964

965 966
traverseAsgnsM' :: Monad m => TFStrategy -> MapperM m (LHS, Expr) -> MapperM m ModuleItem
traverseAsgnsM' strat mapper = moduleItemMapper
967 968 969
    where
        moduleItemMapper item = miMapperA item >>= miMapperB

970
        miMapperA (Assign delay lhs expr) = do
971
            (lhs', expr') <- mapper (lhs, expr)
972
            return $ Assign delay lhs' expr'
973 974 975
        miMapperA (Defparam lhs expr) = do
            (lhs', expr') <- mapper (lhs, expr)
            return $ Defparam lhs' expr'
976 977
        miMapperA other = return other

978
        miMapperB = traverseStmtsM' strat stmtMapper
979
        stmtMapper = traverseStmtAsgnsM mapper
980

981 982 983 984 985 986 987
traverseAsgns' :: TFStrategy -> Mapper (LHS, Expr) -> Mapper ModuleItem
traverseAsgns' strat = unmonad $ traverseAsgnsM' strat
collectAsgnsM' :: Monad m => TFStrategy -> CollectorM m (LHS, Expr) -> CollectorM m ModuleItem
collectAsgnsM' strat = collectify $ traverseAsgnsM' strat

traverseAsgnsM :: Monad m => MapperM m (LHS, Expr) -> MapperM m ModuleItem
traverseAsgnsM = traverseAsgnsM' IncludeTFs
988
traverseAsgns :: Mapper (LHS, Expr) -> Mapper ModuleItem
989
traverseAsgns = traverseAsgns' IncludeTFs
990
collectAsgnsM :: Monad m => CollectorM m (LHS, Expr) -> CollectorM m ModuleItem
991
collectAsgnsM = collectAsgnsM' IncludeTFs
992

993 994 995 996 997 998 999 1000 1001 1002 1003 1004 1005 1006 1007 1008
traverseStmtAsgnsM :: Monad m => MapperM m (LHS, Expr) -> MapperM m Stmt
traverseStmtAsgnsM mapper = stmtMapper
    where
        stmtMapper (AsgnBlk op lhs expr) = do
            (lhs', expr') <- mapper (lhs, expr)
            return $ AsgnBlk op lhs' expr'
        stmtMapper (Asgn    mt lhs expr) = do
            (lhs', expr') <- mapper (lhs, expr)
            return $ Asgn    mt lhs' expr'
        stmtMapper other = return other

traverseStmtAsgns :: Mapper (LHS, Expr) -> Mapper Stmt
traverseStmtAsgns = unmonad traverseStmtAsgnsM
collectStmtAsgnsM :: Monad m => CollectorM m (LHS, Expr) -> CollectorM m Stmt
collectStmtAsgnsM = collectify traverseStmtAsgnsM

1009 1010
traverseNestedModuleItemsM :: Monad m => MapperM m ModuleItem -> MapperM m ModuleItem
traverseNestedModuleItemsM mapper item = do
1011
    converted <-
1012
        traverseModuleItemsM mapper (Part [] False Module Nothing "DNE" [] [item])
1013
    let items' = case converted of
1014
            Part [] False Module Nothing "DNE" [] newItems -> newItems
1015 1016
            _ -> error $ "redirected NestedModuleItems traverse failed: "
                    ++ show converted
1017 1018 1019
    return $ case items' of
        [item'] -> item'
        _ -> Generate $ map GenModuleItem items'
1020 1021 1022 1023 1024 1025

traverseNestedModuleItems :: Mapper ModuleItem -> Mapper ModuleItem
traverseNestedModuleItems = unmonad traverseNestedModuleItemsM
collectNestedModuleItemsM :: Monad m => CollectorM m ModuleItem -> CollectorM m ModuleItem
collectNestedModuleItemsM = collectify traverseNestedModuleItemsM

1026 1027
traverseNestedStmts :: Mapper Stmt -> Mapper Stmt
traverseNestedStmts = unmonad traverseNestedStmtsM
1028 1029
collectNestedStmtsM :: Monad m => CollectorM m Stmt -> CollectorM m Stmt
collectNestedStmtsM = collectify traverseNestedStmtsM
1030 1031 1032 1033 1034

traverseNestedExprs :: Mapper Expr -> Mapper Expr
traverseNestedExprs = unmonad traverseNestedExprsM
collectNestedExprsM :: Monad m => CollectorM m Expr -> CollectorM m Expr
collectNestedExprsM = collectify traverseNestedExprsM
1035 1036 1037

-- Traverse all the declaration scopes within a ModuleItem. Note that Functions,
-- Tasks, Always and Initial blocks are all NOT passed through ModuleItem
1038
-- mapper, and Decl ModuleItems are NOT passed through the Decl mapper. The
1039 1040 1041 1042 1043 1044
-- state is restored to its previous value after each scope is exited. Only the
-- Decl mapper may modify the state, as we maintain the invariant that all other
-- functions restore the state on exit. The Stmt mapper must not traverse
-- statements recursively, as we add a recursive wrapper here.
traverseScopesM
    :: (Eq s, Show s)
1045 1046 1047 1048 1049
    => Monad m
    => MapperM (StateT s m) Decl
    -> MapperM (StateT s m) ModuleItem
    -> MapperM (StateT s m) Stmt
    -> MapperM (StateT s m) ModuleItem
1050 1051 1052 1053
traverseScopesM declMapper moduleItemMapper stmtMapper =
    fullModuleItemMapper
    where

1054 1055
        nestedStmtMapper stmt =
            stmtMapper stmt >>= traverseSinglyNestedStmtsM fullStmtMapper
1056
        fullStmtMapper (Block kw name decls stmts) = do
1057 1058
            prevState <- get
            decls' <- mapM declMapper decls
1059
            block <- nestedStmtMapper $ Block kw name decls' stmts
1060 1061
            put prevState
            return block
1062
        fullStmtMapper other = nestedStmtMapper other
1063 1064 1065 1066 1067 1068 1069 1070 1071 1072 1073 1074 1075 1076 1077 1078 1079 1080 1081 1082 1083 1084 1085 1086 1087

        redirectModuleItem (MIPackageItem (Function ml t x decls stmts)) = do
            prevState <- get
            t' <- do
                res <- declMapper $ Variable Local t x [] Nothing
                case res of
                    Variable Local newType _ [] Nothing -> return newType
                    _ -> error $ "redirected func ret traverse failed: " ++ show res
            decls' <- mapM declMapper decls
            stmts' <- mapM fullStmtMapper stmts
            put prevState
            return $ MIPackageItem $ Function ml t' x decls' stmts'
        redirectModuleItem (MIPackageItem (Task     ml   x decls stmts)) = do
            prevState <- get
            decls' <- mapM declMapper decls
            stmts' <- mapM fullStmtMapper stmts
            put prevState
            return $ MIPackageItem $ Task     ml    x decls' stmts'
        redirectModuleItem (AlwaysC kw stmt) =
            fullStmtMapper stmt >>= return . AlwaysC kw
        redirectModuleItem (Initial stmt) =
            fullStmtMapper stmt >>= return . Initial
        redirectModuleItem item =
            moduleItemMapper item

1088 1089 1090 1091
        -- This previously checked the invariant that the module item mappers
        -- should not modify the state. Now we simply "enforce" it but resetting
        -- the state to its previous value. Comparing the state, as we did
        -- previously, incurs a noticeable performance hit.
1092 1093 1094
        fullModuleItemMapper item = do
            prevState <- get
            item' <- redirectModuleItem item
1095 1096
            put prevState
            return item'
1097

1098 1099 1100 1101 1102 1103 1104 1105 1106 1107 1108
-- applies the given decl conversion across the description, and then performs a
-- scoped traversal for each ModuleItem in the description
scopedConversion
    :: (Eq s, Show s)
    => MapperM (State s) Decl
    -> MapperM (State s) ModuleItem
    -> MapperM (State s) Stmt
    -> s
    -> Description
    -> Description
scopedConversion traverseDeclM traverseModuleItemM traverseStmtM s description =
1109 1110 1111 1112 1113 1114 1115 1116 1117 1118 1119 1120 1121
    runIdentity $ scopedConversionM traverseDeclM traverseModuleItemM traverseStmtM s description

scopedConversionM
    :: (Eq s, Show s)
    => Monad m
    => MapperM (StateT s m) Decl
    -> MapperM (StateT s m) ModuleItem
    -> MapperM (StateT s m) Stmt
    -> s
    -> Description
    -> m Description
scopedConversionM traverseDeclM traverseModuleItemM traverseStmtM s description =
    evalStateT (initialTraverse description >>= scopedTraverse) s
1122
    where
1123
        initialTraverse = traverseModuleItemsM traverseMIPackageItemDecl
1124 1125
        scopedTraverse = traverseModuleItemsM $
            traverseScopesM traverseDeclM traverseModuleItemM traverseStmtM
1126 1127 1128
        traverseMIPackageItemDecl (MIPackageItem (Decl decl)) =
            traverseDeclM decl >>= return . MIPackageItem . Decl
        traverseMIPackageItemDecl other = return other
1129

1130 1131 1132 1133 1134
-- convert a basic mapper with an initial argument to a stateful mapper
stately :: (Eq s, Show s) => (s -> Mapper a) -> MapperM (State s) a
stately mapper thing = do
    s <- get
    return $ mapper s thing
1135 1136 1137 1138 1139 1140 1141 1142 1143 1144

-- In many conversions, we want to resolve items locally first, and then fall
-- back to looking at other source files, if necessary. This helper captures
-- this behavior, allowing a conversion to fall back to arbitrary global
-- collected item, if one exists. While this isn't foolproof (we could
-- inadvertently resolve a name that doesn't exist in the given file), many
-- projects rely on their toolchain to locate their modules, interfaces,
-- packages, or typenames in other files. Global resolution of modules and
-- interfaces is more commonly expected than global resolution of typenames and
-- packages.
1145 1146
traverseFilesM
    :: (Monoid w, Monad m)
1147
    => CollectorM (Writer w) AST
1148 1149 1150 1151
    -> (w -> MapperM m AST)
    -> MapperM m [AST]
traverseFilesM fileCollectorM fileMapperM files =
    mapM traverseFileM files
1152 1153
    where
        globalNotes = execWriter $ mapM fileCollectorM files
1154 1155
        traverseFileM file =
            fileMapperM notes file
1156 1157 1158
            where
                localNotes = execWriter $ fileCollectorM file
                notes = localNotes <> globalNotes
1159 1160 1161 1162 1163 1164 1165 1166
traverseFiles
    :: Monoid w
    => CollectorM (Writer w) AST
    -> (w -> Mapper AST)
    -> Mapper [AST]
traverseFiles fileCollectorM fileMapper files =
    evalState (traverseFilesM fileCollectorM fileMapperM  files) ()
    where fileMapperM = (\w -> return . fileMapper w)