In this article I show how to build a simple Hindley-Milner type-inference engine in Haskell. I used this a while back in a MiniML frontend/interpreter. Due to the size of the full code, I will only show the type-inference engine and not the rest of the interpreter. I’ll show some simple uses and then in future articles I will show how this can be used more practically in a real compiler. Or potentially expand the type-checker with new concepts.

For this code I used three custom monads that I made which are actually quite useful in general. They will each go in their separate module. Since this means that this article is not one literate haskell file, I will not use literate style, except for the code belonging to the actual type checker module. Just copy paste the other three files in appropriately named .hs files. The first monad is very simple and was actually written by Cale Gibbard. It takes as input an infinite list of data (for instance variables), and then whenever asked it will give a new fresh symbol. The license of all the code is simple permissive license (I explicitly mention this because the first monad is Cale’s work). The original code comes from MonadSupply. I extended it with a few more monad instances. Of interest is the function ‘makeSupply’ which was also added so that it’s easy to generate an infinite list of unique symbols based on a list of prefixes and a list of characters that can be used to extend the symbols to make new symbols.

{-# OPTIONS_GHC -fglasgow-exts -fallow-undecidable-instances #-}

module MonadSupply where

import Control.Monad

import Control.Monad.State

import Control.Monad.Reader

newtype SupplyT s m a = SupplyT { unSupplyT :: StateT [s] m a}

deriving (Functor, Monad, MonadTrans, MonadIO)

newtype Supply s a = Supply (SupplyT s Maybe a)

deriving (Functor, Monad, MonadSupply s)

class Monad m => MonadSupply s m | m -> s where

supply :: m s

instance Monad m => MonadSupply s (SupplyT s m) where

supply = SupplyT $ do

(x:xs) <- get

put xs

return x

instance MonadState s m => MonadState s (SupplyT x m) where

get = SupplyT . lift $ get

put v = SupplyT . lift $ put v

instance MonadReader r m => MonadReader r (SupplyT x m) where

ask = SupplyT . lift $ ask

local f a = SupplyT . local f . unSupplyT $ a

evalSupplyT (SupplyT s) supp = evalStateT s supp

evalSupply (Supply s) supp = evalSupplyT s supp

runSupplyT (SupplyT s) supp = runStateT s supp

runSupply (Supply s) supp = runSupplyT s supp

--------------------------------------------------------------------------------

--- makeSupply

--- Makes an infinite list

--------------------------------------------------------------------------------

-- makeSupply :: [[a]] -> [[a]] -> [[a]]

-- makeSupply inits tails =

-- let vars = inits ++ (map concat . sequence $ [vars, tails]) in vars

--------------------------------------------------------------------------------

--- Cleaner version supplied by TuringTest

--------------------------------------------------------------------------------

makeSupply :: [[a]] -> [[a]] -> [[a]]

makeSupply inits tails =

let vars = inits ++ (liftM2 (++) vars tails) in vars

The second monad is an environment monad which I made. It’s basically a fancy wrapper around a state monad that stores key-value pairs.

{-# OPTIONS_GHC -fglasgow-exts -fallow-undecidable-instances #-}

module MonadEnv where

import Control.Monad

import Control.Monad.State

import Control.Monad.Reader

import qualified Data.Map as M

type Memory k v = M.Map k v

newtype EnvT k v m a = EnvT { unEnvT :: StateT (Memory k v) m a }

deriving (Functor, Monad, MonadTrans, MonadIO)

newtype Env k v a = Env (EnvT k v Maybe a)

deriving (Functor, Monad, MonadEnv k v)

class (Ord k, Monad m) => MonadEnv k v m | m -> k, m -> v where

getfor :: k -> m (Maybe v)

putfor :: k -> v -> m ()

instance (Ord k, Monad m) => MonadEnv k v (EnvT k v m) where

getfor k = EnvT $ do

s <- get

return $ M.lookup k s

putfor k v = EnvT $ do

modify (\s -> M.insert k v s)

instance MonadState s m => MonadState s (EnvT k v m) where

get = EnvT . lift $ get

put v = EnvT . lift $ put v

instance MonadReader r m => MonadReader r (EnvT k v m) where

ask = EnvT . lift $ ask

local f a = EnvT . local f . unEnvT $ a

evalEnvT (EnvT s) ctx = evalStateT s ctx

evalEnv (Env s) ctx = evalEnvT s ctx

The final monad which I used was actually to scope behaviour during type-checking. This is not required, perse, for the pure typechecker, as you could put it on top of the typechecker monad in your compiler monad. However I think that the three fit nicely together (and I admit for being too lazy to rewrite the code to have split monads). It is a scoping environment. I hope that this is not too much code for one blog article. I plan to release a few more monads I have made in separate articles, so this feels like a collosal article, but I’m rather intent on showing this type-checker .

{-# OPTIONS_GHC -fglasgow-exts -fallow-undecidable-instances #-}

module MonadScope where

import Control.Monad

import Control.Monad.State

import Control.Monad.Reader

import qualified Data.Map as M

type Context k v = M.Map k v

type Subst k v = Context k v-> Context k v

newtype ScopeT k v m a = ScopeT { unScopeT :: ReaderT (Context k v) m a }

deriving (Functor, Monad, MonadTrans, MonadIO)

newtype Scope k v a = Scope (ScopeT k v Maybe a)

deriving (Functor, Monad, MonadScope k v)

class (Ord k, Monad m) => MonadScope k v m | m -> k, m -> v where

find :: k -> m (Maybe v)

scope :: Subst k v -> m a -> m a

instance (Ord k, Monad m) => MonadScope k v (ScopeT k v m) where

find k = ScopeT $ do

s <- ask

return $ M.lookup k s

scope s act = ScopeT $ do

local s . unScopeT $ act

instance MonadState s m => MonadState s (ScopeT k v m) where

get = ScopeT . lift $ get

put v = ScopeT . lift $ put v

evalScopeT (ScopeT s) ctx = runReaderT s ctx

evalScope (Scope s) ctx = evalScopeT s ctx

Now that we have the above three modules behind the way, we can start looking at the typechecker. Obviously we start with the definition of the module and the import of these three module as well as a few other basic necessities.

> {-# LANGUAGE FlexibleContexts, GeneralizedNewtypeDeriving, TypeSynonymInstances, MultiParamTypeClasses #-}

>moduleTypeCheckerwhere

>

>importMonadSupply

>importMonadScope

>importMonadEnv

>importData.List(nub, (\\), intersperse)

>importqualifiedData.MapasM

>importControl.Monad.Trans

## Basic type definitions

Now that we have that in place, it is important to specify types. One should realize that a typechecker has two kinds of types inside: monotypes and polytypes. The difference is that any free variables in a mono-type may only be unified with 1 specific type, while a polytype has variables that are parametrically polymorphic. More specifically, if we take the function “map :: (a -> b) -> [a] -> [b]“, this is actually “map :: forall a b. (a -> b) -> [a] -> [b]“. Thus it has a polytype. All top-level definitions and let-bound variables in Haskell are poly-typed. If map were mono-typed, then the first use of map would unify the ‘a’ and ‘b’ to specific types and any other invocation would have to respect those types. So a polytype is a monotype that has been quantified over several variables. As for monotypes, there are two specific options. Either the type is purely a type variable (TyVar), or it is a type-constructor that has several parameters. For convenience sake, we also define a function type (TyFun). In essence it’s just a TyConst with as Const “->”, but it’s always nice to make special cases . For the actual type-variables, we will be using strings.

Besides the above basic ‘types’, we also introduce a few useful type-classes. Additionally, we define a few std type creationers (I know, bad word, but ‘type constructors’ would be ambiguous) to define some std types. Some of the functions are a bit simplistic and could be replaced by a simple call to the data-constructor. These exist because originally my types were a bit more complicated (I had a typeclass for types since I wanted to be able to add extra information to types, as such the type of types was actually ‘MonoType mt’ and then I would tie the knot by fixing the type declaration over whatever record contained the extra info, ignore and just be glad that the functions all have nicely similar names .

> --- Type variable

>typeVar = String

> --- Type constructor

>typeConst = String

> --- MonoTypes

>dataMonoType =

> TyVar Var

> | TyFun MonoType MonoType

> | TyConst Const [MonoType]

>deriving(Eq, Ord)

> --- PolyTypes

>dataPolyType =

> TyPoly [Var] MonoType

>deriving(Eq, Ord)

>

>instanceShow MonoTypewhere

> show (TyVar x) = x

> show (TyFun ta tb) = "(" ++ show ta ++ ") -> " ++ show tb

> show (TyConst "(,)" tis) = "(" ++ (concat . intersperse ", " . map show $ tis) ++ ")"

> show (TyConst n []) = n

> show (TyConst x tis) = x ++ " " ++ (unwords . map show $ tis)

>

>instanceShow PolyTypewhere

> show (TyPoly ids mt) = (ifnull idsthen""else"forall " ++ unwords ids ++ ".") ++ show mt

>

>classHasVars awhere

> freeVars :: a -> [Var]

> occurs :: Var -> a -> Bool

> occurs x t = x `elem` freeVars t

>

>instanceHasVars MonoTypewhere

> freeVars (TyVar a) = [a]

> freeVars (TyFun ta tb) = nub $ freeVars ta ++ freeVars tb

> freeVars (TyConst_ts) = nub . concatMap freeVars $ ts

> occurs x (TyVar y) = x == y

> occurs x (TyFun ta tb) = occurs x ta || occurs x tb

> occurs x (TyConst_ts) = or . map (occurs x) $ ts

>

>instanceHasVars PolyTypewhere

> freeVars (TyPoly vs t) = freeVars t \\ vs

>

> toPoly :: MonoType -> PolyType

> toPoly x = TyPoly (freeVars x) x

> fromPoly :: PolyType -> MonoType

> fromPoly (TyPoly fvs t) = t

>

> unitConstructor = "()"

> tupleConstructor = "(,)"

>

> mkTyVar :: Var -> MonoType

> mkTyVar x = TyVar x

> mkTyFun :: MonoType -> MonoType -> MonoType

> mkTyFun ta tb = TyFun ta tb

> mkTyConst :: Const -> [MonoType] -> MonoType

> mkTyConst c ts = TyConst c ts

>

> mkTupleType :: [MonoType] -> MonoType

> mkTupleType ts = mkTyConst tupleConstructor ts

> mkUnitType :: MonoType

> mkUnitType = mkTyConst unitConstructor []

> mkBoolType :: MonoType

> mkBoolType = mkTyConst "Bool" []

> mkIntType :: MonoType

> mkIntType = mkTyConst "Int" []

>

> mkFunType :: [MonoType] -> MonoType

> mkFunType (a:b:r) = mkTyFun a $ mkFunType (b:r)

> mkFunType [a] = a

> mkFunType [] = error "mkFunType: Trying to make a function with 0 args"

## Type Checker Monad

Now that we’ve defined some standard functions and smart constructors, we could play around with it, but I will leave that to the reader. Instead, let’s move on to the meat of this blog entry. for this we define a new monad that is the typer monad. The typer monad is actual a stack of monad transformers. As bottom monad, we define the supply monad transformer that will ensure we have an unlimited supply of fresh names that the typer monad can use. On top of this, we define an environment monad (EnvT) that maps typevariables to appropriate MonoType’s. Finally, because we want to have let-binding in whatever language we may use, we define a scoping monad (ScopeT) so that whenever the language uses (let x = … in …) we can have a polytype for ‘x’ such that we do not need to worry about the momonomorphism restriction. With this stack, we define a custom monad (TyperMonad) that has two basic operations besides the standard operations that the stack we made gives and besides the regular monad interface. These two actions allow us to unify two types and allow us to normalize a type based on what has been unified so far. For instance, if we unify (TyVar “a”) with (TyConst “Int” []), then whenever we normalize a type that has (TyVar “a”) in it, we want to be sure that it becomes the Int type. We will show simple examples at the bottom on how this is used.

> -- Ident is the type for identifiers in a program

>typeIdent = String

> -------------------------------------------------------------------------------

> --- TyperState:

> --- State of typer monad

> --- Keeps a map of tyvar -> MonoType substitutions

> --- Keeps a list of free tyvars

> --- Requires: tyvars used as keys do not appear in the values

> --- unify needs to ensure this at all times

> --------------------------------------------------------------------------------

>typeTyperContext = M.Map Ident PolyType

>typeTyperSubst = TyperContext -> TyperContext

>

>newtypeTyperT m a = TyperT {

> runTyperT :: ScopeT Ident PolyType (EnvT Var MonoType (SupplyT Var m)) a

> }deriving(Functor, Monad, MonadIO)

>

>instanceMonadTrans TyperTwhere

> lift = TyperT . lift . lift . lift

>

>instance(Monad m) => MonadSupply Var (TyperT m)where

> supply = TyperT . lift . lift $ supply

>

>instance(Monad m) => MonadScope Ident PolyType (TyperT m)where

> find = TyperT . find

> scope f act = TyperT . scope f . runTyperT $ act

>

>class(Monad m, MonadScope Ident PolyType m, MonadSupply Var m) => TyperMonad mwhere

> unify :: MonoType -> MonoType -> m ()

> normalize :: MonoType -> m MonoType

With the above interface for the TyperMonad given, it is time to actually define it. This is the real meat of the typer monad.

>

> problem :: (Monad m) => String -> String -> m ()

> problem phase message = fail $ phase ++ ": " ++ message

>

>instance(Monad m) => TyperMonad (TyperT m)where

> ------------------------------------------------------------------------------

> --- unify

> ------------------------------------------------------------------------------

> unify type1 type2 = TyperT $do

> type1' <- runTyperT $ normalize type1

> type2' <- runTyperT $ normalize type2

> unifier type1' type2'

>whereunifier (TyVar x) ty@(TyVar y) =

>ifx == y

>thenreturn ()

>elselift $ putfor x ty

> unifier (TyVar x) ty =

>ifx `occurs` ty

>thenproblem "TyperChecker" ("Trying to unify recursive type " ++ show x ++ " = " ++ show ty)

>elselift $ putfor x ty

> unifier tx (TyVar y) =

>ify `occurs` tx

>thenproblem "TyperChecker" ("Trying to unify recursive type " ++ show y ++ " = " ++ show tx)

>elselift $ putfor y tx

> unifier tx@(TyConst nx px) ty@(TyConst ny py) =

>ifnx == ny && length px == length py

> -- IMPORTANT TO use unify and not not unifier

> -- or the code will recurse for:

> -- foo x y = x y; bar x = foo bar bar

>thenrunTyperT . sequence_ $ zipWith unify px py

>elseproblem "TyperChecker" (show tx ++ " can not be inferred to " ++ show ty)

> unifier (TyFun tax tbx) (TyFun tay tby) =do

> runTyperT $ unify tax tay

> runTyperT $ unify tbx tby

> unifier tx ty =

> problem "TypeChecker" (show tx ++ " can not be inferred to " ++ show ty)

> ------------------------------------------------------------------------------

> --- normalize

> --- Normalizes a type by looking it up if it's a tyvar

> --- Applies a small optimization, it basically removes any extra unnecessary

> --- mappings. For instance:

> --- [a->b, b->Int]: normalize a

> --- [a->b, b->Int]: b' <- normalize b

> --- [a->Int, b->Int]: return b'

> ------- -----------------------------------------------------------------------

> normalize tx@(TyVar x) = TyperT $do

> tx' <- lift $ getfor x

>casetx'of

> Nothing -> return tx

> Just ty ->do

> ty' <- runTyperT $ normalize ty

>ifty == ty'

>thenreturn ty'

>elsedo

> lift $ putfor x ty'

> return ty'

> normalize tx@(TyConst n tl) = TyperT $do

> tl' <- runTyperT $ mapM normalize tl

> return $ TyConst n tl'

> normalize tx@(TyFun ta tb) = TyperT $do

> ta' <- runTyperT $ normalize ta

> tb' <- runTyperT $ normalize tb

> return $ TyFun ta' tb'

> ------------------------------------------------------------------------------

>

> evalTyperT :: (Monad m) => TyperContext -> TyperT m a -> m a

> evalTyperT ctx action =

>letvars = makeSupply (map return ['a'..'z']) ["'"]in

> evalSupplyT (evalEnvT (evalScopeT (runTyperT action) ctx) M.empty) vars

I will let you grok the code. If any bits are not clear, feel free to tell me so and I will expand upon them. Now we can easily do some experiments:

evalTyperT M.empty (do{unify (TyVar "a") mkIntType; normalize (mkFunType [TyVar "a", TyVar "a", TyVar "b"])})

=> (Int) -> (Int) -> b

You might notice that I included some extra machinery, namely the scoping and all that. In a future blog post I will show this is actually used to allow for typechecking a simple language (like MiniML). But for now, I will leave it at that as it has already grown to become quite a big blogpost.

If you have any questions or comments, please feel free to post them at the bottom and I will make the necessary changes if required to clarify or improve the presentation.