Saturday, April 21, 2012

Intro to Functional Programming (with examples in Java) - Part 4

Immutable data structures

In the past few posts, we've seen some interesting things that can be done with first-class and higher-order functions. However, I've glossed over the fact that functional programming also tends to rely on parameters being immutable. For example, in the previous post, memoization relies on an underlying HashMap, where the keys are function arguments. We assumed that the arguments implemented reasonable hashCode and equals methods, but things fall apart if those function arguments can be modified after they've been memoized. (In particular, their hashCode method should return a new value, but the existing memoization HashMap has already placed them in a bucket based on their previous hashCode.) Since functional programming dictates that the same function called with the same arguments should always return the same value, one of the easiest ways to guarantee that is to ensure that inputs and outputs are immutable.

Immutability also makes it easier to guarantee thread safety. If a given value is immutable, then any thread can access it, and read from it, without fearing that another thread may modify it (since it cannot be modified).

In some of my first few posts, I described an implementation of immutable binary trees. In this post, we're going to look at some simpler immutable data structures, namely ImmutableLists and Options.

Immutable Lists

For lists, we're going to implement a singly linked list structure that supports only two "modification" operations, prepend and tail. These add a new element at the head of the list, and return the list without the head element, respectively. These are both O(1) operations, since prepend simply creates a new head node whose tail is the previous list, while tail returns the list pointed to by the tail of the head node. Neither operation modifies the existing list, but rather returns "new" list reference (where prepend creates a new list reference, and tail returns a list reference that already exists). To extract that actual elements from the list, we'll have a single operation, head that returns the element at the front of the list. Readers familiar with Lisp should recognize head and tail as car and cdr. For convenient list traversal, we'll use the familiar Java Iterable interface, so we can use the friendly Java 1.5 "foreach" syntax. Also, for convenience, we'll define an isEmpty method.

As with the immutable binary trees, we're going to take advantage of immutability to assert that there is only one empty list (regardless of type arguments), and create a specialized singleton subclass for it, called EmptyList. Everything else will be a NonEmptyList. Here is the code:

That's not terribly complicated, but also a little unpleasant to use. In particular, the only way we have of creating a list is via the nil() factory method, which returns an empty list, followed by a series of prepend() operations. Let's add a factory method that takes a varargs parameter to produce a list. Since we only have the prepend operation at our disposal, we'll need to iterate through the varargs array backwards:

Let's create a quick test to confirm that the list factory method and list iteration actually work:

Now, since it's a frequently used operation, let's add a size method. The size of a list can be defined recursively, as follows: the size of the empty list is 0, the size of a non-empty list is the size of its tail plus 1. As a first try, we'll define size exactly that way:

That's nice and simple. Let's create a unit test for it:

Uh-oh -- at least on my machine, that test triggered a StackOverflowError. If it doesn't trigger a StackOverflowError on your machine (which should be unlikely -- the Java stack usually overflows around 50k calls), try increasing the value of the size local variable in the test. The problem is that our recursive definition is triggering 100k recursive calls to size(). But "Michael," I hear you say, "aren't recursive calls a fundamental keystone of functional programming?" Well, yes, but even in Scheme, Haskell, or Erlang, I believe this particular implementation would blow the stack. For those languages, the problem is that our method is not tail-recursive. A tail-recursive method/function is one where any recursive call is the last operation that executes. Since we add 1 to the result of the recursive call to size, that addition is the last operation to execute. Many functional languages implement what's called tail-call optimization, which basically turns tail-calls into a goto that returns to the beginning of the method. To clarify, let's first implement a tail-recursive version of size:

In this case, we've implemented a helper method that keeps a so-called accumulator, that keeps track of the size up to now, in the recursive call stack. The recursive call to sizeHelper is simply modifying the input parameters, and jumping back to the beginning of the method (which is why tail-call optimization is so easy to implement). Unfortunately, Java compilers do not traditionally implement tail-call optimization, so this code will still cause a StackOverflowError on our unit test. Instead, we can simulate tail-call optimization by optimizing by hand:

If Java did support tail-call optimization, it would produce roughly the same bytecode as the code above. The code above passes the unit test with flying colours. That said, because Java does not support tail-call optimization, we are probably better off inlining sizeHelper to produce the following:

Note that Scala, a JVM language, would actually optimize the previous tail-call (by producing the goto bytecode to return to the start of the helper method), so we wouldn't have to optimize it into a while loop. That said, if you come from an object-oriented background, the optimized version may actually be more intuitive. Where Scala is not able to match "traditional" functional programming languages is with "mutual tail-calls", where function a tail-calls function b, which tail-calls function a. I believe you could implement this by hand in C by longjmping to the beginning of the body of the other function/procedure, but my understanding is that the JVM's builtin security requirements prevent you from gotoing to code outside the current method. Basically, the goto bytecode is only designed to make control statements (if, while, for, switch, etc.) work within the current method, I think. If I recall correctly, Clojure (another JVM language) is able to produce optimized mutual tail-calls if you use "trampolines", but I have no idea how those work at the bytecode level.

For another example of a function that we can implement with tail-recursion (and would be implemented with tail-recursion in languages that lack while loops), consider reversing a list. When we reverse a list, the first element becomes the last element, then we prepend the second element to that first element, then prepend the third element to that, etc. As a recursive function, we pass in the original list and an empty list as the initial accumulator value. We prepend the tail to the accumulator, and recurse with the tail of the original list and the new accumulator value:

As with size, this implementation of reverse will overflow the stack on large lists. Also, like with size, we can perform the tail-call optimization by hand in almost exactly the same way:

Below, we will make extensive use of reverse, and use while loops to implement several other methods that would traditionally (in the functional world) be written with tail-recursive functions.

Anyway, let's move on to an even simpler collection type, which essentially makes null obsolete.

Option

The Option type is effectively a collection of size zero or one, with the semantics of "maybe" having a value. (In fact, the Haskell equivalent of Option is the Maybe monad. I personally like the name Maybe better, since the idea of a function that returns Maybe int nicely expresses that it may return an int, but might not. That said, I've decided to go with Scala's naming in this case, and call it Option.)

Why would we want to return or keep an Option value instead of just returning or keeping an object reference? We do this partly to explicitly document the fact that a method may not return a valid value, or that a field may not be initialized yet. In traditional Java, this is accomplished by returning null or storing a null object reference. Unfortunately, any object reference could be null, so this leads to the common situation of developers adding defensive null-checks everywhere, including on return values from methods that will never return null (since they are guaranteed to produce a valid result in non-exceptional circumstances). With Option, you're able to explicitly say "This method may not have a defined return value for all inputs." Assuming you and your team settle on a coding standard where you never return null from a method, you should be able to guarantee that a method that returns String returns a well-defined (not null) String, while a method that returns Option<String> might not.

As with our other immutable collections, we'll follow the same pattern: abstract base class (in this case Option), with a singleton subclass for the empty collection (called None) and a subclass for non-empty collections (called Some). As with immutable lists, there will be no public constructors, but instances will be made available through factory methods on the base Option class.

Now, if we have a method that returns an Option, we can call isDefined() before calling get() to avoid having an exception thrown. Of course, this is only marginally better than just using null pointers, since the Option warns us that we should call isDefined(), but we're still using the same clunky system of checks. It would be nicer if we could just say "Do something to the value held in this Option, if it is a Some, or propagate the None otherwise", or simply "Do something with a Some, and nothing with a None". Fortunately, we can do both of these. Let's do the second one first, by making Option iterable:

Here is a test showing how we can iterate over Options to execute code only if there is an actual value:

Using this idiom, we can wrap the None check and the extraction of the value into the for statement. It's a little more elegant than writing if (a.isDefined()) a.get(), and ties in nicely with the idea of thinking of Option as a collection of size 0 or 1.

We can also replace a common Java null idiom, where we use a default value if a variable is null, using the getOrElse method:

Here is a test that shows how that works (and verifies the standard get() behaviour):

Before getting into some other collection operations that we'll add to ImmutableList and Option, let's add a convenient factory method that wraps an object reference in Some or returns None, depending on whether the reference is null.

And the test:

Higher-order Collection Methods

In the second post in this series, we saw the map and fold higher-order functions. Since these require a List parameter anyway (and, specifically, according to my previous implementation, one of those nasty mutable java.util.Lists), for convenience, we can add map and fold operations to our immutable collections as methods. While we're at it, we'll add two more higher-order functions, filter and flatMap. Since all of our collections should support for-each syntax, we'll specify the contract for our methods as an interface extending Iterable.

We'll implement these methods on ImmutableList and Option. Let's look at map first:

To establish that these map methods word, let's try mapping a function that doubles integers:

Next, we'll move on to the two folds. For an ImmutableList, the foldLeft operation traverses the list in the natural order, while a foldRight is easier if we reverse the list first. For Option, since there is (at most) one element, the only difference between the folds is the order in which parameters are passed to the function argument. To distinguish between foldLeft and foldRight, we need to use a non-commutative operation, so we'll use string concatenation in the tests.

The filter method returns a sub-collection, based on a predicate that gets evaluated on every element of the original collection. If the predicate returns true for a given element, that element is included in the output collection. For Option, filter will return a Some if and only if the original Option was a Some and the predicate returns true for the wrapped value. For an ImmutableList, the code for filter is similar to map, but we only copy into the accumulator the values that satisfy the predicate.

Our last higher-order method is a little more complicated. flatMap takes a function that returns collections, maps it across all elements in our collection, and then "flattens" the result down to a collection of elements. Say you have a function that takes an Integer and returns an ImmutableList. If you map that function over an ImmutableList<Integer>, the result is an ImmutableList<ImmutableList<Integer>>, or a list of lists of integers. However, if you flatMap it, you get back an ImmutableList<Integer> consisting of the concatenation of all of the lists you would get using map. Interestingly, both map and filter can be implemented in terms of flatMap. For map, we simply return wrap the return value from the passed function in single-element lists (or Some), while for filter we return a single-element list (or Some) for elements where the predicate evaluates to true and empty lists (or None) for elements where the predicate is false. Here are implementations and tests for flatMap:

A nice way of combining these two collection types is to flatMap a function that returns Option across an ImmutableList. The elements that evaluate to None simply disappear, while the Some values are fully realized (as in, you don't need to invoke get() on them).

Summary

In this post, we got to see a couple of immutable collection types.

While ImmutableList as I've implemented it takes a performance hit by using its own interface (since it tends to have to call reverse to provide consistent ordering), there are a couple of improvements that could be made. One would be to provide additional implementations of the methods that would contain all of the logic before the call to reverse, for the cases where order doesn't matter, and then implement the ordered versions by calling the unordered version followed by a call to reverse. The other would be to make NonEmptyList's tail member non-final, and make use of our private access to implement efficient appends by the higher-order methods. The code for that would be a fair bit more verbose (since we would have to hold a reference to the element preceding the terminal EmptyList). That said, it's possible that the additional operations per iteration might outweigh the cost of traversing (and creating) the list twice.

I think Option is one of my favourite types from Scala. The idea that Option gives you a warning through the type system that a value may not be initialized is fantastic. Sure, you're wrapping your object in another object (at a cost of 8 byes on a 32-bit JVM, I believe), but (with some discipline) you have some assurance that things that don't return Option will not return null. With the option() factory method, you can even wrap the return values from "old-fashioned" Java APIs. Unfortunately, I haven't yet seen a language where NullPointerExceptions are completely unavoidable (or aren't masked by treating unitialized values as some default). Since Scala works on the JVM, null still exists. Haskell has the "bottom" type (which actually corresponds more closely to Scala's Nothing), which effectively takes the place of null (as I understand it -- I'm a dabbler at best in Haskell), though the use of bottom is heavily discouraged, from what I've read.

In between these collection types, I managed to sneak in an explanation of tail-call optimization, and how it can be used to avoid blowing your stack on recursive calls on a list. Tail-calls are not restricted to recursive functions (or even mutually-recursive sets of functions), but the stack-saving nature of tail-call optimizations tend to be most apparent when recursion is involved. The other bonus of TCO (with apologies for my vague memories of some form of assembly) is that the tail-call gets turned into a JMP ("jump" or "goto"), rather than a JSR ("jump, set return"), so the final evaluation can immediately return to the last spot where a tail-call didn't occur. That said, I should remind you that tail-call optimization only works when you're delegating your return to another method/function/procedure (that is, the compiler doesn't need to push your current location onto the stack in order to return). When you're traversing a binary tree, and need to visit the left and right children via recursion, at least the first recursive call is not a tail-call. Of course, in Java, with default JVM settings, in my experience, the stack seems to overflow around 50k calls. If you have a balanced binary tree with 50k levels, then you probably have enough elements to make traversal outlast the earth being consumed by the sun, in which case a StackOverflowError is the least of your worries, since you and everyone you love will be long dead. (Actually, since the recursion is likely to be depth-first, you're lucky -- the stack will overflow quite quickly, and you can call your mom to remind her that you love her and you're glad that she didn't die while your program ran. Moms love that kind of thing.)

For some really interesting information about tail-call optimization on the JVM at the bytecode level, I suggest I you read John Rose's blog post on the subject from 2007.

I think this may be the longest post in the series so far. In particular, I've gone whole-hog on using unit tests to both establish that my code actually works (give or take copy-paste errors between my IDE and GitHub gists) and show examples of how to use these methods.

I must confess that this post was also more comfortable to write. Java doesn't make writing functions as objects terribly pleasant (since it wasn't really designed for it). I think the code in the last couple of posts involved more type arguments than actual logic. Working with data types is much easier, by comparison.

Where do we go from here?

I had been thinking about trying to implement a mechanism to combine lazy-loading of Function0 values with parallel execution, by replacing uninitialized values with Futures linked to a running thread on a ForkJoinPool. Unfortunately, the more I think about it, the more I realize that "parallelize everything" is a really bad idea (which is why nobody actually does it). I might be able to do it like memoization, where you can selectively memoize certain functions that are frequently used on a small set of inputs, so you would be able to parallelize only your "hot spot" functions. Unfortunately, my memoization implementation is kind of an ugly hack, requiring a parallel MemoizedFunctionN hierarchy, and I'm not particularly enthusiastic about creating a separate ParallelizedFunctionN hierarchy. Sure, I can write code generators for these different function specializations, but it's still not particularly elegant.

The last remaining "standard" functional programming concept I can think of that I haven't covered (besides pattern matching, which I don't think I can implement nicely in Java) is tuples. If I dedicate a post exclusively to tuples, I think it should be fairly short. After these last couple of lengthy posts, that may be a good way to go.