From 7aca688a31a2768951fbea55a2f189f5519cd2df Mon Sep 17 00:00:00 2001 From: Alexander Wood Date: Thu, 17 Aug 2023 22:39:23 +0100 Subject: [PATCH] make the arbitrary expr tests compile and run without errors --- src/Elara/AST/Generic.hs | 133 ++++++++++++++++++++++++++++++-- src/Elara/AST/StripLocation.hs | 17 ++-- src/Elara/AST/Unlocated.hs | 2 +- src/Elara/Parse/Expression.hs | 2 +- src/Elara/ToCore.hs | 6 +- src/Elara/TypeInfer/Infer.hs | 5 ++ src/Elara/TypeInfer/Monotype.hs | 9 +++ test/Infer.hs | 26 ++++++- test/Parse.hs | 2 +- 9 files changed, 180 insertions(+), 22 deletions(-) diff --git a/src/Elara/AST/Generic.hs b/src/Elara/AST/Generic.hs index be0c0689..fabbe823 100644 --- a/src/Elara/AST/Generic.hs +++ b/src/Elara/AST/Generic.hs @@ -27,6 +27,7 @@ import Elara.AST.Select (LocatedAST, UnlocatedAST) import Elara.AST.StripLocation (StripLocation (..)) import Elara.Data.Pretty import GHC.TypeLits +import Relude.Extra (bimapF) import TODO (todo) import Prelude hiding (group) @@ -296,6 +297,24 @@ instance ToMaybe (Maybe a) (Maybe a) where instance {-# INCOHERENT #-} ToMaybe a (Maybe a) where toMaybe = Just +-- Sometimes fields are wrapped in functors eg lists, we need a way of transcending them. +-- This class does that. +-- For example, let's say we have `cleanPattern :: Pattern ast1 -> Pattern ast2`, and `x :: Select ast1 "Pattern"`. +-- `x` could be `Pattern ast1`, `[Pattern ast1]`, `Maybe (Pattern ast1)`, or something else entirely. +-- `cleanPattern` will only work on the first of these, so we need a way of lifting it to the others. Obviously, this sounds like a Functor +-- but the problem is that `Pattern ast1` has the wrong kind. +class ApplyAsFunctorish i o a b where + applyAsFunctorish :: (a -> b) -> i -> o + +instance Functor f => ApplyAsFunctorish (f a) (f b) a b where + applyAsFunctorish = fmap + +instance ApplyAsFunctorish a b a b where + applyAsFunctorish f = f + +instance ApplyAsFunctorish NoFieldValue NoFieldValue a b where + applyAsFunctorish _ = identity + -- | Unwraps 1 level of 'Maybe' from a type. Useful when a type family returns Maybe type family UnwrapMaybe (a :: Kind.Type) = (k :: Kind.Type) where UnwrapMaybe (Maybe a) = a @@ -372,7 +391,12 @@ instance pretty (ConstructorPattern c ps) = prettyConstructorPattern c ps pretty (ListPattern l) = prettyListPattern l pretty (ConsPattern p1 p2) = prettyConsPattern p1 p2 - pretty other = error "aaaaaaa" + pretty WildcardPattern = "_" + pretty (IntegerPattern i) = pretty i + pretty (FloatPattern f) = pretty f + pretty (StringPattern s) = pretty '\"' <> pretty s <> pretty '\"' + pretty (CharPattern c) = "'" <> escapeChar c <> "'" + pretty UnitPattern = "()" instance ( Pretty (ASTLocate ast (Type' ast)) @@ -396,21 +420,118 @@ instance stripExprLocation :: forall (ast1 :: LocatedAST) (ast2 :: UnlocatedAST). - ((ASTLocate' ast1 ~ Located), ASTLocate' ast2 ~ Unlocated) => + ( ASTLocate' ast1 ~ Located + , ASTLocate' ast2 ~ Unlocated + , ApplyAsFunctorish (Select "LambdaPattern" ast1) (Select "LambdaPattern" ast2) (Pattern ast1) (Pattern ast2) + , _ + ) => Expr ast1 -> Expr ast2 stripExprLocation (Expr (e :: ASTLocate ast1 (Expr' ast1), t)) = let e' = fmapUnlocated @LocatedAST @ast1 stripExprLocation' e - in -- t' = fmapUnlocated @LocatedAST @ast1 _ t :: ASTLocate ast1 (Select "ExprType" ast2) - Expr (stripLocation e', todo t) + in Expr (stripLocation e', fmap stripTypeLocation t) where - stripExprLocation' :: forall (ast1 :: LocatedAST) (ast2 :: UnlocatedAST). Expr' ast1 -> Expr' ast2 + stripExprLocation' :: Expr' ast1 -> Expr' ast2 stripExprLocation' (Int i) = Int i stripExprLocation' (Float f) = Float f stripExprLocation' (String s) = String s stripExprLocation' (Char c) = Char c stripExprLocation' Unit = Unit - stripExprLocation' (Var v) = Var (todo) + stripExprLocation' (Var v) = Var (stripLocation v) + stripExprLocation' (Constructor c) = Constructor (stripLocation c) + stripExprLocation' (Lambda ps e) = + let ps' = stripLocation ps + ps'' = + applyAsFunctorish @(Select "LambdaPattern" ast1) @(Select "LambdaPattern" ast2) @(Pattern ast1) @(Pattern ast2) + stripPatternLocation + ps' + in Lambda ps'' (stripExprLocation e) + stripExprLocation' (FunctionCall e1 e2) = FunctionCall (stripExprLocation e1) (stripExprLocation e2) + stripExprLocation' (If e1 e2 e3) = If (stripExprLocation e1) (stripExprLocation e2) (stripExprLocation e3) + stripExprLocation' (BinaryOperator op e1 e2) = BinaryOperator (stripBinaryOperatorLocation op) (stripExprLocation e1) (stripExprLocation e2) + stripExprLocation' (List l) = List (stripExprLocation <$> l) + stripExprLocation' (Match e m) = Match (stripExprLocation e) (bimapF stripPatternLocation stripExprLocation m) + stripExprLocation' (LetIn v p e1 e2) = + let p' = stripLocation p + p'' = + applyAsFunctorish @(Select "LetPattern" ast1) @(Select "LetPattern" ast2) @(Pattern ast1) @(Pattern ast2) + stripPatternLocation + p' + in LetIn + (stripLocation v) + p'' + (stripExprLocation e1) + (stripExprLocation e2) + stripExprLocation' (Let v p e) = + let p' = stripLocation p + p'' = + applyAsFunctorish @(Select "LetPattern" ast1) @(Select "LetPattern" ast2) @(Pattern ast1) @(Pattern ast2) + stripPatternLocation + p' + in Let (stripLocation v) p'' (stripExprLocation e) + stripExprLocation' (Block b) = Block (stripExprLocation <$> b) + stripExprLocation' (InParens e) = InParens (stripExprLocation e) + stripExprLocation' (Tuple t) = Tuple (stripExprLocation <$> t) + +stripPatternLocation :: + forall (ast1 :: LocatedAST) (ast2 :: UnlocatedAST). + ( (ASTLocate' ast1 ~ Located) + , ASTLocate' ast2 ~ Unlocated + , _ + ) => + Pattern ast1 -> + Pattern ast2 +stripPatternLocation (Pattern (p :: ASTLocate ast1 (Pattern' ast1), t)) = + let p' = fmapUnlocated @LocatedAST @ast1 stripPatternLocation' p + in Pattern (stripLocation p', fmap stripTypeLocation t) + where + stripPatternLocation' :: Pattern' ast1 -> Pattern' ast2 + stripPatternLocation' (VarPattern v) = VarPattern (stripLocation v) + stripPatternLocation' (ConstructorPattern c ps) = ConstructorPattern (stripLocation c) (stripPatternLocation <$> ps) + stripPatternLocation' (ListPattern l) = ListPattern (stripPatternLocation <$> l) + stripPatternLocation' (ConsPattern p1 p2) = ConsPattern (stripPatternLocation p1) (stripPatternLocation p2) + stripPatternLocation' WildcardPattern = WildcardPattern + stripPatternLocation' (IntegerPattern i) = IntegerPattern i + stripPatternLocation' (FloatPattern f) = FloatPattern f + stripPatternLocation' (StringPattern s) = StringPattern s + stripPatternLocation' (CharPattern c) = CharPattern c + stripPatternLocation' UnitPattern = UnitPattern + +stripBinaryOperatorLocation :: + forall (ast1 :: LocatedAST) (ast2 :: UnlocatedAST). + ( (ASTLocate' ast1 ~ Located) + , ASTLocate' ast2 ~ Unlocated + , _ + ) => + BinaryOperator ast1 -> + BinaryOperator ast2 +stripBinaryOperatorLocation (MkBinaryOperator (op :: ASTLocate ast1 (BinaryOperator' ast1))) = + let op' = fmapUnlocated @LocatedAST @ast1 stripBinaryOperatorLocation' op + in MkBinaryOperator (stripLocation op') + where + stripBinaryOperatorLocation' :: BinaryOperator' ast1 -> BinaryOperator' ast2 + stripBinaryOperatorLocation' (SymOp name) = SymOp (stripLocation name) + stripBinaryOperatorLocation' (Infixed name) = Infixed (stripLocation name) + +stripTypeLocation :: + forall (ast1 :: LocatedAST) (ast2 :: UnlocatedAST). + ( (ASTLocate' ast1 ~ Located) + , ASTLocate' ast2 ~ Unlocated + , _ + ) => + Type ast1 -> + Type ast2 +stripTypeLocation (Type (t :: ASTLocate ast1 (Type' ast1))) = + let t' = fmapUnlocated @LocatedAST @ast1 stripTypeLocation' t + in Type (stripLocation t') + where + stripTypeLocation' :: Type' ast1 -> Type' ast2 + stripTypeLocation' (TypeVar name) = TypeVar (rUnlocate @LocatedAST @ast1 name) + stripTypeLocation' (FunctionType a b) = FunctionType (stripTypeLocation a) (stripTypeLocation b) + stripTypeLocation' UnitType = UnitType + stripTypeLocation' (TypeConstructorApplication a b) = TypeConstructorApplication (stripTypeLocation a) (stripTypeLocation b) + +-- stripTypeLocation' (UserDefinedType name) = UserDefinedType (_ name) {- ===================== Messy deriving stuff diff --git a/src/Elara/AST/StripLocation.hs b/src/Elara/AST/StripLocation.hs index 13ed620e..49e5e555 100644 --- a/src/Elara/AST/StripLocation.hs +++ b/src/Elara/AST/StripLocation.hs @@ -1,14 +1,16 @@ {-# LANGUAGE FlexibleInstances #-} -{-# LANGUAGE FunctionalDependencies #-} {-# LANGUAGE UndecidableInstances #-} module Elara.AST.StripLocation where import Elara.AST.Region (Located (Located), SourceRegion) -class StripLocation a b | a -> b where +class StripLocation a b where stripLocation :: a -> b +instance {-# OVERLAPPABLE #-} (a ~ b) => StripLocation a b where + stripLocation = identity + instance StripLocation (Located a) a where stripLocation :: Located a -> a stripLocation (Located _ a) = a @@ -16,19 +18,16 @@ instance StripLocation (Located a) a where instance {-# INCOHERENT #-} StripLocation SourceRegion () where stripLocation _ = () -instance {-# OVERLAPPABLE #-} (StripLocation a1 b1, StripLocation b1 b2) => StripLocation a1 b2 where - stripLocation b = stripLocation (stripLocation b) - -- We could provide a general Functor instance but the overlapping tends to cause problems -instance (StripLocation a a', StripLocation b b') => StripLocation (a, b) (a', b') where +instance {-# OVERLAPPABLE #-} (StripLocation a a', StripLocation b b') => StripLocation (a, b) (a', b') where stripLocation (a, b) = (stripLocation a, stripLocation b) -instance (StripLocation a a') => StripLocation (Maybe a) (Maybe a') where +instance {-# OVERLAPPABLE #-} (StripLocation a a') => StripLocation (Maybe a) (Maybe a') where stripLocation = fmap stripLocation -instance (StripLocation a a') => StripLocation [a] [a'] where +instance {-# OVERLAPPABLE #-} (StripLocation a a') => StripLocation [a] [a'] where stripLocation = fmap stripLocation -instance (StripLocation a a') => StripLocation (NonEmpty a) (NonEmpty a') where +instance {-# OVERLAPPABLE #-} (StripLocation a a') => StripLocation (NonEmpty a) (NonEmpty a') where stripLocation = fmap stripLocation diff --git a/src/Elara/AST/Unlocated.hs b/src/Elara/AST/Unlocated.hs index 1967c27b..64994699 100644 --- a/src/Elara/AST/Unlocated.hs +++ b/src/Elara/AST/Unlocated.hs @@ -18,5 +18,5 @@ type family Replace (needle :: LocatedAST) (replacement :: UnlocatedAST) (haysta Replace needle replacement (DeclarationBody needle) = DeclarationBody replacement Replace needle replacement (DeclarationBody' needle) = DeclarationBody' replacement Replace needle replacement [list] = [Replace needle replacement list] - -- Replace needle replacement (Select f needle) = Select f replacement + Replace needle replacement (Maybe maybe) = Maybe (Replace needle replacement maybe) Replace needle replacement other = other diff --git a/src/Elara/Parse/Expression.hs b/src/Elara/Parse/Expression.hs index 5005fdb4..3ae1f402 100644 --- a/src/Elara/Parse/Expression.hs +++ b/src/Elara/Parse/Expression.hs @@ -21,7 +21,7 @@ import Text.Megaparsec (MonadParsec (eof), customFailure, sepEndBy) import Prelude hiding (Op) locatedExpr :: HParser FrontendExpr' -> HParser FrontendExpr -locatedExpr = fmap (\x -> Expr (x, _)) . (H.parse . located . H.toParsec) +locatedExpr = fmap (\x -> Expr (x, Nothing)) . (H.parse . located . H.toParsec) exprParser :: HParser FrontendExpr exprParser = diff --git a/src/Elara/ToCore.hs b/src/Elara/ToCore.hs index 0c09f4f5..93310b49 100644 --- a/src/Elara/ToCore.hs +++ b/src/Elara/ToCore.hs @@ -8,12 +8,12 @@ import Data.Map qualified as M import Data.Traversable (for) import Elara.AST.Generic as AST import Elara.AST.Module (Module (Module)) -import Elara.AST.Name (NameLike (..), Qualified (..), TypeName) +import Elara.AST.Name (NameLike (..), Qualified (..), TypeName, VarName) import Elara.AST.Region (Located (Located), SourceRegion, unlocated) import Elara.AST.Select (LocatedAST (Typed)) import Elara.AST.StripLocation import Elara.AST.Typed -import Elara.AST.VarRef (UnlocatedVarRef, VarRef' (Global, Local), varRefVal) +import Elara.AST.VarRef (UnlocatedVarRef, VarRef, VarRef' (Global, Local), varRefVal) import Elara.Core as Core import Elara.Core.Module (CoreDeclaration (..), CoreModule (..)) import Elara.Data.Pretty (Pretty (..)) @@ -152,7 +152,7 @@ toCore le@(Expr (Located _ e, t)) = toCore' e AST.Var (Located _ v) -> do t' <- typeToCore t - pure $ Core.Var (Core.Id (nameText <$> stripLocation v) t') + pure $ Core.Var (Core.Id (nameText <$> stripLocation @(VarRef VarName) @(UnlocatedVarRef VarName) v) t') AST.Constructor v -> do ctor <- lookupCtor v pure $ Core.Var (conToVar ctor) diff --git a/src/Elara/TypeInfer/Infer.hs b/src/Elara/TypeInfer/Infer.hs index daebb806..3bdfec9c 100644 --- a/src/Elara/TypeInfer/Infer.hs +++ b/src/Elara/TypeInfer/Infer.hs @@ -51,6 +51,7 @@ import Polysemy.Error (Error, throw) import Polysemy.State (State) import Polysemy.State qualified as State import Prettyprinter qualified as Pretty +import Print (showPretty) -- | Type-checking state data Status = Status @@ -1251,6 +1252,9 @@ infer (Syntax.Expr (Located location e0, _)) = case e0 of Syntax.Unit -> do let t = (Type.Scalar{scalar = Monotype.Unit, ..}) pure $ Expr (Located location Unit, t) + Syntax.Char c -> do + let t = (Type.Scalar{scalar = Monotype.Char, ..}) + pure $ Expr (Located location (Char c), t) Syntax.LetIn name NoFieldValue val body -> do -- TODO: infer whether a let-in is recursive or not -- insert a new unsolved type variable for the let-in to make recursive let-ins possible @@ -1281,6 +1285,7 @@ infer (Syntax.Expr (Located location e0, _)) = case e0 of let t = (Type.Tuple{tupleArguments = Syntax.typeOf <$> elementTypes, ..}) pure $ Expr (Located location (Syntax.Tuple elementTypes), t) + other -> error $ "infer: " <> showPretty other -- -- Anno -- Syntax.Annotation{..} -> do diff --git a/src/Elara/TypeInfer/Monotype.hs b/src/Elara/TypeInfer/Monotype.hs index 33ba2c4d..67ad846d 100644 --- a/src/Elara/TypeInfer/Monotype.hs +++ b/src/Elara/TypeInfer/Monotype.hs @@ -80,7 +80,15 @@ data Scalar -- >>> pretty Text -- Text Text + | -- | Char type + -- + -- >>> pretty Char + -- Char + Char | -- | Unit type + -- + -- >>> pretty Unit + -- Unit Unit deriving stock (Eq, Generic, Show, Data) @@ -92,6 +100,7 @@ instance Pretty Scalar where pretty Integer = "Integer" pretty Text = "Text" pretty Unit = "Unit" + pretty Char = "Char" instance ToJSON Scalar diff --git a/test/Infer.hs b/test/Infer.hs index f5391d4d..9d3fa117 100644 --- a/test/Infer.hs +++ b/test/Infer.hs @@ -8,7 +8,7 @@ import Test.Hspec import Prelude hiding (fail) spec :: Spec -spec = describe "Infers types correctly" $ do +spec = describe "Infers types correctly" $ parallel $ do simpleTypes functionTypes @@ -20,6 +20,30 @@ simpleTypes = describe "Infers simple types correctly" $ do Scalar () Scalar.Integer -> pass o -> fail o + it "Infers Unit literals correctly" $ do + (t, fail) <- inferSpec "()" "()" + case t of + Scalar () Scalar.Unit -> pass + o -> fail o + + it "Infers Real literals correctly" $ do + (t, fail) <- inferSpec "1.0" "Real" + case t of + Scalar () Scalar.Real -> pass + o -> fail o + + it "Infers Text literals correctly" $ do + (t, fail) <- inferSpec "\"hello\"" "Text" + case t of + Scalar () Scalar.Text -> pass + o -> fail o + + it "Infers Char literals correctly" $ do + (t, fail) <- inferSpec "'c'" "Text" + case t of + Scalar () Scalar.Char -> pass + o -> fail o + functionTypes :: Spec functionTypes = describe "Infers function types correctly" $ do it "Infers identity function correctly" $ do diff --git a/test/Parse.hs b/test/Parse.hs index d451eef0..2c42795f 100644 --- a/test/Parse.hs +++ b/test/Parse.hs @@ -17,7 +17,7 @@ import Test.QuickCheck spec :: Spec spec = do - -- quickCheckSpec + quickCheckSpec pass quickCheckSpec :: Spec