diff --git a/circuit-notation.cabal b/circuit-notation.cabal index 9f9a0a6..05992d5 100644 --- a/circuit-notation.cabal +++ b/circuit-notation.cabal @@ -38,4 +38,4 @@ Test-Suite library-testsuite type: exitcode-stdio-1.0 main-is: unittests.hs hs-source-dirs: tests - build-depends: base, circuit-notation + build-depends: base, circuit-notation, clash-prelude >= 1.0 diff --git a/shell.nix b/shell.nix index ed4af77..97f2f83 100644 --- a/shell.nix +++ b/shell.nix @@ -8,8 +8,8 @@ stdenv.mkDerivation { buildInputs = [ ghc - cabal-install - haskellPackages.ghcid + # cabal-install + # haskellPackages.ghcid haskellPackages.stylish-haskell ]; diff --git a/src/Circuit.hs b/src/Circuit.hs index 9b2c62c..1b09ec6 100644 --- a/src/Circuit.hs +++ b/src/Circuit.hs @@ -11,19 +11,18 @@ This file contains the 'Circuit' type, that the notation describes. -} {-# LANGUAGE DeriveFunctor #-} +{-# LANGUAGE NoImplicitPrelude #-} +{-# LANGUAGE PatternSynonyms #-} {-# LANGUAGE ScopedTypeVariables #-} {-# LANGUAGE TypeFamilies #-} -{-# LANGUAGE PatternSynonyms #-} -{-# LANGUAGE ViewPatterns #-} -{-# LANGUAGE DataKinds #-} -{-# LANGUAGE TypeOperators #-} -{-# LANGUAGE GADTs #-} +{-# LANGUAGE DataKinds #-} +{-# LANGUAGE GADTs #-} +{-# LANGUAGE TypeOperators #-} module Circuit where -import Clash.Prelude (Domain, Signal, Vec(..)) -import Data.Default +import Clash.Prelude type family Fwd a type family Bwd a @@ -70,3 +69,27 @@ type instance Bwd (Signal dom a) = () -- | Circuit type. newtype Circuit a b = Circuit { runCircuit :: CircuitT a b } type CircuitT a b = (Fwd a :-> Bwd b) -> (Bwd a :-> Fwd b) + +class TrivialBwd a where + unitBwd :: a + +instance TrivialBwd () where + unitBwd = () + +instance (TrivialBwd a, KnownNat n) => TrivialBwd (Vec n a) where + unitBwd = repeat unitBwd + +instance (TrivialBwd a, TrivialBwd b) => TrivialBwd (a,b) where + unitBwd = (unitBwd, unitBwd) + +instance (TrivialBwd a, TrivialBwd b, TrivialBwd c) => TrivialBwd (a,b,c) where + unitBwd = (unitBwd, unitBwd, unitBwd) + +instance (TrivialBwd a, TrivialBwd b, TrivialBwd c, TrivialBwd d) => TrivialBwd (a,b,c,d) where + unitBwd = (unitBwd, unitBwd, unitBwd, unitBwd) + +instance (TrivialBwd a, TrivialBwd b, TrivialBwd c, TrivialBwd d, TrivialBwd e) => TrivialBwd (a,b,c,d,e) where + unitBwd = (unitBwd, unitBwd, unitBwd, unitBwd, unitBwd) + +instance (TrivialBwd a, TrivialBwd b, TrivialBwd c, TrivialBwd d, TrivialBwd e, TrivialBwd f) => TrivialBwd (a,b,c,d,e,f) where + unitBwd = (unitBwd, unitBwd, unitBwd, unitBwd, unitBwd, unitBwd) diff --git a/src/CircuitNotation.hs b/src/CircuitNotation.hs index d5f4b10..030e3ca 100644 --- a/src/CircuitNotation.hs +++ b/src/CircuitNotation.hs @@ -165,6 +165,7 @@ data PortDescription a = Tuple [PortDescription a] | Vec SrcSpan [PortDescription a] | Ref a + | RefMulticast a | Lazy SrcSpan (PortDescription a) | SignalExpr (LHsExpr GhcPs) | SignalPat (LPat GhcPs) @@ -349,6 +350,7 @@ portTypeSigM = \case L.use (portVarTypes . L.at fs) <&> \case Nothing -> varT loc (GHC.unpackFS fs <> "Ty") Just (_sigLoc, sig) -> sig + RefMulticast p -> portTypeSigM (Ref p) PortErr loc msgdoc -> do dflags <- GHC.getDynFlags unsafePerformIO . throwOneError $ @@ -643,35 +645,59 @@ checkCircuit = do topNames = portNames Slave slaves <> portNames Master masters nameMap = Map.fromListWith mappend $ topNames <> concatMap bindingNames binds - L.iforM_ nameMap \name occ -> + duplicateMasters <- concat <$> L.iforM nameMap \name occ -> case occ of - ([_], [_]) -> pure () + ([_], [_]) -> pure [] (ss, ms) -> do unless (head (unpackFS name) == '_') $ do when (null ms) $ errM (head ss) $ "Slave port " <> show name <> " has no associated master" when (null ss) $ errM (head ms) $ "Master port " <> show name <> " has no associated slave" -- would be nice to show locations of all occurrences here, not sure how to do that while -- keeping ghc api - when (length ms > 1) $ - errM (head ms) $ "Master port " <> show name <> " defined " <> show (length ms) <> " times" when (length ss > 1) $ errM (head ss) $ "Slave port " <> show name <> " defined " <> show (length ss) <> " times" + -- if master is defined multiple times, we try to broadcast it + if length ms > 1 + then pure [name] + else pure [] + + let + modifyMulticast = \case + Ref p@(PortName _ a) | a `elem` duplicateMasters -> RefMulticast p + p -> p + + -- update relevant master ports to be multicast + circuitSlaves %= L.transform modifyMulticast + circuitMasters %= L.transform modifyMulticast + circuitBinds . L.mapped %= \b -> b + { bIn = L.transform modifyMulticast (bIn b), + bOut = L.transform modifyMulticast (bOut b) + } + -- Creating ------------------------------------------------------------ -bindWithSuffix :: (p ~ GhcPs, ?nms :: ExternalNames) => GHC.DynFlags -> String -> PortDescription PortName -> LPat p -bindWithSuffix dflags suffix = \case - Tuple ps -> tildeP noSrcSpan $ tupP $ fmap (bindWithSuffix dflags suffix) ps - Vec s ps -> vecP s $ fmap (bindWithSuffix dflags suffix) ps - Ref (PortName loc fs) -> varP loc (GHC.unpackFS fs <> suffix) +data Direc = Fwd | Bwd deriving Show + +bindWithSuffix :: (p ~ GhcPs, ?nms :: ExternalNames) => GHC.DynFlags -> Direc -> PortDescription PortName -> LPat p +bindWithSuffix dflags dir = \case + Tuple ps -> tildeP noSrcSpan $ tupP $ fmap (bindWithSuffix dflags dir) ps + Vec s ps -> vecP s $ fmap (bindWithSuffix dflags dir) ps + Ref (PortName loc fs) -> varP loc (GHC.unpackFS fs <> "_" <> show dir) + RefMulticast (PortName loc fs) -> case dir of + Bwd -> L loc (WildPat noExt) + Fwd -> varP loc (GHC.unpackFS fs <> "_" <> show dir) PortErr loc msgdoc -> unsafePerformIO . throwOneError $ Err.mkLongErrMsg dflags loc Outputable.alwaysQualify (Outputable.text "Unhandled bind") msgdoc - Lazy loc p -> tildeP loc $ bindWithSuffix dflags suffix p + Lazy loc p -> tildeP loc $ bindWithSuffix dflags dir p SignalExpr (L l _) -> L l (WildPat noExt) SignalPat lpat -> lpat - PortType _ p -> bindWithSuffix dflags suffix p + PortType _ p -> bindWithSuffix dflags dir p -data Direc = Fwd | Bwd +revDirec :: Direc -> Direc +revDirec = \case + Fwd -> Bwd + Bwd -> Fwd bindOutputs :: (p ~ GhcPs, ?nms :: ExternalNames) @@ -682,26 +708,25 @@ bindOutputs -> PortDescription PortName -- ^ master ports -> LPat p -bindOutputs dflags Fwd slaves masters = noLoc $ conPatIn (noLoc (fwdBwdCon ?nms)) (InfixCon m2s s2m) +bindOutputs dflags direc slaves masters = noLoc $ conPatIn (noLoc (fwdBwdCon ?nms)) (InfixCon m2s s2m) where - m2s = bindWithSuffix dflags "_Fwd" masters - s2m = bindWithSuffix dflags "_Bwd" slaves -bindOutputs dflags Bwd slaves masters = noLoc $ conPatIn (noLoc (fwdBwdCon ?nms)) (InfixCon m2s s2m) - where - m2s = bindWithSuffix dflags "_Bwd" masters - s2m = bindWithSuffix dflags "_Fwd" slaves - -expWithSuffix :: p ~ GhcPs => String -> PortDescription PortName -> LHsExpr p -expWithSuffix suffix = \case - Tuple ps -> tupE noSrcSpan $ fmap (expWithSuffix suffix) ps - Vec s ps -> vecE s $ fmap (expWithSuffix suffix) ps - Ref (PortName loc fs) -> varE loc (var $ GHC.unpackFS fs <> suffix) + m2s = bindWithSuffix dflags direc masters + s2m = bindWithSuffix dflags (revDirec direc) slaves + +expWithSuffix :: (p ~ GhcPs, ?nms :: ExternalNames) => Direc -> PortDescription PortName -> LHsExpr p +expWithSuffix dir = \case + Tuple ps -> tupE noSrcSpan $ fmap (expWithSuffix dir) ps + Vec s ps -> vecE s $ fmap (expWithSuffix dir) ps + Ref (PortName loc fs) -> varE loc (var $ GHC.unpackFS fs <> "_" <> show dir) + RefMulticast (PortName loc fs) -> case dir of + Bwd -> varE noSrcSpan (trivialBwd ?nms) + Fwd -> varE loc (var $ GHC.unpackFS fs <> "_" <> show dir) -- laziness only affects the pattern side - Lazy _ p -> expWithSuffix suffix p + Lazy _ p -> expWithSuffix dir p PortErr _ _ -> error "expWithSuffix PortErr!" SignalExpr lexpr -> lexpr SignalPat (L l _) -> tupE l [] - PortType _ p -> expWithSuffix suffix p + PortType _ p -> expWithSuffix dir p createInputs :: (p ~ GhcPs, ?nms :: ExternalNames) @@ -711,14 +736,10 @@ createInputs -> PortDescription PortName -- ^ master ports -> LHsExpr p -createInputs Fwd slaves masters = noLoc $ OpApp noExt s2m (varE noSrcSpan (fwdBwdCon ?nms)) m2s - where - m2s = expWithSuffix "_Bwd" masters - s2m = expWithSuffix "_Fwd" slaves -createInputs Bwd slaves masters = noLoc $ OpApp noExt s2m (varE noSrcSpan (fwdBwdCon ?nms)) m2s +createInputs dir slaves masters = noLoc $ OpApp noExt s2m (varE noSrcSpan (fwdBwdCon ?nms)) m2s where - m2s = expWithSuffix "_Fwd" masters - s2m = expWithSuffix "_Bwd" slaves + m2s = expWithSuffix (revDirec dir) masters + s2m = expWithSuffix dir slaves decFromBinding :: (p ~ GhcPs, ?nms :: ExternalNames) => GHC.DynFlags -> Int -> Binding (LHsExpr p) PortName -> HsBind p decFromBinding dflags i Binding {..} = do @@ -1025,6 +1046,7 @@ data ExternalNames = ExternalNames , circuitTTyCon :: GHC.RdrName , runCircuitName :: GHC.RdrName , fwdBwdCon :: GHC.RdrName + , trivialBwd :: GHC.RdrName } defExternalNames :: ExternalNames @@ -1034,4 +1056,5 @@ defExternalNames = ExternalNames , circuitTTyCon = GHC.Unqual (OccName.mkTcOcc "CircuitT") , runCircuitName = GHC.Unqual (OccName.mkVarOcc "runCircuit") , fwdBwdCon = GHC.Unqual (OccName.mkDataOcc ":->") + , trivialBwd = GHC.Unqual (OccName.mkVarOcc "unitBwd") } diff --git a/tests/unittests.hs b/tests/unittests.hs index 87b18f2..5a12596 100644 --- a/tests/unittests.hs +++ b/tests/unittests.hs @@ -6,5 +6,20 @@ module Main where +import Circuit +import Clash.Prelude + main :: IO () main = pure () + +testIdCircuit :: Circuit (Signal dom Bool) (Signal dom Bool) +testIdCircuit = circuit $ \x -> x + +-- test that signals can be duplicated +testDupCircuit :: Circuit (Signal dom Bool) (Signal dom Bool, Signal dom Bool) +testDupCircuit = circuit $ \x -> (x, x) + +testDup2Circuit :: Circuit (Signal dom Bool) (Signal dom Bool, Signal dom Bool, Signal dom Bool) +testDup2Circuit = circuit $ \x -> do + y <- idC -< x + idC -< (y, y, x)