title | date |
---|---|
A Quick-and-Dirty Explanation of MonadFix |
2019-01-15 |
This post aims to give a quick-and-dirty explanation of MonadFix
. It is quick and
dirty because MonadFix
is complex and weird stuff (or at least it seems so), and I do not
think I am able to present a comprehensive discussion. Instead, this post mainly consists of a simple example
that hopefully can give one a rough idea of what MonadFix
does and where it might
be useful.
It is recommended that one has basic knowledge of how fix
works before reading on, since I will be making comparison between fix
and mfix
, the method
in the MonadFix
type class.
MonadFix
is a type class containing a single method, mfix
:
class Monad m => MonadFix m where
mfix :: (a -> m a) -> m a
Note the similarity between the types of fix
and mfix
:
fix :: (a -> a) -> a
mfix :: (a -> m a) -> m a
mfix
is basically the monadic version of fix
. Their definitions also bear similarity:
fix f = f (fix f)
mfix f = f =<< mfix f
Note that these are not how fix
and mfix
are implemented. They should be considered
specifications or definitions as opposed to implementations.
So,
fix f = f (f (f ...
mfix f = f =<< f =<< f =<< ...
Next, I'll give a quick recap of fix
(it won't be comprehensive by any means) before giving an
example of mfix
, since understanding the former helps with the latter.
fix
is a combinator that returns the (least) fix point of a function, and can be used to
eliminate recursion. For example, suppose we have a recursively defined value a
of type A
:
a :: A
a = <body that mentions a>
A recursive value like this can be defined non-recursively as fix f
for some function f
: if we define
<body that mentions a>
as f a
, i.e.,
f :: A -> A
f x = <body that mentions x>
then we have a = f a
, and so a
is a fixed point of f
. Note that a
is a fixed point
of f
, not the fixed point, since f
may have more than one fixed point. For example, take a
to be
the following function:
a :: Integer -> Integer
a = \x -> a (x+1)
Obviously a x
evaluates to bottom (or ⊥) for every x
. The corresponding f
is
f :: (Integer -> Integer) -> Integer -> Integer
f a = \x -> a (x+1)
Now, this f
has infinitely many fixed points. Any function const n
where n
is an integer or ⊥ is a
fixed point of f
, for instance
f (const 42) x = 42 = const 42 x
for any x
, which shows const 42
is a fixed point of f
. The function a
defined by above is basically
const ⊥
, which is the least fixed point of f
(it is "least" because ⊥ is considered
to be "less" than any non-⊥ value, where "less" is not <
on numbers, but a partial order defined
on values of a given type). Since fix f
always returns the least fixed point of f
,
we have a = fix f
.
As another example, here's the original and fixpoint definition of the factorial function:
facOriginal :: Natural -> Natural
facOriginal = \x -> if x == 0 then 1 else x * facOriginal (x-1)
facFix :: Natural -> Natural
facFix = fix g
where
g :: (Natural -> Natural) -> Natural -> Natural
g f = \x -> if x == 0 then 1 else x * f (x-1)
As mentioned before, mfix
is the monadic version of fix
. It is useful if we want to compute
a monadic value, whose body mentions the value inside the monad, i.e.,
aM :: M A
aM = <body that mentions a>
where M
is some monad. Note that the body of aM
doesn't mention aM
itself (otherwise it would be a normal
recursion), but mentions the value a :: A
inside aM :: M A
. This is where mfix
is useful.
Analogous to what we did with fix
above, let's define fM
to be the following function:
fM :: A -> M A
fM x = <body that mentions x>
Then we have aM = fM =<< aM
, and so aM = mfix fM
.
Now let's see a concrete example.
In this example, let the monad be IO
. The goal is to write an IO action that asks the user
to enter a boolean value, and depending on that boolean value, either returns the factorial function,
or the function that computes the k
th triangular number (i.e., 1 + 2 + ... + k
).
A "regular" implementation of this IO action would look like this:
fIO :: IO (Natural -> Natural)
fIO = do
let fac = \x -> if x == 0 then 1 else x * fac (x-1)
tri = \x -> if x == 0 then 0 else x + tri (x-1)
b <- readLn
pure $ if b then fac else tri
fIO
defines two inner functions fac
and tri
, and returns one of them based on the
boolean value. Test it in ghci by running fIO >>= pure . ($5)
. It will display 120 if you enter True
,
and 15 if you enter False
.
Now, can fIO
be defined in terms of fix gIO
for some gIO
, similar as what we did for the pure
factorial function? That would be difficult, since we can't really make fIO
a recursive function
that calls itself. Perhaps the best we can do is this:
gIO :: IO (Natural -> Natural) -> IO (Natural -> Natural)
gIO fIO' = do
b <- readLn
pure $ \x ->
if b
then if x == 0 then 1 else x * unsafePerformIO fIO' (x-1)
else if x == 0 then 0 else x + unsafePerformIO fIO' (x-1)
The unsafePerformIO
is used to obtain Natural -> Natural
by running the IO action
IO (Natural -> Natural)
, since we need a Natural -> Natural
, i.e., the value
inside fIO'
as opposed to fIO'
itself, to apply to (x-1)
.
And even if we are cool with unsafePerformIO
, the function
fix gIO
is quite different from fIO
, since we are running the IO action via
unsafePerformIO
on every recursion. So if you want to compute the factorial of 5 via
fix gIO
by running
fix gIO >>= pure . ($5)
You'll need to enter True
6 times, whereas with fIO
you only enter it once.
This is exactly the scenario where mfix
comes into play. Again, note the type
of mfix
:
mfix :: (a -> m a) -> m a
That a
allows you to access the value inside m a
, the latter of which is the result
we are trying to obtain. Using mfix
to implement fIO
is nice and simple:
fIOmfix :: IO (Natural -> Natural)
fIOmfix = mfix gIO'
where
gIO' :: (Natural -> Natural) -> IO (Natural -> Natural)
gIO' f = do
b <- readLn
pure $ \x ->
if b
then if x == 0 then 1 else x * f (x-1)
else if x == 0 then 0 else x + f (x-1)
Here f
is the pure function inside fIOmfix
. Since gIO'
takes this pure function (unlike gIO
), no unsafePerformIO
is needed, and the IO action is performed only once. So mfix gIO'
is equivalent to
fIO
.
Haskell provides two syntax sugars, mdo
and rec
(both require the RecursiveDo
extension), both
of which can be used to prettify functions like fIOmfix
:
{-# LANGUAGE RecursiveDo #-}
fIOmdo :: IO (Natural -> Natural)
fIOmdo = mdo
b <- readLn
res <- pure $ \x ->
if b
then if x == 0 then 1 else x * res (x-1)
else if x == 0 then 0 else x + res (x-1)
pure res
fIOrec :: IO (Natural -> Natural)
fIOrec = do
b <- readLn
rec res <- pure $ \x ->
if b
then if x == 0 then 1 else x * res (x-1)
else if x == 0 then 0 else x + res (x-1)
pure res
In both fIOmdo
and fIOrec
, res
is the pure Natural -> Natural
function inside
the result (IO (Natural -> Natural)
) that we are trying to obtain. They are both
equivalent to fIO
and mfix gIO'
.
There are subtle differences between mdo
and rec
but I won't go into them here.
The following are helpful to getter a better understanding of how the magic stuff of
mfix
works internally: