In the same way we can organize information in data structures, we can organize coding patterns into monads. Think for example in the case of lists, we use them when we need to have an ordered collection of values and we provide a set of useful methods to manipulate them such us iterators, maps, filters and folds. If we don’t need the order or we need to test for the presence of elements in the list, we tend to use sets, but we still keep maps, filters and folds. For different requirements, we use different data structures and we try to use the same methods to manipulate them for the things that are not specific of the structure we choose. In this article I will show a way to structure code using the Option monad.
Illustration program: The language of arithmetic expressions
Suppose now that we have to implement the following kind program. We have an arithmetic expression and we have to evaluate it. An arithmetic expression is either a number, the addition, subtraction, multiplication or division of two arithmetic expressions, we can model this in Scala as follows:
could be as follows:
sealed abstract class Expr case class Con(c : Int) extends Expr case class Add[E <: Expr](a : E, b : E) extends Expr case class Min[E <: Expr](a : E, b : E) extends Expr case class Mul[E <: Expr](a : E, b : E) extends Expr case class Div[E <: Expr](a : E, b : E) extends Expr
So far so good. Now we need to implement the method to evaluate this expressions. A first attempt could be as follows:
def eval(e : Expr) : Int = { e match { case Con(c) => c case Add(e1, e2) => eval(e1) + eval(e2) case Min(e1, e2) => eval(e1) - eval(e2) case Mul(e1, e2) => eval(e1) * eval(e2) case Div(e1, e2) => eval(e1) / eval(e2) } }
Lets see our code in action
scala> eval(Add(Con(1),Div(Con(3),Con(3)))) res0: Int = 2
This looks good, but there is a problem. If we evaluate a division we can get an exception. For example:
scala> eval(Div(Con(1),Con(0))) java.lang.ArithmeticException: / by zero at .eval(<console>:28) ... 32 elided
To avoid raising exceptions we can use the Option type. Instead of returning an integer we can return an Option[Int]. If we divide by zero at any point we return None, and if we compute a value v we return Some(v). Lets try to apply this modification to the code.
def eval2(e : Expr) : Option[Int] = { e match { case Con(c) => Some(c) case Add(e1, e2) => val leftExp = eval2(e1) val rightExp = eval2(e2) if(leftExp.isEmpty || rightExp.isEmpty) None else Some(leftExp.get + rightExp.get) case Min(e1, e2) => val leftExp = eval2(e1) val rightExp = eval2(e2) if(leftExp.isEmpty || rightExp.isEmpty) None else Some(leftExp.get - rightExp.get) case Mul(e1, e2) => val leftExp = eval2(e1) val rightExp = eval2(e2) if(leftExp.isEmpty || rightExp.isEmpty) None else Some(leftExp.get * rightExp.get) case Div(e1, e2) => val leftExp = eval2(e1) val rightExp = eval2(e2) if(leftExp.isEmpty || rightExp.isEmpty || rightExp.get == 0) None else Some(leftExp.get / rightExp.get) } }
We can see that this work as expected:
scala> eval2(Add(Con(2),Div(Con(1),Con(0)))) res3: Option[Int] = None scala> eval2(Add(Con(2),Div(Con(1),Con(1)))) res4: Option[Int] = Some(3)
but it is quite clear that the code should be improved. The structure of the computation we do is self-evident. We take values that may represent an error (None) and successful computations (those with Some). Every time we have one of this kind of computations we need to test if we are dealing with a success of failure. If we have a failure, we have to propagate it and if we have a success we have to extract the result.
This pattern will not only be present in our eval2 function, it will be present in any piece of code that calls it, and more generally, this is a pattern used in any code that deals with the Option class.
We can capture the pattern by using the for notation in Scala.
def eval3(e : Expr) : Option[Int] = { e match { case Con(c) => Some(c) case Add(e1, e2) => for(leftExp <- eval3(e1); rightExp <- eval3(e2)) yield leftExp + rightExp case Min(e1, e2) => for(leftExp <- eval3(e1); rightExp <- eval3(e2)) yield leftExp - rightExp case Mul(e1, e2) => for(leftExp <- eval3(e1); rightExp <- eval3(e2)) yield leftExp * rightExp case Div(e1, e2) => for(leftExp <- eval3(e1); rightExp <- eval3(e2); if rightExp != 0) yield leftExp / rightExp } }
This let us avoid a lot of the boilerplate and expose in a clear way the structure and logic of the code.