Skip to content

Commit

Permalink
make the arbitrary expr tests compile and run without errors
Browse files Browse the repository at this point in the history
  • Loading branch information
bristermitten committed Aug 17, 2023
1 parent 9947180 commit 7aca688
Show file tree
Hide file tree
Showing 9 changed files with 180 additions and 22 deletions.
133 changes: 127 additions & 6 deletions src/Elara/AST/Generic.hs
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ import Elara.AST.Select (LocatedAST, UnlocatedAST)
import Elara.AST.StripLocation (StripLocation (..))
import Elara.Data.Pretty
import GHC.TypeLits
import Relude.Extra (bimapF)
import TODO (todo)
import Prelude hiding (group)

Expand Down Expand Up @@ -296,6 +297,24 @@ instance ToMaybe (Maybe a) (Maybe a) where
instance {-# INCOHERENT #-} ToMaybe a (Maybe a) where
toMaybe = Just

-- Sometimes fields are wrapped in functors eg lists, we need a way of transcending them.
-- This class does that.
-- For example, let's say we have `cleanPattern :: Pattern ast1 -> Pattern ast2`, and `x :: Select ast1 "Pattern"`.
-- `x` could be `Pattern ast1`, `[Pattern ast1]`, `Maybe (Pattern ast1)`, or something else entirely.
-- `cleanPattern` will only work on the first of these, so we need a way of lifting it to the others. Obviously, this sounds like a Functor
-- but the problem is that `Pattern ast1` has the wrong kind.
class ApplyAsFunctorish i o a b where
applyAsFunctorish :: (a -> b) -> i -> o

instance Functor f => ApplyAsFunctorish (f a) (f b) a b where
applyAsFunctorish = fmap

instance ApplyAsFunctorish a b a b where
applyAsFunctorish f = f

instance ApplyAsFunctorish NoFieldValue NoFieldValue a b where
applyAsFunctorish _ = identity

-- | Unwraps 1 level of 'Maybe' from a type. Useful when a type family returns Maybe
type family UnwrapMaybe (a :: Kind.Type) = (k :: Kind.Type) where
UnwrapMaybe (Maybe a) = a
Expand Down Expand Up @@ -372,7 +391,12 @@ instance
pretty (ConstructorPattern c ps) = prettyConstructorPattern c ps
pretty (ListPattern l) = prettyListPattern l
pretty (ConsPattern p1 p2) = prettyConsPattern p1 p2
pretty other = error "aaaaaaa"
pretty WildcardPattern = "_"
pretty (IntegerPattern i) = pretty i
pretty (FloatPattern f) = pretty f
pretty (StringPattern s) = pretty '\"' <> pretty s <> pretty '\"'
pretty (CharPattern c) = "'" <> escapeChar c <> "'"
pretty UnitPattern = "()"

instance
( Pretty (ASTLocate ast (Type' ast))
Expand All @@ -396,21 +420,118 @@ instance

stripExprLocation ::
forall (ast1 :: LocatedAST) (ast2 :: UnlocatedAST).
((ASTLocate' ast1 ~ Located), ASTLocate' ast2 ~ Unlocated) =>
( ASTLocate' ast1 ~ Located
, ASTLocate' ast2 ~ Unlocated
, ApplyAsFunctorish (Select "LambdaPattern" ast1) (Select "LambdaPattern" ast2) (Pattern ast1) (Pattern ast2)
, _
) =>
Expr ast1 ->
Expr ast2
stripExprLocation (Expr (e :: ASTLocate ast1 (Expr' ast1), t)) =
let e' = fmapUnlocated @LocatedAST @ast1 stripExprLocation' e
in -- t' = fmapUnlocated @LocatedAST @ast1 _ t :: ASTLocate ast1 (Select "ExprType" ast2)
Expr (stripLocation e', todo t)
in Expr (stripLocation e', fmap stripTypeLocation t)
where
stripExprLocation' :: forall (ast1 :: LocatedAST) (ast2 :: UnlocatedAST). Expr' ast1 -> Expr' ast2
stripExprLocation' :: Expr' ast1 -> Expr' ast2
stripExprLocation' (Int i) = Int i
stripExprLocation' (Float f) = Float f
stripExprLocation' (String s) = String s
stripExprLocation' (Char c) = Char c
stripExprLocation' Unit = Unit
stripExprLocation' (Var v) = Var (todo)
stripExprLocation' (Var v) = Var (stripLocation v)
stripExprLocation' (Constructor c) = Constructor (stripLocation c)
stripExprLocation' (Lambda ps e) =
let ps' = stripLocation ps
ps'' =
applyAsFunctorish @(Select "LambdaPattern" ast1) @(Select "LambdaPattern" ast2) @(Pattern ast1) @(Pattern ast2)
stripPatternLocation
ps'
in Lambda ps'' (stripExprLocation e)
stripExprLocation' (FunctionCall e1 e2) = FunctionCall (stripExprLocation e1) (stripExprLocation e2)
stripExprLocation' (If e1 e2 e3) = If (stripExprLocation e1) (stripExprLocation e2) (stripExprLocation e3)
stripExprLocation' (BinaryOperator op e1 e2) = BinaryOperator (stripBinaryOperatorLocation op) (stripExprLocation e1) (stripExprLocation e2)
stripExprLocation' (List l) = List (stripExprLocation <$> l)
stripExprLocation' (Match e m) = Match (stripExprLocation e) (bimapF stripPatternLocation stripExprLocation m)
stripExprLocation' (LetIn v p e1 e2) =
let p' = stripLocation p
p'' =
applyAsFunctorish @(Select "LetPattern" ast1) @(Select "LetPattern" ast2) @(Pattern ast1) @(Pattern ast2)
stripPatternLocation
p'
in LetIn
(stripLocation v)
p''
(stripExprLocation e1)
(stripExprLocation e2)
stripExprLocation' (Let v p e) =
let p' = stripLocation p
p'' =
applyAsFunctorish @(Select "LetPattern" ast1) @(Select "LetPattern" ast2) @(Pattern ast1) @(Pattern ast2)
stripPatternLocation
p'
in Let (stripLocation v) p'' (stripExprLocation e)
stripExprLocation' (Block b) = Block (stripExprLocation <$> b)
stripExprLocation' (InParens e) = InParens (stripExprLocation e)
stripExprLocation' (Tuple t) = Tuple (stripExprLocation <$> t)

stripPatternLocation ::
forall (ast1 :: LocatedAST) (ast2 :: UnlocatedAST).
( (ASTLocate' ast1 ~ Located)
, ASTLocate' ast2 ~ Unlocated
, _
) =>
Pattern ast1 ->
Pattern ast2
stripPatternLocation (Pattern (p :: ASTLocate ast1 (Pattern' ast1), t)) =
let p' = fmapUnlocated @LocatedAST @ast1 stripPatternLocation' p
in Pattern (stripLocation p', fmap stripTypeLocation t)
where
stripPatternLocation' :: Pattern' ast1 -> Pattern' ast2
stripPatternLocation' (VarPattern v) = VarPattern (stripLocation v)
stripPatternLocation' (ConstructorPattern c ps) = ConstructorPattern (stripLocation c) (stripPatternLocation <$> ps)
stripPatternLocation' (ListPattern l) = ListPattern (stripPatternLocation <$> l)
stripPatternLocation' (ConsPattern p1 p2) = ConsPattern (stripPatternLocation p1) (stripPatternLocation p2)
stripPatternLocation' WildcardPattern = WildcardPattern
stripPatternLocation' (IntegerPattern i) = IntegerPattern i
stripPatternLocation' (FloatPattern f) = FloatPattern f
stripPatternLocation' (StringPattern s) = StringPattern s
stripPatternLocation' (CharPattern c) = CharPattern c
stripPatternLocation' UnitPattern = UnitPattern

stripBinaryOperatorLocation ::
forall (ast1 :: LocatedAST) (ast2 :: UnlocatedAST).
( (ASTLocate' ast1 ~ Located)
, ASTLocate' ast2 ~ Unlocated
, _
) =>
BinaryOperator ast1 ->
BinaryOperator ast2
stripBinaryOperatorLocation (MkBinaryOperator (op :: ASTLocate ast1 (BinaryOperator' ast1))) =
let op' = fmapUnlocated @LocatedAST @ast1 stripBinaryOperatorLocation' op
in MkBinaryOperator (stripLocation op')
where
stripBinaryOperatorLocation' :: BinaryOperator' ast1 -> BinaryOperator' ast2
stripBinaryOperatorLocation' (SymOp name) = SymOp (stripLocation name)
stripBinaryOperatorLocation' (Infixed name) = Infixed (stripLocation name)

stripTypeLocation ::
forall (ast1 :: LocatedAST) (ast2 :: UnlocatedAST).
( (ASTLocate' ast1 ~ Located)
, ASTLocate' ast2 ~ Unlocated
, _
) =>
Type ast1 ->
Type ast2
stripTypeLocation (Type (t :: ASTLocate ast1 (Type' ast1))) =
let t' = fmapUnlocated @LocatedAST @ast1 stripTypeLocation' t
in Type (stripLocation t')
where
stripTypeLocation' :: Type' ast1 -> Type' ast2
stripTypeLocation' (TypeVar name) = TypeVar (rUnlocate @LocatedAST @ast1 name)
stripTypeLocation' (FunctionType a b) = FunctionType (stripTypeLocation a) (stripTypeLocation b)
stripTypeLocation' UnitType = UnitType
stripTypeLocation' (TypeConstructorApplication a b) = TypeConstructorApplication (stripTypeLocation a) (stripTypeLocation b)

-- stripTypeLocation' (UserDefinedType name) = UserDefinedType (_ name)

{- =====================
Messy deriving stuff
Expand Down
17 changes: 8 additions & 9 deletions src/Elara/AST/StripLocation.hs
Original file line number Diff line number Diff line change
@@ -1,34 +1,33 @@
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE FunctionalDependencies #-}
{-# LANGUAGE UndecidableInstances #-}

module Elara.AST.StripLocation where

import Elara.AST.Region (Located (Located), SourceRegion)

class StripLocation a b | a -> b where
class StripLocation a b where
stripLocation :: a -> b

instance {-# OVERLAPPABLE #-} (a ~ b) => StripLocation a b where
stripLocation = identity

instance StripLocation (Located a) a where
stripLocation :: Located a -> a
stripLocation (Located _ a) = a

instance {-# INCOHERENT #-} StripLocation SourceRegion () where
stripLocation _ = ()

instance {-# OVERLAPPABLE #-} (StripLocation a1 b1, StripLocation b1 b2) => StripLocation a1 b2 where
stripLocation b = stripLocation (stripLocation b)

-- We could provide a general Functor instance but the overlapping tends to cause problems

instance (StripLocation a a', StripLocation b b') => StripLocation (a, b) (a', b') where
instance {-# OVERLAPPABLE #-} (StripLocation a a', StripLocation b b') => StripLocation (a, b) (a', b') where
stripLocation (a, b) = (stripLocation a, stripLocation b)

instance (StripLocation a a') => StripLocation (Maybe a) (Maybe a') where
instance {-# OVERLAPPABLE #-} (StripLocation a a') => StripLocation (Maybe a) (Maybe a') where
stripLocation = fmap stripLocation

instance (StripLocation a a') => StripLocation [a] [a'] where
instance {-# OVERLAPPABLE #-} (StripLocation a a') => StripLocation [a] [a'] where
stripLocation = fmap stripLocation

instance (StripLocation a a') => StripLocation (NonEmpty a) (NonEmpty a') where
instance {-# OVERLAPPABLE #-} (StripLocation a a') => StripLocation (NonEmpty a) (NonEmpty a') where
stripLocation = fmap stripLocation
2 changes: 1 addition & 1 deletion src/Elara/AST/Unlocated.hs
Original file line number Diff line number Diff line change
Expand Up @@ -18,5 +18,5 @@ type family Replace (needle :: LocatedAST) (replacement :: UnlocatedAST) (haysta
Replace needle replacement (DeclarationBody needle) = DeclarationBody replacement
Replace needle replacement (DeclarationBody' needle) = DeclarationBody' replacement
Replace needle replacement [list] = [Replace needle replacement list]
-- Replace needle replacement (Select f needle) = Select f replacement
Replace needle replacement (Maybe maybe) = Maybe (Replace needle replacement maybe)
Replace needle replacement other = other
2 changes: 1 addition & 1 deletion src/Elara/Parse/Expression.hs
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ import Text.Megaparsec (MonadParsec (eof), customFailure, sepEndBy)
import Prelude hiding (Op)

locatedExpr :: HParser FrontendExpr' -> HParser FrontendExpr
locatedExpr = fmap (\x -> Expr (x, _)) . (H.parse . located . H.toParsec)
locatedExpr = fmap (\x -> Expr (x, Nothing)) . (H.parse . located . H.toParsec)

exprParser :: HParser FrontendExpr
exprParser =
Expand Down
6 changes: 3 additions & 3 deletions src/Elara/ToCore.hs
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,12 @@ import Data.Map qualified as M
import Data.Traversable (for)
import Elara.AST.Generic as AST
import Elara.AST.Module (Module (Module))
import Elara.AST.Name (NameLike (..), Qualified (..), TypeName)
import Elara.AST.Name (NameLike (..), Qualified (..), TypeName, VarName)
import Elara.AST.Region (Located (Located), SourceRegion, unlocated)
import Elara.AST.Select (LocatedAST (Typed))
import Elara.AST.StripLocation
import Elara.AST.Typed
import Elara.AST.VarRef (UnlocatedVarRef, VarRef' (Global, Local), varRefVal)
import Elara.AST.VarRef (UnlocatedVarRef, VarRef, VarRef' (Global, Local), varRefVal)
import Elara.Core as Core
import Elara.Core.Module (CoreDeclaration (..), CoreModule (..))
import Elara.Data.Pretty (Pretty (..))
Expand Down Expand Up @@ -152,7 +152,7 @@ toCore le@(Expr (Located _ e, t)) = toCore' e
AST.Var (Located _ v) -> do
t' <- typeToCore t

pure $ Core.Var (Core.Id (nameText <$> stripLocation v) t')
pure $ Core.Var (Core.Id (nameText <$> stripLocation @(VarRef VarName) @(UnlocatedVarRef VarName) v) t')
AST.Constructor v -> do
ctor <- lookupCtor v
pure $ Core.Var (conToVar ctor)
Expand Down
5 changes: 5 additions & 0 deletions src/Elara/TypeInfer/Infer.hs
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ import Polysemy.Error (Error, throw)
import Polysemy.State (State)
import Polysemy.State qualified as State
import Prettyprinter qualified as Pretty
import Print (showPretty)

-- | Type-checking state
data Status = Status
Expand Down Expand Up @@ -1251,6 +1252,9 @@ infer (Syntax.Expr (Located location e0, _)) = case e0 of
Syntax.Unit -> do
let t = (Type.Scalar{scalar = Monotype.Unit, ..})
pure $ Expr (Located location Unit, t)
Syntax.Char c -> do
let t = (Type.Scalar{scalar = Monotype.Char, ..})
pure $ Expr (Located location (Char c), t)
Syntax.LetIn name NoFieldValue val body -> do
-- TODO: infer whether a let-in is recursive or not
-- insert a new unsolved type variable for the let-in to make recursive let-ins possible
Expand Down Expand Up @@ -1281,6 +1285,7 @@ infer (Syntax.Expr (Located location e0, _)) = case e0 of

let t = (Type.Tuple{tupleArguments = Syntax.typeOf <$> elementTypes, ..})
pure $ Expr (Located location (Syntax.Tuple elementTypes), t)
other -> error $ "infer: " <> showPretty other

-- -- Anno
-- Syntax.Annotation{..} -> do
Expand Down
9 changes: 9 additions & 0 deletions src/Elara/TypeInfer/Monotype.hs
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,15 @@ data Scalar
-- >>> pretty Text
-- Text
Text
| -- | Char type
--
-- >>> pretty Char
-- Char
Char
| -- | Unit type
--
-- >>> pretty Unit
-- Unit
Unit
deriving stock (Eq, Generic, Show, Data)

Expand All @@ -92,6 +100,7 @@ instance Pretty Scalar where
pretty Integer = "Integer"
pretty Text = "Text"
pretty Unit = "Unit"
pretty Char = "Char"

instance ToJSON Scalar

Expand Down
26 changes: 25 additions & 1 deletion test/Infer.hs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ import Test.Hspec
import Prelude hiding (fail)

spec :: Spec
spec = describe "Infers types correctly" $ do
spec = describe "Infers types correctly" $ parallel $ do
simpleTypes
functionTypes

Expand All @@ -20,6 +20,30 @@ simpleTypes = describe "Infers simple types correctly" $ do
Scalar () Scalar.Integer -> pass
o -> fail o

it "Infers Unit literals correctly" $ do
(t, fail) <- inferSpec "()" "()"
case t of
Scalar () Scalar.Unit -> pass
o -> fail o

it "Infers Real literals correctly" $ do
(t, fail) <- inferSpec "1.0" "Real"
case t of
Scalar () Scalar.Real -> pass
o -> fail o

it "Infers Text literals correctly" $ do
(t, fail) <- inferSpec "\"hello\"" "Text"
case t of
Scalar () Scalar.Text -> pass
o -> fail o

it "Infers Char literals correctly" $ do
(t, fail) <- inferSpec "'c'" "Text"
case t of
Scalar () Scalar.Char -> pass
o -> fail o

functionTypes :: Spec
functionTypes = describe "Infers function types correctly" $ do
it "Infers identity function correctly" $ do
Expand Down
2 changes: 1 addition & 1 deletion test/Parse.hs
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ import Test.QuickCheck

spec :: Spec
spec = do
-- quickCheckSpec
quickCheckSpec
pass

quickCheckSpec :: Spec
Expand Down

0 comments on commit 7aca688

Please sign in to comment.