hnix/src/Nix/Type/Infer.hs

511 lines
17 KiB
Haskell

{-# LANGUAGE ExistentialQuantification #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE GeneralizedNewtypeDeriving #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE StandaloneDeriving #-}
{-# LANGUAGE TypeSynonymInstances #-}
{-# OPTIONS_GHC -Wno-name-shadowing #-}
module Nix.Type.Infer (
Constraint(..),
TypeError(..),
Subst(..),
inferTop
) where
import Nix.Atoms
import Nix.Convert
import Nix.Eval (MonadEval(..))
import qualified Nix.Eval as Eval
import Nix.Expr.Types
import Nix.Expr.Types.Annotated
import Nix.Frames (Frame)
import Nix.Scope
import Nix.Thunk
import qualified Nix.Type.Assumption as As
import Nix.Type.Env
import qualified Nix.Type.Env as Env
import Nix.Type.Type
import Nix.Utils
import Control.Applicative
import Control.Arrow
import Control.Monad.Except
import Control.Monad.Reader
import Control.Monad.State
import Data.Fix
import Data.Foldable
import Data.List (delete, find, nub, intersect, (\\))
import Data.Map (Map)
import qualified Data.Map as Map
import Data.Maybe (fromJust)
import Data.Semigroup
import qualified Data.HashMap.Lazy as M
import qualified Data.Set as Set
import Data.Text (Text)
-------------------------------------------------------------------------------
-- Classes
-------------------------------------------------------------------------------
-- | Inference monad
newtype Infer a = Infer
{ getInfer ::
ReaderT (Set.Set TVar, Scopes Infer Judgment) -- Monomorphic set
(StateT InferState -- Inference state
(Except TypeError)) -- Inference errors
a -- Result
}
deriving (Functor, Applicative, Alternative, Monad, MonadPlus, MonadFix,
MonadReader (Set.Set TVar, Scopes Infer Judgment),
MonadState InferState, MonadError TypeError)
-- | Inference state
newtype InferState = InferState { count :: Int }
-- | Initial inference state
initInfer :: InferState
initInfer = InferState { count = 0 }
data Constraint
= EqConst Type Type
| EqConstOneOf Type [Type]
-- ^ The first type must unify with the second. For example, integer
-- could unify with integer, or a type variable.
| ExpInstConst Type Scheme
| ImpInstConst Type (Set.Set TVar) Type
deriving (Show, Eq, Ord)
newtype Subst = Subst (Map TVar Type)
deriving (Eq, Ord, Show, Semigroup, Monoid)
class Substitutable a where
apply :: Subst -> a -> a
instance Substitutable TVar where
apply (Subst s) a = tv
where t = TVar a
(TVar tv) = Map.findWithDefault t a s
instance Substitutable Type where
apply _ (TCon a) = TCon a
apply s (TSet a) = TSet (M.map (apply s) a)
apply s (TSubSet a) = TSubSet (M.map (apply s) a)
apply s (TList a) = TList (map (apply s) a)
apply (Subst s) t@(TVar a) = Map.findWithDefault t a s
apply s (t1 `TArr` t2) = apply s t1 `TArr` apply s t2
instance Substitutable Scheme where
apply (Subst s) (Forall as t) = Forall as $ apply s' t
where s' = Subst $ foldr Map.delete s as
instance Substitutable Constraint where
apply s (EqConst t1 t2) = EqConst (apply s t1) (apply s t2)
apply s (EqConstOneOf t1 t2) = EqConstOneOf (apply s t1) (apply s t2)
apply s (ExpInstConst t sc) = ExpInstConst (apply s t) (apply s sc)
apply s (ImpInstConst t1 ms t2) = ImpInstConst (apply s t1) (apply s ms) (apply s t2)
instance Substitutable a => Substitutable [a] where
apply = map . apply
instance (Ord a, Substitutable a) => Substitutable (Set.Set a) where
apply = Set.map . apply
class FreeTypeVars a where
ftv :: a -> Set.Set TVar
instance FreeTypeVars Type where
ftv TCon{} = Set.empty
ftv (TVar a) = Set.singleton a
ftv (TSet a) = Set.unions (map ftv (M.elems a))
ftv (TSubSet a) = Set.unions (map ftv (M.elems a))
ftv (TList a) = Set.unions (map ftv a)
ftv (t1 `TArr` t2) = ftv t1 `Set.union` ftv t2
instance FreeTypeVars TVar where
ftv = Set.singleton
instance FreeTypeVars Scheme where
ftv (Forall as t) = ftv t `Set.difference` Set.fromList as
instance FreeTypeVars a => FreeTypeVars [a] where
ftv = foldr (Set.union . ftv) Set.empty
instance (Ord a, FreeTypeVars a) => FreeTypeVars (Set.Set a) where
ftv = foldr (Set.union . ftv) Set.empty
class ActiveTypeVars a where
atv :: a -> Set.Set TVar
instance ActiveTypeVars Constraint where
atv (EqConst t1 t2) = ftv t1 `Set.union` ftv t2
atv (EqConstOneOf t1 t2) = ftv t1 `Set.union` ftv t2
atv (ImpInstConst t1 ms t2) = ftv t1 `Set.union` (ftv ms `Set.intersection` ftv t2)
atv (ExpInstConst t s) = ftv t `Set.union` ftv s
instance ActiveTypeVars a => ActiveTypeVars [a] where
atv = foldr (Set.union . atv) Set.empty
data TypeError
= UnificationFail Type Type
| InfiniteType TVar Type
| UnboundVariable Text
| Ambigious [Constraint]
| UnificationMismatch [Type] [Type]
| forall s. Frame s => EvaluationError s
| InferenceAborted
deriving instance Show TypeError
instance Semigroup TypeError where
x <> _ = x
instance Monoid TypeError where
mempty = InferenceAborted
mappend = (<>)
-------------------------------------------------------------------------------
-- Inference
-------------------------------------------------------------------------------
-- | Run the inference monad
runInfer :: Infer a -> Either TypeError a
runInfer m =
runExcept $ evalStateT (runReaderT (getInfer m) (Set.empty, emptyScopes)) initInfer
inferType :: Env -> NExpr -> Infer (Subst, Type)
inferType env ex = do
Judgment as cs t <- infer ex
let unbounds = Set.fromList (As.keys as) `Set.difference` Set.fromList (Env.keys env)
unless (Set.null unbounds) $ throwError $ UnboundVariable (Set.findMin unbounds)
let cs' = [ExpInstConst t s | (x, s) <- Env.toList env, t <- As.lookup x as]
subst <- solve (cs ++ cs')
return (subst, apply subst t)
-- | Solve for the toplevel type of an expression in a given environment
inferExpr :: Env -> NExpr -> Either TypeError Scheme
inferExpr env ex = case runInfer (inferType env ex) of
Left err -> Left err
Right (subst, ty) -> Right $ closeOver $ apply subst ty
-- | Canonicalize and return the polymorphic toplevel type.
closeOver :: Type -> Scheme
closeOver = normalize . generalize Set.empty
extendMSet :: TVar -> Infer a -> Infer a
extendMSet x = Infer . local (first (Set.insert x)) . getInfer
letters :: [String]
letters = [1..] >>= flip replicateM ['a'..'z']
fresh :: Infer Type
fresh = Infer $ do
s <- get
put s{count = count s + 1}
return $ TVar $ TV (letters !! count s)
instantiate :: Scheme -> Infer Type
instantiate (Forall as t) = do
as' <- mapM (const fresh) as
let s = Subst $ Map.fromList $ zip as as'
return $ apply s t
generalize :: Set.Set TVar -> Type -> Scheme
generalize free t = Forall as t
where as = Set.toList $ ftv t `Set.difference` free
unops :: Type -> NUnaryOp -> [Constraint]
unops u1 = \case
NNot -> [ EqConst u1 ( typeFun [typeBool, typeBool] ) ]
NNeg -> [ EqConstOneOf u1 [ typeFun [typeInt, typeInt]
, typeFun [typeFloat, typeFloat] ] ]
binops :: Type -> NBinaryOp -> [Constraint]
binops u1 = \case
NApp -> [] -- this is handled separately
-- Equality tells you nothing about the types, because any two types are
-- allowed.
NEq -> []
NNEq -> []
NGt -> inequality
NGte -> inequality
NLt -> inequality
NLte -> inequality
NAnd -> [ EqConst u1 ( typeFun [typeBool, typeBool, typeBool]) ]
NOr -> [ EqConst u1 ( typeFun [typeBool, typeBool, typeBool]) ]
NImpl -> [ EqConst u1 ( typeFun [typeBool, typeBool, typeBool]) ]
NConcat -> [ EqConstOneOf u1 [ typeFun [typeList, typeList, typeList]
, typeFun [typeList, typeNull, typeList]
, typeFun [typeNull, typeList, typeList]
] ]
NUpdate -> [ EqConstOneOf u1 [ typeFun [typeSet, typeSet, typeSet]
, typeFun [typeSet, typeNull, typeSet]
, typeFun [typeNull, typeSet, typeSet]
] ]
NPlus -> [ EqConstOneOf u1 [ typeFun [typeInt, typeInt, typeInt]
, typeFun [typeFloat, typeFloat, typeFloat]
, typeFun [typeInt, typeFloat, typeFloat]
, typeFun [typeFloat, typeInt, typeFloat]
, typeFun [typeString, typeString, typeString]
, typeFun [typePath, typePath, typePath]
, typeFun [typeString, typeString, typePath]
] ]
NMinus -> arithmetic
NMult -> arithmetic
NDiv -> arithmetic
where
inequality =
[ EqConstOneOf u1 [ typeFun [typeInt, typeInt, typeBool]
, typeFun [typeFloat, typeFloat, typeBool]
, typeFun [typeInt, typeFloat, typeBool]
, typeFun [typeFloat, typeInt, typeBool]
] ]
arithmetic =
[ EqConstOneOf u1 [ typeFun [typeInt, typeInt, typeInt]
, typeFun [typeFloat, typeFloat, typeFloat]
, typeFun [typeInt, typeFloat, typeFloat]
, typeFun [typeFloat, typeInt, typeFloat]
] ]
instance MonadThunk Judgment Judgment Infer where
thunk = id
force v f = f v
value = id
instance MonadEval Judgment Infer where
freeVariable var = do
tv <- fresh
return $ Judgment (As.singleton var tv) [] tv
evaledSym _ = pure
evalCurPos =
return $ Judgment As.empty [] $ TSet $ M.fromList
[ ("file", typePath)
, ("line", typeInt)
, ("col", typeInt) ]
evalConstant c = return $ Judgment As.empty [] (go c)
where
go = \case
NInt _ -> typeInt
NFloat _ -> typeFloat
NBool _ -> typeBool
NNull -> typeNull
NUri _ -> typeUri
evalString = const $ return $ Judgment As.empty [] typeString
evalLiteralPath = const $ return $ Judgment As.empty [] typePath
evalEnvPath = const $ return $ Judgment As.empty [] typePath
evalUnary op (Judgment as1 cs1 t1) = do
tv <- fresh
return $ Judgment as1 (cs1 ++ unops (t1 `TArr` tv) op) tv
evalBinary op (Judgment as1 cs1 t1) e2 = do
Judgment as2 cs2 t2 <- e2
tv <- fresh
return $ Judgment
(as1 `As.merge` as2)
(cs1 ++ cs2 ++ binops (t1 `TArr` (t2 `TArr` tv)) op)
tv
evalWith _scope _body = undefined-- pushWeakScope undefined body
evalIf (Judgment as1 cs1 t1) t f = do
Judgment as2 cs2 t2 <- t
Judgment as3 cs3 t3 <- f
return $ Judgment
(as1 `As.merge` as2 `As.merge` as3)
(cs1 ++ cs2 ++ cs3 ++ [EqConst t1 typeBool, EqConst t2 t3])
t2
evalAssert (Judgment as1 cs1 t1) body = do
Judgment as2 cs2 t2 <- body
return $ Judgment
(as1 `As.merge` as2)
(cs1 ++ cs2 ++ [EqConst t1 typeBool])
t2
evalApp (Judgment as1 cs1 t1) e2 = do
Judgment as2 cs2 t2 <- e2
tv <- fresh
return $ Judgment
(as1 `As.merge` as2)
(cs1 ++ cs2 ++ [EqConst t1 (t2 `TArr` tv)])
tv
evalAbs (Param x) e = do
tv@(TVar a) <- fresh
Judgment as cs t <-
extendMSet a (e (pure (Judgment (As.singleton x tv) [] tv)))
return $ Judgment
(as `As.remove` x)
(cs ++ [EqConst t' tv | t' <- As.lookup x as])
(tv `TArr` t)
evalAbs (ParamSet _x _variadic _mname) _e = undefined
evalError = throwError . EvaluationError
data Judgment = Judgment
{ assumptions :: As.Assumption
, typeConstraints :: [Constraint]
, inferredType :: Type
}
deriving Show
instance FromValue (Text, DList Text) Infer Judgment where
fromValueMay _ = return Nothing
fromValue _ = error "Unused"
instance FromValue (AttrSet Judgment, AttrSet SourcePos) Infer Judgment where
-- jww (2018-04-30): How can we do this? TSet doesn't record enough information
fromValueMay (Judgment _ _ (TSet xs)) =
pure $ Just (M.mapWithKey (\k v -> Judgment (As.singleton k v) [] v) xs, M.empty)
fromValueMay _ = pure Nothing
fromValue = fromValueMay >=> \case
Just v -> pure v
Nothing -> pure (M.empty, M.empty)
instance ToValue (AttrSet Judgment, AttrSet SourcePos) Infer Judgment where
toValue (xs, _) = pure $ Judgment
(foldr (As.merge . assumptions) As.empty xs)
(concatMap typeConstraints xs)
(TSet (M.map inferredType xs))
instance ToValue [Judgment] Infer Judgment where
toValue xs = pure $ Judgment
(foldr (As.merge . assumptions) As.empty xs)
(concatMap typeConstraints xs)
(TList (map inferredType xs))
instance ToValue Bool Infer Judgment where
toValue _ = pure $ Judgment As.empty [] typeBool
infer :: NExpr -> Infer Judgment
infer = cata Eval.eval
inferTop :: Env -> [(Text, NExpr)] -> Either TypeError Env
inferTop env [] = Right env
inferTop env ((name, ex):xs) = case inferExpr env ex of
Left err -> Left err
Right ty -> inferTop (extend env (name, ty)) xs
normalize :: Scheme -> Scheme
normalize (Forall _ body) = Forall (map snd ord) (normtype body)
where
ord = zip (nub $ fv body) (map TV letters)
fv (TVar a) = [a]
fv (TArr a b) = fv a ++ fv b
fv (TCon _) = []
fv (TSet a) = concatMap fv (M.elems a)
fv (TSubSet a) = concatMap fv (M.elems a)
fv (TList a) = concatMap fv a
normtype (TArr a b) = TArr (normtype a) (normtype b)
normtype (TCon a) = TCon a
normtype (TSet a) = TSet (M.map normtype a)
normtype (TSubSet a) = TSubSet (M.map normtype a)
normtype (TList a) = TList (map normtype a)
normtype (TVar a) =
case Prelude.lookup a ord of
Just x -> TVar x
Nothing -> error "type variable not in signature"
-------------------------------------------------------------------------------
-- Constraint Solver
-------------------------------------------------------------------------------
-- | The empty substitution
emptySubst :: Subst
emptySubst = mempty
-- | Compose substitutions
compose :: Subst -> Subst -> Subst
(Subst s1) `compose` (Subst s2) = Subst $ Map.map (apply (Subst s1)) s2 `Map.union` s1
unifyMany :: [Type] -> [Type] -> Infer Subst
unifyMany [] [] = return emptySubst
unifyMany (t1 : ts1) (t2 : ts2) =
do su1 <- unifies t1 t2
su2 <- unifyMany (apply su1 ts1) (apply su1 ts2)
return (su2 `compose` su1)
unifyMany t1 t2 = throwError $ UnificationMismatch t1 t2
unifies :: Type -> Type -> Infer Subst
unifies t1 t2 | t1 == t2 = return emptySubst
unifies (TVar v) t = v `bind` t
unifies t (TVar v) = v `bind` t
unifies (TList _) (TList _) = return emptySubst
unifies (TSet b) (TSubSet s)
| M.keys b `intersect` M.keys s == M.keys s = return emptySubst
unifies (TSubSet s) (TSet b)
| M.keys b `intersect` M.keys s == M.keys s = return emptySubst
unifies (TSet s) (TSet b)
| null (M.keys b \\ M.keys s) = return emptySubst
unifies (TArr t1 t2) (TArr t3 t4) = unifyMany [t1, t2] [t3, t4]
unifies t1 t2 = throwError $ UnificationFail t1 t2
bind :: TVar -> Type -> Infer Subst
bind a t | t == TVar a = return emptySubst
| occursCheck a t = throwError $ InfiniteType a t
| otherwise = return (Subst $ Map.singleton a t)
occursCheck :: FreeTypeVars a => TVar -> a -> Bool
occursCheck a t = a `Set.member` ftv t
nextSolvable :: [Constraint] -> (Constraint, [Constraint])
nextSolvable xs = fromJust (find solvable (chooseOne xs))
where
chooseOne xs = [(x, ys) | x <- xs, let ys = delete x xs]
solvable (EqConst{}, _) = True
solvable (EqConstOneOf{}, _) = True
solvable (ExpInstConst{}, _) = True
solvable (ImpInstConst _t1 ms t2, cs) =
Set.null ((ftv t2 `Set.difference` ms) `Set.intersection` atv cs)
solve :: [Constraint] -> Infer Subst
solve [] = return emptySubst
solve cs = solve' (nextSolvable cs)
solve' :: (Constraint, [Constraint]) -> Infer Subst
solve' (EqConst t1 t2, cs) = do
su1 <- unifies t1 t2
su2 <- solve (apply su1 cs)
return (su2 `compose` su1)
solve' (EqConstOneOf t1 t2, cs) = do
-- jww (2018-04-30): Instead of picking the first that matches, collect all
-- that match into a 'TVariant [Type]' type, so that we can report that a
-- function like 'x: y: x + y' has type: forall a b. a one of integer,
-- float, string, b the same as a, or compatible, result is determined by
-- the finally decided type of the function (in this case, one of int,
-- float, string or path, based on the types of a and b).
su1 <- asum (map (unifies t1) t2)
su2 <- solve (apply su1 cs)
return (su2 `compose` su1)
solve' (ImpInstConst t1 ms t2, cs) =
solve (ExpInstConst t1 (generalize ms t2) : cs)
solve' (ExpInstConst t s, cs) = do
s' <- instantiate s
solve (EqConst t s' : cs)