Functional Programming

Tail Recursion, CPS, and a General Method for Converting Recursion to Loops

This article introduces what tail recursion is, what CPS is, and how to apply the first two concepts to convert recursion into loops.

Published

This article introduces what tail recursion is, what CPS is, and how to apply the first two concepts to convert recursion into loops.

Tail Recursion

On Wikipedia, tail recursion is defined as follows:

Tail recursion is a programming technique. Recursive functions are functions that call themselves within the function. If, in a recursive function, the result returned by the recursive call is always returned directly, it is called tail recursion.

Compared with ordinary recursion, tail recursion has one advantage [tail-recursion-optimization]: because there are no other operations after the recursive call finishes, and the result of the recursive call is returned directly as the result of the current call, there is no need to save the current call stack frame. During the recursive call, the variables on the current stack frame can be modified directly.

In general, tail recursion can be considered equivalent to a loop.

Let us look at an example: the factorial function.

int factorial(int n) {
    assert (n >= 0);
    if (n == 0) {
        return 1;
    } else {
        return n * factorial(n - 1);
    }
}

This function is not tail-recursive, because after calling itself it still has to perform some operations, such as multiplying by n. But we can easily rewrite it as tail recursion:

static int factorial_recursive_aux(int acc, int n) {
    if (n) {
        return factorial_recursive_aux(acc * n, n - 1);
    } else {
        return acc;
    }
}

int factorial_recursive(int n) {
    assert (n >= 0);
    return factorial_recursive_aux(1, n);
}

By introducing a parameter acc, we can move the computation that originally occurred after the recursive call forward to the moment of the recursive call, thereby transforming an ordinary recursive call into a tail-recursive call. According to the optimization principle [tail-recursion-optimization] mentioned above, we can also optimize this tail-recursive function as follows:

static int factorial_recursive_opt(int acc, int n) {
FRO_START:
    if (n) {
        acc *= n;
        n--;
        goto FRO_START;
    } else {
        return acc;
    }
}

It is easy to see that such a function is equivalent to the following loop:

int acc = 1;
for (; n; n--) {
    acc *= n;
}

Note that although using a while (n--) here would be more concise, I deliberately wrote a for loop to correspond one-to-one with the code above. This is because, in some cases, the steps for checking and changing variables are more complex, and then using a while loop becomes rather difficult.

Unfortunately, converting a recursive call into a tail-recursive call is not always this easy. For example, consider the following:

int fibonacci(int n) {
    if (n <= 0) {
        return 0;
    } else if (n == 1 || n == 2) {
        return 1;
    } else {
        return fibonacci(n - 1) + fibonacci(n - 2);
    }
}

Another example would be the more complex merge sort and quicksort.

Because the following content is easier to describe in a functional language, here is the F# code corresponding to the Fibonacci function first:

let rec fibonacci n =
    match n with
    | _ when n <= 0 -> 0
    | 1 | 2 -> 1
    | _ -> (fibonacci (n - 1)) + (fibonacci (n - 2))
;;

Continuation Passing Style (CPS)

On Wikipedia, CPS is defined as follows:

In functional programming, continuation-passing style (CPS) is a style of programming in which control is passed explicitly in the form of a continuation.

Theory is always so abstract; we might as well look at an example:

let rec factorial n =
    match n with
    | _ when n < 0 -> failwith "n must not be less than 0."
    | 0 -> 1
    | _ -> n * (factorial (n - 1))
;;

let factorial_cps n =
    let rec factorial_aux n cont =
        match n with
        | 0 -> cont 1
        | _ -> factorial_aux (n - 1) (fun acc -> cont (acc * n))
    in
    if n < 0 then failwith "n must not be less than 0."
    else factorial_aux n (fun x -> x)
;;

As you can see, we take the operation that originally would have happened after the recursive call—​multiplying by n--and, by means of a Continuation, keep it to be called at some later moment (specifically in this example, when n = 0). In addition, on each recursive call, the Continuation is passed as one of the parameters. This is Continuation-passing style.

The benefit of CPS is that it can transform an ordinary recursive call that was not originally tail-recursive into tail-recursive form. At this point, the method for converting general recursion into loops becomes relatively clear.

Next, let us look at a more complex example [lambda]: Fibonacci. (Here we do not consider optimizing it with Dynamic Programming.)

let fibonacci_cps n =
    let rec fibonacci_aux n cont =
        match n with
        | 1 | 2 -> cont 1
        | _ -> fibonacci_aux
                    (n - 1)
                    (fun acc1 -> fibonacci_aux
                                    (n - 2)
                                    (fun acc2 -> cont (acc1 + acc2)))
    in
    match n with
    | _ when n <= 0 -> 0
    | 1 | 2 -> 1
    | _ -> fibonacci_aux n (fun x -> x)
;;

Looking back at the fibonacci function we wrote earlier, it is easy to see that the only difference is that we use the fibonacci_aux function to accomplish the main functionality (recursion) that the previous fibonacci function handled. The differences between fibonacci_aux and the original fibonacci function, apart from not handling the n ⇐ 0 case, mainly lie in using anonymous functions to pass, as parameters, what needs to be done after the current statement finishes. In this way, by passing the operations to be performed after the recursive call through parameters, and then performing the computation after all recursive steps have finished, we have converted a complex ordinary recursive call into a tail-recursive call.

For readers who do not understand this point, try manually expanding the case where n = 4.

Below, abbreviate fibonacci_aux as f
   f 4 (fun x -> x)
=> f 3 (fun x -> f 2 (fun y -> (fun xx -> xx) (x + y)))
=> f 3 (fun x -> f 2 (fun y -> x + y))
=> f 3 (fun x -> (fun y -> x + y) 1)
=> f 3 (fun x -> x + 1)
=> f 2 (fun x -> f 1 (fun y -> (fun xx -> xx + 1) (x + y)))
=> f 2 (fun x -> (fun y -> (fun xx -> xx + 1) (x + y)) 1)
=> f 2 (fun x -> (fun xx -> xx + 1) (x + 1))
=> (fun x -> (fun xx -> xx + 1) (x + 1)) 1
=> (fun xx -> xx + 1) 2
=> 3

Converting Recursion to Loops

At this point, we at least know how to convert a general recursive call into a tail-recursive call in a functional programming language. Next, we discuss how to do it in C++ and C.

An Attempt in C++

For C++11, because Lambda [CXX11-lambda] was introduced, in theory the difficulty of doing this is about the same as in a functional programming language.

template<typename _FContTy>
static int fibonacci_aux(int n, _FContTy cont)
{
    if (n == 1 || n == 2) {
        return cont(1);
    } else {
        return fibonacci_aux(n - 1, [=](int acc1) {
            return fibonacci_aux(n - 2, [=](int acc2) {
                return cont(acc1 + acc2);
            });
        });
    }
}

int fibonacci(int n)
{
    return fibonacci_aux(n, [](int x) { return x; });
}

But in practice, this code cannot be compiled by either the latest version of Visual Studio [latest-vs] or the latest version of G++ [latest-gxx] (it manifests as a long period of no response, with memory usage increasing wildly). However, this is not a compiler bug. It is because, when this recursive program is transformed into functors during compilation, it produces infinite recursion. Interested readers can discover this by manually simulating the compiler’s work. Moreover, in practice, even in the best case, two fibonacci_aux functions would be generated (because two different lambdas are used as the parameter cont), so the generated code is not actually a tail-recursive call (but rather an indirect tail-recursive call).

The compiler is not smart enough to handle this well, but we can handle this situation manually. Using inheritance and polymorphism, we can unify the type of cont.

#include <memory>

using namespace std;

class ContF {
protected:
    virtual int Imp(int x) const = 0;
public:
    int operator()(int x) const { return Imp(x); }
};

static int f(int n, shared_ptr<ContF> cont);

class ContIdentityF : public ContF {
protected:
    virtual int Imp(int x) const override { return x; }
};

class ContInnerF : public ContF {
private:
    int x;
    shared_ptr<ContF> cont;
protected:
    virtual int Imp(int y) const override {
        return cont->operator() (x + y);
    }
public:
    ContInnerF(int x, shared_ptr<ContF> cont) : x(x), cont(cont) { }
};

class ContOuterF : public ContF {
private:
    int n;
    shared_ptr<ContF> cont;
protected:
    virtual int Imp(int x) const override {
        return f(n - 2, make_shared<ContInnerF>(x, cont));
    }
public:
    ContOuterF(int n, shared_ptr<ContF> cont) : n(n), cont(cont) { }
};

static int f(int n, shared_ptr<ContF> cont) {
    if (n == 1 || n == 2) {
        return cont->operator() (1);
    } else {
        return f(n - 1, make_shared<ContOuterF>(n, cont));
    }
}

int fibonacci(int n) {
    if (n <= 0) {
        return 0;
    } else {
        return f(n, make_shared<ContIdentityF>());
    }
}

But even so, because we call the function f inside the functor, the compiler still will not "intelligently" optimize this indirect tail-recursive call for us. Therefore, we need to adjust it one step further: separate the function and data inside the functor, creating conditions for further integrating the functions.

#include <memory>

using namespace std;

enum class ContDataType { Identity, Outer, Inner };

struct ContData {
    const ContDataType type;
    const int x;
    const shared_ptr<ContData> p;

    ContData(ContDataType type, int x, shared_ptr<ContData> p)
        : type(type), x(x), p(p)
    {
        _ASSERT(type == ContDataType::Outer
             || type == ContDataType::Inner);
    }
    ContData()
        : type(ContDataType::Identity), x(0), p(nullptr)
    { }
};

static int f(int n, shared_ptr<ContData> data);

static int f_cont(int x, shared_ptr<ContData> data) {
    switch (data->type)
    {
    case ContDataType::Identity:
        return x;
    case ContDataType::Outer:
        return f(
            data->x - 2,
            make_shared<ContData>(ContDataType::Inner, x, data->p));
    case ContDataType::Inner:
        return f_cont(data->x + x, data->p);
    }
}

static int f(int n, shared_ptr<ContData> cont) {
    if (n == 1 || n == 2) {
        return f_cont(1, cont);
    } else {
        return f(
            n - 1,
            make_shared<ContData>(ContDataType::Outer, n, cont));
    }
}

int fib(int n) {
    if (n <= 0) {
        return 0;
    } else {
        return f(n, make_shared<ContData>());
    }
}

Next, we manually merge function f and function f_cont, and perform tail-recursion optimization.

static int f(int n, shared_ptr<ContData> cont) {
F_START:
    if (n == 1 || n == 2) {
        n = 1;
        goto CONT_START;
    } else {
        cont = make_shared<ContData>(ContDataType::Outer, n, cont);
        n--;
        goto F_START;
    }

CONT_START:
    switch (cont->type)
    {
    case ContDataType::Identity:
        return n;
    case ContDataType::Outer: {
            int nn = n;
            n = cont->x - 2;
            cont = make_shared<ContData>(ContDataType::Inner, nn, cont->p);
            goto F_START;
        }
    case ContDataType::Inner:
        n = cont->x + n;
        cont = cont->p;
        goto CONT_START;
    }
}

At this point, we have completed the process of converting a general recursive-call function into a tail-recursive-call function in the C++ version.

C Implementation

C is even more difficult, because C has neither things like shared_ptr nor struct with constructors. But these are no longer major problems. Below is a C implementation.

enum { Identity, Outer, Inner };

struct ContData {
    int type;
    int x;
    int pIdx;
};

static const int STACK_SIZE = 10000;

static int f(int n, ContData stack[], int stackTop) {
F_START:
    if (n == 1 || n == 2) {
        n = 1;
        goto CONT_START;
    } else {
        stackTop++;
        stack[stackTop].type = Outer;
        stack[stackTop].x = n;
        stack[stackTop].pIdx = stackTop - 1;
        n--;
        goto F_START;
    }

CONT_START:
    switch (stack[stackTop].type)
    {
    case Identity:
        return n;
    case Outer: {
            int nn = n;
            n = stack[stackTop].x - 2;
            stackTop++;
            stack[stackTop].type = Inner;
            stack[stackTop].x = nn;
            stack[stackTop].pIdx = stack[stackTop - 1].pIdx;
            goto F_START;
        }
    case Inner:
        n = stack[stackTop].x + n;
        stackTop = stack[stackTop].pIdx;
        goto CONT_START;
    }
}

int fib(int n) {
    if (n <= 0) {
        return 0;
    } else {
        ContData stack[STACK_SIZE];
        stack[0].type = Identity;
        return f(n, stack, 0);
    }
}
  • [tail-recursion-optimization]: Tail call elimination allows procedure calls in tail position to be implemented as efficiently as goto statements, thus allowing efficient structured programming. In the words of Guy L. Steele "in general procedure calls may be usefully thought of as GOTO statements which also pass parameters, and can be uniformly coded as [machine code] JUMP instructions". See wikipedia.

  • [lambda]: F# uses the fun keyword to create an anonymous function. See MSDN.

  • [CXX11-lambda]: See C++11 FAQ.

  • [latest-vs]: Visual Studio 2012 Update 3. CL version 17.00.60610.1.

  • [latest-gxx]: g++ 4:4.7.2-1 in Debian jessie.