diff --git a/glean/db/Glean/Query/Typecheck/Unify.hs b/glean/db/Glean/Query/Typecheck/Unify.hs index d02766569..dddcdcc54 100644 --- a/glean/db/Glean/Query/Typecheck/Unify.hs +++ b/glean/db/Glean/Query/Typecheck/Unify.hs @@ -17,6 +17,7 @@ import Control.Monad.Except import Control.Monad.State import qualified Data.IntMap as IntMap import qualified Data.Map as Map +import qualified Data.Map.Merge.Lazy as Map import Data.Maybe import qualified Data.Text as Text import Compat.Prettyprinter hiding ((<>), enclose) @@ -90,11 +91,23 @@ unify a@(HasTy fa ra x) b@(HasTy fb rb y) (Just x, Just y) | x /= y -> unifyError a b (Nothing, _) -> return rb _otherwise -> return ra - mapM_ (uncurry unify) $ Map.intersectionWith (,) fa fb - z <- freshTyVarInt - let all = HasTy (Map.union fa fb) rec z - extend x all - extend y all + union <- Map.mergeA + Map.preserveMissing + Map.preserveMissing + (Map.zipWithAMatched $ \_ a b -> do unify a b; return a) + fa fb + -- if either a or b is the same as the unified type, avoid creating + -- a new type variable. + let size = Map.size union + if size == Map.size fa && ra == rec + then extend y a + else if size == Map.size fb && rb == rec + then extend x b + else do + z <- freshTyVarInt + let all = HasTy union rec z + extend x all + extend y all unify a@(HasTy _ (Just Sum) _) b@RecordTy{} = unifyError a b @@ -103,9 +116,8 @@ unify a@(HasTy m _ x) b@(RecordTy fs) = do case Map.lookup f m of Nothing -> return () Just ty' -> unify ty ty' - forM_ (Map.keys m) $ \n -> - when (n `notElem` map fieldDefName fs) $ - unifyError a b + when (not (Map.null (foldr (Map.delete . fieldDefName) m fs))) $ + unifyError a b extend x (RecordTy fs) unify a@(HasTy _ (Just Record) _) b@SumTy{} =