Skip to content

Commit

Permalink
Generalize delayedFold to arbitrary vectors
Browse files Browse the repository at this point in the history
  • Loading branch information
t-wallet committed Aug 29, 2024
1 parent 5927123 commit c027864
Show file tree
Hide file tree
Showing 3 changed files with 79 additions and 29 deletions.
92 changes: 70 additions & 22 deletions clash-prelude/src/Clash/Explicit/Signal/Delayed.hs
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@ Maintainer : Christiaan Baaij <christiaan.baaij@gmail.com>
{-# LANGUAGE Trustworthy #-}

{-# OPTIONS_GHC -fplugin GHC.TypeLits.Normalise #-}
{-# OPTIONS_GHC -fplugin GHC.TypeLits.KnownNat.Solver #-}
{-# OPTIONS_GHC -fconstraint-solver-iterations=10 #-}
{-# OPTIONS_HADDOCK show-extensions #-}

module Clash.Explicit.Signal.Delayed
Expand All @@ -42,22 +44,24 @@ module Clash.Explicit.Signal.Delayed
)
where

import Prelude ((.), (<$>), (<*>), id, Num(..))
import Prelude ((.), ($), (<$>), id, Num(..), Maybe(..), fmap, liftA2)

import Data.Coerce (coerce)
import Data.Kind (Type)
import Data.Proxy (Proxy (..))
import Data.Singletons (Apply, TyFun, type (@@))
import GHC.TypeLits (KnownNat, Nat, type (+), type (^), type (*))
import Data.Type.Equality ((:~:)(Refl))
import GHC.TypeLits (sameNat, Div, Mod, KnownNat, Nat, type (+), type (*), type (<=))
import GHC.TypeLits.Extra (CLog)

import Clash.Magic (clashCompileError)
import Clash.Sized.Vector
import Clash.Signal.Delayed.Internal
(DSignal(..), dfromList, dfromList_lazy, fromSignal, toSignal,
unsafeFromSignal, antiDelay, feedback, forward)
import qualified Clash.Signal.Delayed.Bundle as D

import Clash.Explicit.Signal
(KnownDomain, Clock, Domain, Reset, Signal, Enable, register, delay, bundle, unbundle)
import Clash.Promoted.Nat (SNat (..), snatToInteger)
import Clash.Promoted.Nat (SNat (..), SNatLE (..), compareSNat, snatToInteger)
import Clash.XException (NFDataX)

{- $setup
Expand Down Expand Up @@ -230,12 +234,9 @@ delayI
-> DSignal dom (n+d) a
delayI dflt = delayN (SNat :: SNat d) dflt

data DelayedFold (dom :: Domain) (n :: Nat) (delay :: Nat) (a :: Type) (f :: TyFun Nat Type) :: Type
type instance Apply (DelayedFold dom n delay a) k = DSignal dom (n + (delay*k)) a

-- | Tree fold over a 'Vec' of 'DSignal's with a combinatorial function,
-- and delaying @delay@ cycles after each application.
-- Values at times 0..(delay*k)-1 are set to a default.
-- Values at times 0..(delay * CLog 2 n)-1 are set to a default.
--
-- @
-- countingSignals :: Vec 4 (DSignal dom 0 Int)
Expand All @@ -248,11 +249,12 @@ type instance Apply (DelayedFold dom n delay a) k = DSignal dom (n + (delay*k))
-- >>> printX $ sampleN 8 (delayedFold d2 (-1) (*) enableGen systemClockGen countingSignals)
-- [-1,-1,1,1,0,1,16,81]
delayedFold
:: forall dom n delay k a
:: forall dom d delay n a
. ( NFDataX a
, KnownDomain dom
, KnownNat delay
, KnownNat k )
, KnownNat n
, 1 <= n )
=> SNat delay
-- ^ Delay applied after each step
-> a
Expand All @@ -261,14 +263,60 @@ delayedFold
-- ^ Fold operation to apply
-> Enable dom
-> Clock dom
-> Vec (2^k) (DSignal dom n a)
-- ^ Vector input of size 2^k
-> DSignal dom (n + (delay * k)) a
-- ^ Output Signal delayed by (delay * k)
delayedFold _ dflt op ena clk = dtfold (Proxy :: Proxy (DelayedFold dom n delay a)) id go
where
go :: SNat l
-> DelayedFold dom n delay a @@ l
-> DelayedFold dom n delay a @@ l
-> DelayedFold dom n delay a @@ (l+1)
go SNat x y = delayI dflt ena clk (op <$> x <*> y)
-> Vec n (DSignal dom d a)
-- ^ Vector input of size @n@
-> DSignal dom (d + delay * CLog 2 n) a
-- ^ Output Signal delayed by @delay * CLog 2 n@
delayedFold SNat initial f ena clk inps = case sameNat (SNat @1) (SNat @n) of
Just Refl -> head inps
_ -> case (modProof, strictlyPosDivRu, divMulProof) of
(SNatLE, SNatLE, Just Refl) ->
case sameNat (SNat @(1 + CLog 2 (n `Div` 2 + n `Mod` 2))) (SNat @(CLog 2 n)) of
Just Refl -> delayedFold (SNat @delay) initial f ena clk newLayer
where
newLayer = D.unbundle $
step @(n `Div` 2) @(n `Mod` 2) @d @delay (SNat @(n `Div` 2)) initial f ena clk (D.bundle inps)
_ -> clashCompileError
"delayedFold0: absurd, report this to the clash-compiler team: https://github.com/clash-lang/clash-compiler/issues"
_ -> clashCompileError
"delayedFold1: absurd, report this to the clash-compiler team: https://github.com/clash-lang/clash-compiler/issues"
where
modProof = compareSNat (SNat @(n `Mod` 2)) (SNat @1)
strictlyPosDivRu = compareSNat (SNat @1) (SNat @(n `Div` 2 + n `Mod` 2))
divMulProof = sameNat (SNat @n) (SNat @(2 * (n `Div` 2) + n `Mod` 2))

-- | A single layer of the pipelined fold
step :: forall (m :: Nat) (p :: Nat) (d :: Nat) (delay :: Nat) (dom :: Domain) (a :: Type).
KnownNat p
=> KnownNat delay
=> KnownDomain dom
=> p <= 1
=> NFDataX a
=> SNat m
-> a
-> (a -> a -> a)
-> Enable dom
-> Clock dom
-> DSignal dom d (Vec (2 * m + p) a)
-> DSignal dom (d + delay) (Vec (m + p) a)
step SNat initial f ena clk inps =
let
layerCalc :: DSignal dom d (Vec (2 * m) a) -> DSignal dom d (Vec m a)
layerCalc = fmap (map applyF . unconcatI)

applyF :: Vec 2 a -> a
applyF (a `Cons` b `Cons` _) = f a b
in
case (sameNat (SNat @p) (SNat @0), sameNat (SNat @p) (SNat @1)) of
-- Size of the input vector is even
(Just Refl, Nothing) ->
delayI (repeat initial) ena clk (layerCalc inps)
-- Size of the input vector is odd
(Nothing, Just Refl) ->
delayI (repeat initial) ena clk $
liftA2
(++)
(singleton . head <$> inps)
(layerCalc (tail <$> inps))
_ -> clashCompileError
"delayedFold step: absurd, report this to the clash-compiler team: https://github.com/clash-lang/clash-compiler/issues"
14 changes: 8 additions & 6 deletions clash-prelude/src/Clash/Signal/Delayed.hs
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,8 @@ module Clash.Signal.Delayed
where

import GHC.TypeLits
(KnownNat, type (^), type (+), type (*))
(KnownNat, type (+), type (*), type (<=))
import GHC.TypeLits.Extra (CLog)

import Clash.Signal.Delayed.Internal
(DSignal(..), dfromList, dfromList_lazy, fromSignal, toSignal,
Expand Down Expand Up @@ -192,7 +193,7 @@ delayI dflt = hideClock (hideEnable (E.delayI dflt))

-- | Tree fold over a 'Vec' of 'DSignal's with a combinatorial function,
-- and delaying @delay@ cycles after each application.
-- Values at times 0..(delay*k)-1 are set to a default.
-- Values at times 0..(delay * CLog 2 n)-1 are set to a default.
--
-- @
-- countingSignals :: Vec 4 (DSignal dom 0 Int)
Expand All @@ -205,20 +206,21 @@ delayI dflt = hideClock (hideEnable (E.delayI dflt))
-- >>> printX $ sampleN @System 8 (toSignal (delayedFold d2 (-1) (*) countingSignals))
-- [-1,-1,1,1,0,1,16,81]
delayedFold
:: forall dom n delay k a
:: forall dom d delay n a
. ( HiddenClock dom
, HiddenEnable dom
, NFDataX a
, KnownNat delay
, KnownNat k )
, KnownNat n
, 1 <= n)
=> SNat delay
-- ^ Delay applied after each step
-> a
-- ^ Initial value
-> (a -> a -> a)
-- ^ Fold operation to apply
-> Vec (2^k) (DSignal dom n a)
-> Vec n (DSignal dom d a)
-- ^ Vector input of size 2^k
-> DSignal dom (n + (delay * k)) a
-> DSignal dom (d + (delay * CLog 2 n)) a
-- ^ Output Signal delayed by (delay * k)
delayedFold d dflt f = hideClock (hideEnable (E.delayedFold d dflt f))
2 changes: 1 addition & 1 deletion clash-prelude/src/Clash/Signal/Delayed/Bundle.hs
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ import GHC.TypeLits (KnownNat)
import Prelude hiding (head, map, tail)

import Clash.Signal.Internal (Domain)
import Clash.Signal.Delayed (DSignal, toSignal, unsafeFromSignal)
import Clash.Signal.Delayed.Internal (DSignal, toSignal, unsafeFromSignal)
import qualified Clash.Signal.Bundle as B

import Clash.Sized.BitVector (Bit, BitVector)
Expand Down

0 comments on commit c027864

Please sign in to comment.