Skip to content

Commit

Permalink
add a let polymorphism test & tuple types
Browse files Browse the repository at this point in the history
  • Loading branch information
bristermitten committed Aug 17, 2023
1 parent c4ff6e8 commit 59958f7
Show file tree
Hide file tree
Showing 10 changed files with 99 additions and 21 deletions.
6 changes: 3 additions & 3 deletions build/Main.core.elr
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
module Main
{ Main.y : r_2 -> r_2
= let id_0 : r_3 -> r_3
= \(x_1 : r_2) -> x_1 in id_0 }
{ Main.y : l0_2 -> l0_2
= let id_0 : l0_5 -> l0_5
= \(x_1 : l0_2) -> x_1 in id_0 id_0 }
2 changes: 1 addition & 1 deletion build/Main.shunted.elr
Original file line number Diff line number Diff line change
Expand Up @@ -3,4 +3,4 @@ module Main exposing (..)
import Prelude exposing (..)

let Main.y =
(let id_0 = \x_1 -> x_1 in id_0 id_0 )
(let id_0 = \x_1 -> x_1 in (id_0 1 , id_0 () ) )
10 changes: 6 additions & 4 deletions build/Main.typed.elr
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,10 @@ module Main exposing (..)

import Prelude exposing (..)

def Main.y : forall (r : Type) . r -> r
def Main.y : (Integer, Unit)
let Main.y =
(let id_0 = \x_1 -> x_1 : r : r -> r in id_0 : r ->
r) : forall (r : Type) .
r -> r
(let id_0 = \x_1 -> x_1 : b0 : forall (b0 : Type) .
b0 -> b0 in (id_0 : forall (d : Type) .
d -> d
1 : Integer : Integer, id_0 : forall (d : Type) . d -> d
() : Unit : Unit) : (Integer, Unit)) : (Integer, Unit)
4 changes: 4 additions & 0 deletions src/Elara/AST/Generic.hs
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,9 @@ data Expr' (ast :: a)
newtype Expr (ast :: a) = Expr (ASTLocate ast (Expr' ast), Select "ExprType" ast)
deriving (Generic)

typeOf :: forall ast. Expr ast -> Select "ExprType" ast
typeOf (Expr (_, t)) = t

data Pattern' (ast :: a)
= VarPattern (ASTLocate ast (Select "VarPat" ast))
| ConstructorPattern (ASTLocate ast (Select "ConPat" ast)) [Pattern ast]
Expand Down Expand Up @@ -342,6 +345,7 @@ instance
pretty (Let v p e) = prettyLetExpr v (maybeToList $ toMaybe p :: [a2]) e
pretty (Block b) = prettyBlockExpr b
pretty (InParens e) = parens (pretty e)
pretty (Tuple t) = prettyTupleExpr t

instance
( Pretty a1
Expand Down
3 changes: 3 additions & 0 deletions src/Elara/AST/Pretty.hs
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,9 @@ prettyBinaryOperatorExpr e1 o e2 = parens (pretty e1 <+> pretty o <+> pretty e2)
prettyListExpr :: (Pretty a) => [a] -> Doc AnsiStyle
prettyListExpr l = list (pretty <$> l)

prettyTupleExpr :: (Pretty a) => NonEmpty a -> Doc AnsiStyle
prettyTupleExpr l = parens (hsep (punctuate "," (pretty <$> toList l)))

prettyMatchExpr :: (Pretty a1, Pretty a2, Foldable t) => a1 -> t a2 -> Doc AnsiStyle
prettyMatchExpr e m = parens ("match" <+> pretty e <+> "with" <+> prettyBlockExpr m)

Expand Down
5 changes: 4 additions & 1 deletion src/Elara/TypeInfer.hs
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ import Polysemy hiding (transform)
import Polysemy.Error (Error, mapError, throw)
import Polysemy.State
import Print
import qualified Data.List.NonEmpty as NonEmpty

inferModule ::
forall r.
Expand Down Expand Up @@ -266,6 +267,8 @@ completeExpression ctx (Expr (y', t)) = do
(Infer.Scalar{}, Infer.Scalar{}) -> pass -- Scalars are always the same
(Infer.Custom{typeArguments = unsolvedArgs}, Infer.Custom{typeArguments = solvedArgs}) -> do
traverse_ (uncurry unify) (zip unsolvedArgs solvedArgs)
(Infer.Tuple{tupleArguments = unsolvedArgs}, Infer.Tuple{tupleArguments = solvedArgs}) -> do
traverse_ (uncurry unify) (NonEmpty.zip unsolvedArgs solvedArgs)
other -> error (showPretty other)

stripForAlls :: Type SourceRegion -> Type SourceRegion
Expand All @@ -277,7 +280,6 @@ completeExpression ctx (Expr (y', t)) = do
subst Infer.UnsolvedType{existential} solved = do
let annotation = SolvedType existential (toMonoType solved)
push annotation
pass
subst _ _ = pass

toMonoType :: Type SourceRegion -> Mono.Monotype
Expand All @@ -287,4 +289,5 @@ completeExpression ctx (Expr (y', t)) = do
Infer.List{type_} -> Mono.List (toMonoType type_)
Infer.UnsolvedType{existential} -> Mono.UnsolvedType existential
Infer.VariableType{name = v} -> Mono.VariableType v
Infer.Custom{name = n, typeArguments = args} -> Mono.Custom n (toMonoType <$> args)
other -> error $ "toMonoType: " <> showPretty other
36 changes: 33 additions & 3 deletions src/Elara/TypeInfer/Infer.hs
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ import Elara.TypeInfer.Existential (Existential)
import Elara.TypeInfer.Monotype (Monotype)
import Elara.TypeInfer.Type (Type (..))

import Data.List.NonEmpty qualified as NE
import Data.Map qualified as Map
import Data.Text qualified as Text
import Data.Traversable (for)
Expand All @@ -50,7 +51,6 @@ import Polysemy.Error (Error, throw)
import Polysemy.State (State)
import Polysemy.State qualified as State
import Prettyprinter qualified as Pretty
import Print (debugPretty, showPretty)

-- | Type-checking state
data Status = Status
Expand Down Expand Up @@ -155,7 +155,7 @@ wellFormedType ::
Context SourceRegion ->
Type SourceRegion ->
Sem r ()
wellFormedType _Γ type0 =
wellFormedType _Γ type0 =do
case type0 of
-- UvarWF
Type.VariableType{..}
Expand Down Expand Up @@ -203,6 +203,7 @@ wellFormedType _Γ type0 =
| Context.Variable Domain.Alternatives a `elem`-> traverse_ (\(_, _A) -> wellFormedType _Γ _A) kAs
| otherwise -> throw (UnboundAlternatives location a)
Type.Scalar{} -> pass
Type.Tuple _ ts -> traverse_ (wellFormedType _Γ) ts

{- | This corresponds to the judgment:
Expand Down Expand Up @@ -305,6 +306,12 @@ subtype _A0 _B0 = do
| s0 == s1 -> pass
(Type.Optional{type_ = _A}, Type.Optional{type_ = _B}) -> subtype _A _B
(Type.List{type_ = _A}, Type.List{type_ = _B}) -> subtype _A _B
(Type.Tuple{tupleArguments = typesA}, Type.Tuple{tupleArguments = typesB}) -> do
when (length typesA /= length typesB) do
error "Tuple types must have the same number of elements"

for_ (NE.zip typesA typesB) \(a, b) -> do
subtype a b
-- This is where you need to add any non-trivial subtypes. For example,
-- the following three rules specify that `Natural` is a subtype of
-- `Integer`, which is in turn a subtype of `Real`.
Expand Down Expand Up @@ -1241,6 +1248,9 @@ infer (Syntax.Expr (Located location e0, _)) = case e0 of
Syntax.String s -> do
let t = (Type.Scalar{scalar = Monotype.Text, ..})
pure $ Expr (Located location (String s), t)
Syntax.Unit -> do
let t = (Type.Scalar{scalar = Monotype.Unit, ..})
pure $ Expr (Located location Unit, t)
Syntax.LetIn name NoFieldValue val body -> do
-- TODO: infer whether a let-in is recursive or not
-- insert a new unsolved type variable for the let-in to make recursive let-ins possible
Expand All @@ -1261,6 +1271,16 @@ infer (Syntax.Expr (Located location e0, _)) = case e0 of

body'@(Expr (_, bodyType)) <- infer body
pure $ Expr (Located location (LetIn name NoFieldValue val' body'), bodyType)
Syntax.Tuple elements -> do
let process element = do
<- get

infer element

elementTypes <- traverse process elements

let t = (Type.Tuple{tupleArguments = Syntax.typeOf <$> elementTypes, ..})
pure $ Expr (Located location (Syntax.Tuple elementTypes), t)

-- -- Anno
-- Syntax.Annotation{..} -> do
Expand Down Expand Up @@ -1772,7 +1792,7 @@ check ::
ShuntedExpr ->
Type SourceRegion ->
Sem r TypedExpr
check expr t = do
check expr@(Expr (Located exprLoc _, _)) t = do
let x = expr ^. _Unwrapped . _1 . unlocated
check' x t
where
Expand Down Expand Up @@ -1819,6 +1839,16 @@ check expr t = do
check' e Type.Forall{..} = scoped (Context.Variable domain name) do
check' e type_

check' (Syntax.Tuple elements) Type.Tuple{tupleArguments} = do
let process (element, type_) = do
<- get

check element (Context.solveType _Γ type_)

y <- traverse process (NE.zip elements tupleArguments)

pure $ Expr (Located exprLoc (Syntax.Tuple y), t)

-- Sub
check' _ _B = do
_A@(Syntax.Expr (_, _At)) <- infer expr
Expand Down
33 changes: 33 additions & 0 deletions src/Elara/TypeInfer/Type.hs
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,11 @@ data Type s
-- >>> pretty @(Type ()) (Custom () "Maybe" ["a"])
-- Maybe a
Custom {location :: s, name :: Text, typeArguments :: [Type s]}
| -- | A tuple
--
-- >>> pretty @(Type ()) (Tuple () ["a", "b"])
-- (a, b)
Tuple {location :: s, tupleArguments :: NonEmpty (Type s)}
deriving stock (Eq, Functor, Generic, Show, Data)

instance (Show c, ToJSON c) => ToJSON (Type c) where
Expand Down Expand Up @@ -149,6 +154,9 @@ instance Plated (Type s) where
Custom{typeArguments = oldTypeArguments, ..} -> do
newTypeArguments <- traverse onType oldTypeArguments
pure Custom{typeArguments = newTypeArguments, ..}
Tuple{tupleArguments = oldTypeArguments, ..} -> do
newTypeArguments <- traverse onType oldTypeArguments
pure Tuple{tupleArguments = newTypeArguments, ..}

-- | A potentially polymorphic record type
data Record s = Fields [(Text, Type s)] RemainingFields
Expand Down Expand Up @@ -296,6 +304,8 @@ substituteType a n _A type_ =
Scalar{..}
Custom{typeArguments = oldTypeArguments, ..} ->
Custom{typeArguments = fmap (substituteType a n _A) oldTypeArguments, ..}
Tuple{tupleArguments = oldTypeArguments, ..} ->
Tuple{tupleArguments = fmap (substituteType a n _A) oldTypeArguments, ..}

{- | Replace all occurrences of a variable within one `Type` with another `Type`,
given the variable's label and index
Expand Down Expand Up @@ -346,6 +356,8 @@ substituteFields ρ0 n r@(Fields kτs ρ1) type_ =
Scalar{..}
Custom{typeArguments = oldTypeArguments, ..} ->
Custom{typeArguments = fmap (substituteFields ρ0 n r) oldTypeArguments, ..}
Tuple{tupleArguments = oldTypeArguments, ..} ->
Tuple{tupleArguments = fmap (substituteFields ρ0 n r) oldTypeArguments, ..}

{- | Replace all occurrences of a variable within one `Type` with another `Type`,
given the variable's label and index
Expand Down Expand Up @@ -396,6 +408,8 @@ substituteAlternatives ρ0 n r@(Alternatives kτs ρ1) type_ =
Scalar{..}
Custom{typeArguments = oldTypeArguments, ..} ->
Custom{typeArguments = fmap (substituteAlternatives ρ0 n r) oldTypeArguments, ..}
Tuple{tupleArguments = oldTypeArguments, ..} ->
Tuple{tupleArguments = fmap (substituteAlternatives ρ0 n r) oldTypeArguments, ..}

{- | Count how many times the given `Existential` `Type` variable appears within
a `Type`
Expand Down Expand Up @@ -603,6 +617,25 @@ prettyPrimitiveType Custom{..} =
Pretty.align
( foldMap (\t -> prettyQuantifiedType t <> Pretty.hardline) typeArguments
)
prettyPrimitiveType Tuple{..} =
Pretty.group (Pretty.flatAlt long short)
where
short =
punctuation "("
<> prettyQuantifiedType (head tupleArguments)
<> foldMap (\t -> punctuation "," <> " " <> prettyQuantifiedType t) (tail tupleArguments)
<> punctuation ")"

long =
Pretty.align
( punctuation "("
<> " "
<> prettyQuantifiedType (head tupleArguments)
<> foldMap (\t -> punctuation "," <> " " <> prettyQuantifiedType t) (tail tupleArguments)
<> Pretty.hardline
<> punctuation ")"
)

prettyPrimitiveType other =
Pretty.group (Pretty.flatAlt long short)
where
Expand Down
11 changes: 5 additions & 6 deletions test/Infer.hs
Original file line number Diff line number Diff line change
Expand Up @@ -53,9 +53,8 @@ functionTypes = describe "Infers function types correctly" $ do
) | a == a' && b == b' && b == b'' -> pass
o -> fail o

-- it "Infers polymorphic lets correctly" $ do
-- (t, fail) <- inferSpec "let id = \\x -> x in (id 1, id ())" "(Int, ())"
-- case t of
-- VariableType' "x" -> pass
-- -- (Tuple' [VariableType' "Int", Unit']) -> pass
-- o -> fail o
it "Infers polymorphic lets correctly" $ do
(t, fail) <- inferSpec "let id = \\x -> x in (id 1, id ())" "(Int, ())"
case t of
Tuple' (Scalar () Scalar.Integer :| [Scalar () Scalar.Unit]) -> pass
o -> fail o
10 changes: 7 additions & 3 deletions test/Infer/Common.hs
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ import Elara.TypeInfer.Error (TypeInferenceError)
import Elara.TypeInfer.Infer (initialStatus)
import Elara.TypeInfer.Infer qualified as Infer
import Elara.TypeInfer.Type (Type (..))
import Elara.TypeInfer.Type qualified as Type
import Error.Diagnose (Diagnostic, TabSize (..), WithUnicode (WithUnicode), defaultStyle, printDiagnostic)
import Parse.Common (lexAndParse)
import Polysemy
Expand All @@ -46,6 +47,10 @@ pattern Function' a b = Function () a b
pattern VariableType' :: Text -> Type ()
pattern VariableType' name = VariableType () name


pattern Tuple' :: NonEmpty (Type ()) -> Type ()
pattern Tuple' ts = Type.Tuple () ts

completeInference :: Member (State Infer.Status) r => TypedExpr -> Sem r TypedExpr
completeInference x = do
ctx <- Infer.getAll
Expand Down Expand Up @@ -96,14 +101,13 @@ diagShouldSucceed (d, x) = do
Just x -> pure x
Nothing -> error "Expected successful inference"

typeOf :: forall {a} {ast :: a} {b}. StripLocation (Select "ExprType" ast) b => Expr ast -> b
typeOf (Expr (_, t)) = stripLocation t


typeOf' :: MonadIO m => Text -> m (Type ())
typeOf' msg = do
x <- liftIO $ runInferPipeline msg
y <- liftIO $ diagShouldSucceed x
pure $ typeOf y
pure $ stripLocation $ typeOf y

failTypeMismatch :: Pretty a1 => String -> String -> a1 -> IO a2
failTypeMismatch name expected actual =
Expand Down

0 comments on commit 59958f7

Please sign in to comment.