hnix/src/Nix/Type/Infer.hs

710 lines
22 KiB
Haskell

{-# LANGUAGE AllowAmbiguousTypes #-}
{-# LANGUAGE ConstraintKinds #-}
{-# LANGUAGE ExistentialQuantification #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE GeneralizedNewtypeDeriving #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE StandaloneDeriving #-}
{-# LANGUAGE TupleSections #-}
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeSynonymInstances #-}
{-# OPTIONS_GHC -Wno-name-shadowing #-}
module Nix.Type.Infer
( Constraint(..)
, TypeError(..)
, InferError(..)
, Subst(..)
, inferTop
)
where
import Control.Applicative
import Control.Arrow
import Control.Monad.Catch
import Control.Monad.Except
import Control.Monad.Fail
import Control.Monad.Logic
import Control.Monad.Reader
import Control.Monad.Ref
import Control.Monad.ST
import Control.Monad.State.Strict
import Data.Fix ( cata )
import Data.Foldable
import qualified Data.HashMap.Lazy as M
import Data.List ( delete
, find
, nub
, intersect
, (\\)
)
import Data.Map ( Map )
import qualified Data.Map as Map
import Data.Maybe ( fromJust )
import qualified Data.Set as Set
import Data.Text ( Text )
import Nix.Atoms
import Nix.Convert
import Nix.Eval ( MonadEval(..) )
import qualified Nix.Eval as Eval
import Nix.Expr.Types
import Nix.Expr.Types.Annotated
import Nix.Fresh
import Nix.String
import Nix.Scope
-- import Nix.Thunk
-- import Nix.Thunk.Basic
import qualified Nix.Type.Assumption as As
import Nix.Type.Env
import qualified Nix.Type.Env as Env
import Nix.Type.Type
import Nix.Utils
import Nix.Value.Monad
import Nix.Var
-------------------------------------------------------------------------------
-- Classes
-------------------------------------------------------------------------------
-- | Inference monad
newtype InferT s m a = InferT
{ getInfer ::
ReaderT (Set.Set TVar, Scopes (InferT s m) (Judgment s))
(StateT InferState (ExceptT InferError m)) a
}
deriving
( Functor
, Applicative
, Alternative
, Monad
, MonadPlus
, MonadFix
, MonadReader (Set.Set TVar, Scopes (InferT s m) (Judgment s))
, MonadFail
, MonadState InferState
, MonadError InferError
)
instance MonadTrans (InferT s) where
lift = InferT . lift . lift . lift
-- instance MonadThunkId m => MonadThunkId (InferT s m) where
-- type ThunkId (InferT s m) = ThunkId m
-- | Inference state
newtype InferState = InferState { count :: Int }
-- | Initial inference state
initInfer :: InferState
initInfer = InferState { count = 0 }
data Constraint
= EqConst Type Type
| ExpInstConst Type Scheme
| ImpInstConst Type (Set.Set TVar) Type
deriving (Show, Eq, Ord)
newtype Subst = Subst (Map TVar Type)
deriving (Eq, Ord, Show, Semigroup, Monoid)
class Substitutable a where
apply :: Subst -> a -> a
instance Substitutable TVar where
apply (Subst s) a = tv
where
t = TVar a
(TVar tv) = Map.findWithDefault t a s
instance Substitutable Type where
apply _ ( TCon a ) = TCon a
apply s ( TSet b a ) = TSet b (M.map (apply s) a)
apply s ( TList a ) = TList (map (apply s) a)
apply (Subst s) t@(TVar a ) = Map.findWithDefault t a s
apply s ( t1 :~> t2) = apply s t1 :~> apply s t2
apply s ( TMany ts ) = TMany (map (apply s) ts)
instance Substitutable Scheme where
apply (Subst s) (Forall as t) = Forall as $ apply s' t
where s' = Subst $ foldr Map.delete s as
instance Substitutable Constraint where
apply s (EqConst t1 t2) = EqConst (apply s t1) (apply s t2)
apply s (ExpInstConst t sc) = ExpInstConst (apply s t) (apply s sc)
apply s (ImpInstConst t1 ms t2) =
ImpInstConst (apply s t1) (apply s ms) (apply s t2)
instance Substitutable a => Substitutable [a] where
apply = map . apply
instance (Ord a, Substitutable a) => Substitutable (Set.Set a) where
apply = Set.map . apply
class FreeTypeVars a where
ftv :: a -> Set.Set TVar
instance FreeTypeVars Type where
ftv TCon{} = Set.empty
ftv (TVar a ) = Set.singleton a
ftv (TSet _ a ) = Set.unions (map ftv (M.elems a))
ftv (TList a ) = Set.unions (map ftv a)
ftv (t1 :~> t2) = ftv t1 `Set.union` ftv t2
ftv (TMany ts ) = Set.unions (map ftv ts)
instance FreeTypeVars TVar where
ftv = Set.singleton
instance FreeTypeVars Scheme where
ftv (Forall as t) = ftv t `Set.difference` Set.fromList as
instance FreeTypeVars a => FreeTypeVars [a] where
ftv = foldr (Set.union . ftv) Set.empty
instance (Ord a, FreeTypeVars a) => FreeTypeVars (Set.Set a) where
ftv = foldr (Set.union . ftv) Set.empty
class ActiveTypeVars a where
atv :: a -> Set.Set TVar
instance ActiveTypeVars Constraint where
atv (EqConst t1 t2) = ftv t1 `Set.union` ftv t2
atv (ImpInstConst t1 ms t2) =
ftv t1 `Set.union` (ftv ms `Set.intersection` ftv t2)
atv (ExpInstConst t s) = ftv t `Set.union` ftv s
instance ActiveTypeVars a => ActiveTypeVars [a] where
atv = foldr (Set.union . atv) Set.empty
data TypeError
= UnificationFail Type Type
| InfiniteType TVar Type
| UnboundVariables [Text]
| Ambigious [Constraint]
| UnificationMismatch [Type] [Type]
deriving (Eq, Show)
data InferError
= TypeInferenceErrors [TypeError]
| TypeInferenceAborted
| forall s. Exception s => EvaluationError s
typeError :: MonadError InferError m => TypeError -> m ()
typeError err = throwError $ TypeInferenceErrors [err]
deriving instance Show InferError
instance Exception InferError
instance Semigroup InferError where
x <> _ = x
instance Monoid InferError where
mempty = TypeInferenceAborted
mappend = (<>)
-------------------------------------------------------------------------------
-- Inference
-------------------------------------------------------------------------------
-- | Run the inference monad
runInfer' :: MonadInfer m => InferT s m a -> m (Either InferError a)
runInfer' =
runExceptT
. (`evalStateT` initInfer)
. (`runReaderT` (Set.empty, emptyScopes))
. getInfer
runInfer :: (forall s . InferT s (FreshIdT Int (ST s)) a) -> Either InferError a
runInfer m = runST $ do
i <- newVar (1 :: Int)
runFreshIdT i (runInfer' m)
inferType
:: forall s m . MonadInfer m => Env -> NExpr -> InferT s m [(Subst, Type)]
inferType env ex = do
Judgment as cs t <- infer ex
let unbounds =
Set.fromList (As.keys as) `Set.difference` Set.fromList (Env.keys env)
unless (Set.null unbounds) $ typeError $ UnboundVariables
(nub (Set.toList unbounds))
let cs' =
[ ExpInstConst t s
| (x, ss) <- Env.toList env
, s <- ss
, t <- As.lookup x as
]
inferState <- get
let eres = (`evalState` inferState) $ runSolver $ do
subst <- solve (cs ++ cs')
return (subst, subst `apply` t)
case eres of
Left errs -> throwError $ TypeInferenceErrors errs
Right xs -> pure xs
-- | Solve for the toplevel type of an expression in a given environment
inferExpr :: Env -> NExpr -> Either InferError [Scheme]
inferExpr env ex = case runInfer (inferType env ex) of
Left err -> Left err
Right xs -> Right $ map (\(subst, ty) -> closeOver (subst `apply` ty)) xs
-- | Canonicalize and return the polymorphic toplevel type.
closeOver :: Type -> Scheme
closeOver = normalizeScheme . generalize Set.empty
extendMSet :: Monad m => TVar -> InferT s m a -> InferT s m a
extendMSet x = InferT . local (first (Set.insert x)) . getInfer
letters :: [String]
letters = [1 ..] >>= flip replicateM ['a' .. 'z']
freshTVar :: MonadState InferState m => m TVar
freshTVar = do
s <- get
put s { count = count s + 1 }
return $ TV (letters !! count s)
fresh :: MonadState InferState m => m Type
fresh = TVar <$> freshTVar
instantiate :: MonadState InferState m => Scheme -> m Type
instantiate (Forall as t) = do
as' <- mapM (const fresh) as
let s = Subst $ Map.fromList $ zip as as'
return $ apply s t
generalize :: Set.Set TVar -> Type -> Scheme
generalize free t = Forall as t
where as = Set.toList $ ftv t `Set.difference` free
unops :: Type -> NUnaryOp -> [Constraint]
unops u1 = \case
NNot -> [EqConst u1 (typeFun [typeBool, typeBool])]
NNeg ->
[ EqConst
u1
(TMany [typeFun [typeInt, typeInt], typeFun [typeFloat, typeFloat]])
]
binops :: Type -> NBinaryOp -> [Constraint]
binops u1 = \case
NApp -> [] -- this is handled separately
-- Equality tells you nothing about the types, because any two types are
-- allowed.
NEq -> []
NNEq -> []
NGt -> inequality
NGte -> inequality
NLt -> inequality
NLte -> inequality
NAnd -> [EqConst u1 (typeFun [typeBool, typeBool, typeBool])]
NOr -> [EqConst u1 (typeFun [typeBool, typeBool, typeBool])]
NImpl -> [EqConst u1 (typeFun [typeBool, typeBool, typeBool])]
NConcat ->
[ EqConst
u1
(TMany
[ typeFun [typeList, typeList, typeList]
, typeFun [typeList, typeNull, typeList]
, typeFun [typeNull, typeList, typeList]
]
)
]
NUpdate ->
[ EqConst
u1
(TMany
[ typeFun [typeSet, typeSet, typeSet]
, typeFun [typeSet, typeNull, typeSet]
, typeFun [typeNull, typeSet, typeSet]
]
)
]
NPlus ->
[ EqConst
u1
(TMany
[ typeFun [typeInt, typeInt, typeInt]
, typeFun [typeFloat, typeFloat, typeFloat]
, typeFun [typeInt, typeFloat, typeFloat]
, typeFun [typeFloat, typeInt, typeFloat]
, typeFun [typeString, typeString, typeString]
, typeFun [typePath, typePath, typePath]
, typeFun [typeString, typeString, typePath]
]
)
]
NMinus -> arithmetic
NMult -> arithmetic
NDiv -> arithmetic
where
inequality =
[ EqConst
u1
(TMany
[ typeFun [typeInt, typeInt, typeBool]
, typeFun [typeFloat, typeFloat, typeBool]
, typeFun [typeInt, typeFloat, typeBool]
, typeFun [typeFloat, typeInt, typeBool]
]
)
]
arithmetic =
[ EqConst
u1
(TMany
[ typeFun [typeInt, typeInt, typeInt]
, typeFun [typeFloat, typeFloat, typeFloat]
, typeFun [typeInt, typeFloat, typeFloat]
, typeFun [typeFloat, typeInt, typeFloat]
]
)
]
liftInfer :: Monad m => m a -> InferT s m a
liftInfer = InferT . lift . lift . lift
instance MonadRef m => MonadRef (InferT s m) where
type Ref (InferT s m) = Ref m
newRef x = liftInfer $ newRef x
readRef x = liftInfer $ readRef x
writeRef x y = liftInfer $ writeRef x y
instance MonadAtomicRef m => MonadAtomicRef (InferT s m) where
atomicModifyRef x f = liftInfer $ do
res <- snd . f <$> readRef x
_ <- modifyRef x (fst . f)
return res
-- newtype JThunkT s m = JThunk (NThunkF (InferT s m) (Judgment s))
instance Monad m => MonadThrow (InferT s m) where
throwM = throwError . EvaluationError
instance Monad m => MonadCatch (InferT s m) where
catch m h = catchError m $ \case
EvaluationError e -> maybe
(error $ "Exception was not an exception: " ++ show e)
h
(fromException (toException e))
err -> error $ "Unexpected error: " ++ show err
type MonadInfer m
= ({- MonadThunkId m,-}
MonadVar m, MonadFix m)
instance Monad m => MonadValue (Judgment s) (InferT s m) where
defer = id
demand = flip ($)
inform j f = f (pure j)
{-
instance MonadInfer m
=> MonadThunk (JThunkT s m) (InferT s m) (Judgment s) where
thunk = fmap JThunk . thunk
thunkId (JThunk x) = thunkId x
queryM (JThunk x) b f = queryM x b f
-- If we have a thunk loop, we just don't know the type.
force (JThunk t) f = catch (force t f)
$ \(_ :: ThunkLoop) ->
f =<< Judgment As.empty [] <$> fresh
-- If we have a thunk loop, we just don't know the type.
forceEff (JThunk t) f = catch (forceEff t f)
$ \(_ :: ThunkLoop) ->
f =<< Judgment As.empty [] <$> fresh
-}
instance MonadInfer m => MonadEval (Judgment s) (InferT s m) where
freeVariable var = do
tv <- fresh
return $ Judgment (As.singleton var tv) [] tv
synHole var = do
tv <- fresh
return $ Judgment (As.singleton var tv) [] tv
-- If we fail to look up an attribute, we just don't know the type.
attrMissing _ _ = Judgment As.empty [] <$> fresh
evaledSym _ = pure
evalCurPos = return $ Judgment As.empty [] $ TSet False $ M.fromList
[("file", typePath), ("line", typeInt), ("col", typeInt)]
evalConstant c = return $ Judgment As.empty [] (go c)
where
go = \case
NInt _ -> typeInt
NFloat _ -> typeFloat
NBool _ -> typeBool
NNull -> typeNull
evalString = const $ return $ Judgment As.empty [] typeString
evalLiteralPath = const $ return $ Judgment As.empty [] typePath
evalEnvPath = const $ return $ Judgment As.empty [] typePath
evalUnary op (Judgment as1 cs1 t1) = do
tv <- fresh
return $ Judgment as1 (cs1 ++ unops (t1 :~> tv) op) tv
evalBinary op (Judgment as1 cs1 t1) e2 = do
Judgment as2 cs2 t2 <- e2
tv <- fresh
return $ Judgment (as1 `As.merge` as2)
(cs1 ++ cs2 ++ binops (t1 :~> t2 :~> tv) op)
tv
evalWith = Eval.evalWithAttrSet
evalIf (Judgment as1 cs1 t1) t f = do
Judgment as2 cs2 t2 <- t
Judgment as3 cs3 t3 <- f
return $ Judgment
(as1 `As.merge` as2 `As.merge` as3)
(cs1 ++ cs2 ++ cs3 ++ [EqConst t1 typeBool, EqConst t2 t3])
t2
evalAssert (Judgment as1 cs1 t1) body = do
Judgment as2 cs2 t2 <- body
return
$ Judgment (as1 `As.merge` as2) (cs1 ++ cs2 ++ [EqConst t1 typeBool]) t2
evalApp (Judgment as1 cs1 t1) e2 = do
Judgment as2 cs2 t2 <- e2
tv <- fresh
return $ Judgment (as1 `As.merge` as2)
(cs1 ++ cs2 ++ [EqConst t1 (t2 :~> tv)])
tv
evalAbs (Param x) k = do
a <- freshTVar
let tv = TVar a
((), Judgment as cs t) <- extendMSet
a
(k (pure (Judgment (As.singleton x tv) [] tv)) (\_ b -> ((), ) <$> b))
return $ Judgment (as `As.remove` x)
(cs ++ [ EqConst t' tv | t' <- As.lookup x as ])
(tv :~> t)
evalAbs (ParamSet ps variadic _mname) k = do
js <- fmap concat $ forM ps $ \(name, _) -> do
tv <- fresh
pure [(name, tv)]
let (env, tys) =
(\f -> foldl' f (As.empty, M.empty) js) $ \(as1, t1) (k, t) ->
(as1 `As.merge` As.singleton k t, M.insert k t t1)
arg = pure $ Judgment env [] (TSet True tys)
call = k arg $ \args b -> (args, ) <$> b
names = map fst js
(args, Judgment as cs t) <- foldr (\(_, TVar a) -> extendMSet a) call js
ty <- TSet variadic <$> traverse (inferredType <$>) args
return $ Judgment
(foldl' As.remove as names)
(cs ++ [ EqConst t' (tys M.! x) | x <- names, t' <- As.lookup x as ])
(ty :~> t)
evalError = throwError . EvaluationError
data Judgment s = Judgment
{ assumptions :: As.Assumption
, typeConstraints :: [Constraint]
, inferredType :: Type
}
deriving Show
instance Monad m => FromValue NixString (InferT s m) (Judgment s) where
fromValueMay _ = return Nothing
fromValue _ = error "Unused"
instance MonadInfer m
=> FromValue (AttrSet (Judgment s), AttrSet SourcePos)
(InferT s m) (Judgment s) where
fromValueMay (Judgment _ _ (TSet _ xs)) = do
let sing _ = Judgment As.empty []
pure $ Just (M.mapWithKey sing xs, M.empty)
fromValueMay _ = pure Nothing
fromValue = fromValueMay >=> \case
Just v -> pure v
Nothing -> pure (M.empty, M.empty)
instance MonadInfer m
=> ToValue (AttrSet (Judgment s), AttrSet SourcePos)
(InferT s m) (Judgment s) where
toValue (xs, _) =
Judgment
<$> foldrM go As.empty xs
<*> (concat <$> traverse (`demand` (pure . typeConstraints)) xs)
<*> (TSet True <$> traverse (`demand` (pure . inferredType)) xs)
where go x rest = demand x $ \x' -> pure $ As.merge (assumptions x') rest
instance MonadInfer m => ToValue [Judgment s] (InferT s m) (Judgment s) where
toValue xs =
Judgment
<$> foldrM go As.empty xs
<*> (concat <$> traverse (`demand` (pure . typeConstraints)) xs)
<*> (TList <$> traverse (`demand` (pure . inferredType)) xs)
where go x rest = demand x $ \x' -> pure $ As.merge (assumptions x') rest
instance MonadInfer m => ToValue Bool (InferT s m) (Judgment s) where
toValue _ = pure $ Judgment As.empty [] typeBool
infer :: MonadInfer m => NExpr -> InferT s m (Judgment s)
infer = cata Eval.eval
inferTop :: Env -> [(Text, NExpr)] -> Either InferError Env
inferTop env [] = Right env
inferTop env ((name, ex) : xs) = case inferExpr env ex of
Left err -> Left err
Right ty -> inferTop (extend env (name, ty)) xs
normalizeScheme :: Scheme -> Scheme
normalizeScheme (Forall _ body) = Forall (map snd ord) (normtype body)
where
ord = zip (nub $ fv body) (map TV letters)
fv (TVar a ) = [a]
fv (a :~> b ) = fv a ++ fv b
fv (TCon _ ) = []
fv (TSet _ a) = concatMap fv (M.elems a)
fv (TList a ) = concatMap fv a
fv (TMany ts) = concatMap fv ts
normtype (a :~> b ) = normtype a :~> normtype b
normtype (TCon a ) = TCon a
normtype (TSet b a) = TSet b (M.map normtype a)
normtype (TList a ) = TList (map normtype a)
normtype (TMany ts) = TMany (map normtype ts)
normtype (TVar a ) = case Prelude.lookup a ord of
Just x -> TVar x
Nothing -> error "type variable not in signature"
-------------------------------------------------------------------------------
-- Constraint Solver
-------------------------------------------------------------------------------
newtype Solver m a = Solver (LogicT (StateT [TypeError] m) a)
deriving (Functor, Applicative, Alternative, Monad, MonadPlus,
MonadLogic, MonadState [TypeError])
instance MonadTrans Solver where
lift = Solver . lift . lift
instance Monad m => MonadError TypeError (Solver m) where
throwError err = Solver $ lift (modify (err :)) >> mzero
catchError _ _ = error "This is never used"
runSolver :: Monad m => Solver m a -> m (Either [TypeError] [a])
runSolver (Solver s) = do
res <- runStateT (observeAllT s) []
pure $ case res of
(x : xs, _ ) -> Right (x : xs)
(_ , es) -> Left (nub es)
-- | The empty substitution
emptySubst :: Subst
emptySubst = mempty
-- | Compose substitutions
compose :: Subst -> Subst -> Subst
Subst s1 `compose` Subst s2 =
Subst $ Map.map (apply (Subst s1)) s2 `Map.union` s1
unifyMany :: Monad m => [Type] -> [Type] -> Solver m Subst
unifyMany [] [] = return emptySubst
unifyMany (t1 : ts1) (t2 : ts2) = do
su1 <- unifies t1 t2
su2 <- unifyMany (apply su1 ts1) (apply su1 ts2)
return (su2 `compose` su1)
unifyMany t1 t2 = throwError $ UnificationMismatch t1 t2
allSameType :: [Type] -> Bool
allSameType [] = True
allSameType [_ ] = True
allSameType (x : y : ys) = x == y && allSameType (y : ys)
unifies :: Monad m => Type -> Type -> Solver m Subst
unifies t1 t2 | t1 == t2 = return emptySubst
unifies (TVar v) t = v `bind` t
unifies t (TVar v) = v `bind` t
unifies (TList xs) (TList ys)
| allSameType xs && allSameType ys = case (xs, ys) of
(x : _, y : _) -> unifies x y
_ -> return emptySubst
| length xs == length ys = unifyMany xs ys
-- We assume that lists of different lengths containing various types cannot
-- be unified.
unifies t1@(TList _ ) t2@(TList _ ) = throwError $ UnificationFail t1 t2
unifies ( TSet True _) ( TSet True _) = return emptySubst
unifies (TSet False b) (TSet True s)
| M.keys b `intersect` M.keys s == M.keys s = return emptySubst
unifies (TSet True s) (TSet False b)
| M.keys b `intersect` M.keys s == M.keys b = return emptySubst
unifies (TSet False s) (TSet False b) | null (M.keys b \\ M.keys s) =
return emptySubst
unifies (t1 :~> t2) (t3 :~> t4) = unifyMany [t1, t2] [t3, t4]
unifies (TMany t1s) t2 = considering t1s >>- unifies ?? t2
unifies t1 (TMany t2s) = considering t2s >>- unifies t1
unifies t1 t2 = throwError $ UnificationFail t1 t2
bind :: Monad m => TVar -> Type -> Solver m Subst
bind a t | t == TVar a = return emptySubst
| occursCheck a t = throwError $ InfiniteType a t
| otherwise = return (Subst $ Map.singleton a t)
occursCheck :: FreeTypeVars a => TVar -> a -> Bool
occursCheck a t = a `Set.member` ftv t
nextSolvable :: [Constraint] -> (Constraint, [Constraint])
nextSolvable xs = fromJust (find solvable (chooseOne xs))
where
chooseOne xs = [ (x, ys) | x <- xs, let ys = delete x xs ]
solvable (EqConst{} , _) = True
solvable (ExpInstConst{}, _) = True
solvable (ImpInstConst _t1 ms t2, cs) =
Set.null ((ftv t2 `Set.difference` ms) `Set.intersection` atv cs)
considering :: [a] -> Solver m a
considering xs = Solver $ LogicT $ \c n -> foldr c n xs
solve :: MonadState InferState m => [Constraint] -> Solver m Subst
solve [] = return emptySubst
solve cs = solve' (nextSolvable cs)
where
solve' (EqConst t1 t2, cs) = unifies t1 t2
>>- \su1 -> solve (apply su1 cs) >>- \su2 -> return (su2 `compose` su1)
solve' (ImpInstConst t1 ms t2, cs) =
solve (ExpInstConst t1 (generalize ms t2) : cs)
solve' (ExpInstConst t s, cs) = do
s' <- lift $ instantiate s
solve (EqConst t s' : cs)
instance Monad m => Scoped (Judgment s) (InferT s m) where
currentScopes = currentScopesReader
clearScopes = clearScopesReader @(InferT s m) @(Judgment s)
pushScopes = pushScopesReader
lookupVar = lookupVarReader