Rotating Header Image

January 16th, 2008:

SImple Type Inference in Haskell

In this article I show how to build a simple Hindley-Milner type-inference engine in Haskell. I used this a while back in a MiniML frontend/interpreter. Due to the size of the full code, I will only show the type-inference engine and not the rest of the interpreter. I’ll show some simple uses and then in future articles I will show how this can be used more practically in a real compiler. Or potentially expand the type-checker with new concepts.

For this code I used three custom monads that I made which are actually quite useful in general. They will each go in their separate module. Since this means that this article is not one literate haskell file, I will not use literate style, except for the code belonging to the actual type checker module. Just copy paste the other three files in appropriately named .hs files. The first monad is very simple and was actually written by Cale Gibbard. It takes as input an infinite list of data (for instance variables), and then whenever asked it will give a new fresh symbol. The license of all the code is simple permissive license (I explicitly mention this because the first monad is Cale’s work). The original code comes from MonadSupply. I extended it with a few more monad instances. Of interest is the function ‘makeSupply’ which was also added so that it’s easy to generate an infinite list of unique symbols based on a list of prefixes and a list of characters that can be used to extend the symbols to make new symbols.

{-# OPTIONS_GHC -fglasgow-exts -fallow-undecidable-instances #-}
module MonadSupply where
import Control.Monad
import Control.Monad.State
import Control.Monad.Reader

newtype SupplyT s m a = SupplyT { unSupplyT :: StateT [s] m a}
deriving (Functor, Monad, MonadTrans, MonadIO)

newtype Supply s a = Supply (SupplyT s Maybe a)
deriving (Functor, Monad, MonadSupply s)

class Monad m => MonadSupply s m | m -> s where
supply :: m s

instance Monad m => MonadSupply s (SupplyT s m) where
supply = SupplyT $ do
(x:xs) <- get
put xs
return x

instance MonadState s m => MonadState s (SupplyT x m) where
get = SupplyT . lift $ get
put v = SupplyT . lift $ put v

instance MonadReader r m => MonadReader r (SupplyT x m) where
ask = SupplyT . lift $ ask
local f a = SupplyT . local f . unSupplyT $ a

evalSupplyT (SupplyT s) supp = evalStateT s supp
evalSupply (Supply s) supp = evalSupplyT s supp

runSupplyT (SupplyT s) supp = runStateT s supp
runSupply (Supply s) supp = runSupplyT s supp

--- makeSupply
--- Makes an infinite list
-- makeSupply :: [[a]] -> [[a]] -> [[a]]
-- makeSupply inits tails =
-- let vars = inits ++ (map concat . sequence $ [vars, tails]) in vars
--- Cleaner version supplied by TuringTest
makeSupply :: [[a]] -> [[a]] -> [[a]]
makeSupply inits tails =
let vars = inits ++ (liftM2 (++) vars tails) in vars

The second monad is an environment monad which I made. It’s basically a fancy wrapper around a state monad that stores key-value pairs.

{-# OPTIONS_GHC -fglasgow-exts -fallow-undecidable-instances #-}
module MonadEnv where
import Control.Monad
import Control.Monad.State
import Control.Monad.Reader
import qualified Data.Map as M

type Memory k v = M.Map k v

newtype EnvT k v m a = EnvT { unEnvT :: StateT (Memory k v) m a }
deriving (Functor, Monad, MonadTrans, MonadIO)

newtype Env k v a = Env (EnvT k v Maybe a)
deriving (Functor, Monad, MonadEnv k v)

class (Ord k, Monad m) => MonadEnv k v m | m -> k, m -> v where
getfor :: k -> m (Maybe v)
putfor :: k -> v -> m ()

instance (Ord k, Monad m) => MonadEnv k v (EnvT k v m) where
getfor k = EnvT $ do
s <- get
return $ M.lookup k s
putfor k v = EnvT $ do
modify (\s -> M.insert k v s)

instance MonadState s m => MonadState s (EnvT k v m) where
get = EnvT . lift $ get
put v = EnvT . lift $ put v

instance MonadReader r m => MonadReader r (EnvT k v m) where
ask = EnvT . lift $ ask
local f a = EnvT . local f . unEnvT $ a

evalEnvT (EnvT s) ctx = evalStateT s ctx
evalEnv (Env s) ctx = evalEnvT s ctx

The final monad which I used was actually to scope behaviour during type-checking. This is not required, perse, for the pure typechecker, as you could put it on top of the typechecker monad in your compiler monad. However I think that the three fit nicely together (and I admit for being too lazy to rewrite the code to have split monads). It is a scoping environment. I hope that this is not too much code for one blog article. I plan to release a few more monads I have made in separate articles, so this feels like a collosal article, but I’m rather intent on showing this type-checker :) .

{-# OPTIONS_GHC -fglasgow-exts -fallow-undecidable-instances #-}
module MonadScope where
import Control.Monad
import Control.Monad.State
import Control.Monad.Reader
import qualified Data.Map as M

type Context k v = M.Map k v
type Subst k v = Context k v-> Context k v

newtype ScopeT k v m a = ScopeT { unScopeT :: ReaderT (Context k v) m a }
deriving (Functor, Monad, MonadTrans, MonadIO)

newtype Scope k v a = Scope (ScopeT k v Maybe a)
deriving (Functor, Monad, MonadScope k v)

class (Ord k, Monad m) => MonadScope k v m | m -> k, m -> v where
find :: k -> m (Maybe v)
scope :: Subst k v -> m a -> m a

instance (Ord k, Monad m) => MonadScope k v (ScopeT k v m) where
find k = ScopeT $ do
s <- ask
return $ M.lookup k s
scope s act = ScopeT $ do
local s . unScopeT $ act

instance MonadState s m => MonadState s (ScopeT k v m) where
get = ScopeT . lift $ get
put v = ScopeT . lift $ put v

evalScopeT (ScopeT s) ctx = runReaderT s ctx
evalScope (Scope s) ctx = evalScopeT s ctx

Now that we have the above three modules behind the way, we can start looking at the typechecker. Obviously we start with the definition of the module and the import of these three module as well as a few other basic necessities.

> {-# LANGUAGE FlexibleContexts, GeneralizedNewtypeDeriving, TypeSynonymInstances, MultiParamTypeClasses #-}
> module TypeChecker where
> import MonadSupply
> import MonadScope
> import MonadEnv
> import Data.List(nub, (\\), intersperse)
> import qualified Data.Map as M
> import Control.Monad.Trans

Basic type definitions

Now that we have that in place, it is important to specify types. One should realize that a typechecker has two kinds of types inside: monotypes and polytypes. The difference is that any free variables in a mono-type may only be unified with 1 specific type, while a polytype has variables that are parametrically polymorphic. More specifically, if we take the function “map :: (a -> b) -> [a] -> [b]“, this is actually “map :: forall a b. (a -> b) -> [a] -> [b]“. Thus it has a polytype. All top-level definitions and let-bound variables in Haskell are poly-typed. If map were mono-typed, then the first use of map would unify the ‘a’ and ‘b’ to specific types and any other invocation would have to respect those types. So a polytype is a monotype that has been quantified over several variables. As for monotypes, there are two specific options. Either the type is purely a type variable (TyVar), or it is a type-constructor that has several parameters. For convenience sake, we also define a function type (TyFun). In essence it’s just a TyConst with as Const “->”, but it’s always nice to make special cases ;) . For the actual type-variables, we will be using strings.

Besides the above basic ‘types’, we also introduce a few useful type-classes. Additionally, we define a few std type creationers (I know, bad word, but ‘type constructors’ would be ambiguous) to define some std types. Some of the functions are a bit simplistic and could be replaced by a simple call to the data-constructor. These exist because originally my types were a bit more complicated (I had a typeclass for types since I wanted to be able to add extra information to types, as such the type of types was actually ‘MonoType mt’ and then I would tie the knot by fixing the type declaration over whatever record contained the extra info, ignore and just be glad that the functions all have nicely similar names :) .

> --- Type variable
> type Var = String
> --- Type constructor
> type Const = String
> --- MonoTypes
> data MonoType =
> TyVar Var
> | TyFun MonoType MonoType
> | TyConst Const [MonoType]
> deriving (Eq, Ord)
> --- PolyTypes
> data PolyType =
> TyPoly [Var] MonoType
> deriving (Eq, Ord)
> instance Show MonoType where
> show (TyVar x) = x
> show (TyFun ta tb) = "(" ++ show ta ++ ") -> " ++ show tb
> show (TyConst "(,)" tis) = "(" ++ (concat . intersperse ", " . map show $ tis) ++ ")"
> show (TyConst n []) = n
> show (TyConst x tis) = x ++ " " ++ (unwords . map show $ tis)
> instance Show PolyType where
> show (TyPoly ids mt) = (if null ids then "" else "forall " ++ unwords ids ++ ".") ++ show mt
> class HasVars a where
> freeVars :: a -> [Var]
> occurs :: Var -> a -> Bool
> occurs x t = x `elem` freeVars t
> instance HasVars MonoType where
> freeVars (TyVar a) = [a]
> freeVars (TyFun ta tb) = nub $ freeVars ta ++ freeVars tb
> freeVars (TyConst _ ts) = nub . concatMap freeVars $ ts
> occurs x (TyVar y) = x == y
> occurs x (TyFun ta tb) = occurs x ta || occurs x tb
> occurs x (TyConst _ ts) = or . map (occurs x) $ ts
> instance HasVars PolyType where
> freeVars (TyPoly vs t) = freeVars t \\ vs
> toPoly :: MonoType -> PolyType
> toPoly x = TyPoly (freeVars x) x
> fromPoly :: PolyType -> MonoType
> fromPoly (TyPoly fvs t) = t
> unitConstructor = "()"
> tupleConstructor = "(,)"
> mkTyVar :: Var -> MonoType
> mkTyVar x = TyVar x
> mkTyFun :: MonoType -> MonoType -> MonoType
> mkTyFun ta tb = TyFun ta tb
> mkTyConst :: Const -> [MonoType] -> MonoType
> mkTyConst c ts = TyConst c ts
> mkTupleType :: [MonoType] -> MonoType
> mkTupleType ts = mkTyConst tupleConstructor ts
> mkUnitType :: MonoType
> mkUnitType = mkTyConst unitConstructor []
> mkBoolType :: MonoType
> mkBoolType = mkTyConst "Bool" []
> mkIntType :: MonoType
> mkIntType = mkTyConst "Int" []
> mkFunType :: [MonoType] -> MonoType
> mkFunType (a:b:r) = mkTyFun a $ mkFunType (b:r)
> mkFunType [a] = a
> mkFunType [] = error "mkFunType: Trying to make a function with 0 args"

Type Checker Monad

Now that we’ve defined some standard functions and smart constructors, we could play around with it, but I will leave that to the reader. Instead, let’s move on to the meat of this blog entry. for this we define a new monad that is the typer monad. The typer monad is actual a stack of monad transformers. As bottom monad, we define the supply monad transformer that will ensure we have an unlimited supply of fresh names that the typer monad can use. On top of this, we define an environment monad (EnvT) that maps typevariables to appropriate MonoType’s. Finally, because we want to have let-binding in whatever language we may use, we define a scoping monad (ScopeT) so that whenever the language uses (let x = … in …) we can have a polytype for ‘x’ such that we do not need to worry about the momonomorphism restriction. With this stack, we define a custom monad (TyperMonad) that has two basic operations besides the standard operations that the stack we made gives and besides the regular monad interface. These two actions allow us to unify two types and allow us to normalize a type based on what has been unified so far. For instance, if we unify (TyVar “a”) with (TyConst “Int” []), then whenever we normalize a type that has (TyVar “a”) in it, we want to be sure that it becomes the Int type. We will show simple examples at the bottom on how this is used.

> -- Ident is the type for identifiers in a program
> type Ident = String
> -------------------------------------------------------------------------------
> --- TyperState:
> --- State of typer monad
> --- Keeps a map of tyvar -> MonoType substitutions
> --- Keeps a list of free tyvars
> --- Requires: tyvars used as keys do not appear in the values
> --- unify needs to ensure this at all times
> --------------------------------------------------------------------------------
> type TyperContext = M.Map Ident PolyType
> type TyperSubst = TyperContext -> TyperContext
> newtype TyperT m a = TyperT {
> runTyperT :: ScopeT Ident PolyType (EnvT Var MonoType (SupplyT Var m)) a
> } deriving (Functor, Monad, MonadIO)
> instance MonadTrans TyperT where
> lift = TyperT . lift . lift . lift
> instance (Monad m) => MonadSupply Var (TyperT m) where
> supply = TyperT . lift . lift $ supply
> instance (Monad m) => MonadScope Ident PolyType (TyperT m) where
> find = TyperT . find
> scope f act = TyperT . scope f . runTyperT $ act
> class (Monad m, MonadScope Ident PolyType m, MonadSupply Var m) => TyperMonad m where
> unify :: MonoType -> MonoType -> m ()
> normalize :: MonoType -> m MonoType

With the above interface for the TyperMonad given, it is time to actually define it. This is the real meat of the typer monad.

> problem :: (Monad m) => String -> String -> m ()
> problem phase message = fail $ phase ++ ": " ++ message
> instance (Monad m) => TyperMonad (TyperT m) where
> ------------------------------------------------------------------------------
> --- unify
> ------------------------------------------------------------------------------
> unify type1 type2 = TyperT $ do
> type1' <- runTyperT $ normalize type1
> type2' <- runTyperT $ normalize type2
> unifier type1' type2'
> where unifier (TyVar x) ty@(TyVar y) =
> if x == y
> then return ()
> else lift $ putfor x ty
> unifier (TyVar x) ty =
> if x `occurs` ty
> then problem "TyperChecker" ("Trying to unify recursive type " ++ show x ++ " = " ++ show ty)
> else lift $ putfor x ty
> unifier tx (TyVar y) =
> if y `occurs` tx
> then problem "TyperChecker" ("Trying to unify recursive type " ++ show y ++ " = " ++ show tx)
> else lift $ putfor y tx
> unifier tx@(TyConst nx px) ty@(TyConst ny py) =
> if nx == ny && length px == length py
> -- IMPORTANT TO use unify and not not unifier
> -- or the code will recurse for:
> -- foo x y = x y; bar x = foo bar bar
> then runTyperT . sequence_ $ zipWith unify px py
> else problem "TyperChecker" (show tx ++ " can not be inferred to " ++ show ty)
> unifier (TyFun tax tbx) (TyFun tay tby) = do
> runTyperT $ unify tax tay
> runTyperT $ unify tbx tby
> unifier tx ty =
> problem "TypeChecker" (show tx ++ " can not be inferred to " ++ show ty)
> ------------------------------------------------------------------------------
> --- normalize
> --- Normalizes a type by looking it up if it's a tyvar
> --- Applies a small optimization, it basically removes any extra unnecessary
> --- mappings. For instance:
> --- [a->b, b->Int]: normalize a
> --- [a->b, b->Int]: b' <- normalize b
> --- [a->Int, b->Int]: return b'
> ------- -----------------------------------------------------------------------
> normalize tx@(TyVar x) = TyperT $ do
> tx' <- lift $ getfor x
> case tx' of
> Nothing -> return tx
> Just ty -> do
> ty' <- runTyperT $ normalize ty
> if ty == ty'
> then return ty'
> else do
> lift $ putfor x ty'
> return ty'
> normalize tx@(TyConst n tl) = TyperT $ do
> tl' <- runTyperT $ mapM normalize tl
> return $ TyConst n tl'
> normalize tx@(TyFun ta tb) = TyperT $ do
> ta' <- runTyperT $ normalize ta
> tb' <- runTyperT $ normalize tb
> return $ TyFun ta' tb'
> ------------------------------------------------------------------------------
> evalTyperT :: (Monad m) => TyperContext -> TyperT m a -> m a
> evalTyperT ctx action =
> let vars = makeSupply (map return ['a'..'z']) ["'"] in
> evalSupplyT (evalEnvT (evalScopeT (runTyperT action) ctx) M.empty) vars

I will let you grok the code. If any bits are not clear, feel free to tell me so and I will expand upon them. Now we can easily do some experiments:

evalTyperT M.empty (do{unify (TyVar "a") mkIntType; normalize (mkFunType [TyVar "a", TyVar "a", TyVar "b"])})
=> (Int) -> (Int) -> b

You might notice that I included some extra machinery, namely the scoping and all that. In a future blog post I will show this is actually used to allow for typechecking a simple language (like MiniML). But for now, I will leave it at that as it has already grown to become quite a big blogpost.

If you have any questions or comments, please feel free to post them at the bottom and I will make the necessary changes if required to clarify or improve the presentation.