Commit 1dfa9a9e by Zachary Snow

simplify struct conversion

parent 6b81f87a
...@@ -18,36 +18,23 @@ import Language.SystemVerilog.AST ...@@ -18,36 +18,23 @@ import Language.SystemVerilog.AST
type TypeFunc = [Range] -> Type type TypeFunc = [Range] -> Type
type StructInfo = (Type, Map.Map Identifier (Range, Expr)) type StructInfo = (Type, Map.Map Identifier (Range, Expr))
type Structs = Map.Map TypeFunc StructInfo
type Types = Map.Map Identifier Type type Types = Map.Map Identifier Type
type Idents = Set.Set Identifier
convert :: [AST] -> [AST] convert :: [AST] -> [AST]
convert = map $ traverseDescriptions convertDescription convert = map $ traverseDescriptions convertDescription
convertDescription :: Description -> Description convertDescription :: Description -> Description
convertDescription (description @ (Part _ _ Module _ _ _ _)) = convertDescription (description @ (Part _ _ Module _ _ _ _)) =
traverseModuleItems (traverseTypes' ExcludeParamTypes $ convertType structs) $ traverseModuleItems (traverseTypes' ExcludeParamTypes convertType) $
Part attrs extern kw lifetime name ports (items ++ funcs)
where
description' @ (Part attrs extern kw lifetime name ports items) =
scopedConversion traverseDeclM' traverseModuleItemM scopedConversion traverseDeclM' traverseModuleItemM
traverseStmtM tfArgTypes description traverseStmtM tfArgTypes description
where
-- collect information about this description -- collect information about this description
structs = execWriter $ collectModuleItemsM
(collectTypesM collectStructM) description
tfArgTypes = execWriter $ collectModuleItemsM collectTFArgsM description tfArgTypes = execWriter $ collectModuleItemsM collectTFArgsM description
-- determine which of the packer functions we actually need
calledFuncs = execWriter $ collectModuleItemsM
(collectExprsM $ collectNestedExprsM collectCallsM) description'
packerFuncs = Set.map packerFnName $ Map.keysSet structs
calledPackedFuncs = Set.intersection calledFuncs packerFuncs
funcs = map packerFn $ filter isNeeded $ Map.keys structs
isNeeded tf = Set.member (packerFnName tf) calledPackedFuncs
-- helpers for the scoped traversal -- helpers for the scoped traversal
traverseDeclM' :: Decl -> State Types Decl traverseDeclM' :: Decl -> State Types Decl
traverseDeclM' decl = do traverseDeclM' decl = do
decl' <- traverseDeclM structs decl decl' <- traverseDeclM decl
res <- traverseModuleItemM $ MIPackageItem $ Decl decl' res <- traverseModuleItemM $ MIPackageItem $ Decl decl'
let MIPackageItem (Decl decl'') = res let MIPackageItem (Decl decl'') = res
return decl'' return decl''
...@@ -59,8 +46,7 @@ convertDescription (description @ (Part _ _ Module _ _ _ _)) = ...@@ -59,8 +46,7 @@ convertDescription (description @ (Part _ _ Module _ _ _ _)) =
traverseStmtM :: Stmt -> State Types Stmt traverseStmtM :: Stmt -> State Types Stmt
traverseStmtM (Subroutine expr args) = do traverseStmtM (Subroutine expr args) = do
stateTypes <- get stateTypes <- get
let stmt' = Subroutine expr $ convertCall let stmt' = Subroutine expr $ convertCall stateTypes expr args
structs stateTypes expr args
traverseStmtM' stmt' traverseStmtM' stmt'
traverseStmtM stmt = traverseStmtM' stmt traverseStmtM stmt = traverseStmtM' stmt
traverseStmtM' :: Stmt -> State Types Stmt traverseStmtM' :: Stmt -> State Types Stmt
...@@ -73,35 +59,32 @@ convertDescription (description @ (Part _ _ Module _ _ _ _)) = ...@@ -73,35 +59,32 @@ convertDescription (description @ (Part _ _ Module _ _ _ _)) =
where where
converter :: Types -> Expr -> Expr converter :: Types -> Expr -> Expr
converter types expr = converter types expr =
snd $ convertAsgn structs types (LHSIdent "", expr) snd $ convertAsgn types (LHSIdent "", expr)
traverseLHSM = traverseLHSM =
traverseNestedLHSsM $ stately converter traverseNestedLHSsM $ stately converter
where where
converter :: Types -> LHS -> LHS converter :: Types -> LHS -> LHS
converter types lhs = converter types lhs =
fst $ convertAsgn structs types (lhs, Ident "") fst $ convertAsgn types (lhs, Ident "")
traverseAsgnM = stately $ convertAsgn structs traverseAsgnM = stately convertAsgn
convertDescription other = other convertDescription other = other
-- write down unstructured versions of packed struct types -- write down unstructured versions of packed struct types
collectStructM :: Type -> Writer Structs ()
collectStructM (Struct Unpacked fields _) = convertStruct :: Type -> Maybe StructInfo
collectStructM' (Struct Unpacked) True Unspecified fields convertStruct (Struct Unpacked fields _) =
collectStructM (Struct (Packed sg) fields _) = convertStruct' True Unspecified fields
collectStructM' (Struct $ Packed sg) True sg fields convertStruct (Struct (Packed sg) fields _) =
collectStructM (Union (Packed sg) fields _) = convertStruct' True sg fields
collectStructM' (Union $ Packed sg) False sg fields convertStruct (Union (Packed sg) fields _) =
collectStructM _ = return () convertStruct' False sg fields
convertStruct _ = Nothing
collectStructM'
:: ([Field] -> [Range] -> Type) convertStruct' :: Bool -> Signing -> [Field] -> Maybe StructInfo
-> Bool -> Signing -> [Field] -> Writer Structs () convertStruct' isStruct sg fields =
collectStructM' constructor isStruct sg fields = do
if canUnstructure if canUnstructure
then tell $ Map.singleton then Just (unstructType, unstructFields)
(constructor fields) else Nothing
(unstructType, unstructFields)
else return ()
where where
zero = Number "0" zero = Number "0"
typeRange :: Type -> Range typeRange :: Type -> Range
...@@ -152,20 +135,18 @@ collectStructM' constructor isStruct sg fields = do ...@@ -152,20 +135,18 @@ collectStructM' constructor isStruct sg fields = do
isFlatIntVec _ = False isFlatIntVec _ = False
canUnstructure = all isFlatIntVec fieldTypes canUnstructure = all isFlatIntVec fieldTypes
isReadyStruct :: Type -> Bool
isReadyStruct = (Nothing /=) . convertStruct
-- convert a struct type to its unstructured equivalent -- convert a struct type to its unstructured equivalent
convertType :: Structs -> Type -> Type convertType :: Type -> Type
convertType structs t1 = convertType t1 =
case Map.lookup tf1 structs of case convertStruct t1 of
Nothing -> t1 Nothing -> t1
Just (t2, _) -> tf2 (rs1 ++ rs2) Just (t2, _) -> tf2 (rs1 ++ rs2)
where (tf2, rs2) = typeRanges t2 where (tf2, rs2) = typeRanges t2
where (tf1, rs1) = typeRanges t1 where (_, rs1) = typeRanges t1
-- writes down the names of called functions
collectCallsM :: Expr -> Writer Idents ()
collectCallsM (Call (Ident f) _) = tell $ Set.singleton f
collectCallsM _ = return ()
collectTFArgsM :: ModuleItem -> Writer Types () collectTFArgsM :: ModuleItem -> Writer Types ()
collectTFArgsM (MIPackageItem item) = do collectTFArgsM (MIPackageItem item) = do
...@@ -186,8 +167,8 @@ collectTFArgsM (MIPackageItem item) = do ...@@ -186,8 +167,8 @@ collectTFArgsM (MIPackageItem item) = do
collectTFArgsM _ = return () collectTFArgsM _ = return ()
-- write down the types of declarations -- write down the types of declarations
traverseDeclM :: Structs -> Decl -> State Types Decl traverseDeclM :: Decl -> State Types Decl
traverseDeclM structs origDecl = do traverseDeclM origDecl = do
case origDecl of case origDecl of
Variable d t x a e -> do Variable d t x a e -> do
let (tf, rs) = typeRanges t let (tf, rs) = typeRanges t
...@@ -206,30 +187,13 @@ traverseDeclM structs origDecl = do ...@@ -206,30 +187,13 @@ traverseDeclM structs origDecl = do
convertDeclExpr :: Identifier -> Expr -> State Types Expr convertDeclExpr :: Identifier -> Expr -> State Types Expr
convertDeclExpr x e = do convertDeclExpr x e = do
types <- get types <- get
let (LHSIdent _, e') = convertAsgn structs types (LHSIdent x, e) let (LHSIdent _, e') = convertAsgn types (LHSIdent x, e)
return e' return e'
isRangeable :: Type -> Bool isRangeable :: Type -> Bool
isRangeable (IntegerAtom _ _) = False isRangeable (IntegerAtom _ _) = False
isRangeable (NonInteger _ ) = False isRangeable (NonInteger _ ) = False
isRangeable _ = True isRangeable _ = True
-- produces a function which packs the components of a struct literal
packerFn :: TypeFunc -> ModuleItem
packerFn structTf =
MIPackageItem $
Function Automatic (structTf []) fnName decls [retStmt]
where
Struct _ fields [] = structTf []
toInput (t, x) = Variable Input t x [] Nil
decls = map toInput fields
retStmt = Return $ Concat $ map (Ident . snd) fields
fnName = packerFnName structTf
-- returns a "unique" name for the packer for a given struct type
packerFnName :: TypeFunc -> Identifier
packerFnName structTf =
"sv2v_struct_" ++ shortHash structTf
-- removes the innermost range from the given type, if possible -- removes the innermost range from the given type, if possible
dropInnerTypeRange :: Type -> Type dropInnerTypeRange :: Type -> Type
dropInnerTypeRange t = dropInnerTypeRange t =
...@@ -243,8 +207,8 @@ dropInnerTypeRange t = ...@@ -243,8 +207,8 @@ dropInnerTypeRange t =
-- looking at the innermost type of a node to convert outer uses of fields, and -- looking at the innermost type of a node to convert outer uses of fields, and
-- then using the outermost type to figure out the corresponding struct -- then using the outermost type to figure out the corresponding struct
-- definition for struct literals that are encountered. -- definition for struct literals that are encountered.
convertAsgn :: Structs -> Types -> (LHS, Expr) -> (LHS, Expr) convertAsgn :: Types -> (LHS, Expr) -> (LHS, Expr)
convertAsgn structs types (lhs, expr) = convertAsgn types (lhs, expr) =
(lhs', expr') (lhs', expr')
where where
(typ, lhs') = convertLHS lhs (typ, lhs') = convertLHS lhs
...@@ -311,10 +275,10 @@ convertAsgn structs types (lhs, expr) = ...@@ -311,10 +275,10 @@ convertAsgn structs types (lhs, expr) =
" has extra named fields: " ++ " has extra named fields: " ++
show (Set.toList extraNames) ++ " that are not in " ++ show (Set.toList extraNames) ++ " that are not in " ++
show structTf show structTf
else if Map.member structTf structs then else if isReadyStruct (structTf []) then
Call Concat
(Ident $ packerFnName structTf) $ map (uncurry $ Cast . Left)
(Args (map snd items) []) $ zip (map fst fields) (map snd items)
else else
Pattern items Pattern items
where where
...@@ -397,7 +361,7 @@ convertAsgn structs types (lhs, expr) = ...@@ -397,7 +361,7 @@ convertAsgn structs types (lhs, expr) =
convertSubExpr (Dot e x) = convertSubExpr (Dot e x) =
if maybeFields == Nothing if maybeFields == Nothing
then (Implicit Unspecified [], Dot e' x) then (Implicit Unspecified [], Dot e' x)
else if Map.notMember structTf structs else if not $ isReadyStruct (structTf [])
then (fieldType, Dot e' x) then (fieldType, Dot e' x)
else (dropInnerTypeRange fieldType, undotted) else (dropInnerTypeRange fieldType, undotted)
where where
...@@ -414,7 +378,7 @@ convertAsgn structs types (lhs, expr) = ...@@ -414,7 +378,7 @@ convertAsgn structs types (lhs, expr) =
convertSubExpr (Range (Dot e x) NonIndexed rOuter) = convertSubExpr (Range (Dot e x) NonIndexed rOuter) =
if maybeFields == Nothing if maybeFields == Nothing
then (Implicit Unspecified [], orig') then (Implicit Unspecified [], orig')
else if Map.notMember structTf structs else if not $ isReadyStruct (structTf [])
then (fieldType, orig') then (fieldType, orig')
else (dropInnerTypeRange fieldType, undotted) else (dropInnerTypeRange fieldType, undotted)
where where
...@@ -435,7 +399,7 @@ convertAsgn structs types (lhs, expr) = ...@@ -435,7 +399,7 @@ convertAsgn structs types (lhs, expr) =
convertSubExpr (Range (Dot e x) mode (baseO, lenO)) = convertSubExpr (Range (Dot e x) mode (baseO, lenO)) =
if maybeFields == Nothing if maybeFields == Nothing
then (Implicit Unspecified [], orig') then (Implicit Unspecified [], orig')
else if Map.notMember structTf structs else if not $ isReadyStruct (structTf [])
then (fieldType, orig') then (fieldType, orig')
else (dropInnerTypeRange fieldType, undotted) else (dropInnerTypeRange fieldType, undotted)
where where
...@@ -463,7 +427,7 @@ convertAsgn structs types (lhs, expr) = ...@@ -463,7 +427,7 @@ convertAsgn structs types (lhs, expr) =
convertSubExpr (Bit (Dot e x) i) = convertSubExpr (Bit (Dot e x) i) =
if maybeFields == Nothing if maybeFields == Nothing
then (Implicit Unspecified [], Bit (Dot e' x) i) then (Implicit Unspecified [], Bit (Dot e' x) i)
else if Map.notMember structTf structs else if not $ isReadyStruct (structTf [])
then (dropInnerTypeRange fieldType, Bit (Dot e' x) i) then (dropInnerTypeRange fieldType, Bit (Dot e' x) i)
else (dropInnerTypeRange fieldType, Bit e' i') else (dropInnerTypeRange fieldType, Bit e' i')
where where
...@@ -481,7 +445,7 @@ convertAsgn structs types (lhs, expr) = ...@@ -481,7 +445,7 @@ convertAsgn structs types (lhs, expr) =
(t, e') = convertSubExpr e (t, e') = convertSubExpr e
t' = dropInnerTypeRange t t' = dropInnerTypeRange t
convertSubExpr (Call e args) = convertSubExpr (Call e args) =
(retType, Call e $ convertCall structs types e' args) (retType, Call e $ convertCall types e' args)
where where
(_, e') = convertSubExpr e (_, e') = convertSubExpr e
retType = case e' of retType = case e' of
...@@ -514,7 +478,9 @@ convertAsgn structs types (lhs, expr) = ...@@ -514,7 +478,9 @@ convertAsgn structs types (lhs, expr) =
Nothing -> error $ "field '" ++ fieldName ++ Nothing -> error $ "field '" ++ fieldName ++
"' not found in struct: " ++ show structTf "' not found in struct: " ++ show structTf
Just r -> r Just r -> r
where fieldRangeMap = Map.map fst $ snd $ structs Map.! structTf where
Just structInfo = convertStruct $ structTf []
fieldRangeMap = Map.map fst $ snd structInfo
-- lookup the type of a field in the given field list -- lookup the type of a field in the given field list
lookupFieldType :: [(Type, Identifier)] -> Identifier -> Type lookupFieldType :: [(Type, Identifier)] -> Identifier -> Type
...@@ -538,8 +504,8 @@ convertAsgn structs types (lhs, expr) = ...@@ -538,8 +504,8 @@ convertAsgn structs types (lhs, expr) =
dims = snd $ typeRanges fieldType dims = snd $ typeRanges fieldType
-- attempts to convert based on the assignment-like contexts of TF arguments -- attempts to convert based on the assignment-like contexts of TF arguments
convertCall :: Structs -> Types -> Expr -> Args -> Args convertCall :: Types -> Expr -> Args -> Args
convertCall structs types fn (Args pnArgs kwArgs) = convertCall types fn (Args pnArgs kwArgs) =
case fn of case fn of
Ident _ -> args Ident _ -> args
_ -> Args pnArgs kwArgs _ -> Args pnArgs kwArgs
...@@ -552,6 +518,5 @@ convertCall structs types fn (Args pnArgs kwArgs) = ...@@ -552,6 +518,5 @@ convertCall structs types fn (Args pnArgs kwArgs) =
convertArg :: (Identifier, Expr) -> (Identifier, Expr) convertArg :: (Identifier, Expr) -> (Identifier, Expr)
convertArg (x, e) = (x, e') convertArg (x, e) = (x, e')
where where
(_, e') = convertAsgn structs types (_, e') = convertAsgn types
(LHSIdent $ f ++ ":" ++ x, e) (LHSIdent $ f ++ ":" ++ x, e)
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment