Commit 7f2fe54b by Zachary Snow

fix jump statement conversion

parent fb5fd393
...@@ -10,30 +10,35 @@ ...@@ -10,30 +10,35 @@
module Convert.Jump (convert) where module Convert.Jump (convert) where
import Control.Monad.State import Control.Monad.State
import Control.Monad.Writer
import Convert.Traverse import Convert.Traverse
import Language.SystemVerilog.AST import Language.SystemVerilog.AST
data JumpType
= JTNone
| JTContinue
| JTBreak
| JTReturn
deriving (Eq, Ord, Show)
data Info = Info data Info = Info
{ sJumpType :: JumpType { sLoopDepth :: Int
, sLoopID :: Identifier , sHasJump :: Bool
, sReturnAllowed :: Bool
, sJumpAllowed :: Bool
}
initialState :: Info
initialState = Info
{ sLoopDepth = 0
, sHasJump = False
, sReturnAllowed = False
, sJumpAllowed = True
} }
initialStateTF :: Info
initialStateTF = initialState { sReturnAllowed = True }
convert :: [AST] -> [AST] convert :: [AST] -> [AST]
convert = map $ traverseDescriptions $ traverseModuleItems convertModuleItem convert = map $ traverseDescriptions $ traverseModuleItems convertModuleItem
convertModuleItem :: ModuleItem -> ModuleItem convertModuleItem :: ModuleItem -> ModuleItem
convertModuleItem (MIPackageItem (Function ml t f decls stmtsOrig)) = convertModuleItem (MIPackageItem (Function ml t f decls stmtsOrig)) =
if sJumpType finalState == JTNone || sJumpType finalState == JTReturn MIPackageItem $ Function ml t f decls' stmts''
then MIPackageItem $ Function ml t f decls stmts'
else error "illegal jump statement within task"
where where
stmts = map (traverseNestedStmts convertReturn) stmtsOrig stmts = map (traverseNestedStmts convertReturn) stmtsOrig
convertReturn :: Stmt -> Stmt convertReturn :: Stmt -> Stmt
...@@ -44,37 +49,63 @@ convertModuleItem (MIPackageItem (Function ml t f decls stmtsOrig)) = ...@@ -44,37 +49,63 @@ convertModuleItem (MIPackageItem (Function ml t f decls stmtsOrig)) =
, Return Nil , Return Nil
] ]
convertReturn other = other convertReturn other = other
initialState = Info { sJumpType = JTNone, sLoopID = "" } stmts' = evalState (convertStmts stmts) initialStateTF
(stmts', finalState) = runState (convertStmts stmts) initialState (decls', stmts'') = addJumpStateDeclTF decls stmts'
convertModuleItem (MIPackageItem (Task ml f decls stmts)) = convertModuleItem (MIPackageItem (Task ml f decls stmts)) =
if sJumpType finalState == JTNone || sJumpType finalState == JTReturn MIPackageItem $ Task ml f decls' stmts''
then MIPackageItem $ Task ml f decls $ stmts'
else error "illegal jump statement within task"
where
initialState = Info { sJumpType = JTNone, sLoopID = "" }
(stmts', finalState) = runState (convertStmts stmts) initialState
convertModuleItem (Initial stmt) =
if sJumpType finalState == JTNone
then Initial stmt'
else error "illegal jump statement within initial construct"
where where
initialState = Info { sJumpType = JTNone, sLoopID = "" } stmts' = evalState (convertStmts stmts) initialStateTF
(stmt', finalState) = runState (convertStmt stmt) initialState (decls', stmts'') = addJumpStateDeclTF decls stmts'
convertModuleItem (Final stmt) = convertModuleItem (Initial stmt) = convertMIStmt Initial stmt
if sJumpType finalState == JTNone convertModuleItem (Final stmt) = convertMIStmt Final stmt
then Final stmt' convertModuleItem (AlwaysC kw stmt) = convertMIStmt (AlwaysC kw) stmt
else error "illegal jump statement within final construct" convertModuleItem other = other
convertMIStmt :: (Stmt -> ModuleItem) -> Stmt -> ModuleItem
convertMIStmt constructor stmt =
constructor stmt''
where where
initialState = Info { sJumpType = JTNone, sLoopID = "" } stmt' = evalState (convertStmt stmt) initialState
(stmt', finalState) = runState (convertStmt stmt) initialState stmt'' = addJumpStateDeclStmt stmt'
convertModuleItem (AlwaysC kw stmt) =
if sJumpType finalState == JTNone -- adds a declaration of the jump state variable if it is needed; if the jump
then AlwaysC kw stmt' -- state is not used at all, then it is removed from the given statements
else error "illegal jump statement within always construct" -- entirely
addJumpStateDeclTF :: [Decl] -> [Stmt] -> ([Decl], [Stmt])
addJumpStateDeclTF decls stmts =
if uses && not declares then
( decls ++
[Variable Local jumpStateType jumpState [] (Just jsNone)]
, stmts )
else if uses then
(decls, stmts)
else
(decls, map (traverseNestedStmts removeJumpState) stmts)
where where
initialState = Info { sJumpType = JTNone, sLoopID = "" } dummyModuleItem = Initial $ Block Seq "" decls stmts
(stmt', finalState) = runState (convertStmt stmt) initialState declares = elem jumpState $ execWriter $
convertModuleItem other = other collectDeclsM collectVarM dummyModuleItem
uses = elem jumpState $ execWriter $
collectExprsM (collectNestedExprsM collectExprIdentM) dummyModuleItem
collectVarM :: Decl -> Writer [String] ()
collectVarM (Variable Local _ ident _ _) = tell [ident]
collectVarM _ = return ()
collectExprIdentM :: Expr -> Writer [String] ()
collectExprIdentM (Ident ident) = tell [ident]
collectExprIdentM _ = return ()
addJumpStateDeclStmt :: Stmt -> Stmt
addJumpStateDeclStmt stmt =
if null decls
then stmt'
else Block Seq "" decls [stmt']
where (decls, [stmt']) = addJumpStateDeclTF [] [stmt]
removeJumpState :: Stmt -> Stmt
removeJumpState (orig @ (AsgnBlk _ (LHSIdent ident) _)) =
if ident == jumpState
then Null
else orig
removeJumpState other = other
convertStmts :: [Stmt] -> State Info [Stmt] convertStmts :: [Stmt] -> State Info [Stmt]
convertStmts stmts = do convertStmts stmts = do
...@@ -88,10 +119,11 @@ convertStmt :: Stmt -> State Info Stmt ...@@ -88,10 +119,11 @@ convertStmt :: Stmt -> State Info Stmt
convertStmt (Block Par x decls stmts) = do convertStmt (Block Par x decls stmts) = do
-- break, continue, and return disallowed in fork-join -- break, continue, and return disallowed in fork-join
modify $ \s -> s { sLoopID = "" } jumpAllowed <- gets sJumpAllowed
loopID <- gets sLoopID returnAllowed <- gets sReturnAllowed
modify $ \s -> s { sJumpAllowed = False, sReturnAllowed = False }
stmts' <- mapM convertStmt stmts stmts' <- mapM convertStmt stmts
modify $ \s -> s { sLoopID = loopID } modify $ \s -> s { sJumpAllowed = jumpAllowed, sReturnAllowed = returnAllowed }
return $ Block Par x decls stmts' return $ Block Par x decls stmts'
convertStmt (Block Seq x decls stmts) = do convertStmt (Block Seq x decls stmts) = do
...@@ -101,41 +133,35 @@ convertStmt (Block Seq x decls stmts) = do ...@@ -101,41 +133,35 @@ convertStmt (Block Seq x decls stmts) = do
step :: [Stmt] -> State Info [Stmt] step :: [Stmt] -> State Info [Stmt]
step [] = return [] step [] = return []
step (s : ss) = do step (s : ss) = do
jt <- gets sJumpType hasJump <- gets sHasJump
loopID <- gets sLoopID loopDepth <- gets sLoopDepth
if jt == JTNone then do modify $ \st -> st { sHasJump = False }
s' <- convertStmt s s' <- convertStmt s
jt' <- gets sJumpType currHasJump <- gets sHasJump
if jt' == JTNone || not (isBranch s) || null loopID then do currLoopDepth <- gets sLoopDepth
assertMsg (loopDepth == currLoopDepth) "loop depth invariant failed"
modify $ \st -> st { sHasJump = hasJump || currHasJump }
ss' <- step ss ss' <- step ss
return $ s' : ss' if currHasJump && not (null ss)
else do then do
modify $ \t -> t { sJumpType = JTNone } let comp = BinOp Eq (Ident jumpState) jsNone
ss' <- step ss
let comp = BinOp Eq (Ident loopID) runLoop
let stmt = Block Seq "" [] ss' let stmt = Block Seq "" [] ss'
modify $ \t -> t { sJumpType = jt' }
return [s', If NoCheck comp stmt Null] return [s', If NoCheck comp stmt Null]
else do else do
return [Null] return $ s' : ss'
isBranch :: Stmt -> Bool
isBranch (If{}) = True
isBranch (Case{}) = True
isBranch _ = False
convertStmt (If unique expr thenStmt elseStmt) = do convertStmt (If unique expr thenStmt elseStmt) = do
(thenStmt', thenJT) <- convertSubStmt thenStmt (thenStmt', thenHasJump) <- convertSubStmt thenStmt
(elseStmt', elseJT) <- convertSubStmt elseStmt (elseStmt', elseHasJump) <- convertSubStmt elseStmt
let newJT = max thenJT elseJT modify $ \s -> s { sHasJump = thenHasJump || elseHasJump }
modify $ \s -> s { sJumpType = newJT }
return $ If unique expr thenStmt' elseStmt' return $ If unique expr thenStmt' elseStmt'
convertStmt (Case unique kw expr cases) = do convertStmt (Case unique kw expr cases) = do
results <- mapM convertSubStmt $ map snd cases results <- mapM convertSubStmt $ map snd cases
let (stmts', jts) = unzip results let (stmts', hasJumps) = unzip results
let cases' = zip (map fst cases) stmts' let cases' = zip (map fst cases) stmts'
let newJT = foldl max JTNone jts let hasJump = foldl (||) False hasJumps
modify $ \s -> s { sJumpType = newJT } modify $ \s -> s { sHasJump = hasJump }
return $ Case unique kw expr cases' return $ Case unique kw expr cases'
convertStmt (For inits comp incr stmt) = convertStmt (For inits comp incr stmt) =
...@@ -147,33 +173,42 @@ convertStmt (DoWhile comp stmt) = ...@@ -147,33 +173,42 @@ convertStmt (DoWhile comp stmt) =
convertLoop DoWhile comp stmt convertLoop DoWhile comp stmt
convertStmt (Continue) = do convertStmt (Continue) = do
loopID <- gets sLoopID loopDepth <- gets sLoopDepth
modify $ \s -> s { sJumpType = JTContinue } jumpAllowed <- gets sJumpAllowed
assertMsg (not $ null loopID) "encountered continue outside of loop" assertMsg (loopDepth > 0) "encountered continue outside of loop"
return $ asgn loopID continueLoop assertMsg jumpAllowed "encountered continue inside fork-join"
modify $ \s -> s { sHasJump = True }
return $ asgn jumpState jsContinue
convertStmt (Break) = do convertStmt (Break) = do
loopID <- gets sLoopID loopDepth <- gets sLoopDepth
modify $ \s -> s { sJumpType = JTBreak } jumpAllowed <- gets sJumpAllowed
assertMsg (not $ null loopID) "encountered break outside of loop" assertMsg (loopDepth > 0) "encountered break outside of loop"
return $ asgn loopID exitLoop assertMsg jumpAllowed "encountered break inside fork-join"
modify $ \s -> s { sHasJump = True }
return $ asgn jumpState jsBreak
convertStmt (Return Nil) = do convertStmt (Return Nil) = do
loopID <- gets sLoopID jumpAllowed <- gets sJumpAllowed
modify $ \s -> s { sJumpType = JTReturn } returnAllowed <- gets sReturnAllowed
if null loopID assertMsg jumpAllowed "encountered return inside fork-join"
then return Null assertMsg returnAllowed "encountered return outside of task or function"
else return $ asgn loopID exitLoop modify $ \s -> s { sHasJump = True }
return $ asgn jumpState jsReturn
convertStmt (RepeatL expr stmt) = do convertStmt (RepeatL expr stmt) = do
modify $ \s -> s { sLoopID = "repeat" } loopDepth <- gets sLoopDepth
modify $ \s -> s { sLoopDepth = loopDepth + 1 }
stmt' <- convertStmt stmt stmt' <- convertStmt stmt
jt <- gets sJumpType hasJump <- gets sHasJump
assertMsg (jt == JTNone) "jumps not supported within repeat loops" assertMsg (not hasJump) "jumps not supported within repeat loops"
modify $ \s -> s { sLoopDepth = loopDepth }
return $ RepeatL expr stmt' return $ RepeatL expr stmt'
convertStmt (Forever stmt) = do convertStmt (Forever stmt) = do
modify $ \s -> s { sLoopID = "forever" } loopDepth <- gets sLoopDepth
modify $ \s -> s { sLoopDepth = loopDepth + 1 }
stmt' <- convertStmt stmt stmt' <- convertStmt stmt
jt <- gets sJumpType hasJump <- gets sHasJump
assertMsg (jt == JTNone) "jumps not supported within forever loops" assertMsg (not hasJump) "jumps not supported within forever loops"
modify $ \s -> s { sLoopDepth = loopDepth }
return $ Forever stmt' return $ Forever stmt'
convertStmt (Timing timing stmt) = convertStmt (Timing timing stmt) =
...@@ -188,54 +223,74 @@ convertStmt (Foreach{}) = return $ ...@@ -188,54 +223,74 @@ convertStmt (Foreach{}) = return $
convertStmt other = return other convertStmt other = return other
-- convert a statement on its own without changing the state, but returning
-- convert a statement on its own without changing the state, but returning the -- whether or not the statement contains a jump; used to reconcile across
-- resulting jump type; used to reconcile across branching statements -- branching statements
convertSubStmt :: Stmt -> State Info (Stmt, JumpType) convertSubStmt :: Stmt -> State Info (Stmt, Bool)
convertSubStmt stmt = do convertSubStmt stmt = do
origState <- get origState <- get
stmt' <- convertStmt stmt stmt' <- convertStmt stmt
jt <- gets sJumpType hasJump <- gets sHasJump
put origState put origState
if sJumpType origState == JTNone return (stmt', hasJump)
then return (stmt', jt)
else error $ "convertStmt invariant failed on: " ++ show stmt
convertLoop :: (Expr -> Stmt -> Stmt) -> Expr -> Stmt -> State Info Stmt convertLoop :: (Expr -> Stmt -> Stmt) -> Expr -> Stmt -> State Info Stmt
convertLoop loop comp stmt = do convertLoop loop comp stmt = do
Info { sJumpType = origJT, sLoopID = origLoopID } <- get -- save the loop state and increment loop depth
let loopID = (++) "_sv2v_loop_" $ shortHash $ loop comp stmt Info { sLoopDepth = origLoopDepth, sHasJump = origHasJump } <- get
modify $ \s -> s { sLoopID = loopID } assertMsg (not origHasJump) "has jump invariant failed"
modify $ \s -> s { sLoopDepth = origLoopDepth + 1 }
-- convert the loop body
stmt' <- convertStmt stmt stmt' <- convertStmt stmt
jt <- gets sJumpType -- restore the loop state
let afterJT = if jt == JTReturn then jt else origJT Info { sLoopDepth = afterLoopDepth, sHasJump = afterHasJump } <- get
put $ Info { sJumpType = afterJT, sLoopID = origLoopID } assertMsg (origLoopDepth + 1 == afterLoopDepth) "loop depth invariant failed"
let comp' = BinOp LogAnd (BinOp Ne (Ident loopID) exitLoop) comp modify $ \s -> s { sLoopDepth = origLoopDepth }
return $ if jt == JTNone
then loop comp stmt' let comp' = BinOp LogAnd comp $ BinOp Lt (Ident jumpState) jsBreak
else Block Seq "" let body = Block Seq "" []
[ Variable Local loopStateType loopID [] (Just runLoop) [ asgn jumpState jsNone
]
[ loop comp' $ Block Seq "" []
[ asgn loopID runLoop
, stmt' , stmt'
] ]
, if afterJT == JTReturn && origLoopID /= "" let jsStackIdent = jumpState ++ "_" ++ show origLoopDepth
then asgn origLoopID exitLoop let jsStackDecl = Variable Local jumpStateType jsStackIdent []
else Null (Just $ Ident jumpState)
let jsStackRestore = If NoCheck
(BinOp Ne (Ident jumpState) jsReturn)
(asgn jumpState (Ident jsStackIdent))
Null
return $
if not afterHasJump then
loop comp stmt'
else if origLoopDepth == 0 then
Block Seq "" []
[ loop comp' body ]
else
Block Seq ""
[ jsStackDecl ]
[ loop comp' body
, jsStackRestore
] ]
where loopStateType = IntegerVector TBit Unspecified [(Number "0", Number "1")]
jumpStateType :: Type
jumpStateType = IntegerVector TBit Unspecified [(Number "0", Number "1")]
jumpState :: String
jumpState = "_sv2v_jump"
-- stop running the loop immediately (break or return) -- keep running the loop/function normally
exitLoop :: Expr jsNone :: Expr
exitLoop = Number "0" jsNone = Number "2'b00"
-- keep running the loop normally
runLoop :: Expr
runLoop = Number "1"
-- skip to the next iteration of the loop (continue) -- skip to the next iteration of the loop (continue)
continueLoop :: Expr jsContinue :: Expr
continueLoop = Number "2" jsContinue = Number "2'b01"
-- stop running the loop immediately (break)
jsBreak :: Expr
jsBreak = Number "2'b10"
-- stop running the function immediately (return)
jsReturn :: Expr
jsReturn = Number "2'b11"
assertMsg :: Bool -> String -> State Info () assertMsg :: Bool -> String -> State Info ()
......
...@@ -31,9 +31,15 @@ module top; ...@@ -31,9 +31,15 @@ module top;
$display("test1: %b %b", 3'b0z1, test1(3'b0z1)); $display("test1: %b %b", 3'b0z1, test1(3'b0z1));
end end
integer arr [] = { 32'd60, 32'd61, 32'd63 };
function test2; function test2;
input integer inp; input integer inp;
return inp inside { [16:23], [32:47] }; // TODO: Add support for array value ranges.
test2 = 0;
for (integer i = 0; i < 3; ++i)
if (inp == arr[i])
return 1'b1;
return test2 || inp inside { [16:23], [32:47] };
endfunction endfunction
initial begin initial begin
for (integer i = 0; i < 64; ++i) for (integer i = 0; i < 64; ++i)
......
...@@ -33,9 +33,17 @@ module top; ...@@ -33,9 +33,17 @@ module top;
$display("test1: %b %b", 3'b0z1, test1(3'b0z1)); $display("test1: %b %b", 3'b0z1, test1(3'b0z1));
end end
wire [0:2][31:0] arr;
assign arr = { 32'd60, 32'd61, 32'd63 };
function test2; function test2;
input integer inp; input integer inp;
test2 = (16 <= inp && inp <= 23) || (32 <= inp && inp <= 47); integer i;
begin
test2 = 0;
for (i = 0; i < 3; ++i)
test2 = test2 || (inp == arr[i]);
test2 = test2 || (16 <= inp && inp <= 23) || (32 <= inp && inp <= 47);
end
endfunction endfunction
initial begin : foobar initial begin : foobar
integer i; integer i;
......
...@@ -3,25 +3,25 @@ module top; ...@@ -3,25 +3,25 @@ module top;
task skip1; task skip1;
$display("HELLO skip1"); $display("HELLO skip1");
return; return;
$display("UNREACHABLE"); $display("UNREACHABLE ", `__LINE__);
endtask endtask
function void skip2; function void skip2;
$display("HELLO skip2"); $display("HELLO skip2");
return; return;
$display("UNREACHABLE"); $display("UNREACHABLE ", `__LINE__);
endfunction endfunction
function int skip3; function int skip3;
$display("HELLO skip3"); $display("HELLO skip3");
return 1; return 1;
$display("UNREACHABLE"); $display("UNREACHABLE ", `__LINE__);
endfunction endfunction
task skip4; task skip4;
for (int i = 0; i < 10; ++i) begin for (int i = 0; i < 10; ++i) begin
$display("HELLO skip4"); $display("HELLO skip4");
return; return;
$display("UNREACHABLE"); $display("UNREACHABLE ", `__LINE__);
end end
$display("UNREACHABLE"); $display("UNREACHABLE ", `__LINE__);
endtask endtask
task skip5; task skip5;
for (int i = 0; i < 10; ++i) begin for (int i = 0; i < 10; ++i) begin
...@@ -29,11 +29,36 @@ module top; ...@@ -29,11 +29,36 @@ module top;
for (int j = 0; j < 10; ++j) begin for (int j = 0; j < 10; ++j) begin
$display("HELLO skip5-2"); $display("HELLO skip5-2");
return; return;
$display("UNREACHABLE"); $display("UNREACHABLE ", `__LINE__);
end end
$display("UNREACHABLE"); $display("UNREACHABLE ", `__LINE__);
end end
$display("UNREACHABLE"); $display("UNREACHABLE ", `__LINE__);
endtask
task skip6;
for (int i = 0; i < 0; ++i) begin
$display("UNREACHABLE ", `__LINE__);
return;
end
$display("HELLO skip6");
endtask
task skip7;
begin
parameter x = 1;
$display("HELLO skip7");
if (x == 1) return;
$display("UNREACHABLE ", `__LINE__);
end
$display("UNREACHABLE ", `__LINE__);
endtask
task skip8;
begin
parameter x = 1;
$display("HELLO skip8-1");
if (x == 2) return;
$display("HELLO skip8-2");
end
$display("HELLO skip8-3");
endtask endtask
initial begin initial begin
skip1; skip1;
...@@ -41,20 +66,23 @@ module top; ...@@ -41,20 +66,23 @@ module top;
$display(skip3()); $display(skip3());
skip4; skip4;
skip5; skip5;
skip6;
skip7;
skip8;
end end
initial initial
for (int i = 0; i < 10; ++i) begin for (int i = 0; i < 10; ++i) begin
$display("Loop Y:", i); $display("Loop Y:", i);
continue; continue;
$display("UNREACHABLE"); $display("UNREACHABLE ", `__LINE__);
end end
initial initial
for (int i = 0; i < 10; ++i) begin for (int i = 0; i < 10; ++i) begin
$display("Loop Z:", i); $display("Loop Z:", i);
break; break;
$display("UNREACHABLE"); $display("UNREACHABLE ", `__LINE__);
end end
initial initial
...@@ -73,7 +101,7 @@ module top; ...@@ -73,7 +101,7 @@ module top;
else begin else begin
$display("Loop B-3:", i); $display("Loop B-3:", i);
continue; continue;
$display("UNREACHABLE"); $display("UNREACHABLE ", `__LINE__);
end end
$display("Loop B:", i); $display("Loop B:", i);
end end
......
...@@ -22,12 +22,28 @@ module top; ...@@ -22,12 +22,28 @@ module top;
$display("HELLO skip5-2"); $display("HELLO skip5-2");
end end
endtask endtask
task skip6;
$display("HELLO skip6");
endtask
task skip7;
$display("HELLO skip7");
endtask
task skip8;
begin
$display("HELLO skip8-1");
$display("HELLO skip8-2");
$display("HELLO skip8-3");
end
endtask
initial begin initial begin
skip1; skip1;
skip2; skip2;
$display(skip3(0)); $display(skip3(0));
skip4; skip4;
skip5; skip5;
skip6;
skip7;
skip8;
end end
initial begin : loop_y initial begin : loop_y
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment