Introduction to tail recursion

In functional programming we tend to use recursion as a way to solve problems that would require loops. Recursion let us reason in terms of induction, leading to reduce errors and helping to keep away side effects as changing the state of variables. However, recursion is also consider inefficient because making function call is more expensive (in terms of time and memory) than jumping to a new iteration of a loop.

Today we will see how to solve a problem using recursion and avoiding this inefficiency at the same time.

We will illustrate the technique with a function that calculates the factorial of an integer. We start with the trivial recursive solution:

def factorial(n : Int) : Int = {
  if(n == 0) 1
  else n * factorial(n-1)
}

The code is so simple that it is clear that there is no errors. However, each recursive call to factorial will create an unnecessary overhead. To optimize the code, we would like to write a loop instead of a recursive function. In this example, the loop is easy to write from scratch, but lets try to derive the loop from the recursive function.

For this, we will rewrite in tail recursive fashion. This means that if we make a recursive call, this should be made as the last statement of the function. In the current version of the factorial function, we call factorial in the else branch. However, it is not the last statement executed in that branch, because the result of the call will be multiplied by n.

So, we need to rearrange the code. The usual way to do this is by adding an accumulator, i.e. a new parameter to the function as we can see bellow.

def factorial(n : Int) : Int = {
  def factorialAcc(acc : Int, n : Int) : Int = {
    if (n == 0) acc
    else factorialAcc(n * acc, n - 1)
  }
  factorialAcc(1, n)
}

Now, factorialAcc is tail recursive and calculates the factorial in the accumulator.

There is a systematic way to transform factorialAcc to a loop. The parameter n will act as the loop counter and the computation in the accumulator will be the code inside the loop, and the code in the second parameter compute the loop variable update. We can see this solution below.

def factorial(n : Int) : Int = {
  var i = n
  var solution = 1
  while(i != 0) {
    solution = i*solution
    i = i - 1
  }
  solution
}

This may seem tedious, because we could write a more intuitive loop from scratch with less work. However, there are good news. You don’t need to make the transformation from a tail recursive function into a loop! In languages like Scala, you can use the @tailrec annotation. If you write a tail recursive function and use this annotation, the compiler will try to convert it automatically into a loop. If the compiler fails, it will tell you with a  compiling error. So, we can just write this solution to get optimize code:

import scala.annotation.tailrec

def factorial(n : Int) : Int = {
  @tailrec def factorialAcc(acc : Int, n : Int) : Int = {
    if (n == 0) acc
    else factorialAcc(n * acc, n - 1)
  }
  factorialAcc(1, n)
}

This will compile correctly.

Note that if we try to use the @tailrec annotation in the first definition

import scala.annotation.tailrec

@tailrec
def factorial(n : Int) : Int = {
  if(n == 0) 1
  else n * factorial(n-1)
}

we gave of factorial, we get an error:

<console>:15: error: could not optimize @tailrec annotated
              method factorial: it contains a recursive
              call not in tail position
                  else n * factorial(n-1)
                         ^

As a summary, you can now solve problems thinking in terms of recursion and then optimize the performance by applying this technique.

I hope you enjoyed the introduction.

Leave a comment