Monad Transformers

In the previous sections you learned about some handy monads Option, IO, Reader, State and Except, and you now know how to make your function use one of these, but what you do not yet know is how to make your function use multiple monads at once.

For example, suppose you need a function that wants to access some Reader context and optionally throw an exception? This would require composition of two monads ReaderM and Except and this is what monad transformers are for.

A monad transformer is fundamentally a wrapper type. It is generally parameterized by another monadic type. You can then run actions from the inner monad, while adding your own customized behavior for combining actions in this new monad. The common transformers add T to the end of an existing monad name. You will find OptionT, ExceptT, ReaderT, StateT but there is no transformer for IO. So generally if you need IO it becomes the innermost wrapped monad.

In the following example we use ReaderT to provide some read only context to a function and this ReaderT transformer will wrap an Except monad. If all goes well the requiredArgument returns the value of a required argument and optionalSwitch returns true if the optional argument is present.

abbrev 
Arguments: Type
Arguments
:=
List: Type → Type
List
String: Type
String
def
indexOf?: {α : Type u_1} → [inst : BEq α] → List α → α → optParam Nat 0 → Option Nat
indexOf?
[
BEq: Type u_1 → Type u_1
BEq
α: Type u_1
α
] (
xs: List α
xs
:
List: Type u_1 → Type u_1
List
α: Type u_1
α
) (
s: α
s
:
α: Type u_1
α
) (
start: optParam Nat 0
start
:=
0: Nat
0
):
Option: Type → Type
Option
Nat: Type
Nat
:= match
xs: List α
xs
with | [] =>
none: {α : Type} → Option α
none
|
a: α
a
::
tail: List α
tail
=> if
a: α
a
==
s: α
s
then
some: {α : Type} → α → Option α
some
start: optParam Nat 0
start
else
indexOf?: {α : Type u_1} → [inst : BEq α] → List α → α → optParam Nat 0 → Option Nat
indexOf?
tail: List α
tail
s: α
s
(
start: optParam Nat 0
start
+
1: Nat
1
) def
requiredArgument: String → ReaderT Arguments (Except String) String
requiredArgument
(
name: String
name
:
String: Type
String
) :
ReaderT: Type → (Type → Type) → Type → Type
ReaderT
Arguments: Type
Arguments
(
Except: Type → Type → Type
Except
String: Type
String
)
String: Type
String
:= do let
args: Arguments
args
read: {ρ : outParam Type} → {m : Type → Type} → [self : MonadReader ρ m] → m ρ
read
let
value: String
value
:= match
indexOf?: {α : Type} → [inst : BEq α] → List α → α → optParam Nat 0 → Option Nat
indexOf?
args: Arguments
args
name: String
name
with |
some: {α : Type ?u.835} → α → Option α
some
i: Nat
i
=> if
i: Nat
i
+
1: Nat
1
<
args: Arguments
args
.
length: {α : Type} → List α → Nat
length
then
args: Arguments
args
[
i: Nat
i
+
1: Nat
1
]! else
"": String
""
|
none: {α : Type ?u.1106} → Option α
none
=>
"": String
""
if
value: String
value
==
"": String
""
then
throw: {ε : outParam Type} → {m : Type → Type} → [self : MonadExcept ε m] → {α : Type} → ε → m α
throw
throw s!"Command line argument {name} missing": ReaderT Arguments (Except String) String
throw s!"Command line argument {name} missing": ReaderT Arguments (Except String) String
s!
throw s!"Command line argument {name} missing": ReaderT Arguments (Except String) String
"Command line argument {
name: String
name
throw s!"Command line argument {name} missing": ReaderT Arguments (Except String) String
} missing"
return
value: String
value
def
optionalSwitch: String → ReaderT Arguments (Except String) Bool
optionalSwitch
(
name: String
name
:
String: Type
String
) :
ReaderT: Type → (Type → Type) → Type → Type
ReaderT
Arguments: Type
Arguments
(
Except: Type → Type → Type
Except
String: Type
String
)
Bool: Type
Bool
:= do let
args: Arguments
args
read: {ρ : outParam Type} → {m : Type → Type} → [self : MonadReader ρ m] → m ρ
read
return match (
indexOf?: {α : Type} → [inst : BEq α] → List α → α → optParam Nat 0 → Option Nat
indexOf?
args: Arguments
args
name: String
name
) with |
some: {α : Type ?u.1664} → α → Option α
some
_ =>
true: Bool
true
|
none: {α : Type ?u.1680} → Option α
none
=>
false: Bool
false
Except.ok "foo"
requiredArgument: String → ReaderT Arguments (Except String) String
requiredArgument
"--input": String
"--input"
|>.
run: {ρ : Type} → {m : Type → Type} → {α : Type} → ReaderT ρ m α → ρ → m α
run
[
"--input": String
"--input"
,
"foo": String
"foo"
] -- Except.ok "foo"
Except.error "Command line argument --input missing"
requiredArgument: String → ReaderT Arguments (Except String) String
requiredArgument
"--input": String
"--input"
|>.
run: {ρ : Type} → {m : Type → Type} → {α : Type} → ReaderT ρ m α → ρ → m α
run
[
"foo": String
"foo"
,
"bar": String
"bar"
] -- Except.error "Command line argument --input missing"
Except.ok true
optionalSwitch: String → ReaderT Arguments (Except String) Bool
optionalSwitch
"--help": String
"--help"
|>.
run: {ρ : Type} → {m : Type → Type} → {α : Type} → ReaderT ρ m α → ρ → m α
run
[
"--help": String
"--help"
] -- Except.ok true
Except.ok false
optionalSwitch: String → ReaderT Arguments (Except String) Bool
optionalSwitch
"--help": String
"--help"
|>.
run: {ρ : Type} → {m : Type → Type} → {α : Type} → ReaderT ρ m α → ρ → m α
run
[]: List String
[]
-- Except.ok false

Notice that throw was available from the inner Except monad. The cool thing is you can switch this around and get the exact same result using ExceptT as the outer monad transformer and ReaderM as the wrapped monad. Try changing requiredArgument to ExceptT String (ReaderM Arguments) Bool.

Note: the |>. notation is described in Readers.

Adding more layers

Here's the best part about monad transformers. Since the result of a monad transformer is itself a monad, you can wrap it inside another transformer! Suppose you need to pass in some read only context like the command line arguments, update some read-write state (like program Config) and optionally throw an exception, then you could write this:

structure 
Config: Type
Config
where
help: Config → Bool
help
:
Bool: Type
Bool
:=
false: Bool
false
verbose: Config → Bool
verbose
:
Bool: Type
Bool
:=
false: Bool
false
input: Config → String
input
:
String: Type
String
:=
"": String
""
deriving
Repr: Type u → Type u
Repr
abbrev
CliConfigM: Type → Type
CliConfigM
:=
StateT: Type → (Type → Type) → Type → Type
StateT
Config: Type
Config
(
ReaderT: Type → (Type → Type) → Type → Type
ReaderT
Arguments: Type
Arguments
(
Except: Type → Type → Type
Except
String: Type
String
)) def
parseArguments: CliConfigM Bool
parseArguments
:
CliConfigM: Type → Type
CliConfigM
Bool: Type
Bool
:= do let mut
config: Config
config
get: {σ : outParam Type} → {m : Type → Type} → [self : MonadState σ m] → m σ
get
if
(← optionalSwitch "--help"): Bool
(
optionalSwitch: String → ReaderT Arguments (Except String) Bool
optionalSwitch
(← optionalSwitch "--help"): Bool
"--help": String
"--help"
(← optionalSwitch "--help"): Bool
)
then
throw: {ε : outParam Type} → {m : Type → Type} → [self : MonadExcept ε m] → {α : Type} → ε → m α
throw
"Usage: example [--help] [--verbose] [--input <input file>]": String
"Usage: example [--help] [--verbose] [--input <input file>]"
config: Config
config
:= {
config: Config
config
with verbose :=
(← optionalSwitch "--verbose"): Bool
(
optionalSwitch: String → ReaderT Arguments (Except String) Bool
optionalSwitch
(← optionalSwitch "--verbose"): Bool
"--verbose": String
"--verbose"
(← optionalSwitch "--verbose"): Bool
)
, input :=
(← requiredArgument "--input"): String
(
requiredArgument: String → ReaderT Arguments (Except String) String
requiredArgument
(← requiredArgument "--input"): String
"--input": String
"--input"
(← requiredArgument "--input"): String
)
}
set: {σ : semiOutParam Type} → {m : Type → Type} → [self : MonadStateOf σ m] → σ → m PUnit
set
config: Config
config
return
true: Bool
true
def
main: List String → IO Unit
main
(
args: List String
args
:
List: Type → Type
List
String: Type
String
) :
IO: Type → Type
IO
Unit: Type
Unit
:= do let
config: Config
config
:
Config: Type
Config
:= { input :=
"default": String
"default"
} match
parseArguments: CliConfigM Bool
parseArguments
|>.
run: {σ : Type} → {m : Type → Type} → {α : Type} → StateT σ m α → σ → m (α × σ)
run
config: Config
config
|>.
run: {ρ : Type} → {m : Type → Type} → {α : Type} → ReaderT ρ m α → ρ → m α
run
args: List String
args
with |
Except.ok: {ε : Type ?u.5351} → {α : Type ?u.5350} → α → Except ε α
Except.ok
(_,
c: Config
c
) => do
IO.println: {α : Type} → [inst : ToString α] → α → IO Unit
IO.println
s!"Processing input '{
c: Config
c
.
input: Config → String
input
}' with verbose={
c: Config
c
.
verbose: Config → Bool
verbose
}" |
Except.error: {ε : Type ?u.5590} → {α : Type ?u.5589} → ε → Except ε α
Except.error
s: String
s
=>
IO.println: {α : Type} → [inst : ToString α] → α → IO Unit
IO.println
s: String
s
Usage: example [--help] [--verbose] [--input <input file>]
main: List String → IO Unit
main
[
"--help": String
"--help"
] -- Usage: example [--help] [--verbose] [--input <input file>]
Processing input 'foo' with verbose=false
main: List String → IO Unit
main
[
"--input": String
"--input"
,
"foo": String
"foo"
] -- Processing input file 'foo' with verbose=false
Processing input 'bar' with verbose=true
main: List String → IO Unit
main
[
"--verbose": String
"--verbose"
,
"--input": String
"--input"
,
"bar": String
"bar"
] -- Processing input 'bar' with verbose=true

In this example parseArguments is actually three stacked monads, StateM, ReaderM, Except. Notice the convention of abbreviating long monadic types with an alias like CliConfigM.

Monad Lifting

Lean makes it easy to compose functions that use different monads using a concept of automatic monad lifting. You already used lifting in the above code, because you were able to compose optionalSwitch which has type ReaderT Arguments (Except String) Bool and call it from parseArguments which has a bigger type StateT Config (ReaderT Arguments (Except String)). This "just worked" because Lean did some magic with monad lifting.

To give you a simpler example of this, suppose you have the following function:

def 
divide: Float → Float → ExceptT String Id Float
divide
(
x: Float
x
:
Float: Type
Float
) (
y: Float
y
:
Float: Type
Float
):
ExceptT: Type → (Type → Type) → Type → Type
ExceptT
String: Type
String
Id: Type → Type
Id
Float: Type
Float
:= if
y: Float
y
==
0: Float
0
then
throw: {ε : outParam Type} → {m : Type → Type} → [self : MonadExcept ε m] → {α : Type} → ε → m α
throw
"can't divide by zero": String
"can't divide by zero"
else
pure: {f : Type → Type} → [self : Pure f] → {α : Type} → α → f α
pure
(
x: Float
x
/
y: Float
y
)
Except.ok 2.000000
divide: Float → Float → ExceptT String Id Float
divide
6: Float
6
3: Float
3
-- Except.ok 2.000000
Except.error "can't divide by zero"
divide: Float → Float → ExceptT String Id Float
divide
1: Float
1
0: Float
0
-- Except.error "can't divide by zero"

Notice here we used the ExceptT transformer, but we composed it with the Id identity monad. This is then the same as writing Except String Float since the identity monad does nothing.

Now suppose you want to count the number of times divide is called and store the result in some global state:

def 
divideCounter: Float → Float → StateT Nat (ExceptT String Id) Float
divideCounter
(
x: Float
x
:
Float: Type
Float
) (
y: Float
y
:
Float: Type
Float
) :
StateT: Type → (Type → Type) → Type → Type
StateT
Nat: Type
Nat
(
ExceptT: Type → (Type → Type) → Type → Type
ExceptT
String: Type
String
Id: Type → Type
Id
)
Float: Type
Float
:= do
modify: {σ : Type} → {m : Type → Type} → [inst : MonadState σ m] → (σ → σ) → m PUnit
modify
fun
s: Nat
s
=>
s: Nat
s
+
1: Nat
1
divide: Float → Float → ExceptT String Id Float
divide
x: Float
x
y: Float
y
Except.ok (2.000000, 1)
divideCounter: Float → Float → StateT Nat (ExceptT String Id) Float
divideCounter
6: Float
6
3: Float
3
|>.
run: {σ : Type} → {m : Type → Type} → {α : Type} → StateT σ m α → σ → m (α × σ)
run
0: Nat
0
-- Except.ok (2.000000, 1)
Except.error "can't divide by zero"
divideCounter: Float → Float → StateT Nat (ExceptT String Id) Float
divideCounter
1: Float
1
0: Float
0
|>.
run: {σ : Type} → {m : Type → Type} → {α : Type} → StateT σ m α → σ → m (α × σ)
run
0: Nat
0
-- Except.error "can't divide by zero"

The modify function is a helper which makes it easier to use modifyGet from the StateM monad. But something interesting is happening here, divideCounter is returning the value of divide, but the types don't match, yet it works? This is monad lifting in action.

You can see this more clearly with the following test:

def 
liftTest: Except String Float → StateT Nat (Except String) Float
liftTest
(
x: Except String Float
x
:
Except: Type → Type → Type
Except
String: Type
String
Float: Type
Float
) :
StateT: Type → (Type → Type) → Type → Type
StateT
Nat: Type
Nat
(
Except: Type → Type → Type
Except
String: Type
String
)
Float: Type
Float
:=
x: Except String Float
x
Except.ok (5.000000, 3)
liftTest: Except String Float → StateT Nat (Except String) Float
liftTest
(
divide: Float → Float → ExceptT String Id Float
divide
5: Float
5
1: Float
1
) |>.
run: {σ : Type} → {m : Type → Type} → {α : Type} → StateT σ m α → σ → m (α × σ)
run
3: Nat
3
-- Except.ok (5.000000, 3)

Notice that liftTest returned x without doing anything to it, yet that matched the return type StateT Nat (Except String) Float. Monad lifting is provided by monad transformers. if you #print liftTest you will see that Lean is implementing this using a call to a function named monadLift from the MonadLift type class:

class MonadLift (m : Type u → Type v) (n : Type u → Type w) where
  monadLift : {α : Type u} → m α → n α

So monadLift is a function for lifting a computation from an inner Monad m α to an outer Monad n α. You could replace x in liftTest with monadLift x if you want to be explicit about it.

The StateT monad transformer defines an instance of MonadLift like this:

@[inline] protected def lift {α : Type u} (t : m α) : StateT σ m α :=
  fun s => do let a ← t; pure (a, s)

instance : MonadLift m (StateT σ m) := ⟨StateT.lift⟩

This means that any monad m can be wrapped in a StateT monad by using the function fun s => do let a ← t; pure (a, s) that takes state s, runs the inner monad action t, and returns the result and the new state in a pair (a, s) without making any changes to s.

Because MonadLift is a type class, Lean can automatically find the required monadLift instances in order to make your code compile and in this way it was able to find the StateT.lift function and use it to wrap the result of divide so that the correct type is returned from divideCounter.

If you have an instance MonadLift m n that means there is a way to turn a computation that happens inside of m into one that happens inside of n and (this is the key part) usually without the instance itself creating any additional data that feeds into the computation. This means you can in principle declare lifting instances from any monad to any other monad, it does not, however, mean that you should do this in all cases. You can get a very nice report on how all this was done by adding the line set_option trace.Meta.synthInstance true in before divideCounter and moving you cursor to the end of the first line after do.

This was a lot of detail, but it is very important to understand how monad lifting works because it is used heavily in Lean programs.

Transitive lifting

There is also a transitive version of MonadLift called MonadLiftT which can lift multiple monad layers at once. In the following example we added another monad layer with ReaderT String ... and notice that x is also automatically lifted to match.

def 
liftTest2: Except String Float → ReaderT String (StateT Nat (Except String)) Float
liftTest2
(
x: Except String Float
x
:
Except: Type → Type → Type
Except
String: Type
String
Float: Type
Float
) :
ReaderT: Type → (Type → Type) → Type → Type
ReaderT
String: Type
String
(
StateT: Type → (Type → Type) → Type → Type
StateT
Nat: Type
Nat
(
Except: Type → Type → Type
Except
String: Type
String
))
Float: Type
Float
:=
x: Except String Float
x
Except.ok (5.000000, 3)
liftTest2: Except String Float → ReaderT String (StateT Nat (Except String)) Float
liftTest2
(
divide: Float → Float → ExceptT String Id Float
divide
5: Float
5
1: Float
1
) |>.
run: {ρ : Type} → {m : Type → Type} → {α : Type} → ReaderT ρ m α → ρ → m α
run
"": String
""
|>.
run: {σ : Type} → {m : Type → Type} → {α : Type} → StateT σ m α → σ → m (α × σ)
run
3: Nat
3
-- Except.ok (5.000000, 3)

The ReaderT monadLift is even simpler than the one for StateT:

instance  : MonadLift m (ReaderT ρ m) where
  monadLift x := fun _ => x

This lift operation creates a function that defines the required ReaderT input argument, but the inner monad doesn't know or care about ReaderT so the monadLift function throws it away with the _ then calls the inner monad action x. This is a perfectly legal implementation of the ReaderM monad.

Add your own Custom MonadLift

This does not compile:

def 
main2: IO Unit
main2
:
IO: Type → Type
IO
Unit: Type
Unit
:= do try let
ret: ?m.8397
ret
Error: type mismatch (divideCounter 5 2).run 0 has type ExceptT String Id (Float × Nat) : Type but is expected to have type IO ?m.8428 : Type
IO.println: {α : Type} → [inst : ToString α] → α → IO Unit
IO.println
(
toString: {α : Type} → [self : ToString α] → α → String
toString
ret: ?m.8397
ret
) catch
e: IO.Error
e
=>
IO.println: {α : Type} → [inst : ToString α] → α → IO Unit
IO.println
e: IO.Error
e

saying:

typeclass instance problem is stuck, it is often due to metavariables
  ToString ?m.4786

The reason is divideCounter returns the big StateT Nat (ExceptT String Id) Float and that type cannot be automatically lifted into the main return type of IO Unit unless you give it some help.

The following custom MonadLift solves this problem:

def 
liftIO: {α : Type} → ExceptT String Id α → IO α
liftIO
(
t: ExceptT String Id α
t
:
ExceptT: Type → (Type → Type) → Type → Type
ExceptT
String: Type
String
Id: Type → Type
Id
α: Type
α
) :
IO: Type → Type
IO
α: Type
α
:= do match
t: ExceptT String Id α
t
with |
.ok: {ε α : Type} → α → Except ε α
.ok
r: α
r
=>
EStateM.Result.ok: {ε σ α : Type} → α → σ → EStateM.Result ε σ α
EStateM.Result.ok
r: α
r
|
.error: {ε α : Type} → ε → Except ε α
.error
s: String
s
=>
EStateM.Result.error: {ε σ α : Type} → ε → σ → EStateM.Result ε σ α
EStateM.Result.error
s: String
s
instance: MonadLift (ExceptT String Id) IO
instance
:
MonadLift: semiOutParam (Type → Type) → (Type → Type) → Type 1
MonadLift
(
ExceptT: Type → (Type → Type) → Type → Type
ExceptT
String: Type
String
Id: Type → Type
Id
)
IO: Type → Type
IO
where monadLift :=
liftIO: {α : Type} → ExceptT String Id α → IO α
liftIO
def
main3: IO Unit
main3
:
IO: Type → Type
IO
Unit: Type
Unit
:= do try let
ret: Float × Nat
ret
divideCounter: Float → Float → StateT Nat (ExceptT String Id) Float
divideCounter
5: Float
5
2: Float
2
|>.
run: {σ : Type} → {m : Type → Type} → {α : Type} → StateT σ m α → σ → m (α × σ)
run
0: Nat
0
IO.println: {α : Type} → [inst : ToString α] → α → IO Unit
IO.println
(
toString: {α : Type} → [self : ToString α] → α → String
toString
ret: Float × Nat
ret
) catch
e: IO.Error
e
=>
IO.println: {α : Type} → [inst : ToString α] → α → IO Unit
IO.println
e: IO.Error
e
(2.500000, 1)
main3: IO Unit
main3
-- (2.500000, 1)

It turns out that the IO monad you see in your main function is based on the EStateM.Result type which is similar to the Except type but it has an additional return value. The liftIO function converts any Except String α into IO α by simply mapping the ok case of the Except to the Result.ok and the error case to the Result.error.

Lifting ExceptT

In the previous Except section you saw functions that throw Except values. When you get all the way back up to your main function which has type IO Unit you have the same problem you had above, because Except String Float doesn't match even if you use a try/catch.

def 
main4: IO Unit
main4
:
IO: Type → Type
IO
Unit: Type
Unit
:= do try let
ret: Float
ret
divide: Float → Float → ExceptT String Id Float
divide
5: Float
5
0: Float
0
IO.println: {α : Type} → [inst : ToString α] → α → IO Unit
IO.println
(
toString: {α : Type} → [self : ToString α] → α → String
toString
ret: Float
ret
) -- lifting happens here. catch
e: IO.Error
e
=>
IO.println: {α : Type} → [inst : ToString α] → α → IO Unit
IO.println
s!"Unhandled exception: {
e: IO.Error
e
}"
Unhandled exception: can't divide by zero
main4: IO Unit
main4
-- Unhandled exception: can't divide by zero

Without the liftIO the (toString ret) expression would not compile with a similar error:

typeclass instance problem is stuck, it is often due to metavariables
  ToString ?m.6007

So the general lesson is that if you see an error like this when using monads, check for a missing MonadLift.

Summary

Now that you know how to combine your monads together, you're almost done with understanding the key concepts of monads! You could probably go out now and start writing some pretty nice code! But to truly master monads, you should know how to make your own, and there's one final concept that you should understand for that. This is the idea of type "laws". Each of the structures you've learned so far has a series of laws associated with it. And for your instances of these classes to make sense, they should follow the laws! Check out Monad Laws.