Study note of the Monad concept in Scala. It is based onthe book: Functional Programming, Simplified (FPS). As suggested by the author, one has to understand the state monad to really understand the monad concept. It helps to implement a lazy IO monad using the state monad concepts. Monad transformer is the last piece of the puzzle.

1 The IO Monad in the FPS book and in Cats Effect

The main conculsion is that Scala doesn’t define Monad, it uses types that define map andflatMap methods thus they can be used in a for-compreshension.

The IO monad has two benefits: as an unsafe marker and using in for expression. It is defined as the following:

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
class IO[A] private (constructorCodeBlock: => A) {

  def run = constructorCodeBlock

  def flatMap[B](bind: A => IO[B]): IO[B] = {
    val result1: IO[B] = bind(run)
    val result2: B = result1.run
    IO(result2)
  }

  def map[B](f: A => B): IO[B] = flatMap(a => IO(f(a)))

}

object IO {
  def apply[A](a: => A): IO[A] = new IO(a)

  def getLine: IO[String] = IO(scala.io.StdIn.readLine())
  def putStrLn(s: String): IO[Unit] = IO(println(s))
}


object Main {
  def main(args: Array[String]) = {
    import IO2._

    for {
      _ <- putStrLn("What's your name? ")
      name <- getLine
      _ <- putStrLn(("Welcome " + name))
    } yield ()
  }
}

The author defined it using an early evaluation, whenever flatMap or map is called, the IO operation passed in the constrctor is executed, as shown in bind(run) – the by-name argument is called first and its result is passed as an input to the bind function. The result of bind is executed again.

The cats effect implementation separates the composition and execution.

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
import cats.effect.IO

object Main {
  def main(args: Array[String]) = {

    val program: IO[Unit] = for {
      _ <- IO { println("Welcome to Scala! What's your name?") }
      name <- IO { scala.io.StdIn.readLine }
      nameUC = name.toUpperCase
      _ <- IO { println(s"Well hello, $nameUC!") }
    } yield ()

    program.unsafeRunSync()
  }
}

The IO functions are executed when program.unsafeRunSync() is executed.

2 The State Monad

A state monad has three parts: a state monad, a concrete state and its application. Following is a demo implementation of a state monad.

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
case class StateMonad[S, A](run: S => (S, A)) {
  def flatMap[B](g: A => StateMonad[S, B]) = StateMonad { (state: S) =>
    val (newState, newValue) = run(state)
    val monad = g(newValue)
    monad.run(newState)
  }

  def map[B](f: A => B) = flatMap(a => StateMonad.point(f(a)))
}

object StateMonad {
  def point[S, A](a: A): StateMonad[S, A] = StateMonad((s: S) => (s, a))
}

First, the StateMonad is a case class that has one constructor parameter: run: S => (S, A) where S is the state type and A is a computed value. It is a function that takes an initial state and creates a new state and a new result. An instance of the StateMonad defines how state changes and how result is calculated. Multiple instances define a sequence of changes. To compose these changes, we define two typical monad methods: flatMap and map.

The flatMap method is the essential method that defines the composition behavior. It is important to konw that this method simple create a StateMonad instance from the function parameter, nothing else is executed. The logic of this newly created monad is defined by its run constructor argument. The run method first executes the run method of the current monad instance to create a new state and a new value, then use the new value and the flatMap function parameter to create a new monad instance. The function parameter could be one of two value: if it is the last step, the function is one generated from the map method and it is done by running the changed state and the map function result. Otherwise, it is the result of the flatMap method of the next monad instance, then it runs the new instance’s run method.

In the simplest case, there is only one monad instance and the map method is called. Only two actons are performed: the monad instance calls its run method, then map the value. The result is the changed state and the mapped value. The demo code is as the following:

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
case class State(value: Int)

object State {
  def change(amount: Int) = StateMonad[State, Int] { state: State =>
    val newValue = state.value + amount
    val newState = State(newValue)
    (newState, newValue)
  }
}

object Main {
  def main(args: Array[String]) = {
    val singleChange: StateMonad[State, Int] = State.change(7)

    // same as a map call
    // val plan = singleChange.map(_ * 10)
    val plan: StateMonad[State, Int] = for {
      a <- singleChange
    } yield (a * 10)

    val initialState = State(0)
    val (state, result) = plan.run(initialState)

    println(s"State: ${state}, Result: ${result}")
    // State: State(7), Result: 70
  }
}

If there are mutiple steps, all intermediate steps in a for expression just creates a new monad instance whose run logic has three steps:

  • execute the run to generate a new state and new value.
  • call the next expression in for, actully it just create a new monad instance by calling the next monad’s flatMap.
  • call the run method of the newly created monad with new state. The newly created monad links to its next monad instance in the prevsious flatMap call.

The demo code is as the following:

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
object Main {
  def main(args: Array[String]) = {

    val m1: StateMonad[State, Int] = State.change(10)
    val m2: StateMonad[State, Int] = State.change(20)
    val m3: StateMonad[State, Int] = State.change(30)

    // same as the following
    // val plan = m1.flatMap(_ => m2.flatMap(last => m3.map(_ * 10)))
    val plan: StateMonad[State, Int] = for {
      _ <- m1
      - <- m2
      last <- m3
    } yield last * 10

    val initialState = State(0)
    val (state, result) = plan.run(initialState)

    println(s"State: ${state}, Result: ${result}")
    // State: State(60), Result: 600
  }
}

The benefits of the state monad are

  • on need to pass the state around, the states are passed inside flatMap and map.
  • composition by for experession is simple.
  • the plan and execution are separated thus multiple plans can be composed and reused.

The composition of multiple plans is as the following:

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
object Main {
  def main(args: Array[String]) = {

    val m1: StateMonad[State, Int] = State.change(10)
    val m2: StateMonad[State, Int] = State.change(20)
    val m3: StateMonad[State, Int] = State.change(30)

    // same as the following
    // val plan = m1.flatMap(_ => m2.flatMap(last => m3.map(_ * 10)))
    val plan: StateMonad[State, Int] = for {
      _ <- m1
      - <- m2
      last <- m3
    } yield last * 10

    val plan2 = for {
      _ <- plan
      last2 <- plan
    } yield last2

    val initialState = State(0)
    val (state, result) = plan2.run(initialState)

    println(s"State: ${state}, Result: ${result}")
    // State: State(120), Result: 1200
  }
}

The cons are more code setup and not easy to understand.

3 A Lazy IO Monad

Based on the state code, the following is an implementation of a lazy IO Monad.

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
class IOMonad[A] private (runParameter: => A) {
  def run = runParameter
  def flatMap[B](g: A => IOMonad[B]) = IOMonad[B] {
    val newValue = run
    val monad: IOMonad[B] = g(newValue)
    monad.run
  }

  def map[B](f: A => B) = flatMap(a => IOMonad.point(f(a)))
}

object IOMonad {
  def apply[A](run: => A) = new IOMonad[A](run)
  def point[A](a: A): IOMonad[A] = IOMonad(a)

  def putStrLn(message: String) = IOMonad[Unit] {
    println(message)
  }

  def getLine = IOMonad[String] {
    scala.io.StdIn.readLine()
  }
}

import IOMonad._

object Main {
  def main(args: Array[String]) = {

    val plan = for {
      _ <- putStrLn("First name?")
      firstName <- getLine
      _ <- putStrLn(s"Last name?")
      lastName <- getLine
      _ <- putStrLn(s"First: $firstName, Last: $lastName")
    } yield ()

    println("before run")
    plan.run

  }
}

Following are implementation details:

  • A case class cannot have a by-name parameter, therefore, the monad is defined as a regular class with a run method to access the by-name constructor parameter. The companion object defines the apply method to create a new instance.
  • The flatMap just create a monad instance. The instance’s constructor parameter is defined to run the current instance’s constructor parameter, generate a new monad instance and then run the new instance’s code. No IO functions areexecuted until the run field of the Monad is accessed.
  • Wrap the println and readLin into the IO Monad.

4 Monad Transfomers

A monad transformer can stack its own effects on another monad. Not all monads have monad transfomers. To define a monad transfomer StateT, we need to define a Monad trait that can be extended by other monads to be stacked. The implementation of a monad deterimine the existence of a monad transfomer. For example, it is easy to change the state monad as a monad transformer, but it is impossible to make a strict IO monad a monad transfomer.

First, there is a forumal defination of Monad as the following:

1
2
3
4
5
trait Monad[F[_]] {
  def lift[A](a: A): F[A]
  def flatMap[A, B](ma: F[A])(f: A => F[B]): F[B]
  def map[A, B](ma: F[A])(f: A => B): F[B] = flatMap(ma)(a => lift[B](f(a)))
}

The F[_] is a type constrctuor representing a type once it is constructed with a type parameter, for example, an IO monad. An implementation of Monad[F] needs to implement lift and flatMap.

The source code of a state monad transformer is as the following:

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
case class StateT[F[_], S, A](run: S => F[(S, A)]) {
  def flatMap[B](
      g: A => StateT[F, S, B]
  )(implicit M: Monad[F]): StateT[F, S, B] = StateT { (s0: S) =>
    {
      M.flatMap(run(s0)) {
        case (s1, a) => g(a).run(s1)
      }
    }
  }

  def map[B](f: A => B)(implicit M: Monad[F]): StateT[F, S, B] =
    flatMap(a => StateT.point(f(a)))
}

object StateT {
  def point[M[_], S, A](v: A)(implicit M: Monad[M]): StateT[M, S, A] =
    StateT(s => M.lift(s, v))

  def lift[F[_], S, A](fa: F[A])(implicit M: Monad[F]): StateT[F, S, A] =
    StateT { (s: S) => M.map(fa)(a => (s, a)) }
}

The StateT has a state of type S and the result of a state change is of a type of A. For a type F[_] to be used, there must be an implicit value of M of type Monad[F]. The value is used to implement the four methods of StateT: flatMap calls M.flatMap() to chain the actions, point calls M.lift() to covnert a value of A to StateT type. The map() method is implemented via flatMap and pioint. The lift method calls M.map to lift a F[_] value to a StateT. Simply, the StateT uses an implicit Monad[F] to drive the execution.

There are two possible types of functions: the StateT methods and F[_] methods. StateT methods returns StateT and can be used in for expression directly. However, you need the above lift method to convert F[_] method results to a StateT value.

To use the above monad transformer, you need to define a F[_] and an implicit Monad[F]. The folowing is an example for the IO operations:

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
class MyIO[A] private (runParameter: => A) {
  def run = runParameter
}

object MyIO {
  def apply[A](run: => A) = new MyIO[A](run)

  def putStrLn(message: String) = MyIO[Unit] {
    println(message)
  }

  def getLine = MyIO[String] {
    scala.io.StdIn.readLine()
  }

  implicit object MyIOMonad extends Monad[MyIO] {
    def lift[A](a: A): MyIO[A] = MyIO(a)

    def flatMap[A, B](fa: MyIO[A])(g: A => MyIO[B]): MyIO[B] = MyIO[B] {
      val newValue = fa.run
      val monad: MyIO[B] = g(newValue)
      monad.run
    }
  }
}

The following is an applicatoin of the StateT and MyIO constructs:

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
object Main {
  def main(args: Array[String]): Unit = {

    // for its lift method
    import StateT._

    case class IntState(i: Int)

    def toInt(s: String): Int = {
      try {
        s.toInt
      } catch {
        case e: NumberFormatException => 0
      }
    }

    def add(i: Int) = StateT[MyIO, IntState, Int] { oldState =>
      val newValue = oldState.i + i
      val newState = IntState(newValue)
      MyIO(newState, newValue)
    }

    def multiply(i: Int) = StateT[MyIO, IntState, Int] { oldState =>
      val newValue = oldState.i * i
      val newState = IntState(newValue)
      MyIO(newState, newValue)
    }

    val plan = for {
      _ <- lift(MyIO.putStrLn("give me an int: "))
      input <- lift(MyIO.getLine)
      number <- lift(MyIO(toInt(input)))
      _ <- add(number)
      x <- multiply(2)
    } yield x

    val init = IntState(1)
    val result = plan.run(init).run
    println(result)
  }
}

// the output

// give me an int:
// 20
// (IntState(42),42)