A Tail-recursion-ification Example

This post walks through some of the thinking behind how we might approach the problem of converting a non-tail recursive function to a tail recursive one. The code we present also illustrates some interesting Scala techniques: implicit conversions of primitive classes to Ordered, type aliases and function passing. For those of you not already familiar with tail recursion, there are some great articles out there to get up to speed on the concept (for example, this one).

To summarize very briefly:

  • tail recursion is a characteristic of recursive methods wherein the very last action of any branch of the method is a recursive call to the method itself.
  • recursive functions that are not tail recursive have the annoying property of blowing out your stack because they save intermediate state in each stack frame — state which is then combined by the result of the recursive call. Given that stack memory is finite, and assuming that your input could be arbitrarily long (therefore requiring an arbitrarily large number of stack frames to store intermediate state) you will likely run into a stack overflow at some point.
  • Scala automatically optimizes tail recursive functions by converting the recursion into a loop. This article presents a really good simple example of how tail recursion in the source code is translated to a loop in the byte code.

Now that you are a tail recursion expert, let’s look at some code:

package scalaTry

class Merger[T]()(implicit e: T => Ordered[T]) {
  def compare(a: T, b: T): Unit = {
    if (a < b)
      println("a less")
    else
      println("a NOT less")
  }

  final def merge(a: List[T],
                  b: List[T]): List[T] = {
    (a,b) match{
      case (Nil,Nil) =>
        Nil
      case (a1::tail, Nil) =>
        a1::tail
      case (Nil, b1::tail) =>
        b1::tail
      case (a1::tail1, b1::tail2) =>
        if (a1 < b1) {
          a1::b1::merge(tail1,tail2)
        } else {
          b1::a1::merge(tail1,tail2)
        }
      case default =>
        throw new
            IllegalStateException(
              "Should never get here")
    }
  }

The merge method of the Merger class shown above recursively merges two sorted lists, a and b. The lowest level escape conditions are where one or both of the input lists are empty (Nil). In this case we either return Nil (where both are Nil), or we return the non-Nil list. If both lists are non-Nil, the logic is (conceptually) to create a two element list consisting of the heads of a and b (in their proper order), and append to this two element list the result of merging the tails of a and b (lines 22 and 24).

One interesting aspect of the Merger class, the use of Ordered, and the implicit conversion of T to an Ordered[T] highlighted on line 3, is not specifically related to tail recursion. This line genericizes (in Java-speak) the Merger class such that the merge method can accept Lists of any type T for which an implicit conversion to Ordering is available.

But let’s get back to tail recursion, or the lack of it in the merge method. As discussed above lines 22 and 24 are doing the conceptual equivalent of creating creating a two element list, then recursing on the tails of the input list. The two element list occupies memory on the stack until the recursive call to merge returns, at which point the two element list and the result of the recursive call are joined together. Clearly, more and more memory will be consumed as the input lists get larger.

So what to do? The merge2 function shown below shows one way we can avoid a stack overflow. In this function we explicitly create the list that is to be prepended to the result of merging the tails. But instead of keeping that list on the stack and joining it to the result of the recursive call on the tails, we explicitly pass it down to the recursive call as the prefixSoFar. At the point of each recursive call we know that the elements in prefixSoFar are less than or equal to the head items of either a or b (the variables a1 and b1). Thus when we create a new prefix (as on line 16), we know that prefixSoFar should come first, followed by either a1 then b1, or by b1, then a1 (as on line 20). When one or both of the input lists a and b are Nil, we return the prefixSo far, potentially tacking on any remaining elements in one of the non-Nil input lists (as on line 10). These remaining elements are guaranteed to be equal to or greater than any of the elements of the prefixSoFar (per the given Ordering).

A final comment on the merge2 method: note that Scala provides the handy annotation tailrec, used on line 2, to declare a function as tail recursive, and have the compiler double check that this declaration is correct. If you were to use this annotation right before the method declaration of a non-tail recursive function such as merge, you would get a compiler error.

  def merge2(a: List[T], b: List[T]): List[T] = {
    @scala.annotation.tailrec
    def merge2(a: List[T],
               b: List[T],
               prefixSoFar: List[T]): List[T] = {
      (a,b) match{
        case (Nil,Nil) =>
          prefixSoFar
        case (a1::tail, Nil) =>
          prefixSoFar ::: a1 :: tail
        case (Nil, b1::tail) =>
          prefixSoFar ::: b1 :: tail
        case (a1::tail1, b1::tail2) =>
          if (a1 < b1) {
            val prefix: List[T] =
              prefixSoFar ::: a1 :: b1 :: Nil
            merge2(tail1, tail2, prefix)
          } else {
            val prefix: List[T] =
              prefixSoFar ::: b1 :: a1 :: Nil
            merge2(tail1, tail2, prefix)
          }
        case default =>
          throw new
              IllegalStateException(
                "Should never get here")
      }
    }

    merge2(a, b, Nil)
  }

So, now we have a merge method that won’t blow out our stack. Next we need to test it, which we shall do with the code below. doMergeTest accepts two functions, one that merges two String lists, and another that merges two Int lists. We define the type aliases MergeStringLists and MergeIntLists that capture the signatures of these two functions. doMerge accepts arguments of the first signature as fun1, and the second signature as fun2. In the test ‘merging works’ we first pass in a function pair based on the merge function, and then pass in a function pair based on the tail recursive merge2 function. If you grab the project from this github repo and run it, my guess is that you would see all green. Try it out!

@RunWith(classOf[JUnitRunner])
class MergeTest extends AnyFunSuite {

  type MergeStringLists = (List[String], List[String]) => List[String]
  type MergeIntLists = (List[Int], List[Int]) => List[Int]

  def doMergeTest (fun1: MergeStringLists,
                   fun2: MergeIntLists): Unit = {
    var res = fun1(List[String](), List[String]())
    assert(res.equals(List[String]()))

    res = fun1(List[String](), "foo" :: List[String]())
    assert(res.equals(List[String]("foo")))


    var res2: List[Int] = fun2(2 :: List[Int](), 5 :: List[Int]())
    assert(res2.equals(List[Int](2, 5)))

    res2 = fun2(5 :: Nil, 2 :: Nil)
    assert(res2.equals(List[Int](2, 5)))

    res2 = fun2(5 :: 6 :: Nil, 6 :: Nil)
    assert(res2.equals(List[Int](5, 6, 6)))

    res2 = fun2(6 :: Nil, 5 :: 6 :: Nil)
    assert(res2.equals(List[Int](5, 6, 6)))


    res2 = fun2(5 :: 6 :: Nil, 4 :: 6 :: Nil)
    assert(res2.equals(List[Int](4, 5, 6, 6)))

    res2 = fun2(3 :: 6 :: Nil, 5 :: 6 :: Nil)
    assert(res2.equals(List[Int](3, 5, 6, 6)))
    ()
  }

  test("merging works") {
    val fun1: MergeStringLists =
      (list1: List[String], list2: List[String]) =>
        new Merger[String]().merge(list1, list2)
    val fun2: MergeIntLists =
      (list1: List[Int], list2: List[Int]) =>
        new Merger[Int]().merge(list1, list2)
    doMergeTest(fun1, fun2)

    val fun3: MergeStringLists =
      (list1: List[String], list2: List[String]) =>
        new Merger[String]().merge2(list1, list2)
    val fun4: MergeIntLists =
      (list1: List[Int], list2: List[Int]) =>
        new Merger[Int]().merge2(list1, list2)
    doMergeTest(fun3, fun4)
  }
}

Share this post

Leave a Reply

Your email address will not be published. Required fields are marked *