Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Give evaluator acces to inscope let-bindings #2571

Merged
merged 3 commits into from
Nov 3, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions clash-ghc/src-ghc/Clash/GHC/LoadInterfaceFiles.hs
Original file line number Diff line number Diff line change
Expand Up @@ -495,7 +495,11 @@ loadExprFromTyThing :: CoreSyn.CoreBndr -> GHC.TyThing -> Maybe CoreSyn.CoreExpr
loadExprFromTyThing bndr tyThing = case tyThing of
GHC.AnId _id | Var.isId _id ->
let _idInfo = Var.idInfo _id
#if MIN_VERSION_ghc(9,4,0)
unfolding = IdInfo.realUnfoldingInfo _idInfo
#else
unfolding = IdInfo.unfoldingInfo _idInfo
#endif
in case unfolding of
CoreSyn.CoreUnfolding {} ->
Just (CoreSyn.unfoldingTemplate unfolding)
Expand Down
5 changes: 3 additions & 2 deletions clash-lib/src/Clash/Core/Evaluator/Types.hs
Original file line number Diff line number Diff line change
Expand Up @@ -38,19 +38,20 @@ import Clash.Pretty (ClashPretty(..), fromPretty, showDoc)
whnf'
:: Evaluator
-> BindingMap
-> VarEnv Term
-> TyConMap
-> PrimHeap
-> Supply
-> InScopeSet
-> Bool
-> Term
-> (PrimHeap, PureHeap, Term)
whnf' eval bm tcm ph ids is isSubj e =
whnf' eval bm lh tcm ph ids is isSubj e =
toResult $ whnf eval tcm isSubj m
where
toResult x = (mHeapPrim x, mHeapLocal x, mTerm x)

m = Machine ph gh emptyVarEnv [] ids is e
m = Machine ph gh lh [] ids is e
gh = mapVarEnv bindingTerm bm

-- | Evaluate to WHNF given an existing Heap and Stack
Expand Down
2 changes: 1 addition & 1 deletion clash-lib/src/Clash/Core/Pretty.hs
Original file line number Diff line number Diff line change
Expand Up @@ -455,7 +455,7 @@ pprPrecCast prec e ty1 ty2 = do
pprPrecLetrec :: Monad m => Rational -> Bool -> [(Id, Term)] -> Term -> m ClashDoc
pprPrecLetrec prec isRec xes body = do
let bndrs = fst <$> xes
body' <- annotate (AnnContext $ LetBody bndrs) <$> pprPrec noPrec body
body' <- annotate (AnnContext $ LetBody xes) <$> pprPrec noPrec body
xes' <- mapM (\(x,e) -> do
x' <- pprBndr LetBind x
e' <- pprPrec noPrec e
Expand Down
4 changes: 2 additions & 2 deletions clash-lib/src/Clash/Core/Term.hs
Original file line number Diff line number Diff line change
Expand Up @@ -266,7 +266,7 @@ data CoreContext
-- ^ Function position of a type application
| LetBinding Id [Id]
-- ^ RHS of a Let-binder with the sibling LHS'
| LetBody [Id]
| LetBody [LetBinding]
-- ^ Body of a Let-binding with the bound LHS'
| LamBody Id
-- ^ Body of a lambda-term with the abstracted variable
Expand Down Expand Up @@ -303,7 +303,7 @@ instance Eq CoreContext where
-- NB: we do not see inside the argument here
(TyAppC, TyAppC) -> True
(LetBinding i is, LetBinding i' is') -> i == i' && is == is'
(LetBody is, LetBody is') -> is == is'
(LetBody is, LetBody is') -> map fst is == map fst is'
(LamBody i, LamBody i') -> i == i'
(TyLamBody tv, TyLamBody tv') -> tv == tv'
(CaseAlt p, CaseAlt p') -> p == p'
Expand Down
8 changes: 8 additions & 0 deletions clash-lib/src/Clash/Core/VarEnv.hs
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ module Clash.Core.VarEnv
, delVarEnvList
, unionVarEnv
, unionVarEnvWith
, differenceVarEnv
-- ** Element-wise operations
-- *** Mapping
, mapVarEnv
Expand Down Expand Up @@ -227,6 +228,13 @@ unionVarEnvWith
-> VarEnv a
unionVarEnvWith = UniqMap.unionWith

-- | Filter the first varenv to only contain keys which are not in the second varenv.
differenceVarEnv
:: VarEnv a
-> VarEnv a
-> VarEnv a
differenceVarEnv = UniqMap.difference

-- | Create an environment given a list of var-value pairs
mkVarEnv
:: [(Var a,b)]
Expand Down
2 changes: 1 addition & 1 deletion clash-lib/src/Clash/Normalize/Transformations/DEC.hs
Original file line number Diff line number Diff line change
Expand Up @@ -303,7 +303,7 @@ collectGlobals' is0 substitution seen e@(collectArgsTicks -> (fun, args@(_:_), t
let (ids1,ids2) = splitSupply ids
uniqSupply Lens..= ids2
gh <- Lens.use globalHeap
let eval = (Lens.view Lens._3) . whnf' evaluate bndrs tcm gh ids1 is0 False
let eval = (Lens.view Lens._3) . whnf' evaluate bndrs mempty tcm gh ids1 is0 False
let eTy = inferCoreTypeOf tcm e
untran <- isUntranslatableType False eTy
case untran of
Expand Down
4 changes: 2 additions & 2 deletions clash-lib/src/Clash/Normalize/Transformations/EtaExpand.hs
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ etaExpansionTL (TransformContext is0 ctx) (Lam bndr e) = do
return $ Lam bndr e'

etaExpansionTL (TransformContext is0 ctx) (Let (NonRec i x) e) = do
let ctx' = TransformContext (extendInScopeSet is0 i) (LetBody [i] : ctx)
let ctx' = TransformContext (extendInScopeSet is0 i) (LetBody [(i,x)] : ctx)
e' <- etaExpansionTL ctx' e
case stripLambda e' of
(bs@(_:_),e2) -> do
Expand All @@ -81,7 +81,7 @@ etaExpansionTL (TransformContext is0 ctx) (Let (NonRec i x) e) = do

etaExpansionTL (TransformContext is0 ctx) (Let (Rec xes) e) = do
let bndrs = map fst xes
ctx' = TransformContext (extendInScopeSetList is0 bndrs) (LetBody bndrs : ctx)
ctx' = TransformContext (extendInScopeSetList is0 bndrs) (LetBody xes : ctx)
e' <- etaExpansionTL ctx' e
case stripLambda e' of
(bs@(_:_),e2) -> do
Expand Down
2 changes: 1 addition & 1 deletion clash-lib/src/Clash/Rewrite/Combinators.hs
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ allR trans (TransformContext is c) (Cast e ty1 ty2) =

allR trans (TransformContext is c) (Letrec xes e) = do
xes' <- traverse rewriteBind xes
e' <- trans (TransformContext is' (LetBody bndrs:c)) e
e' <- trans (TransformContext is' (LetBody xes:c)) e
return (Letrec xes' e')
where
bndrs = map fst xes
Expand Down
18 changes: 13 additions & 5 deletions clash-lib/src/Clash/Rewrite/Util.hs
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,8 @@ import Clash.Core.Var
import Clash.Core.VarEnv
(InScopeSet, extendInScopeSet, extendInScopeSetList, mkInScopeSet,
uniqAway, uniqAway', mapVarEnv, eltsVarEnv, unitVarSet, emptyVarEnv,
mkVarEnv, eltsVarSet, elemVarEnv, lookupVarEnv, extendVarEnv, elemVarSet)
mkVarEnv, eltsVarSet, elemVarEnv, lookupVarEnv, extendVarEnv, elemVarSet,
differenceVarEnv)
import Clash.Data.UniqMap (UniqMap)
import qualified Clash.Data.UniqMap as UniqMap
import Clash.Debug
Expand Down Expand Up @@ -730,19 +731,26 @@ whnfRW
-> Term
-> Rewrite extra
-> RewriteMonad extra Term
whnfRW isSubj ctx@(TransformContext is0 _) e rw = do
whnfRW isSubj ctx@(TransformContext is0 hist) e rw = do
tcm <- Lens.view tcCache
bndrs <- Lens.use bindings
eval <- Lens.view evaluator
ids <- Lens.use uniqSupply
let (ids1,ids2) = splitSupply ids
uniqSupply Lens..= ids2
gh <- Lens.use globalHeap
let lh = localBinders mempty hist

case whnf' eval bndrs tcm gh ids1 is0 isSubj e of
case whnf' eval bndrs lh tcm gh ids1 is0 isSubj e of
(!gh1,ph,v) -> do
globalHeap Lens..= gh1
bindPureHeap tcm ph rw ctx v
bindPureHeap tcm (ph `differenceVarEnv` lh) rw ctx v
where
localBinders acc [] = acc
localBinders !acc (h:hs) = case h of
LetBody ls -> localBinders (acc <> mkVarEnv ls) hs
_ -> localBinders acc hs

{-# SCC whnfRW #-}

-- | Binds variables on the PureHeap over the result of the rewrite
Expand Down Expand Up @@ -791,7 +799,7 @@ bindPureHeap tcm heap rw ctx0@(TransformContext is0 hist) e = do
where
heapIds = map fst bndrs
is1 = extendInScopeSetList is0 heapIds
ctx = TransformContext is1 (LetBody heapIds : hist)
ctx = TransformContext is1 (LetBody bndrs : hist)

bndrs = map toLetBinding $ UniqMap.toList heap

Expand Down
18 changes: 9 additions & 9 deletions clash-prelude/src/Clash/Sized/Internal/Index.hs
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ import {-# SOURCE #-} Clash.Sized.Internal.BitVector (BitVector (BV), high, low,
import qualified Clash.Sized.Internal.BitVector as BV
import Clash.Promoted.Nat (SNat(..), snatToNum, natToInteger, leToPlusKN)
import Clash.XException
(ShowX (..), NFDataX (..), errorX, showsPrecXWith, rwhnfX)
(ShowX (..), NFDataX (..), errorX, showsPrecXWith, rwhnfX, seqX)

{- $setup
>>> import Clash.Sized.Internal.Index
Expand Down Expand Up @@ -379,9 +379,9 @@ times# :: Index m -> Index n -> Index (((m - 1) * (n - 1)) + 1)
times# (I a) (I b) = I (a * b)

instance (KnownNat n, 1 <= n) => SaturatingNum (Index n) where
satAdd SatWrap !a !b =
satAdd SatWrap a b =
case natToInteger @n of
1 -> fromInteger# 0
1 -> a +# b
_ -> leToPlusKN @1 @n $
case plus# a b of
z | let m = fromInteger# (natToInteger @n)
Expand Down Expand Up @@ -419,9 +419,9 @@ instance (KnownNat n, 1 <= n) => SaturatingNum (Index n) where
then fromInteger# 0
else a -# b

satMul SatWrap !a !b =
satMul SatWrap a b =
case natToInteger @n of
1 -> fromInteger# 0
1 -> a *# b
2 -> case a of {0 -> 0; _ -> b}
_ -> leToPlusKN @1 @n $
case times# a b of
Expand All @@ -446,19 +446,19 @@ instance (KnownNat n, 1 <= n) => SaturatingNum (Index n) where
, z > m -> maxBound#
z -> resize# z

satSucc SatError !a =
satSucc SatError a =
case natToInteger @n of
1 -> errorX "Index.satSucc: overflow"
1 -> a `seqX` errorX "Index.satSucc: overflow"
_ -> satAdd SatError a $ fromInteger# 1
satSucc satMode !a =
case natToInteger @n of
1 -> fromInteger# 0
_ -> satAdd satMode a $ fromInteger# 1
{-# INLINE satSucc #-}

satPred SatError !a =
satPred SatError a =
case natToInteger @n of
1 -> errorX "Index.satPred: underflow"
1 -> a `seqX` errorX "Index.satPred: underflow"
_ -> satSub SatError a $ fromInteger# 1
satPred satMode !a =
case natToInteger @n of
Expand Down
2 changes: 1 addition & 1 deletion clash-term/Main.hs
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ instance Diff Term where
(Letrec bnds body, LetBinding i' _) ->
Letrec (mapBindings i' bnds) body
(Letrec bnds t, LetBody is) ->
if (fst <$> bnds) == is
if (fst <$> bnds) == (fst <$> is)
then Letrec bnds (go t)
else error "Ctx.LetBody: different bindings"
(Lam i t, LamBody i') ->
Expand Down
Loading