The Option monad in Scala

In the same way we can organize information in data structures, we can organize coding patterns into monads. Think for example in the case of lists, we use them when we need to have an ordered collection of values and we provide a set of useful methods to manipulate them such us iterators, maps, filters and folds. If we don’t need the order or we need to test for the presence of elements in the list, we tend to use sets, but we still keep maps, filters and folds. For different requirements, we use different data structures and we try to use the same methods to manipulate them for the things that are not specific of the structure we choose. In this article I will show a way to structure code using the Option monad.

Illustration program: The language of arithmetic expressions

Suppose now that we have to implement the following kind program. We have an arithmetic expression and we have to evaluate it. An arithmetic expression is either a number, the addition, subtraction, multiplication or division of two arithmetic expressions, we can model this in Scala as follows:

could be as follows:

sealed abstract class Expr
case class Con(c : Int) extends Expr
case class Add[E <: Expr](a : E, b : E) extends Expr
case class Min[E <: Expr](a : E, b : E) extends Expr
case class Mul[E <: Expr](a : E, b : E) extends Expr
case class Div[E <: Expr](a : E, b : E) extends Expr

So far so good. Now we need to implement the method to evaluate this expressions. A first attempt could be as follows:

def eval(e : Expr) : Int = {
 e match {
   case Con(c)      => c
   case Add(e1, e2) => eval(e1) + eval(e2)
   case Min(e1, e2) => eval(e1) - eval(e2)
   case Mul(e1, e2) => eval(e1) * eval(e2)
   case Div(e1, e2) => eval(e1) / eval(e2)
 }
}

Lets see our code in action

scala> eval(Add(Con(1),Div(Con(3),Con(3))))
res0: Int = 2

This looks good, but there is a problem. If we evaluate a division we can get an exception. For example:

scala> eval(Div(Con(1),Con(0)))
java.lang.ArithmeticException: / by zero
   at .eval(<console>:28)
     ... 32 elided

To avoid raising exceptions we can use the Option type. Instead of returning an integer we can return an Option[Int]. If we divide by zero at any point we return None, and if we compute a value v we return Some(v). Lets try to apply this modification to the code.

def eval2(e : Expr) : Option[Int] = {
  e match {
    case Con(c)      => Some(c)
    case Add(e1, e2) =>
      val leftExp  = eval2(e1)
      val rightExp = eval2(e2)
      if(leftExp.isEmpty || rightExp.isEmpty) None
      else Some(leftExp.get + rightExp.get)
    case Min(e1, e2) =>
      val leftExp  = eval2(e1)
      val rightExp = eval2(e2)
      if(leftExp.isEmpty || rightExp.isEmpty) None
      else Some(leftExp.get - rightExp.get)
    case Mul(e1, e2) =>
      val leftExp  = eval2(e1)
      val rightExp = eval2(e2)
      if(leftExp.isEmpty || rightExp.isEmpty) None
      else Some(leftExp.get * rightExp.get)
    case Div(e1, e2) =>
      val leftExp  = eval2(e1)
      val rightExp = eval2(e2)
      if(leftExp.isEmpty || rightExp.isEmpty || rightExp.get == 0) None
      else Some(leftExp.get / rightExp.get)
  }
}

We can see that this work as expected:

scala> eval2(Add(Con(2),Div(Con(1),Con(0))))
res3: Option[Int] = None

scala> eval2(Add(Con(2),Div(Con(1),Con(1))))
res4: Option[Int] = Some(3)

but it is quite clear that the code should be improved. The structure of the computation we do is self-evident. We take values that may represent an error (None) and successful computations (those with Some). Every time we have one of this kind of computations we need to test if we are dealing with a success of failure. If we have a failure, we have to propagate it and if we have a success we have to extract the result.

This pattern will not only be present in our eval2 function, it will be present in any piece of code that calls it, and more generally, this is a pattern used in any code that deals with the Option class.

We can capture the pattern by using the for notation in Scala.

def eval3(e : Expr) : Option[Int] = {
  e match {
    case Con(c)      => Some(c)
    case Add(e1, e2) =>
      for(leftExp  <- eval3(e1);
          rightExp <- eval3(e2))
      yield leftExp + rightExp
    case Min(e1, e2) =>
      for(leftExp  <- eval3(e1);
          rightExp <- eval3(e2))
      yield leftExp - rightExp
    case Mul(e1, e2) =>
      for(leftExp  <- eval3(e1);
          rightExp <- eval3(e2))
      yield leftExp * rightExp
    case Div(e1, e2) =>
      for(leftExp  <- eval3(e1);
          rightExp <- eval3(e2);
          if rightExp != 0)
      yield leftExp / rightExp
  }
}

This let us avoid a lot of the boilerplate and expose in a clear way the structure and logic of the code.

Leave a comment