From 3732d4c6f2c2c7daa377808fdb7a0304e2a7951d Mon Sep 17 00:00:00 2001 From: Martijn Bastiaan Date: Sat, 21 Oct 2023 18:01:04 +0200 Subject: [PATCH] Add ghc-9.2 support --- .github/workflows/ci.yml | 16 +- cabal.project | 6 + circuit-notation.cabal | 16 +- example/Testing.hs | 22 +- src/CircuitNotation.hs | 438 +++++++++++++++++++++--------- src/GHC/Types/Unique/Map.hs | 213 +++++++++++++++ src/GHC/Types/Unique/Map/Extra.hs | 19 ++ 7 files changed, 579 insertions(+), 151 deletions(-) create mode 100644 src/GHC/Types/Unique/Map.hs create mode 100644 src/GHC/Types/Unique/Map/Extra.hs diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 2cd8dbd..97bd51f 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -2,9 +2,7 @@ name: CI # Trigger the workflow on push or pull request, but only for the master branch on: - pull_request: push: - branches: [master] concurrency: group: ${{ github.head_ref || github.run_id }} @@ -18,7 +16,7 @@ jobs: container: image: 'nixos/nix:2.3.6' steps: - - uses: actions/checkout@v2 + - uses: actions/checkout@v3 - name: Build run: | @@ -31,17 +29,18 @@ jobs: fail-fast: false matrix: os: [ubuntu-latest] - cabal: ["3.2"] + cabal: ["3.6"] ghc: - "8.6.5" - "8.10.7" - "9.0.2" + - "9.2.8" steps: - - uses: actions/checkout@v2 + - uses: actions/checkout@v3 # if: github.event.action == 'opened' || github.event.action == 'synchronize' || github.event.ref == 'refs/heads/master' - - uses: haskell/actions/setup@v1 + - uses: haskell-actions/setup@v2 id: setup-haskell-cabal name: Setup Haskell with: @@ -50,21 +49,22 @@ jobs: - name: Freeze run: | + cabal update cabal freeze - uses: actions/cache@v1 name: Cache ~/.cabal/store with: path: ${{ steps.setup-haskell-cabal.outputs.cabal-store }} - key: ${{ runner.os }}-${{ matrix.ghc }}-${{ hashFiles('cabal.project.freeze') }} + key: ${{ runner.os }}-${{ matrix.ghc }}-${{ hashFiles('cabal.project.freeze', 'cabal.project') }} restore-keys: | ${{ runner.os }}-${{ matrix.ghc }}- - name: Build run: | - cabal configure --enable-tests --enable-benchmarks --test-show-details=direct cabal build all --write-ghc-environment-files=always ghc -iexample Example + ghc -iexample Testing - name: Test run: | diff --git a/cabal.project b/cabal.project index e6fdbad..6d8696a 100644 --- a/cabal.project +++ b/cabal.project @@ -1 +1,7 @@ packages: . + +source-repository-package + type: git + location: https://github.com/clash-lang/clash-compiler.git + tag: 5b055fb3fcdaf6e2b89cb86486d7280fc781fa84 + subdir: clash-prelude \ No newline at end of file diff --git a/circuit-notation.cabal b/circuit-notation.cabal index 381d506..8eaf897 100644 --- a/circuit-notation.cabal +++ b/circuit-notation.cabal @@ -14,21 +14,29 @@ cabal-version: >=1.10 library exposed-modules: CircuitNotation Circuit - -- other-modules: + + if impl(ghc < 9.2) + other-modules: + GHC.Types.Unique.Map + + other-modules: + GHC.Types.Unique.Map.Extra + -- other-extensions: build-depends: base >=4.12 , clash-prelude >= 1.0 , containers , data-default - , ghc (>=8.6 && <8.8) || (>=8.10 && < 9.2) - , syb + , ghc (>=8.6 && <8.8) || (>=8.10 && < 9.4) , lens , mtl - , pretty , parsec + , pretty , pretty-show + , syb , template-haskell + , unordered-containers hs-source-dirs: src default-language: Haskell2010 ghc-options: -Wall diff --git a/example/Testing.hs b/example/Testing.hs index 60749d4..b8c8535 100644 --- a/example/Testing.hs +++ b/example/Testing.hs @@ -10,23 +10,25 @@ For testing the circuit notation. -} -{-# LANGUAGE Arrows #-} -{-# LANGUAGE BlockArguments #-} -{-# LANGUAGE GADTs #-} --- {-# LANGUAGE NoMonomorphismRestriction #-} -{-# LANGUAGE ScopedTypeVariables #-} -{-# LANGUAGE DataKinds #-} - --- For testing: -{-# LANGUAGE ScopedTypeVariables #-} +{-# LANGUAGE CPP #-} +#if __GLASGOW_HASKELL__ < 810 +{-# LANGUAGE Arrows #-} +#endif +{-# LANGUAGE BlockArguments #-} +{-# LANGUAGE DataKinds#-} +{-# LANGUAGE GADTs #-} +{-# LANGUAGE ScopedTypeVariables #-} +{-# LANGUAGE ScopedTypeVariables #-} {-# OPTIONS -fplugin=CircuitNotation #-} {-# OPTIONS -Wno-unused-local-binds #-} {-# OPTIONS -Wno-missing-signatures #-} -module Example where +module Testing where import Circuit +import Clash.Prelude hiding (undefined) +import Clash.Signal.Internal -- import Data.Default -- no c = diff --git a/src/CircuitNotation.hs b/src/CircuitNotation.hs index a3f3feb..66c91d7 100644 --- a/src/CircuitNotation.hs +++ b/src/CircuitNotation.hs @@ -17,6 +17,7 @@ Notation for describing the 'Circuit' type. {-# LANGUAGE ImplicitParams #-} {-# LANGUAGE LambdaCase #-} {-# LANGUAGE OverloadedStrings #-} +{-# LANGUAGE PackageImports #-} {-# LANGUAGE RecordWildCards #-} {-# LANGUAGE ScopedTypeVariables #-} {-# LANGUAGE TemplateHaskell #-} @@ -25,6 +26,9 @@ Notation for describing the 'Circuit' type. {-# OPTIONS_GHC -Wno-unused-top-binds #-} +-- TODO: Fix warnings introduced by GHC 9.2 +{-# OPTIONS_GHC -Wno-incomplete-uni-patterns #-} + module CircuitNotation ( plugin , mkPlugin @@ -39,29 +43,43 @@ import Data.Default import Data.Maybe (fromMaybe) #if __GLASGOW_HASKELL__ >= 900 #else -import SrcLoc +import SrcLoc hiding (noLoc) #endif import System.IO.Unsafe import Data.Typeable -- ghc import qualified Language.Haskell.TH as TH +import qualified GHC + +#if __GLASGOW_HASKELL__ >= 902 +import GHC.Types.SourceError (throwOneError) +import qualified GHC.Driver.Env as GHC +import qualified GHC.Types.SourceText as GHC +import qualified GHC.Types.SourceError as GHC +import qualified GHC.Driver.Ppr as GHC +#elif __GLASGOW_HASKELL__ >= 900 +import GHC.Driver.Types (throwOneError) +import qualified GHC.Driver.Types as GHC +#else +import HscTypes (throwOneError) +#endif + +#if __GLASGOW_HASKELL__ == 900 +import qualified GHC.Parser.Annotation as GHC +#endif #if __GLASGOW_HASKELL__ >= 900 import GHC.Data.Bag import GHC.Data.FastString (mkFastString, unpackFS) -import GHC.Driver.Types (throwOneError) import GHC.Plugins (PromotionFlag(NotPromoted)) -import GHC.Types.SrcLoc +import GHC.Types.SrcLoc hiding (noLoc) import qualified GHC.Data.FastString as GHC import qualified GHC.Driver.Plugins as GHC import qualified GHC.Driver.Session as GHC -import qualified GHC.Driver.Types as GHC -import qualified GHC.Parser.Annotation as GHC import qualified GHC.Types.Basic as GHC import qualified GHC.Types.Name.Occurrence as OccName import qualified GHC.Types.Name.Reader as GHC -import qualified GHC.Types.SrcLoc as GHC import qualified GHC.Utils.Error as Err import qualified GHC.Utils.Outputable as GHC import qualified GHC.Utils.Outputable as Outputable @@ -70,7 +88,6 @@ import Bag import qualified ErrUtils as Err import FastString (mkFastString, unpackFS) import qualified GhcPlugins as GHC -import HscTypes (throwOneError) import qualified OccName import qualified Outputable #endif @@ -78,6 +95,9 @@ import qualified Outputable #if __GLASGOW_HASKELL__ > 808 import qualified GHC.ThToHs as Convert import GHC.Hs +#if __GLASGOW_HASKELL__ >= 902 + hiding (locA) +#endif #else import qualified Convert import HsSyn hiding (noExt) @@ -93,13 +113,17 @@ import BasicTypes (PromotionFlag( NotPromoted )) import GHC.Builtin.Types (eqTyCon_RDR) #endif +#if __GLASGOW_HASKELL__ >= 902 +import "ghc" GHC.Types.Unique.Map +#else +import GHC.Types.Unique.Map +#endif + +import GHC.Types.Unique.Map.Extra + -- clash-prelude import Clash.Prelude (Signal, Vec((:>), Nil)) --- containers -import Data.Map (Map) -import qualified Data.Map as Map - -- lens import qualified Control.Lens as L import Control.Lens.Operators @@ -143,19 +167,70 @@ isFletching = isSomeVar "-<" imap :: (Int -> a -> b) -> [a] -> [b] imap f = zipWith f [0 ..] -#if __GLASGOW_HASKELL__ > 808 +-- Utils for backwards compat ------------------------------------------ +#if __GLASGOW_HASKELL__ >= 902 +type MsgDoc = Outputable.SDoc +type ErrMsg = Err.MsgEnvelope Err.DecoratedSDoc + +locA :: SrcSpanAnn' a -> SrcSpan +locA = GHC.locA + +noAnnSortKey :: AnnSortKey +noAnnSortKey = NoAnnSortKey +#else +type MsgDoc = Err.MsgDoc +type ErrMsg = Err.ErrMsg +type SrcSpanAnnA = SrcSpan +type SrcSpanAnnL = SrcSpan + +noSrcSpanA :: SrcSpan +noSrcSpanA = noSrcSpan + +noAnnSortKey :: NoExtField +noAnnSortKey = noExtField + +emptyComments :: NoExtField +emptyComments = noExtField + +locA :: a -> a +locA = id +#endif + +#if __GLASGOW_HASKELL__ > 900 +noExt :: EpAnn ann +noExt = EpAnnNotUsed +#elif __GLASGOW_HASKELL__ > 808 noExt :: NoExtField noExt = noExtField #else noExt :: NoExt noExt = NoExt + +noExtField :: NoExt +noExtField = NoExt + +type NoExtField = NoExt +#endif + +mkErrMsg :: GHC.DynFlags -> SrcSpan -> Outputable.PrintUnqualified -> Outputable.SDoc -> ErrMsg +#if __GLASGOW_HASKELL__ >= 902 +mkErrMsg _ = Err.mkMsgEnvelope +#else +mkErrMsg = Err.mkErrMsg +#endif + +mkLongErrMsg :: GHC.DynFlags -> SrcSpan -> Outputable.PrintUnqualified -> Outputable.SDoc -> Outputable.SDoc -> ErrMsg +#if __GLASGOW_HASKELL__ >= 902 +mkLongErrMsg _ = Err.mkLongMsgEnvelope +#else +mkLongErrMsg = Err.mkLongErrMsg #endif -- Types --------------------------------------------------------------- -- | The name given to a 'port', i.e. the name of a variable either to the left of a '<-' or to the -- right of a '-<'. -data PortName = PortName SrcSpan GHC.FastString +data PortName = PortName SrcSpanAnnA GHC.FastString deriving (Eq) instance Show PortName where @@ -163,14 +238,14 @@ instance Show PortName where data PortDescription a = Tuple [PortDescription a] - | Vec SrcSpan [PortDescription a] + | Vec SrcSpanAnnA [PortDescription a] | Ref a | RefMulticast a - | Lazy SrcSpan (PortDescription a) + | Lazy SrcSpanAnnA (PortDescription a) | SignalExpr (LHsExpr GhcPs) | SignalPat (LPat GhcPs) | PortType (LHsType GhcPs) (PortDescription a) - | PortErr SrcSpan Err.MsgDoc + | PortErr SrcSpanAnnA MsgDoc deriving (Foldable, Functor, Traversable) _Ref :: L.Prism' (PortDescription a) a @@ -196,7 +271,7 @@ data Binding exp l = Binding deriving (Functor) data CircuitState dec exp nm = CircuitState - { _cErrors :: Bag Err.ErrMsg + { _cErrors :: Bag ErrMsg , _counter :: Int -- ^ unique counter for generated variables , _circuitSlaves :: PortDescription nm @@ -209,13 +284,13 @@ data CircuitState dec exp nm = CircuitState -- ^ @out <- circuit <- in@ statements , _circuitMasters :: PortDescription nm -- ^ ports bound at the first lambda of a circuit - , _portVarTypes :: Map GHC.FastString (SrcSpan, LHsType GhcPs) + , _portVarTypes :: UniqMap GHC.FastString (SrcSpanAnnA, LHsType GhcPs) -- ^ types of single variable ports , _portTypes :: [(LHsType GhcPs, PortDescription nm)] -- ^ type of more 'complicated' things (very far from vigorous) , _uniqueCounter :: Int -- ^ counter to keep internal variables "unique" - , _circuitLoc :: SrcSpan + , _circuitLoc :: SrcSpanAnnA -- ^ span of the circuit expression } @@ -223,7 +298,9 @@ L.makeLenses 'CircuitState -- | The monad used when running a single circuit. newtype CircuitM a = CircuitM (StateT (CircuitState (LHsBind GhcPs) (LHsExpr GhcPs) PortName) GHC.Hsc a) - deriving (Functor, Applicative, Monad, MonadIO, MonadState (CircuitState (LHsBind GhcPs) (LHsExpr GhcPs) PortName)) + deriving (Functor, Applicative, Monad, MonadIO, MonadState (CircuitState (GenLocated SrcSpanAnnA (HsBindLR GhcPs GhcPs)) (GenLocated SrcSpanAnnA (HsExpr GhcPs)) PortName)) + +-- , MonadState (CircuitState (LHsBind GhcPs) (LHsExpr GhcPs) PortName) instance GHC.HasDynFlags CircuitM where getDynFlags = (CircuitM . lift) GHC.getDynFlags @@ -238,73 +315,105 @@ runCircuitM (CircuitM m) = do , _circuitLets = [] , _circuitBinds = [] , _circuitMasters = Tuple [] - , _portVarTypes = Map.empty + , _portVarTypes = emptyUniqMap , _portTypes = [] , _uniqueCounter = 1 - , _circuitLoc = noSrcSpan + , _circuitLoc = noSrcSpanA } (a, s) <- runStateT m emptyCircuitState let errs = _cErrors s unless (isEmptyBag errs) $ liftIO . throwIO $ GHC.mkSrcErr errs pure a + errM :: SrcSpan -> String -> CircuitM () errM loc msg = do dflags <- GHC.getDynFlags let errMsg = Err.mkLocMessageAnn Nothing Err.SevFatal loc (Outputable.text msg) - cErrors %= consBag (Err.mkErrMsg dflags loc Outputable.alwaysQualify errMsg) + cErrors %= consBag (mkErrMsg dflags loc Outputable.alwaysQualify errMsg) -- ghc helpers --------------------------------------------------------- -- It's very possible that most of these are already in the ghc library in some form. It's not the -- easiest library to discover these kind of functions. -conPatIn :: (p ~ GhcPs) => Located GHC.RdrName -> HsConPatDetails p -> Pat p +#if __GLASGOW_HASKELL__ >= 902 +conPatIn :: GenLocated SrcSpanAnnN GHC.RdrName -> HsConPatDetails GhcPs -> Pat GhcPs +#else +conPatIn :: Located GHC.RdrName -> HsConPatDetails GhcPs -> Pat GhcPs +#endif #if __GLASGOW_HASKELL__ >= 900 -conPatIn loc con = ConPat noExtField loc con +conPatIn loc con = ConPat noExt loc con #else conPatIn loc con = ConPatIn loc con #endif +#if __GLASGOW_HASKELL__ >= 902 +noEpAnn :: GenLocated SrcSpan e -> GenLocated (SrcAnn ann) e +noEpAnn (L l e) = L (SrcSpanAnn EpAnnNotUsed l) e + +noLoc :: e -> GenLocated (SrcAnn ann) e +noLoc = noEpAnn . GHC.noLoc +#else +noLoc :: e -> Located e +noLoc = GHC.noLoc +#endif + tupP :: p ~ GhcPs => [LPat p] -> LPat p tupP [pat] = pat tupP pats = noLoc $ TuplePat noExt pats GHC.Boxed -vecP :: p ~ GhcPs => SrcSpan -> [LPat p] -> LPat p +vecP :: SrcSpanAnnA -> [LPat GhcPs] -> LPat GhcPs vecP srcLoc = \case - [] -> go srcLoc [] - as -> L srcLoc $ ParPat noExt $ go srcLoc as + [] -> go [] + as -> L srcLoc $ ParPat noExt $ go as where - go loc (p@(L l _):pats) = L loc $ conPatIn (L l (thName '(:>))) (InfixCon p (go loc pats)) - go loc [] = L loc $ WildPat noExt + go :: [LPat GhcPs] -> LPat GhcPs + go (p@(L l0 _):pats) = + let +#if __GLASGOW_HASKELL__ >= 902 + l1 = l0 `seq` noSrcSpanA +#else + l1 = l0 +#endif + in + L srcLoc $ conPatIn (L l1 (thName '(:>))) (InfixCon p (go pats)) + go [] = L srcLoc $ WildPat noExtField -varP :: p ~ GhcPs => SrcSpan -> String -> LPat p -varP loc nm = L loc $ VarPat noExt (L loc $ var nm) +varP :: SrcSpanAnnA -> String -> LPat GhcPs +varP loc nm = L loc $ VarPat noExtField (noLoc $ var nm) -tildeP :: p ~ GhcPs => SrcSpan -> LPat p -> LPat p +tildeP :: SrcSpanAnnA -> LPat GhcPs -> LPat GhcPs tildeP loc lpat = L loc (LazyPat noExt lpat) -tupT :: p ~ GhcPs => [LHsType p] -> LHsType p +hsBoxedTuple :: GHC.HsTupleSort +#if __GLASGOW_HASKELL__ >= 902 +hsBoxedTuple = HsBoxedOrConstraintTuple +#else +hsBoxedTuple = HsBoxedTuple +#endif + +tupT :: [LHsType GhcPs] -> LHsType GhcPs tupT [ty] = ty -tupT tys = noLoc $ HsTupleTy noExt HsBoxedTuple tys +tupT tys = noLoc $ HsTupleTy noExt hsBoxedTuple tys -vecT :: p ~ GhcPs => SrcSpan -> [LHsType p] -> LHsType p +vecT :: SrcSpanAnnA -> [LHsType GhcPs] -> LHsType GhcPs vecT s [] = L s $ HsParTy noExt (conT s (thName ''Vec) `appTy` tyNum s 0 `appTy` (varT s (genLocName s "vec"))) vecT s tys = L s $ HsParTy noExt (conT s (thName ''Vec) `appTy` tyNum s (length tys) `appTy` head tys) -tyNum :: p ~ GhcPs => SrcSpan -> Int -> LHsType p -tyNum s i = L s (HsTyLit noExt (HsNumTy GHC.NoSourceText (fromIntegral i))) +tyNum :: SrcSpanAnnA -> Int -> LHsType GhcPs +tyNum s i = L s (HsTyLit noExtField (HsNumTy GHC.NoSourceText (fromIntegral i))) -appTy :: p ~ GhcPs => LHsType p -> LHsType p -> LHsType p -appTy a b = L noSrcSpan (HsAppTy noExt a (parenthesizeHsType GHC.appPrec b)) +appTy :: LHsType GhcPs -> LHsType GhcPs -> LHsType GhcPs +appTy a b = noLoc (HsAppTy noExtField a (parenthesizeHsType GHC.appPrec b)) -appE :: p ~ GhcPs => LHsExpr p -> LHsExpr p -> LHsExpr p -appE fun arg = L noSrcSpan $ HsApp noExt fun (parenthesizeHsExpr GHC.appPrec arg) +appE :: LHsExpr GhcPs -> LHsExpr GhcPs -> LHsExpr GhcPs +appE fun arg = L noSrcSpanA $ HsApp noExt fun (parenthesizeHsExpr GHC.appPrec arg) -varE :: p ~ GhcPs => SrcSpan -> GHC.RdrName -> LHsExpr p -varE loc rdr = L loc (HsVar noExt (L loc rdr)) +varE :: SrcSpanAnnA -> GHC.RdrName -> LHsExpr GhcPs +varE loc rdr = L loc (HsVar noExtField (noLoc rdr)) -parenE :: p ~ GhcPs => LHsExpr p -> LHsExpr p +parenE :: LHsExpr GhcPs -> LHsExpr GhcPs parenE e@(L l _) = L l (HsPar noExt e) var :: String -> GHC.RdrName @@ -316,7 +425,7 @@ tyVar = GHC.Unqual . OccName.mkTyVarOcc tyCon :: String -> GHC.RdrName tyCon = GHC.Unqual . OccName.mkTcOcc -vecE :: p ~ GhcPs => SrcSpan -> [LHsExpr p] -> LHsExpr p +vecE :: SrcSpanAnnA -> [LHsExpr GhcPs] -> LHsExpr GhcPs vecE srcLoc = \case [] -> go srcLoc [] as -> parenE $ go srcLoc as @@ -324,11 +433,15 @@ vecE srcLoc = \case go loc (e@(L l _):es) = L loc $ OpApp noExt e (varE l (thName '(:>))) (go loc es) go loc [] = varE loc (thName 'Nil) -tupE :: p ~ GhcPs => SrcSpan -> [LHsExpr p] -> LHsExpr p +tupE :: p ~ GhcPs => SrcSpanAnnA -> [LHsExpr p] -> LHsExpr p tupE _ [ele] = ele tupE loc elems = L loc $ ExplicitTuple noExt tupArgs GHC.Boxed where +#if __GLASGOW_HASKELL__ >= 902 + tupArgs = map (Present noExt) elems +#else tupArgs = map (\arg@(L l _) -> L l (Present noExt arg)) elems +#endif unL :: Located a -> a unL (L _ a) = a @@ -347,14 +460,26 @@ portTypeSigM = \case Tuple ps -> tupT <$> mapM portTypeSigM ps Vec s ps -> vecT s <$> mapM portTypeSigM ps Ref (PortName loc fs) -> do - L.use (portVarTypes . L.at fs) <&> \case - Nothing -> varT loc (GHC.unpackFS fs <> "Ty") - Just (_sigLoc, sig) -> sig + L.use portVarTypes >>= \pvt -> + case lookupUniqMap pvt fs of + Nothing -> + let + -- GHC >= 9.2 interprets any type variable name starting with a "_" as + -- a wildcard and throws an error suggesting a concrete type. To prevent + -- this error from cropping up, we prefix it with "dflt" if we detect an + -- underscore. Note that we see "_" in cases where the user wants to ignore + -- a certain protocol, hence then name "dflt". + s0 = GHC.unpackFS fs + s1 | '_':_ <- s0 = "dflt" <> s0 + | otherwise = s0 + in + pure $ varT loc (s1 <> "Ty") + Just (_sigLoc, sig) -> pure sig RefMulticast p -> portTypeSigM (Ref p) PortErr loc msgdoc -> do dflags <- GHC.getDynFlags unsafePerformIO . throwOneError $ - Err.mkLongErrMsg dflags loc Outputable.alwaysQualify (Outputable.text "portTypeSig") msgdoc + mkLongErrMsg dflags (locA loc) Outputable.alwaysQualify (Outputable.text "portTypeSig") msgdoc Lazy _ p -> portTypeSigM p SignalExpr (L l _) -> do n <- uniqueCounter <<+= 1 @@ -365,8 +490,10 @@ portTypeSigM = \case PortType _ p -> portTypeSigM p -- | Generate a "unique" name by appending the location as a string. -genLocName :: SrcSpan -> String -> String -#if __GLASGOW_HASKELL__ >= 900 +genLocName :: SrcSpanAnnA -> String -> String +#if __GLASGOW_HASKELL__ >= 902 +genLocName (locA -> GHC.RealSrcSpan rss _) prefix = +#elif __GLASGOW_HASKELL__ >= 900 genLocName (GHC.RealSrcSpan rss _) prefix = #else genLocName (GHC.RealSrcSpan rss) prefix = @@ -376,7 +503,7 @@ genLocName (GHC.RealSrcSpan rss) prefix = genLocName _ prefix = prefix -- | Extract a simple lambda into inputs and body. -simpleLambda :: HsExpr p -> Maybe ([LPat p], LHsExpr p) +simpleLambda :: HsExpr GhcPs -> Maybe ([LPat GhcPs], LHsExpr GhcPs) simpleLambda expr = do HsLam _ (MG _x alts _origin) <- Just expr L _ [L _ (Match _matchX _matchContext matchPats matchGr)] <- Just alts @@ -387,9 +514,9 @@ simpleLambda expr = do -- | Create a simple let binding. letE :: p ~ GhcPs - => SrcSpan + => SrcSpanAnnA -- ^ location for top level let bindings - -> [LSig GhcPs] + -> [LSig p] -- ^ type signatures -> [LHsBind p] -- ^ let bindings @@ -398,32 +525,43 @@ letE -> LHsExpr p letE loc sigs binds expr = L loc (HsLet noExt localBinds expr) where +#if __GLASGOW_HASKELL__ >= 902 + localBinds :: HsLocalBinds GhcPs + localBinds = HsValBinds noExt valBinds +#else localBinds :: LHsLocalBindsLR GhcPs GhcPs localBinds = L loc $ HsValBinds noExt valBinds +#endif valBinds :: HsValBindsLR GhcPs GhcPs - valBinds = ValBinds noExt hsBinds sigs + valBinds = ValBinds noAnnSortKey hsBinds sigs hsBinds :: LHsBindsLR GhcPs GhcPs hsBinds = listToBag binds -- | Simple construction of a lambda expression -lamE :: p ~ GhcPs => [LPat p] -> LHsExpr p -> LHsExpr p -lamE pats expr = noLoc $ HsLam noExt mg +lamE :: [LPat GhcPs] -> LHsExpr GhcPs -> LHsExpr GhcPs +lamE pats expr = noLoc $ HsLam noExtField mg where - mg = MG noExt matches GHC.Generated + mg :: MatchGroup GhcPs (GenLocated SrcSpanAnnA (HsExpr GhcPs)) + mg = MG noExtField matches GHC.Generated - matches :: Located [LMatch GhcPs (LHsExpr GhcPs)] + matches :: GenLocated SrcSpanAnnL [GenLocated SrcSpanAnnA (Match GhcPs (GenLocated SrcSpanAnnA (HsExpr GhcPs)))] matches = noLoc $ [singleMatch] - singleMatch :: LMatch GhcPs (LHsExpr GhcPs) + singleMatch :: GenLocated SrcSpanAnnA (Match GhcPs (GenLocated SrcSpanAnnA (HsExpr GhcPs))) singleMatch = noLoc $ Match noExt LambdaExpr pats grHss - grHss :: GRHSs GhcPs (LHsExpr GhcPs) - grHss = GRHSs noExt [grHs] (noLoc $ EmptyLocalBinds noExt) + grHss :: GRHSs GhcPs (GenLocated SrcSpanAnnA (HsExpr GhcPs)) + grHss = GRHSs emptyComments [grHs] $ +#if __GLASGOW_HASKELL__ >= 902 + (EmptyLocalBinds noExtField) +#else + (noLoc (EmptyLocalBinds noExtField)) +#endif - grHs :: LGRHS GhcPs (LHsExpr GhcPs) - grHs = noLoc $ GRHS noExt [] expr + grHs :: GenLocated SrcSpan (GRHS GhcPs (GenLocated SrcSpanAnnA (HsExpr GhcPs))) + grHs = L noSrcSpan $ GRHS noExt [] expr -- | Kinda hacky function to get a string name for named ports. fromRdrName :: GHC.RdrName -> GHC.FastString @@ -453,10 +591,7 @@ parseCircuit = \case e -> circuitBody e -- | The main part of a circuit expression. Either a do block or simple rearranging case. -circuitBody - :: p ~ GhcPs - => LHsExpr p - -> CircuitM () +circuitBody :: LHsExpr GhcPs -> CircuitM () circuitBody = \case -- strip out parenthesis L _ (HsPar _ lexp) -> circuitBody lexp @@ -490,7 +625,7 @@ circuitBody = \case bodyBinding (Just ref) (bod) circuitMasters .= ref - stmt -> errM finLoc ("Unhandled final stmt " <> show (Data.toConstr stmt)) + stmt -> errM (locA finLoc) ("Unhandled final stmt " <> show (Data.toConstr stmt)) -- the simple case without do notation L loc master -> do @@ -499,17 +634,20 @@ circuitBody = \case -- | Handle a single statement. handleStmtM - :: (p ~ GhcPs, loc ~ SrcSpan, idL ~ GhcPs, idR ~ GhcPs) - => Located (StmtLR idL idR (LHsExpr p)) + :: GenLocated SrcSpanAnnA (StmtLR GhcPs GhcPs (LHsExpr GhcPs)) -> CircuitM () handleStmtM (L loc stmt) = case stmt of +#if __GLASGOW_HASKELL__ >= 902 + LetStmt _xlet letBind -> +#else LetStmt _xlet (L _ letBind) -> +#endif -- a regular let bindings case letBind of HsValBinds _ (ValBinds _ valBinds sigs) -> do circuitLets <>= bagToList valBinds circuitTypes <>= sigs - _ -> errM loc ("Unhandled let statement" <> show (Data.toConstr letBind)) + _ -> errM (locA loc) ("Unhandled let statement" <> show (Data.toConstr letBind)) BodyStmt _xbody body _idr _idr' -> bodyBinding Nothing body #if __GLASGOW_HASKELL__ >= 900 @@ -518,7 +656,7 @@ handleStmtM (L loc stmt) = case stmt of BindStmt _xbody bind body _idr _idr' -> #endif bodyBinding (Just $ bindSlave bind) body - _ -> errM loc "Unhandled stmt" + _ -> errM (locA loc) "Unhandled stmt" -- | Turn patterns to the left of a @<-@ into a PortDescription. bindSlave :: p ~ GhcPs => LPat p -> PortDescription PortName @@ -526,7 +664,9 @@ bindSlave (L loc expr) = case expr of VarPat _ (L _ rdrName) -> Ref (PortName loc (fromRdrName rdrName)) TuplePat _ lpat _ -> Tuple $ fmap bindSlave lpat ParPat _ lpat -> bindSlave lpat -#if __GLASGOW_HASKELL__ >= 900 +#if __GLASGOW_HASKELL__ >= 902 + ConPat _ (L _ (GHC.Unqual occ)) (PrefixCon [] [lpat]) +#elif __GLASGOW_HASKELL__ >= 900 ConPat _ (L _ (GHC.Unqual occ)) (PrefixCon [lpat]) #else ConPatIn (L _ (GHC.Unqual occ)) (PrefixCon [lpat]) @@ -554,23 +694,32 @@ bindSlave (L loc expr) = case expr of (Err.mkLocMessageAnn Nothing Err.SevFatal - loc + (locA loc) (Outputable.text $ "Unhandled pattern " <> show (Data.toConstr pat)) ) -- | Turn expressions to the right of a @-<@ into a PortDescription. -bindMaster :: p ~ GhcPs => LHsExpr p -> PortDescription PortName +bindMaster :: LHsExpr GhcPs -> PortDescription PortName bindMaster (L loc expr) = case expr of - HsVar _xvar (L vloc rdrName) + HsVar _xvar (L _vloc rdrName) | rdrName == thName '() -> Tuple [] - | rdrName == thName '[] -> Vec vloc [] - | otherwise -> Ref (PortName vloc (fromRdrName rdrName)) + | rdrName == thName '[] -> Vec loc [] -- XXX: vloc? + | otherwise -> Ref (PortName loc (fromRdrName rdrName)) -- XXX: vloc? HsApp _xapp (L _ (HsVar _ (L _ (GHC.Unqual occ)))) sig | OccName.occNameString occ == "Signal" -> SignalExpr sig ExplicitTuple _ tups _ -> let +#if __GLASGOW_HASKELL__ >= 902 + vals = fmap (\(Present _ e) -> e) tups +#else vals = fmap (\(L _ (Present _ e)) -> e) tups +#endif in Tuple $ fmap bindMaster vals - ExplicitList _ _syntaxExpr exprs -> Vec loc $ fmap bindMaster exprs +#if __GLASGOW_HASKELL__ >= 902 + ExplicitList _ exprs -> +#else + ExplicitList _ _syntaxExpr exprs -> +#endif + Vec loc $ fmap bindMaster exprs #if __GLASGOW_HASKELL__ < 810 HsArrApp _xapp (L _ (HsVar _ (L _ (GHC.Unqual occ)))) sig _ _ | OccName.occNameString occ == "Signal" -> SignalExpr sig @@ -589,16 +738,15 @@ bindMaster (L loc expr) = case expr of (Err.mkLocMessageAnn Nothing Err.SevFatal - loc + (locA loc) (Outputable.text $ "Unhandled expression " <> show (Data.toConstr expr)) ) -- | Create a binding expression bodyBinding - :: (p ~ GhcPs, loc ~ SrcSpan) - => Maybe (PortDescription PortName) + :: Maybe (PortDescription PortName) -- ^ the bound variable, this can be Nothing if there is no @<-@ (a circuit with no slaves) - -> GenLocated loc (HsExpr p) + -> GenLocated SrcSpanAnnA (HsExpr GhcPs) -- ^ the statement with an optional @-<@ -> CircuitM () bodyBinding mInput lexpr@(L loc expr) = do @@ -620,7 +768,7 @@ bodyBinding mInput lexpr@(L loc expr) = do #endif _ -> case mInput of - Nothing -> errM loc "standalone expressions are not allowed (are Arrows enabled?)" + Nothing -> errM (locA loc) "standalone expressions are not allowed (are Arrows enabled?)" Just input -> circuitBinds <>= [Binding { bCircuit = lexpr , bOut = Tuple [] @@ -638,24 +786,24 @@ checkCircuit = do binds <- L.use circuitBinds let portNames d = L.toListOf (L.cosmos . _Ref . L.to (f d)) - f :: Dir -> PortName -> (GHC.FastString, ([SrcSpan], [SrcSpan])) + f :: Dir -> PortName -> (GHC.FastString, ([SrcSpanAnnA], [SrcSpanAnnA])) f Slave (PortName srcLoc portName) = (portName, ([srcLoc], [])) f Master (PortName srcLoc portName) = (portName, ([], [srcLoc])) bindingNames = \b -> portNames Master (bOut b) <> portNames Slave (bIn b) topNames = portNames Slave slaves <> portNames Master masters - nameMap = Map.fromListWith mappend $ topNames <> concatMap bindingNames binds + nameMap = listToUniqMap_C mappend $ topNames <> concatMap bindingNames binds - duplicateMasters <- concat <$> L.iforM nameMap \name occ -> + duplicateMasters <- concat <$> forM (nonDetUniqMapToList nameMap) \(name, occ) -> case occ of ([_], [_]) -> pure [] (ss, ms) -> do unless (head (unpackFS name) == '_') $ do - when (null ms) $ errM (head ss) $ "Slave port " <> show name <> " has no associated master" - when (null ss) $ errM (head ms) $ "Master port " <> show name <> " has no associated slave" + when (null ms) $ errM (locA (head ss)) $ "Slave port " <> show name <> " has no associated master" + when (null ss) $ errM (locA (head ms)) $ "Master port " <> show name <> " has no associated slave" -- would be nice to show locations of all occurrences here, not sure how to do that while -- keeping ghc api when (length ss > 1) $ - errM (head ss) $ "Slave port " <> show name <> " defined " <> show (length ss) <> " times" + errM (locA (head ss)) $ "Slave port " <> show name <> " defined " <> show (length ss) <> " times" -- if master is defined multiple times, we broadcast it if length ms > 1 @@ -681,16 +829,21 @@ data Direc = Fwd | Bwd deriving Show bindWithSuffix :: (p ~ GhcPs, ?nms :: ExternalNames) => GHC.DynFlags -> Direc -> PortDescription PortName -> LPat p bindWithSuffix dflags dir = \case - Tuple ps -> tildeP noSrcSpan $ tupP $ fmap (bindWithSuffix dflags dir) ps + Tuple ps -> tildeP noSrcSpanA $ tupP $ fmap (bindWithSuffix dflags dir) ps Vec s ps -> vecP s $ fmap (bindWithSuffix dflags dir) ps Ref (PortName loc fs) -> varP loc (GHC.unpackFS fs <> "_" <> show dir) RefMulticast (PortName loc fs) -> case dir of - Bwd -> L loc (WildPat noExt) + Bwd -> L loc (WildPat noExtField) Fwd -> varP loc (GHC.unpackFS fs <> "_" <> show dir) PortErr loc msgdoc -> unsafePerformIO . throwOneError $ - Err.mkLongErrMsg dflags loc Outputable.alwaysQualify (Outputable.text "Unhandled bind") msgdoc + mkLongErrMsg dflags (locA loc) Outputable.alwaysQualify (Outputable.text "Unhandled bind") msgdoc Lazy loc p -> tildeP loc $ bindWithSuffix dflags dir p +#if __GLASGOW_HASKELL__ >= 902 + -- XXX: propagate location + SignalExpr (L _ _) -> nlWildPat +#else SignalExpr (L l _) -> L l (WildPat noExt) +#endif SignalPat lpat -> lpat PortType _ p -> bindWithSuffix dflags dir p @@ -715,11 +868,11 @@ bindOutputs dflags direc slaves masters = noLoc $ conPatIn (noLoc (fwdBwdCon ?nm expWithSuffix :: (p ~ GhcPs, ?nms :: ExternalNames) => Direc -> PortDescription PortName -> LHsExpr p expWithSuffix dir = \case - Tuple ps -> tupE noSrcSpan $ fmap (expWithSuffix dir) ps + Tuple ps -> tupE noSrcSpanA $ fmap (expWithSuffix dir) ps Vec s ps -> vecE s $ fmap (expWithSuffix dir) ps Ref (PortName loc fs) -> varE loc (var $ GHC.unpackFS fs <> "_" <> show dir) RefMulticast (PortName loc fs) -> case dir of - Bwd -> varE noSrcSpan (trivialBwd ?nms) + Bwd -> varE noSrcSpanA (trivialBwd ?nms) Fwd -> varE loc (var $ GHC.unpackFS fs <> "_" <> show dir) -- laziness only affects the pattern side Lazy _ p -> expWithSuffix dir p @@ -736,7 +889,7 @@ createInputs -> PortDescription PortName -- ^ master ports -> LHsExpr p -createInputs dir slaves masters = noLoc $ OpApp noExt s2m (varE noSrcSpan (fwdBwdCon ?nms)) m2s +createInputs dir slaves masters = noLoc $ OpApp noExt s2m (varE noSrcSpanA (fwdBwdCon ?nms)) m2s where m2s = expWithSuffix (revDirec dir) masters s2m = expWithSuffix dir slaves @@ -745,22 +898,30 @@ decFromBinding :: (p ~ GhcPs, ?nms :: ExternalNames) => GHC.DynFlags -> Int -> B decFromBinding dflags i Binding {..} = do let bindPat = bindOutputs dflags Bwd bIn bOut inputExp = createInputs Fwd bOut bIn - bod = varE noSrcSpan (var $ "run" <> show i) `appE` bCircuit `appE` inputExp + bod = varE noSrcSpanA (var $ "run" <> show i) `appE` bCircuit `appE` inputExp in patBind bindPat bod -patBind :: p ~ GhcPs => LPat p -> LHsExpr p -> HsBind p +patBind :: LPat GhcPs -> LHsExpr GhcPs -> HsBind GhcPs patBind lhs expr = PatBind noExt lhs rhs ([], []) where - rhs = GRHSs noExt [gr] (noLoc $ EmptyLocalBinds noExt) - gr = L (getLoc expr) (GRHS noExt [] expr) + rhs :: GRHSs GhcPs (GenLocated SrcSpanAnnA (HsExpr GhcPs)) + rhs = GRHSs emptyComments [gr] $ +#if __GLASGOW_HASKELL__ >= 902 + EmptyLocalBinds noExtField +#else + noLoc (EmptyLocalBinds noExtField) +#endif + + gr :: GenLocated SrcSpan (GRHS GhcPs (GenLocated SrcSpanAnnA (HsExpr GhcPs))) + gr = L (locA (getLoc expr)) (GRHS noExt [] expr) -circuitConstructor :: (p ~ GhcPs, ?nms :: ExternalNames) => SrcSpan -> LHsExpr p +circuitConstructor :: (?nms :: ExternalNames) => SrcSpanAnnA -> LHsExpr GhcPs circuitConstructor loc = varE loc (circuitCon ?nms) -runCircuitFun :: (p ~ GhcPs, ?nms :: ExternalNames) => SrcSpan -> LHsExpr p +runCircuitFun :: (?nms :: ExternalNames) => SrcSpanAnnA -> LHsExpr GhcPs runCircuitFun loc = varE loc (runCircuitName ?nms) -constVar :: p ~ GhcPs => SrcSpan -> LHsExpr p +constVar :: SrcSpanAnnA -> LHsExpr GhcPs constVar loc = varE loc (thName 'const) deepShowD :: Data.Data a => a -> String @@ -784,17 +945,17 @@ hsFunTy = arrTy :: p ~ GhcPs => LHsType p -> LHsType p -> LHsType p arrTy a b = noLoc $ hsFunTy (parenthesizeHsType GHC.funPrec a) (parenthesizeHsType GHC.funPrec b) -varT :: SrcSpan -> String -> LHsType GhcPs -varT loc nm = L loc (HsTyVar noExt NotPromoted (L loc (tyVar nm))) +varT :: SrcSpanAnnA -> String -> LHsType GhcPs +varT loc nm = L loc (HsTyVar noExt NotPromoted (noLoc (tyVar nm))) -conT :: SrcSpan -> GHC.RdrName -> LHsType GhcPs -conT loc nm = L loc (HsTyVar noExt NotPromoted (L loc nm)) +conT :: SrcSpanAnnA -> GHC.RdrName -> LHsType GhcPs +conT loc nm = L loc (HsTyVar noExt NotPromoted (noLoc nm)) circuitTy :: (p ~ GhcPs, ?nms :: ExternalNames) => LHsType p -> LHsType p -> LHsType p -circuitTy a b = (conT noSrcSpan (circuitTyCon ?nms)) `appTy` a `appTy` b +circuitTy a b = conT noSrcSpanA (circuitTyCon ?nms) `appTy` a `appTy` b circuitTTy :: (p ~ GhcPs, ?nms :: ExternalNames) => LHsType p -> LHsType p -> LHsType p -circuitTTy a b = (conT noSrcSpan (circuitTTyCon ?nms)) `appTy` a `appTy` b +circuitTTy a b = conT noSrcSpanA (circuitTTyCon ?nms) `appTy` a `appTy` b -- a b -> (Circuit a b -> CircuitT a b) mkRunCircuitTy :: (p ~ GhcPs, ?nms :: ExternalNames) => LHsType p -> LHsType p -> LHsType p @@ -812,12 +973,13 @@ gatherTypes gatherTypes = L.traverseOf_ L.cosmos addTypes where addTypes = \case - PortType ty (Ref (PortName loc fs)) -> portVarTypes . L.at fs ?= (loc, ty) + PortType ty (Ref (PortName loc fs)) -> + portVarTypes %= \pvt -> alterUniqMap (const (Just (loc, ty))) pvt fs PortType ty p -> portTypes <>= [(ty, p)] _ -> pure () -tyEq :: p ~ GhcPs => SrcSpan -> LHsType p -> LHsType p -> LHsType p -tyEq l a b = L l $ HsOpTy noExt a (noLoc eqTyCon_RDR) b +tyEq :: LHsType GhcPs -> LHsType GhcPs -> LHsType GhcPs +tyEq a b = noLoc $ HsOpTy noExtField a (noLoc eqTyCon_RDR) b -- eqTyCon is a special name that has to be exactly correct for ghc to recognise it. In 8.6 this -- lives in PrelNames and is called eqTyCon_RDR, in later ghcs it's from TysWiredIn. @@ -845,7 +1007,7 @@ circuitQQExpM = do res = createInputs Bwd slaves masters body :: LHsExpr GhcPs - body = letE noSrcSpan letTypes decs res + body = letE noSrcSpanA letTypes decs res -- see [inference-helper] mapM_ @@ -866,29 +1028,38 @@ circuitQQExpM = do allTypes <- L.use portTypes - context <- mapM (\(ty, p) -> tyEq noSrcSpan <$> (portTypeSigM p) <*> pure ty) allTypes + context <- mapM (\(ty, p) -> tyEq <$> portTypeSigM p <*> pure ty) allTypes -- the full signature loc <- L.use circuitLoc let inferenceHelperName = genLocName loc "inferenceHelper" inferenceSig :: LHsSigType GhcPs +#if __GLASGOW_HASKELL__ >= 902 + inferenceSig = noLoc $ + HsSig + noExtField + (HsOuterImplicit noExtField) + (noLoc $ HsQualTy noExtField (Just (noLoc context)) runCircuitsType) +#else inferenceSig = HsIB noExt (noLoc $ HsQualTy noExt (noLoc context) runCircuitsType) +#endif + inferenceHelperTy = TypeSig noExt [noLoc (var inferenceHelperName)] - (HsWC noExt inferenceSig) + (HsWC noExtField inferenceSig) let numBinds = length binds - runCircuitExprs = lamE [varP noSrcSpan "f"] $ - circuitConstructor noSrcSpan `appE` + runCircuitExprs = lamE [varP noSrcSpanA "f"] $ + circuitConstructor noSrcSpanA `appE` noLoc (HsPar noExt - (varE noSrcSpan (var "f") `appE` tupE noSrcSpan (replicate numBinds (runCircuitFun noSrcSpan)))) - runCircuitBinds = tupP $ map (\i -> varP noSrcSpan ("run" <> show i)) [0 .. numBinds-1] + (varE noSrcSpanA (var "f") `appE` tupE noSrcSpanA (replicate numBinds (runCircuitFun noSrcSpanA)))) + runCircuitBinds = tupP $ map (\i -> varP noSrcSpanA ("run" <> show i)) [0 .. numBinds-1] - let c = letE noSrcSpan + let c = letE noSrcSpanA [noLoc inferenceHelperTy] - [noLoc $ patBind (varP noSrcSpan inferenceHelperName) (runCircuitExprs)] - (varE noSrcSpan (var inferenceHelperName) `appE` lamE [runCircuitBinds, pats] body) + [noLoc $ patBind (varP noSrcSpanA inferenceHelperName) (runCircuitExprs)] + (varE noSrcSpanA (var inferenceHelperName) `appE` lamE [runCircuitBinds, pats] body) -- ppr c pure c @@ -1007,6 +1178,16 @@ mkPlugin nms = GHC.defaultPlugin , GHC.pluginRecompile = \_cliOptions -> pure GHC.NoForceRecompile } +warningMsg :: Outputable.SDoc -> GHC.Hsc () +warningMsg sdoc = do + dflags <- GHC.getDynFlags +#if __GLASGOW_HASKELL__ >= 902 + logger <- GHC.getLogger + liftIO $ Err.warningMsg logger dflags sdoc +#else + liftIO $ Err.warningMsg dflags sdoc +#endif + -- | The actual implementation. pluginImpl :: (?nms :: ExternalNames) => [GHC.CommandLineOption] -> GHC.ModSummary -> GHC.HsParsedModule -> GHC.Hsc GHC.HsParsedModule pluginImpl cliOptions _modSummary m = do @@ -1014,10 +1195,9 @@ pluginImpl cliOptions _modSummary m = do debug <- case cliOptions of [] -> pure False ["debug"] -> pure True - _ -> do dflags <- GHC.getDynFlags - liftIO $ Err.warningMsg dflags $ Outputable.text $ - "CircuitNotation: unknown cli options " <> show cliOptions - pure False + _ -> do + warningMsg $ Outputable.text $ "CircuitNotation: unknown cli options " <> show cliOptions + pure False hpm_module' <- do transform debug (GHC.hpm_module m) let module' = m { GHC.hpm_module = hpm_module' } diff --git a/src/GHC/Types/Unique/Map.hs b/src/GHC/Types/Unique/Map.hs new file mode 100644 index 0000000..9bf20cd --- /dev/null +++ b/src/GHC/Types/Unique/Map.hs @@ -0,0 +1,213 @@ +{-# LANGUAGE RoleAnnotations #-} +{-# LANGUAGE TupleSections #-} +{-# LANGUAGE DeriveDataTypeable #-} +{-# LANGUAGE DeriveFunctor #-} +{-# LANGUAGE CPP #-} +{-# OPTIONS_GHC -Wall #-} + +-- Like 'UniqFM', these are maps for keys which are Uniquable. +-- Unlike 'UniqFM', these maps also remember their keys, which +-- makes them a much better drop in replacement for 'Data.Map.Map'. +-- +-- Key preservation is right-biased. +module GHC.Types.Unique.Map ( + UniqMap(..), + emptyUniqMap, + isNullUniqMap, + unitUniqMap, + listToUniqMap, + listToUniqMap_C, + addToUniqMap, + addListToUniqMap, + addToUniqMap_C, + addToUniqMap_Acc, + alterUniqMap, + addListToUniqMap_C, + adjustUniqMap, + delFromUniqMap, + delListFromUniqMap, + plusUniqMap, + plusUniqMap_C, + plusMaybeUniqMap_C, + plusUniqMapList, + minusUniqMap, + intersectUniqMap, + disjointUniqMap, + mapUniqMap, + filterUniqMap, + partitionUniqMap, + sizeUniqMap, + elemUniqMap, + lookupUniqMap, + lookupWithDefaultUniqMap, + anyUniqMap, + allUniqMap +) where + +#if __GLASGOW_HASKELL__ < 900 +import Unique +import UniqFM +import Outputable +#else +import GHC.Types.Unique.FM +import GHC.Types.Unique +import GHC.Utils.Outputable +#endif + +import Data.Semigroup as Semi ( Semigroup(..) ) +import Data.Coerce +import Data.Maybe +import Data.Data + +-- | Maps indexed by 'Uniquable' keys +#if __GLASGOW_HASKELL__ < 900 +newtype UniqMap k a = UniqMap (UniqFM (k, a)) +#else +newtype UniqMap k a = UniqMap (UniqFM k (k, a)) +#endif + deriving (Data, Eq, Functor) +type role UniqMap nominal representational + +instance Semigroup (UniqMap k a) where + (<>) = plusUniqMap + +instance Monoid (UniqMap k a) where + mempty = emptyUniqMap + mappend = (Semi.<>) + +instance (Outputable k, Outputable a) => Outputable (UniqMap k a) where + ppr (UniqMap m) = + brackets $ fsep $ punctuate comma $ + [ ppr k <+> text "->" <+> ppr v + | (k, v) <- eltsUFM m ] + +liftC :: (a -> a -> a) -> (k, a) -> (k, a) -> (k, a) +liftC f (_, v) (k', v') = (k', f v v') + +emptyUniqMap :: UniqMap k a +emptyUniqMap = UniqMap emptyUFM + +isNullUniqMap :: UniqMap k a -> Bool +isNullUniqMap (UniqMap m) = isNullUFM m + +unitUniqMap :: Uniquable k => k -> a -> UniqMap k a +unitUniqMap k v = UniqMap (unitUFM k (k, v)) + +listToUniqMap :: Uniquable k => [(k,a)] -> UniqMap k a +listToUniqMap kvs = UniqMap (listToUFM [ (k,(k,v)) | (k,v) <- kvs]) + +listToUniqMap_C :: Uniquable k => (a -> a -> a) -> [(k,a)] -> UniqMap k a +listToUniqMap_C f kvs = UniqMap $ + listToUFM_C (liftC f) [ (k,(k,v)) | (k,v) <- kvs] + +addToUniqMap :: Uniquable k => UniqMap k a -> k -> a -> UniqMap k a +addToUniqMap (UniqMap m) k v = UniqMap $ addToUFM m k (k, v) + +addListToUniqMap :: Uniquable k => UniqMap k a -> [(k,a)] -> UniqMap k a +addListToUniqMap (UniqMap m) kvs = UniqMap $ + addListToUFM m [(k,(k,v)) | (k,v) <- kvs] + +addToUniqMap_C :: Uniquable k + => (a -> a -> a) + -> UniqMap k a + -> k + -> a + -> UniqMap k a +addToUniqMap_C f (UniqMap m) k v = UniqMap $ + addToUFM_C (liftC f) m k (k, v) + +addToUniqMap_Acc :: Uniquable k + => (b -> a -> a) + -> (b -> a) + -> UniqMap k a + -> k + -> b + -> UniqMap k a +addToUniqMap_Acc exi new (UniqMap m) k0 v0 = UniqMap $ + addToUFM_Acc (\b (k, v) -> (k, exi b v)) + (\b -> (k0, new b)) + m k0 v0 + +alterUniqMap :: Uniquable k + => (Maybe a -> Maybe a) + -> UniqMap k a + -> k + -> UniqMap k a +alterUniqMap f (UniqMap m) k = UniqMap $ + alterUFM (fmap (k,) . f . fmap snd) m k + +addListToUniqMap_C + :: Uniquable k + => (a -> a -> a) + -> UniqMap k a + -> [(k, a)] + -> UniqMap k a +addListToUniqMap_C f (UniqMap m) kvs = UniqMap $ + addListToUFM_C (liftC f) m + [(k,(k,v)) | (k,v) <- kvs] + +adjustUniqMap + :: Uniquable k + => (a -> a) + -> UniqMap k a + -> k + -> UniqMap k a +adjustUniqMap f (UniqMap m) k = UniqMap $ + adjustUFM (\(_,v) -> (k,f v)) m k + +delFromUniqMap :: Uniquable k => UniqMap k a -> k -> UniqMap k a +delFromUniqMap (UniqMap m) k = UniqMap $ delFromUFM m k + +delListFromUniqMap :: Uniquable k => UniqMap k a -> [k] -> UniqMap k a +delListFromUniqMap (UniqMap m) ks = UniqMap $ delListFromUFM m ks + +plusUniqMap :: UniqMap k a -> UniqMap k a -> UniqMap k a +plusUniqMap (UniqMap m1) (UniqMap m2) = UniqMap $ plusUFM m1 m2 + +plusUniqMap_C :: (a -> a -> a) -> UniqMap k a -> UniqMap k a -> UniqMap k a +plusUniqMap_C f (UniqMap m1) (UniqMap m2) = UniqMap $ + plusUFM_C (liftC f) m1 m2 + +plusMaybeUniqMap_C :: (a -> a -> Maybe a) -> UniqMap k a -> UniqMap k a -> UniqMap k a +plusMaybeUniqMap_C f (UniqMap m1) (UniqMap m2) = UniqMap $ + plusMaybeUFM_C (\(_, v) (k', v') -> fmap (k',) (f v v')) m1 m2 + +plusUniqMapList :: [UniqMap k a] -> UniqMap k a +plusUniqMapList xs = UniqMap $ plusUFMList (coerce xs) + +minusUniqMap :: UniqMap k a -> UniqMap k b -> UniqMap k a +minusUniqMap (UniqMap m1) (UniqMap m2) = UniqMap $ minusUFM m1 m2 + +intersectUniqMap :: UniqMap k a -> UniqMap k b -> UniqMap k a +intersectUniqMap (UniqMap m1) (UniqMap m2) = UniqMap $ intersectUFM m1 m2 + +disjointUniqMap :: UniqMap k a -> UniqMap k b -> Bool +disjointUniqMap (UniqMap m1) (UniqMap m2) = disjointUFM m1 m2 + +mapUniqMap :: (a -> b) -> UniqMap k a -> UniqMap k b +mapUniqMap f (UniqMap m) = UniqMap $ mapUFM (fmap f) m -- (,) k instance + +filterUniqMap :: (a -> Bool) -> UniqMap k a -> UniqMap k a +filterUniqMap f (UniqMap m) = UniqMap $ filterUFM (f . snd) m + +partitionUniqMap :: (a -> Bool) -> UniqMap k a -> (UniqMap k a, UniqMap k a) +partitionUniqMap f (UniqMap m) = + coerce $ partitionUFM (f . snd) m + +sizeUniqMap :: UniqMap k a -> Int +sizeUniqMap (UniqMap m) = sizeUFM m + +elemUniqMap :: Uniquable k => k -> UniqMap k a -> Bool +elemUniqMap k (UniqMap m) = elemUFM k m + +lookupUniqMap :: Uniquable k => UniqMap k a -> k -> Maybe a +lookupUniqMap (UniqMap m) k = fmap snd (lookupUFM m k) + +lookupWithDefaultUniqMap :: Uniquable k => UniqMap k a -> a -> k -> a +lookupWithDefaultUniqMap (UniqMap m) a k = fromMaybe a (fmap snd (lookupUFM m k)) + +anyUniqMap :: (a -> Bool) -> UniqMap k a -> Bool +anyUniqMap f (UniqMap m) = anyUFM (f . snd) m + +allUniqMap :: (a -> Bool) -> UniqMap k a -> Bool +allUniqMap f (UniqMap m) = allUFM (f . snd) m diff --git a/src/GHC/Types/Unique/Map/Extra.hs b/src/GHC/Types/Unique/Map/Extra.hs new file mode 100644 index 0000000..c3d0ca7 --- /dev/null +++ b/src/GHC/Types/Unique/Map/Extra.hs @@ -0,0 +1,19 @@ +{-# LANGUAGE CPP #-} +{-# LANGUAGE PackageImports #-} + +module GHC.Types.Unique.Map.Extra where + +#if __GLASGOW_HASKELL__ >= 902 +import "ghc" GHC.Types.Unique.Map +#else +import GHC.Types.Unique.Map +#endif + +#if __GLASGOW_HASKELL__ >= 900 +import GHC.Types.Unique.FM (nonDetEltsUFM) +#elif __GLASGOW_HASKELL__ <= 810 +import UniqFM (nonDetEltsUFM) +#endif + +nonDetUniqMapToList :: UniqMap key a -> [(key, a)] +nonDetUniqMapToList (UniqMap u) = nonDetEltsUFM u