Skip to content

Commit

Permalink
allow multi-master with trivial Bwd direction e.g. Signal
Browse files Browse the repository at this point in the history
  • Loading branch information
jonfowler committed Nov 28, 2023
1 parent 618e375 commit c25d656
Show file tree
Hide file tree
Showing 5 changed files with 105 additions and 44 deletions.
2 changes: 1 addition & 1 deletion circuit-notation.cabal
Original file line number Diff line number Diff line change
Expand Up @@ -38,4 +38,4 @@ Test-Suite library-testsuite
type: exitcode-stdio-1.0
main-is: unittests.hs
hs-source-dirs: tests
build-depends: base, circuit-notation
build-depends: base, circuit-notation, clash-prelude >= 1.0
4 changes: 2 additions & 2 deletions shell.nix
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@ stdenv.mkDerivation {

buildInputs = [
ghc
cabal-install
haskellPackages.ghcid
# cabal-install
# haskellPackages.ghcid
haskellPackages.stylish-haskell
];

Expand Down
37 changes: 30 additions & 7 deletions src/Circuit.hs
Original file line number Diff line number Diff line change
Expand Up @@ -11,19 +11,18 @@ This file contains the 'Circuit' type, that the notation describes.
-}

{-# LANGUAGE DeriveFunctor #-}
{-# LANGUAGE NoImplicitPrelude #-}
{-# LANGUAGE PatternSynonyms #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE PatternSynonyms #-}
{-# LANGUAGE ViewPatterns #-}

{-# LANGUAGE DataKinds #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE TypeOperators #-}

module Circuit where

import Clash.Prelude (Domain, Signal, Vec(..))
import Data.Default
import Clash.Prelude

type family Fwd a
type family Bwd a
Expand Down Expand Up @@ -70,3 +69,27 @@ type instance Bwd (Signal dom a) = ()
-- | Circuit type.
newtype Circuit a b = Circuit { runCircuit :: CircuitT a b }
type CircuitT a b = (Fwd a :-> Bwd b) -> (Bwd a :-> Fwd b)

class TrivialBwd a where
unitBwd :: a

instance TrivialBwd () where
unitBwd = ()

instance (TrivialBwd a, KnownNat n) => TrivialBwd (Vec n a) where
unitBwd = repeat unitBwd

instance (TrivialBwd a, TrivialBwd b) => TrivialBwd (a,b) where
unitBwd = (unitBwd, unitBwd)

instance (TrivialBwd a, TrivialBwd b, TrivialBwd c) => TrivialBwd (a,b,c) where
unitBwd = (unitBwd, unitBwd, unitBwd)

instance (TrivialBwd a, TrivialBwd b, TrivialBwd c, TrivialBwd d) => TrivialBwd (a,b,c,d) where
unitBwd = (unitBwd, unitBwd, unitBwd, unitBwd)

instance (TrivialBwd a, TrivialBwd b, TrivialBwd c, TrivialBwd d, TrivialBwd e) => TrivialBwd (a,b,c,d,e) where
unitBwd = (unitBwd, unitBwd, unitBwd, unitBwd, unitBwd)

instance (TrivialBwd a, TrivialBwd b, TrivialBwd c, TrivialBwd d, TrivialBwd e, TrivialBwd f) => TrivialBwd (a,b,c,d,e,f) where
unitBwd = (unitBwd, unitBwd, unitBwd, unitBwd, unitBwd, unitBwd)
91 changes: 57 additions & 34 deletions src/CircuitNotation.hs
Original file line number Diff line number Diff line change
Expand Up @@ -165,6 +165,7 @@ data PortDescription a
= Tuple [PortDescription a]
| Vec SrcSpan [PortDescription a]
| Ref a
| RefMulticast a
| Lazy SrcSpan (PortDescription a)
| SignalExpr (LHsExpr GhcPs)
| SignalPat (LPat GhcPs)
Expand Down Expand Up @@ -349,6 +350,7 @@ portTypeSigM = \case
L.use (portVarTypes . L.at fs) <&> \case
Nothing -> varT loc (GHC.unpackFS fs <> "Ty")
Just (_sigLoc, sig) -> sig
RefMulticast p -> portTypeSigM (Ref p)
PortErr loc msgdoc -> do
dflags <- GHC.getDynFlags
unsafePerformIO . throwOneError $
Expand Down Expand Up @@ -643,35 +645,59 @@ checkCircuit = do
topNames = portNames Slave slaves <> portNames Master masters
nameMap = Map.fromListWith mappend $ topNames <> concatMap bindingNames binds

L.iforM_ nameMap \name occ ->
duplicateMasters <- concat <$> L.iforM nameMap \name occ ->
case occ of
([_], [_]) -> pure ()
([_], [_]) -> pure []
(ss, ms) -> do
unless (head (unpackFS name) == '_') $ do
when (null ms) $ errM (head ss) $ "Slave port " <> show name <> " has no associated master"
when (null ss) $ errM (head ms) $ "Master port " <> show name <> " has no associated slave"
-- would be nice to show locations of all occurrences here, not sure how to do that while
-- keeping ghc api
when (length ms > 1) $
errM (head ms) $ "Master port " <> show name <> " defined " <> show (length ms) <> " times"
when (length ss > 1) $
errM (head ss) $ "Slave port " <> show name <> " defined " <> show (length ss) <> " times"

-- if master is defined multiple times, we try to broadcast it
if length ms > 1
then pure [name]
else pure []

let
modifyMulticast = \case
Ref p@(PortName _ a) | a `elem` duplicateMasters -> RefMulticast p
p -> p

-- update relevant master ports to be multicast
circuitSlaves %= L.transform modifyMulticast
circuitMasters %= L.transform modifyMulticast
circuitBinds . L.mapped %= \b -> b
{ bIn = L.transform modifyMulticast (bIn b),
bOut = L.transform modifyMulticast (bOut b)
}

-- Creating ------------------------------------------------------------

bindWithSuffix :: (p ~ GhcPs, ?nms :: ExternalNames) => GHC.DynFlags -> String -> PortDescription PortName -> LPat p
bindWithSuffix dflags suffix = \case
Tuple ps -> tildeP noSrcSpan $ tupP $ fmap (bindWithSuffix dflags suffix) ps
Vec s ps -> vecP s $ fmap (bindWithSuffix dflags suffix) ps
Ref (PortName loc fs) -> varP loc (GHC.unpackFS fs <> suffix)
data Direc = Fwd | Bwd deriving Show

bindWithSuffix :: (p ~ GhcPs, ?nms :: ExternalNames) => GHC.DynFlags -> Direc -> PortDescription PortName -> LPat p
bindWithSuffix dflags dir = \case
Tuple ps -> tildeP noSrcSpan $ tupP $ fmap (bindWithSuffix dflags dir) ps
Vec s ps -> vecP s $ fmap (bindWithSuffix dflags dir) ps
Ref (PortName loc fs) -> varP loc (GHC.unpackFS fs <> "_" <> show dir)
RefMulticast (PortName loc fs) -> case dir of
Bwd -> L loc (WildPat noExt)
Fwd -> varP loc (GHC.unpackFS fs <> "_" <> show dir)
PortErr loc msgdoc -> unsafePerformIO . throwOneError $
Err.mkLongErrMsg dflags loc Outputable.alwaysQualify (Outputable.text "Unhandled bind") msgdoc
Lazy loc p -> tildeP loc $ bindWithSuffix dflags suffix p
Lazy loc p -> tildeP loc $ bindWithSuffix dflags dir p
SignalExpr (L l _) -> L l (WildPat noExt)
SignalPat lpat -> lpat
PortType _ p -> bindWithSuffix dflags suffix p
PortType _ p -> bindWithSuffix dflags dir p

data Direc = Fwd | Bwd
revDirec :: Direc -> Direc
revDirec = \case
Fwd -> Bwd
Bwd -> Fwd

bindOutputs
:: (p ~ GhcPs, ?nms :: ExternalNames)
Expand All @@ -682,26 +708,25 @@ bindOutputs
-> PortDescription PortName
-- ^ master ports
-> LPat p
bindOutputs dflags Fwd slaves masters = noLoc $ conPatIn (noLoc (fwdBwdCon ?nms)) (InfixCon m2s s2m)
bindOutputs dflags direc slaves masters = noLoc $ conPatIn (noLoc (fwdBwdCon ?nms)) (InfixCon m2s s2m)
where
m2s = bindWithSuffix dflags "_Fwd" masters
s2m = bindWithSuffix dflags "_Bwd" slaves
bindOutputs dflags Bwd slaves masters = noLoc $ conPatIn (noLoc (fwdBwdCon ?nms)) (InfixCon m2s s2m)
where
m2s = bindWithSuffix dflags "_Bwd" masters
s2m = bindWithSuffix dflags "_Fwd" slaves

expWithSuffix :: p ~ GhcPs => String -> PortDescription PortName -> LHsExpr p
expWithSuffix suffix = \case
Tuple ps -> tupE noSrcSpan $ fmap (expWithSuffix suffix) ps
Vec s ps -> vecE s $ fmap (expWithSuffix suffix) ps
Ref (PortName loc fs) -> varE loc (var $ GHC.unpackFS fs <> suffix)
m2s = bindWithSuffix dflags direc masters
s2m = bindWithSuffix dflags (revDirec direc) slaves

expWithSuffix :: (p ~ GhcPs, ?nms :: ExternalNames) => Direc -> PortDescription PortName -> LHsExpr p
expWithSuffix dir = \case
Tuple ps -> tupE noSrcSpan $ fmap (expWithSuffix dir) ps
Vec s ps -> vecE s $ fmap (expWithSuffix dir) ps
Ref (PortName loc fs) -> varE loc (var $ GHC.unpackFS fs <> "_" <> show dir)
RefMulticast (PortName loc fs) -> case dir of
Bwd -> varE noSrcSpan (trivialBwd ?nms)
Fwd -> varE loc (var $ GHC.unpackFS fs <> "_" <> show dir)
-- laziness only affects the pattern side
Lazy _ p -> expWithSuffix suffix p
Lazy _ p -> expWithSuffix dir p
PortErr _ _ -> error "expWithSuffix PortErr!"
SignalExpr lexpr -> lexpr
SignalPat (L l _) -> tupE l []
PortType _ p -> expWithSuffix suffix p
PortType _ p -> expWithSuffix dir p

createInputs
:: (p ~ GhcPs, ?nms :: ExternalNames)
Expand All @@ -711,14 +736,10 @@ createInputs
-> PortDescription PortName
-- ^ master ports
-> LHsExpr p
createInputs Fwd slaves masters = noLoc $ OpApp noExt s2m (varE noSrcSpan (fwdBwdCon ?nms)) m2s
where
m2s = expWithSuffix "_Bwd" masters
s2m = expWithSuffix "_Fwd" slaves
createInputs Bwd slaves masters = noLoc $ OpApp noExt s2m (varE noSrcSpan (fwdBwdCon ?nms)) m2s
createInputs dir slaves masters = noLoc $ OpApp noExt s2m (varE noSrcSpan (fwdBwdCon ?nms)) m2s
where
m2s = expWithSuffix "_Fwd" masters
s2m = expWithSuffix "_Bwd" slaves
m2s = expWithSuffix (revDirec dir) masters
s2m = expWithSuffix dir slaves

decFromBinding :: (p ~ GhcPs, ?nms :: ExternalNames) => GHC.DynFlags -> Int -> Binding (LHsExpr p) PortName -> HsBind p
decFromBinding dflags i Binding {..} = do
Expand Down Expand Up @@ -1025,6 +1046,7 @@ data ExternalNames = ExternalNames
, circuitTTyCon :: GHC.RdrName
, runCircuitName :: GHC.RdrName
, fwdBwdCon :: GHC.RdrName
, trivialBwd :: GHC.RdrName
}

defExternalNames :: ExternalNames
Expand All @@ -1034,4 +1056,5 @@ defExternalNames = ExternalNames
, circuitTTyCon = GHC.Unqual (OccName.mkTcOcc "CircuitT")
, runCircuitName = GHC.Unqual (OccName.mkVarOcc "runCircuit")
, fwdBwdCon = GHC.Unqual (OccName.mkDataOcc ":->")
, trivialBwd = GHC.Unqual (OccName.mkVarOcc "unitBwd")
}
15 changes: 15 additions & 0 deletions tests/unittests.hs
Original file line number Diff line number Diff line change
Expand Up @@ -6,5 +6,20 @@

module Main where

import Circuit
import Clash.Prelude

main :: IO ()
main = pure ()

testIdCircuit :: Circuit (Signal dom Bool) (Signal dom Bool)
testIdCircuit = circuit $ \x -> x

-- test that signals can be duplicated
testDupCircuit :: Circuit (Signal dom Bool) (Signal dom Bool, Signal dom Bool)
testDupCircuit = circuit $ \x -> (x, x)

testDup2Circuit :: Circuit (Signal dom Bool) (Signal dom Bool, Signal dom Bool, Signal dom Bool)
testDup2Circuit = circuit $ \x -> do
y <- idC -< x

Check failure on line 24 in tests/unittests.hs

View workflow job for this annotation

GitHub Actions / ubuntu-latest / ghc 8.10.7

Arrow command found where an expression was expected:

Check failure on line 24 in tests/unittests.hs

View workflow job for this annotation

GitHub Actions / ubuntu-latest / ghc 9.0.2

Arrow command found where an expression was expected:
idC -< (y, y, x)

Check failure on line 25 in tests/unittests.hs

View workflow job for this annotation

GitHub Actions / ubuntu-latest / ghc 8.10.7

Arrow command found where an expression was expected:

Check failure on line 25 in tests/unittests.hs

View workflow job for this annotation

GitHub Actions / ubuntu-latest / ghc 9.0.2

Arrow command found where an expression was expected:

0 comments on commit c25d656

Please sign in to comment.