Different ways to understand a monad

As soon as you start mentioning functional programming, monads pops out as something that you have to know. However, hardly anyone is good at explaining what a monad is. That is why we’ll try to get some intuition about it without defining it.

When functor is not enough

I have a square function:

def f_c(c: Double)(x: Double): Double = x*x + c

Actually, it takes 2 parameters, but I am interested in square functions obtained, when I fix c parameter as values of:

List(-1.0, 0.0, 1.0)

I am interested in finding out all arguments for which any of the square functions returns 0 (all possible roots).

So, I apply c and obtain my square functions:

List(-1.0, 0.0, 1.0).map(f_c)

Now, I’ll apply another function which will return roots of my square functions:

def getRoots(f: Double => Double): List[Double]

A square function might have 0, 1 or 2 roots. That’s why it has to return List. So, I ended up with:

val roots = List(-1.0, 0.0, 1.0)
  .map(f_c)
  .map(getRoots)

Thing is, I obtained List[List[Double]], when I wanted List[Double]. I can always .flatten it.

root.flatten: List[Double]

However, I miss something that would perform map and flatten in one step. Accidentally, there is such thing. It’s called .flatMap:

val roots = List(-1.0, 0.0, 1.0)
  .map(f_c)
  .flatMap(getRoots)

Conveniently flatMappable things

Many containers - Option (at most single-element container), List, Vector (containers remembering order of elements and allowing many occurrences), Set (container that, usually, doesn’t care about order and allows at most 1 occurrence of the same element) - has this convenient thing called flatMap.

The face that it is not a map, gives us some flexibility. We could for instance implement filtering with it:

val isEven: Int => Boolean = i => i % 2 == 0
List(0, 1, 2, 3, 4).flatMap {
  case i if isEven(i) => List(i)
  case _              => List.empty
}
// List(0, 2, 4)

On the other hand, we can use it to expand results as well:

import math.sqrt
List(-2.0, 0.0, 2.0).flatMap {
  case d if d < 0.0 => List.empty
  case 0.0          => List(0.0)
  case d            => List(sqrt(d), -sqrt(d))
}
// List(0.0, 1.4142135623730951, -1.4142135623730951)

So, when it comes to container transformation flatMap frees us from the restriction, that we always take 1 thing, return 1 thing and we have to end up with a container with the same number of elements (Set is an exception, but it’s also not a lawful functor).

Circuit breaking

At some point we might face validation. Or at least error handling. Let’s say we use Option to pass on element as long as each stage of computation succeeded and turn it to None first time it fails.

val credentials: Option[(String, String)] = ???
def getUserID(credentials: (String, String)): Option[String] = ???
def getUserData(userID: String): Option[String] = ???

We can already see that .map would not be enough to handle it. That is, unless we wanted to avoid Option[Option[Option[String]]]. If we flattened it after each step, we end up the same way as if we run .flatMap, so we can do just that.

credentials.flatMap(getUserID).flatMap(getUserData)
// Option[String]

As we can see, Option can be used to build a pipeline - if credentials are empty, we end up with None at the end, if getUserID returns None (because credentials are invalid or not entitled to return such details), we also end up with None. Finally, even if we have userId, but is there is no data related to it (or it’s fetching failed), we also end up with None. Only if everything is fine, we’ll get Some(userData).

Such behavior, when your pipeline stops at the first error, and then simply pass on error without executing following steps is called circuit breaking. It works as an analogy to circuit breaking with electric wires - if you cross not isolated wires the electric current will not flow through the whole circuit, but will take a shortcut.

OK, Option works, but is it convenient? When some error occurs we return None, which doesn’t carry any information about the issue. Could’t we solve that?

As a matter of the fact, we can. Option is a single-element collection. When we have Some we map over all 1 elements of the collection, when we have None - an empty collection - we do nothing. But, if we replaced None with something else, something that could contain error, but at the same time be treated as an instance of an empty collection, then we would be able to circuit break and preserve the information about the error. Happily for us, there are some data structures, that does exactly that.

Either[L, R] can have 2 possible values: Left(_ : L) or Right(_ : R). Since 2.12 it is Right-biased, which means, that we treat Right as if it was Some - its value can be mapped or flatMapped. Left is treated as if it was None - it is carried on without a change, so on first Left value our pipeline circuit-breaks.

Before Scala 2.12, you had to use a projection mechanism to tell Either explicitly, which side should be treated as empty and which as non-empty.

either
  .right.map(r => doSth(r))
  .left.map(l => doSthElse(l))
  .right.flatMap { r =>
    if (condition(r)) Right(ok)
    else Left(error)
  }
  .left.flatMap { l =>
    recover(l)
  }

It made composition a little bit messy, especially since everyone wanted to use .right by default.

Currently you are still able to use this syntax, but it is used mostly as.left.flatMap, when you want to recover from error.

Nice! We can circuit break. We have a way of preserving error information. We can recover from error. And, since it’s Either, we can also decide on a type of our error. But is it always needed?

It appears that in practice, when we deal with JVM, quite often most useful error format is Throwable. Errors, Exceptions, RuntimeExceptions… Many Java’s libraries use them as the only right way of handling errors (which I find sad - Exceptions turn into RuntimeExceptions - because why pollute the method signature - then you are loosing any information about potential errors, and eventually the only sane way is to catch all Exceptions, add special handling for the ones you figured out from the code, and add generic handling for everything else, that authors forgot to mention in documentation and you couldn’t find in the sources, because the code is a giant mess). Additionally, because we talk about Exceptions, we would like to have some way to automatically catching them and wrapping them with an empty container. Such purpose serves scala.util.Try:

import scala.util.{Try, Success, Future}
Try {
  // dangerous computation
} match {
  case Success(result) => // handle result
  case Failure(error)  => // handle exception
}

Because it is kind of like Either with left type fixed to Throwable (and automatic exception-catching), Success is treated as Some or Right and during mapping or flatMapping its value is piped. Failure is our None or Left - an empty container from the point of view of a pipeline, but still something that carries on information about error. It also nice enough to provide a way for handling errors:

Try(doSth)
  .map(doSthElse(_))
  .recover {
    case error: Throwable => correctValue
                             // needs value
  }
  .flatMap { value =>
    if (condition(value)) Success(x)
    else Failure(new Exception("error"))
  }
  .recoverWith {
    case error: Exception => Success(correctValue)
                             // needs Try value
  }

Both .recover and .recoverWith take PartialFunction, so we might decide to handle some exceptions and not handle others.

So, is Try the thing we could use everywhere for error handling? Well, not quite. As soon as we start performing long computations - e.g. querying database - we would like to somehow handle the fact, that the value might not be immediately there. And we might want to post each stage to some thread pool, so that these asynchronous computations scale and won’t overrun your computer with thousands threads.

Async and IO

Let us imagine we take Try and make it asynchronous. That is, when we pass some recipe for a value (or error) inside, it won’t block until the results is ready, but returns immediately. You’ll get a container which will eventually contains one value (or error), but you won’t be able to do anything that assumes that the value is already there.

However, you could .map or .flatMap it - in such case, once a value is available it will be piped into the next stage of the pipeline and, again, you’ll get currently-empty-but-eventually-filled container. Kind of like a mailboxes in courier company - one mailbox is waiting for a package to arrive, and once it arrives it is dispatched to another mailbox until it reaches the destination.

What we described here is a Future. It requires some context to perform calculation, a context which defined the thread pool in which computations will be run - an ExecutionContext - and as long as you have it it will be used to asynchronously compute each stage of your asynchronous pipeline:

import scala.concurrent.Future
import scala.concurrent.ExecutionContext.Implicits.global

val resultF = Future {
  // some calculation that might fail
  // and takes some time
}.map { value =>
  // another computation
}
.recover {
  case error: Throwable => correctValue
}.flatMap { value =>
  if (condition(value)) Future.successful(correctValue)
  else Future.failed(new Exception("error"))
}.recoverWith {
  case error: Exception => Future.successful(correctValue)
}

When we look at this, we see that Future is almost identical to Try. It even has .recover and .recoverWith having the same behavior! The only difference is that the value of resultF is not necessarily there, when we obtain it. We cannot pattern-match on it expecting it to be Success or Failure - it makes sense, it was designed for async, the computations are most likely still running. But we can do it asynchronously:

resultF.onComplete {
  case Success(value) => // handle success
  case Failure(error) => // handle error
}

If we only care about the success, we might treat is as if it was Option (or any other collection) and use .foreach:

resultF.foreach { value =>
  // so stuff
}

Thing that you have to remember is that Future is posted to thread pool right after its creation - if it doesn’t have to wait in queue it start getting calculated immediately (if you have an issue remembering that, remember a slogan: Future begins now). While, this is usually the case when you use it, some people find it troublesome when it comes to reasoning about e.g. referential transparency and side-effects:

def resultInt(): Future[Int] = Future {
  scala.util.Random.nextInt(10)
}

resultInt().flatMap { str =>
  resultInt().map { str2 =>
    println(s"1: $str, 2: $str2")
  }
}
// 1: 3, 2: 7

val intF = resultInt
intF.flatMap { str =>
  intF.map { str2 =>
    println(s"1: $str, 2: $str2")
  }
}
// 1: 4, 2: 4

As we can see referential transparency is broken: Future memorizes calculated value, so replacing function returning Future with its value might change the result of calculation. That is why many people prefer to build whole pipeline of tasks first, and then give it thread pool to run on. Such pipeline would differ from Future-based: it wouldn’t run computations immediately, only store recopies for each stage. It would also not memoize the result (cache for all future calls) of each task - if you need the value calculated of a task twice, it would be evaluated twice. But only once you would run it.

import monix.eval.Task
import monix.execution.Scheduler.Implicits.global

def resultInt(): Task[Int] = Task {
  scala.util.Random.nextInt(10)
}

resultInt().flatMap { str =>
  resultInt().map { str2 =>
    println(s"1: $str, 2: $str2")
  }
}.runToFuture
// 1: 4, 2: 6

val intF = resultInt
intF.flatMap { str =>
  intF.map { str2 =>
    println(s"1: $str, 2: $str2")
  }
}.runToFuture
// 1: 7, 2: 0

In this example we can see, that each intF is evaluated separately, so we can safely replace function returning Task with its value, so referential transparency is preserved. Task as an idea was first implemented in Scalaz (as a wrapper for Scalaz’s own Future), then another implementation was done in Monix (which in turn is married to Cats ecosystem).

Let us reiterate. We started out with containers/collections and figured out, that .map doesn’t cover all the cases. Then we started using at-most-single-element collections for handling errors, then enriched them so that we could keep the information about errors. Finally, we made these at-most-single-element-collections-with-error-information asynchronous. It seems that for async and circuit breaking, the metaphor of container started wearing off. So let’s try another one.

Imagine you have two rail roads - one for trains carrying errors and one carrying successful values.

Success
    Failure
x      
#   #
#   #
#   #

As long as nothing bad happens your train continue its journey down the success rail way. But, if there is something wrong at some point, instead of derailing it turns into the railway siding of failure.

Success
    Failure
x
#   #        .flatMap(x => Failure(y))
#\  #    
#\\ #
# \\#
#  \#
    y

Here .flatMap becomes a cross-roads where we decide if we want to make a turn. Of course, we cannot go very far staying on a railway siding. That is why pipelines that deal with error handling provide some way of getting back on track.

Success
    Failure
    y
#   #        .recover(y => z)
#  /#
# //#    
#// #
#/  #
z

(BTW: It has to be noted that in Scala async is very often related to IO. Basically, all IO-handling is done in asynchronous structures like Future, or Task or ZIO, while all asynchronous structures assume that async computation involves side effects and might fail. It is not a rule - you have a synchronous Try or Scalaz’s Future, which doesn’t handle exceptions - but you should be aware of the relation).

This metaphor doesn’t work very well, if we got back to containers. e.g. Lists. We might try to get creative:

List(1, 2, 3)

1   2   3
#   #   #    .flatMap(x => List(x, 2*x))
#\  #\  #\  
#\\ #\\ #\\ 
# # # # # #
1 2 2 4 3 6

List(1, 2, 2, 4, 3, 6)

where we would model each value as a separate railway, that might branch at .flatMap, and the results are all the targets that could be reached by all possible roads.

If you think, that is all that you can achieve with .flatMap, you are wrong. Functional programmers are dangerously creative beasts, that found out a few more usages.

Purely functional state

Let’s say you want to preserve the state across the computations. One way to do it would be by using global variables:

import scala.collection.immutable.Queue

var stack = Queue[String]()

def push(s: String): Unit = { stack = stack.enqueue(s) }
def pop(): Option[String] = stack.dequeueOption match {
  case Some((s, newStack)) =>
  stack = newStack
  Some(s)
  case None => None
}

def computation(a: String, b: String) = {
  push(a)
  push(b)
  pop().getOrElse("") + pop().getOrElse("")
}

However, quite fast we will run into problems with such approach: it is hard to test, it is unreliable in concurrent environment, one cannot have 2 different queues at the same time. So, maybe instead of using globals, we will pass the state in the function and return next to normal output.

type State = Queue[String]

def push(s: String, state: State): State =
  stack.enqueue(s)
def pop(state: State): (Option[String], State) =
  state.dequeueOption match {
    case Some((s, newState)) =>
      Some(s) -> newState
    case None =>
      None -> state
  }

def computation(a: String, b: String, state: State) = {
  val state2 = push(a, state)
  val state3 = push(b, state2)
  val (result1, state4) = pop(state3)
  val (result2, state5) = pop(state4)
  (result1.getOrElse("") + result2.getOrElse("")) -> state5
}

This is much more testable and works great with concurrency. It is also a pain to work with, and quite easy to mess up passing a new state. Let’s try to think a way of passing the state around in some less intrusive way. Maybe some wrapper that would do this for us? Such wrapper would have to keep both the current state and the result of a current computation, and it would have to swap state when it is overridden.

final case class State[S, A](f: S => (S, A)) {
  def run(initial: S): (S, A) = f(initial)
}

Right now State wrapper is only concerned about passing on current state and deriving value out of it. Returning value could be implemented using closures:

def returnStr(a: String) =
  State[Queue[String], String] { state =>
    state -> a
  }

It might come handy in push implementation - we have to put a: String into the pipeline somehow! But we also need a way of piping the transformations. We want to turn one State into another and we want to have full control over created State. It sounds like flatMap:

final case class State[S, A](runF: S => (S, A)) {
  def run(initial: S): (S, A) = runF(initial)
  
  def flatMap[B](g: A => State[S, B]): State[S, B] =
    State[S, B] { state1 =>
      val (state2, a) = runF(state1)
      g(a).run(state2)
    }
}

Let’s think what happened here.

  • We have a State[S, A] ( which hides a computation S => (S, A)).
  • We want to compose it with A => State[S, B] (which is actually A => S => (S, B)) to obtain State[S, B] (inside: S => (S, B)).
  • In the process, we have to produce A from S (possibly updating S) and then use it to turn A => S => (S, B) to produce S => (S, B).
  • When we apply S returned by S => (S, A) to S => (S, B) we receive a final (S, B) (which might update S again).
  • Considering, that we started with S coming from outside and ended up with (S, B) we successfully managed to create a composition of both state transitions.

The A => State[S, B] does the real trick here - looking at the signature we still see the same X[A].flatMap[B](a => X[B]) pattern, but A => State[S, B] is actually a A => S => (S, B) in disguise - so we provide both values potentially needed for calculating the next step!

Let’s try to implement push and pop to confirm that it works:

def push(a: String) = State[Queue[String], Unit] { queue =>
  queue.enqueue(a) -> ()
}
val pop = State[Queue[String], Option[String]] { queue =>
   queue.dequeueOption match {
    case Some((s, newQueue)) =>
      newQueue -> Some(s)
    case None =>
      queue -> None
  }
}

def computation(a: String, b: String) = {
  push(a).flatMap { _ => 
    push(b).flatMap { _ =>
      pop.flatMap { r1 =>
        pop.flatMap { r2 =>
          State[Queue[String], String] { state =>
            state -> (r1.getOrElse("") + r2.getOrElse(""))
          }
        }
      }
    }
  }
}

computation("x", "y").run(Queue())

Putting aside thousands parenthesis (that we’ll deal with later) it looks pretty good! We can compose state transformations, we are purely functional, things are easy to test. We could possibly add few utility functions, for some common cases, like updating the state without updating value, setting up value/state to specific value etc. All of that is done in Cats/Scalaz implementations in State’s companion objects:

utility Cats Scalaz
() => State { s => s -> s} get get/init
(s: S) => State { _ => s -> () )} set put
(a: A) => State { s => s -> a } pure state
(f: S => S) => State { s => f(s) -> () } modify modify
(f: S => T) => State { s => s -> f(s) } inspect gets

Polimorphic state

One can think that such pure way of handing state would be useful in many cases e.g. during program initializations - you know, parsing arguments, reading configs and wiring services together.

import org.typelevel.config.{ Config, ConfigFactory }
final case class AppConfig(/*...*/)
trait Services { /*...*/ }

def readConfigs(config: Config): State[AppConfig, Unit]
def applyArgs(args: Array[String]): State[AppConfig, Unit]
val initializeServices: State[AppConfig, Services]

readConfigs(ConfigFactory.load).flatMap { _ =>
  applyArgs(args).flatMap { _ =>
    initializeServices
  }
}.run(AppConfig())

There is small issue with such approach though. If you want to make sure that you always:

  • start, by reading config files into memory
  • override config files with parameters explicitly passed into program
  • pass into services config that can be considered final

then this approach makes it too easy to swap order of operations. Maybe in some earlier step you need some information, that is not needed later on? (The content of config files and CLI arguments are good candidates). Perhaps at some point you calculate things that you would like to be available from now on, but if you are reusing the same state type, you would have to introduce some fields that would store some garbage for earlier stages. This all looks like limitations that would result in inelegant and clumsy design.

Or you can simply allow for the state type to change across the pipeline. You know - change S => (S, A) into S1 => (S2, A) and adjust the composition accordingly. Such version of State which would trace State type versions over the pipeline is called IndexedState:

final case class IndexedState[S1, S2, A](runF: S1 => (S2, A)) {
  def run(s1: S1): (S2, A) = runF(s1)
  
  def flatMap[S3, B](g: A => (S3, B)): IndexedState[S1, S3, B] = 
    IndexedState[S1, S3, B] { state1 =>
      val (state2, a) = runF(state1)
      g(a).run(state2)
    }
}

At this point State is just an special case of IndexedState:

type State[S, A] = IndexedState[S, S, A]

With IndexedState it would be easy to force order of processing configs:

final case class Uninitialized(
  configs: Config,
  args: Array[String],
  initialValues: AppConfig
)

final case class ConfigParsed(
  args: Array[String],
  updatedConfig: AppConfig
)

val parseConfig: IndexedState[Preparations, ConfigParsed, Unit]
val parseArgs: IndexedState[ConfigParsed, AppConfig, Unit]
val initializeServices: IndexedState[AppConfig, (AppConfig, Services), Unit]

If you define program initialization in such way, you would have to go out of your way to compose things in a wrong way.

parseConfig.flatMap { _ =>
  parseArgs.flatMap { _ =>
    initializeServices
  }
}.run(Uninitialized(config, args, AppConfig()))

If you think this application of flatMap was mind bending, what would do say if I told you, it can be used for reading the future?

Back to the future

Of course there is no such way, as reading the future. But, you can do something that looks like one: create in your function place, that will wait for future state to appear. You can calculate next value from current one immediately. You can define how would you update your state from present to future, but defer actual calculation until the last value is known. As a result, each of your state updates will receive not the current value but a future one.

It can be implemented quite easily with Haskell:

newtype ReverseState s a = ReverseState
  { runF :: s -> (a, s)
  } deriving Functor

-- >>= is flatMap, ignore Monad for now
instance Monad (ReverseState s) where
  mx >>= f =
    ReverseState $ \s ->
      let (a, past) = runF mx future
          (b, future) = runF (f a) s
      in (b, past)

Thing is: in Haskell everything is lazy. If you translated it into Scala, things would break quite easily:

final class ReverseStae[S, A](runF: S => (S, A)) {
  
  def flatMap[B](f: A => ReverseState[S, B]): ReverseState[S, B] =
    new ReverseState[S, B]({ s =>
      val (past, a) = runF(future)
      val (future, b) = f(a).run(s)
      past -> b
    })
  
  def run(s: S): (S, A) = runF(s)
}

It looks nice and everything, but a depends on future and future depends on a. In Haskell it is OK (if the pipeline is finite), because things are lazy and at some point the reference to value you just passed won’t be used anywhere and function will be able to return. In Scala we evaluate things eagerly, so the code explodes.

This issue was already investigated by Vladimir Pavkin in his article Reverse State Monad in Scala. Is it possible? - the conclusion is: it depends. If you want to have elegant, simple solution like in Haskell - forget about it. However, if you wrap everything in lazy data structure (e.g. Eval), you should be able to implement it (user experience will be sad though - you have to wrap things up in Eval yourself):

// based on http://pavkin.ru/reverse-state-monad-in-scala-is-it-possible/

import $ivy.`org.typelevel::cats-core:1.4.0`, cats._, cats.implicits._

{
class ReverseState[S, A](val runF: Eval[Eval[S] => (Eval[S], Eval[A])]) {

  def map[B](f: A => B): ReverseState[S, B] =
    new ReverseState[S, B](
      runF.map(run =>
        s =>
          run(s) match {
            case (next, a) => next -> a.map(f)
      })
    )

  def flatMap[B](f: Eval[A] => ReverseState[S, B]): ReverseState[S, B] =
    ReverseState[S, B]({ s: Eval[S] =>
      new {
        lazy val pastPair  : Eval[(Eval[S], Eval[A])] = Eval.defer(run(future))
        lazy val past      : Eval[S]                  = pastPair.flatMap(_._1)
        lazy val a         : Eval[A]                  = pastPair.flatMap(_._2)
        lazy val futurePair: Eval[(Eval[S], Eval[B])] = Eval.defer(f(a).run(s))
        lazy val future    : Eval[S]                  = futurePair.flatMap(_._1)
        lazy val b         : Eval[B]                  = futurePair.flatMap(_._2)

        val result: (Eval[S], Eval[B]) = past -> b
      }.result
    })

  def run(s: Eval[S]): Eval[(Eval[S], Eval[A])] = runF.map(_(s))

  def runA(s: Eval[S]): Eval[A] = run(s).flatMap(_._2)
  
  def runS(s: Eval[S]): Eval[S] = run(s).flatMap(_._1)
}

object ReverseState {

  def apply[S, A](fn: Eval[S] => (Eval[S], Eval[A])): ReverseState[S, A] =
    new ReverseState(Eval.later(s => fn(s)))
}
}

Let’s try it out:

def insert(ea: Eval[Int]): ReverseState[List[Int], Int] =
  ReverseState[List[Int], Int] { es =>
    val es1 = es.flatMap { s =>
      ea.map { a =>
        val s1 = s :+ a
        println(s"s = $s1")
        println(s"a = $a")
        println()
        s1
      }
    }

    es1 -> ea
  }


insert(Eval.now(1)).flatMap { i =>
  insert(i.map(_ + 1)).flatMap { j =>
    insert(j.map(_ + 1))
  }
}.runS(Eval.now(List())).value
s = List(3)
a = 3

s = List(3, 2)
a = 2

s = List(3, 2, 1)
a = 1

res4: List[Int] = List(3, 2, 1)

It works! We can see that ReverState calculated all values up till a final one, and then inserted it into state updating function resulting in updates running in reverse, as future values arrived before past ones.

For-comprehension

All of these examples had one thing in common: flatMap. However in each case flatMap was used for something different:

  • containers use it to allow 1-to-many mappings (or 1-to-{0,1} in case of Option),
  • structures with circuit breaking like Either, Validated, Try, Future or Task use it to fail computations,
  • State[S, A] and friends use it to allow modifying current state S value and current computation’s value A in parallel.

We can say, that flatMap allows for more fine grained control than map, but that it exactly means is completely dependent on data structure.

In Scala flatMap compositions can be achieved with for-comprehension. Code:

for {
  a <- aM
  b <- bM
  c <- cM
} yield a + b + c

is a syntactic sugar for

aM.flatMap { a =>
  bM.flatMap { b =>
    cM.map { c =>
      a + b + c
    }
  }
}

If there was no yeald

for {
  a <- aM
  b <- bM
  c <- cM
} println(a + b + c)

it would translate to

aM.flatMap { a =>
  bM.flatMap { b =>
    cM.foreach { c =>
      println(a + b + c)
    }
  }
}

In other words:

  • last <- is translated to map or foreach depending on whether there is yield or not,
  • all <- before are translated to flatMap.

Additionally, x = y would put val x = y in block, and x <- y if cond would add withFilter(x => cond) (or filter if withFilter was absent).

It is slightly different from Haskell which popularized flatMap (though they call it >>=), and much different from do notation, which was inspiration for for comprehension. One of the differences is that do notation doesn’t use map (which is called fmap - functor’s map). Instead, each type that you can use in do notation has unit function, that wraps it. In do notation return can be used as an alias for unit.

As a result code:

do { a <- aM
   ; b <- bM
   ; c <- cM
   ; return (a + b + c) }

translates to

aM >>= (\ a ->
  bM >>= (\ b ->
    cM >>= (\c ->
      unit (a + b + c))))

It’s as in in Scala, you had

for {
  a <- List(1, 2, 3)
} yield a * 2

translated to

List(1, 2, 3).flatMap { a =>
  List(a * 2) // List(_) would be unit for List
}

But in Scala we have that problem, that we cannot so easily guess the function that would wrap the result. Because of that, it might be a little less obvious, what is quite obvious in Haskell - that the way unit and flatMap/>>= relates to one another matters.

Monad

Let’s say you have a type constructor M[_] - meaning you can have a type M[A] for any concrete type A. Let’s say you can flatMap it - whether you define it as a method:

class M[A] {
  
  def flatMap[B](f: A => M[B]): M[B]
}

or as a function

def flatMap[A, B](ma: M[A])(f: A => M[B]): M[B]

isn’t relevant now (though it is quite relevant when you’ll using it, mathematically both implementations are interchangeable). Finally, you have a unit which simply wraps A into M[A] . If M[_], flatMap and unit follows some rules:

  • left identity/left unit law

    unit(a).flatMap(f) == f(a)
    
  • right identity/right unit law

    ma.flatMap(unit) == ma
    
  • associativity

    ma.flatMap(f1).flatMap(f2) ==
        ma.flatMap(v => f1(v).flatMap(f2))
    

Then we can say that M is a monad. If you try to think about something, that doesn’t follow these laws you can reach the conclusion, that the basically require your M to be intuitive.

That’s it! You have something that you can flatMap, you can wrap any value with your monadic type using some unit, if they don’t do anything offensive it’s a monad.

Usefulness of lawfulness is more obvious in Haskell - where you actually use return in do notation, a bit less in Scala, where you have map or foreach. That is, until you figure out that lawful map is just a special case of flatMap + unit:

class M[A] {
  def flatMap[B](f: A => M[B]): M[B] = ...
  
  def map[B](f: A => B): M[B] =
    flatMap(a => M.unit(f(a)))
}

object M {
  
  def unit[A](a: A): M[A] = ...
}

which means that intuitiveness of map is a result of intuitiveness of flatMap and unit.

If you are looking for some nice metaphor - metaphors of a container, a pipeline, a manufacture line or a railway works nice, though just in certain contexts. Personally I think of monads as an abstract interface for combining transformations, that gives you more control than map. And if I need a metaphor I pick the one that works best in a given context.

Free monads

Similarly to some other structures that have their free versions, there are free monads. Similarly to them, free means that you can define them for any algebra S[_] without any assumptions about it and you will be able to get a monad.

How can we define them?

sealed abstract class Free[S[_], A] {
  
  def flatMap[B](f: A => Free[S, B]): Free[S, B] =
    FlatMapped(this, f)
  
  def map[B](f: A => B): Free[S, B] =
    flatMap(a => Free.unit(f(a)))
}

object Free {
  
  def unit[S[_], A](a: A): Free[S, A] = Pure(a)
  def lift[S[_], A](sa: S[A]): Free[S, A] = Suspend(sa)
}

case class Pure[S[_], A](a: A) extends Free[S, A]
case class Suspend[S[_], A](sa: S[A]) extends Free[S, A]
case class FlatMapped[S[_], A, B](
    fsa: Free[S, A],
    f: A => Free[S, B]
) extends Free[S, B]

This is all we need in order to lift S[_] into Free[S, ?]. Pure is our unit - it allows us to lift any value A to Free[S, A]. FlatMapped handles mapping and Suspend is used to lift elements of S[_] algebra. But what is the point of using Free?

Well, long before typed tagless final interpreter became popular, people wanted a way to abstract from the monad you are currently using. More specifically, they wanted to abstract from async/IO monad they used: Future/Task/something else. Some people wanted to define their computations once, and then interpret them into whatever monad they liked later on.

The interpreter we mean here is exactly the same interpreter we talked about in post about functors - natural transformation aka ~> aka FunctionK. Let’s illustrate it with example:

sealed trait MyAlgebra[A]
case class Push(value: Int) extends MyAlgebra[Unit]
case class Pop() extends MyAlgebra[Option[Int]]
case class Log(msg: String) extends MyAlgebra[Unit]

def push(value: Int): Free[MyAlgebra, Unit] =
  Free.lift(Push(value))
def pop(): Free[MyAlgebra, Option[Int]] =
  Free.lift(Pop())
def log(msg: String): Free[MyAlgebra, Unit] =
  Free.lift(Log(msg))

Here, each MyAlgebra[A] is tells us that our calculation should be parametrized with A (so if we have interpreter MyAlgebra ~> Future, we should get Future[A]). Elements of our algebra Push(value), Pop(), Log(msg) are kind of abstract functions, that will get interpreted by natural transformation. Finally, Free provides a way of mapping and flatMapping these functions.

val program: Free[MyAlgebra, Option[Int]] = for {
  _ <- push(1)
  _ <- push(2)
  _ <- log("hello")
  value <- pop
} yield value

We also have an interpreter:

import $plugin.$ivy.`org.spire-math::kind-projector:0.9.4`
import cats.arrow.FunctionK
import cats.data.State
import scala.collection.immutable.Queue

val interpreter = new FunctionK[MyAlgebra, 
                                State[Queue[Int], ?]] {
  def apply[A](
    myAlgebra: MyAlgebra[A]
  ): State[Queue[Int], A] = myAlgebra match {
    case Push(value) => State.modify[Queue[Int]] { q =>
      q.enqueue(value)
    }
    case Pop() => State[Queue[Int], Option[Int]] { q =>
      q.dequeueOption
        .map { case (v, q2) => q2 -> Some(v) }
        .getOrElse(q -> None)
    }
    case Log(msg) =>
      State.inspect[Queue[Int], Unit] { _ =>
        println(msg)
      }
  }
}

So, the only thing left is interpretation. What we want to do if both fold over Free and map myAlgebra into state at once. Such function is called (surprisingly) foldMap.

def foldMap[F[_], G[_], A](
  free: Free[F, A]
)(
  fK: FunctionK[F, G]
): G[A] = free match {
  case Pure(a) =>
    ??? // lift a:A to G[A]
  case Suspend(sa) =>
    fK(sa)
  case FlatMapped(fsa, f) =>
    ??? // foldMap(fsa)(fK).flatMap(a => foldMap(f(a))(fK))
}

Hmm. If we tried to implement it, we need a way to lift A into G[A] and a way to flatMap G[A]. (Theoretically, we should also ensure stack safety during whole mapping, but let’s keep the example simple). Since we interpret into a monad (and since both Cats and Scalaz has a type class for monads), we can use Monad type class and syntax for it:

def foldMap[F[_], G[_]: Monad, A](
  free: Free[F, A]
)(
  fK: FunctionK[F, G]
): G[A] = free match {
  case Pure(a) =>
    a.pure[G]
  case Suspend(sa) =>
    fK(sa)
  case FlatMapped(fsa, f) =>
    foldMap(fsa)(fK).flatMap(a => foldMap(f(a))(fK))
}

Now, we should be able to interpret Free into State:

@ foldMap[MyAlgebra,
          State[Queue[Int], ?],
          Option[Int]](program)(interpreter)
  .run(Queue())
  .value 
hello
res19: (Queue[Int], Option[Int]) = (Queue(2), Some(1))

Various versions of free monads are used for handling IO in a referentially transparent way. Slick has a DBIO which is a variant of free that is later on interpreted into Future by Slick driver. Doobie uses internally Cats’ free. Haskell uses IO for defining whole program - each IO operation returns IO[Result] which can be interpreted into actual IO operation only by complier (which means it is almost impossible to define function with side-effects there).

However, as far as abstracting away from a particular monad goes, TTFI is something easier to understand and implement.

Monad transformers

Monad are cool, nice, powerful et all, but at some point you’ll notice that they do not compose well. Or should I say - they don’t nest well. If you put e.g. State into Task, and if you want to have the ability to update state and do it asynchronously in a potentially separate call to thread pool, you would have to flatMap within flatMap. And in a way that doesn’t look good in a for comprehension:

val x: Task[State[List[Int, Int]]] = ???
x.flatMap { state =>
  state.flatMap { i =>
    State.modify[List[Int]] { s => s :+ i }
  }
}
// or
for {
  state <- x
} yield {
  for {
    i <- state
    _ <- State.modify[List[Int]] { _ :+ i }
  } yield ()
}

That is why, some people decided, that these 2 operations could be done in one flatMap.

final case class TaskState[S, A](val x: Task[State[S, A]]) {
  
  def flatMap[B](f: A => TaskState[S, B]): TaskState[S, B] =
    TaskState(x.flatMap { state =>
      state.flatMap { a =>
        f(a).x
      }
    })
  
  def map[B](f: A => B): TaskState[S, B] =
    x.map { state =>
      state.map(s)
    }
}

What if we wanted to change inner or outer type to Option?

final case class TaskOption[A](val x: Task[Option[A]]) {
  
  def flatMap[B](f: A => TaskOption[B]): TaskOption[B] =
    TaskOption(x.flatMap { option =>
      option.flatMap { a =>
        f(a).x
      }
    })
  
  def map[B](f: A => B): TaskOption[B] =
    x.map { option =>
      option.map(s)
    }
}
final case class OptionState[S, A](val x: Option[State[S, A]]) {
  
  def flatMap[B](f: A => OptionState[S, B]): OptionState[S, B] =
    OptionState(x.flatMap { state =>
      state.flatMap { a =>
        f(a).x
      }
    })
  
  def map[B](f: A => B): OptionState[S, B] =
    x.map { state =>
      state.map(s)
    }
}

It seems, that logic is almost the same. However, change of inner type has an effect on type parameters of whole wrapper, while change of outer type has no effect at all. It makes sense - outer type is parametrized by whatever inner type constructor spits. So, it seems, that if we wanted to generalize, it would be easier to use some Monad instance for outer type, while inner type would be glued to wrapper. Let’s try to generalize TaskState and OptionState:

final case class StateT[F[_]: Monad, S, A](val run: F[State[S, A]]) {
  
  def flatMap[B](f: A => StateT[S, B]): StateT[S, B] =
    StateT(x.flatMap { state =>
      state.flatMap { a =>
        f(a).x
      }
    })
  
  def map[B](f: A => B): StateT[S, B] =
    x.map { state =>
      state.map(s)
    }
}

What we just invented is called a monad transformer for State. Such StateT could be used for any F[State[S, A]] for which F has a monad instance. How do monad transformers compose? More or less this way:

val x: Task[State[List[Int, Int]]] = ???
(for {
  value <- StateT(x)
} yield value * 2).run

Putting StateT, (or WhateverT for whatever type you are using) and then calling run to get back unwrapped type is annoying. This is the reason why many people avoid monad transformers as much as possible. If your logic is made up of many services which are nested monads, you might to to either get rid of the nesting… or wrap the result of each service into transformer right from the beginning. This way you will only have to run run once - at the end of the world.

It you are interested in stacking together monad transformers, and running them by type classes, then what you are looking after is Monad Transformer Library or MTL. It’s basically about giving you type classes for monad transformers, so that you could just define some type F[A] = StateT[EitherT[Task, AppError, ?], A] and then use type class syntax to lift, compose, use TTFI, and whatever you want.

import cats.implicits._
import cats.mtl.implicits._

def service[F[_]: Monad]: Int => F[String] =
  _.toString.pure[F]

def program[F[_]: Monad] = for {
  a <- service[F](1)
  b <- service[F](2)
} yield s"$a $b"

program[StateT[EitherT[Task, AppError, ?], ?]]
// StateT[EitherT[Task, AppError, ?], String]

Kleisli

Another thing that is worth considering is Kleisli. Kleisli is basically a wrapper for A => F[B] function - exactly the kind of function that you would pass into flatMap. Which means that is has a tons of utilities helping with composition:

case class Kleisli[M[_], A, B](run: A => M[B]) {
  
  def andThen[C](f: B => M[C])
                (implicit M: Monad[M]): Kleisli[M, A, C] =
   Kleisli((a: A) => M.flatMap(run(a))(f))
  
  def flatMap[C](f: B => Kleisli[M, A, C])
                (implicit M: Monad[M]): Kleisli[M, A, C] =
    Kleisli((a: A) => M.flatMap[B, C](run(a))(((b: B) => f(b).run(r))))
  
  def map[C](f: B => C)
            (implicit M: Functor[M]): Kleisli[M, A, C] =
    Kleisli((a: A) => M.map(run(a))(f))
  
  def dimap[A0, B1](f: A0 => A1, g: B => B1)
                   (implicit M: Functor[M]): Kleisli[M, A0, B1] =
    Kleisli((a0: A0) => M.map(run(f(a0)))(g))
  
  // ...
}

As you can see, you can use Kleisli as a function, a functor or monad (regarding B), a profunctor (regarding A and B) etc. If you have right Functor, Monad etc instances in scope, you can simply define functions from A => M[B] (even if this M is a monad tranformer) and then simply compose them (since Kleisli implements map and flatMap you can use it in for comprehension). Such approach is advertised e.g. by Doobie which prefers you to pick your target Monad, define operations on database and services as Kleisli, compose them together and finally run with the initial argument.

// Monad[EitherT[Task, Error, ?]]] implicit in scope

val service1: A => EitherT[Task, Error, B]
val service2: B => EitherT[Task, Error, C]

val program = for {
 b <- Kleisli(service1)
 c <- Kleisli(service2)
} yield c

program.run(a).run.runAsync

or maybe even

// I believe quite fasy you would write some
// type Service[A, B] = Kleisli[EitherT[Task, Error, ?], A, B] 
val service1: Kleisli[EitherT[Task, Error, ?], A, B]
val service2: Kleisli[EitherT[Task, Error, ?], B, C]

val program = for {
 b <- service1
 c <- service2
} yield c

In mentioned Doobie operations on database has form Kleisli[M, Connection, A], so by combining things you are effectively building function Connection => M[A]. It’s a purely functional, referentially transparent way of defining operation that would give us result, when we pass it connection to database.

Comonads

Just like there are are products and coproducts, functors and cofunctors, there are also monads and comonads. Comonad is a dual to monad, but what is actually means? Well, it basically reverts the orders of arrows in functions:

// let's us type class notation here
trait Comonad[F[_]] extends Functor[F] {
  
  // dual to unit
  def extract(fa: F[A]): A
  
  // dual do flatMap
  def extend[A](fa: F[A])(f: F[A] => B): F[B]
  
  // dual to flatten
  def duplicate[A](fa: F[A]): F[F[A]] = extend(fa)(identity)
}

As you can see it is basically extractor e.g.

implicit val comonadFuture = new Comonad[Future] {
  
  def map[A, B](fa: Future[A])(f: A => B): F[B] =
    extend(f(extract(fa)))
  
  def extract[A](fa: Future[A]): A =
    Await.result(fa, Duration.Inf)
  
  def extend[A](fa: Future[A])(f: F[A] => B): F[B] =
    Future(f(fa))
}

In case of product types, they might be used as a lenses of sort:

implicit def comonadFirst[A2] = new Comonad[(?, A2)] {
  
  def map[A, B](fa: (A, A2))(f: A => B): (B, A2) =
    extend(f(extract(fa)))
  
  def extract[A](fa: (A, A2)): A =
    fa._1
  
  def extend[A](fa: (A, A2))(f: (A, A2) => B): (B, A2) =
    f(fa) -> fa._2
}

I have to admit though that I haven’t seen much examples of comonads in Scala, (meanwhile Haskell community seem to use them more often). Looking at the source code of Cats or Scalaz I usually saw them in context where you had e.g. some computations F[A => B] and wanted to apply A to this function to get B (or F[B]). It would be pretty inefficient to do this in IO operations (because we would be blocking), but it would make more sense with things like Eval or Coeval or other suspended computations (potentialy infinite), where we would want to proceed one step.

If you want to read the source of Cats and Scalaz…

The more FP-heavy project the more monad transformers you’ll get (unfortunatelly). Because of that, creators of both basic FP libraries for Scala make sure that there amount of operations will be possibly limited when you stack transformers. When you look at the sources, you’ll see that there is no:

class State[S, A](...) {
  // define State's behavior here
}
class StateT[F[_], S, A](...) {
  // define how to compose state with F
}

Instead, you would have more chances to see something like this:

final case class StateT[F[_]: Monad, S, A](val runF: F[S => (S, A)]) {
  
  def run(initial: S): F[(S, A)] = runF.map(_(initial))
    
  def flatMap[B](f: A => StateT[S, B]): StateT[S, B] =
    StateT[S, B] { state1 =>
      run(state1).flatMap { case (state2, a) =>
        f(a).run(state2)
      }
    }
  
  def map[B](f: A => B): StateT[S, B] =
    StateT[S, B] { state1 =>
      run(state1).map { case (state2, a) =>
        state2 -> f(a)
      }
    }
}

type Id[A] = A
type State[S, A] = StateT[Id, S, A]

This way, whether you’ll use StateT or State you will have more or less the same number operations (for this layer - the F inside will add its own obviously) - there will be no separate flatMap for state operations and combining it with the outer F.

I said would, because you can expect the highest abstraction possible to be there. So for State-related classes actual implementation would be in:

case class IndexedStateT[F[_], S1, S2, A](...) { ... }

type IndexedState[S1, S2, A] = IndexedStateT[Id, S1, S2, A]
type StateT[F[_], S, A] = IndexedStateT[F, S, S, A]
type State[S, A] = StateT[Id, S, A]

All of these type aliases would have accompaning object with utilities, to make the illusion, that it is a separate class more real. This way, we end up with an implementation, that doesn’t have any repetitions, is easy to maintain by its creators, not bad to use with IDE and nearly impossible to read for newcomers.

Summary

In this article I wanted to show some examples of monads, show their similarities and build up some intuition before giving a formal definition. I’m aware that no amount of text, drawings and examples will help until one tries things out oneself.

Still, I believe that monad is easy to understand if you just write some code and embrace the fact that it’s just a thing you can flatMap - nothing deeper than that. All the powerful structures like IO, State or even ReverseState merely build up on concept that you can combine some stuff and keep control over how it will be combined. If you start imagining how monad are like burritos, you might trap yourself into looking for a deeper meaning that doesn’t exists.

Alternatively - if you are a OOP programmer - just think that monad is an interface ensuring that there is a flatMap operation, that has some contracts you care about only if you implement it. It should make your like easier.