Thu Jul 28 2016

kNN in F# (Benchmarking is hard)

I was reading over some old blog posts and stumbled upon a series of blog posts at Comparing a Machine Learning Algorithm Implemented in F# and OCaml and kNN Algorithm in golang and Haskell on "kNN example code" in golang and Haskell. Others on the /r/haskell reddit have already pointed out the apples-to-peaches comparison going on here, but I think this is a fine way to illustrate a few gotchas / features of F# that may not be super obvious.

To begin, here are some of the numbers on my machine: In the interests of "benchmarking is hard" and isolating just run time, these times do not include compile times (fsc is an order of magnitude slower) Also, the .NET runtime should be faster than Mono, but I haven't tested on a Windows machine.

$ time ./golang-k-nn
real 0m4.265s
$ time ./golang-k-nn-speedup
real 0.790s
$ time mono fsharp-k-nn.exe
real 18.803s

Optimizing this F# code:

At the core of the code, the kNN classifier just calls a distance function to compare two images, and selects the minimum value. This means any optimizations to the distance function go a long way towards net speedup of the code base. Here's the original F# code:

let distance (p1: int[]) (p2: int[]) =
    Math.Sqrt (float(Array.sum (Array.map2 ( fun a b -> (pown (a-b) 2)) p1 p2) ))

First off, pown foo 2 is going to be slower than foo * foo, F# does not optimize for that special case. Switching that out instantly gives us a 150% speed improvement:

let distance (p1: int[]) (p2: int[]) =
    Math.Sqrt (float(Array.sum (Array.map2 ( fun a b -> (a-b) * (a-b))) p1 p2) ))

$ time mono fsharp-k-nn.exe
> real 11.934s

Secondly, F# doesn't inline functions terribly well -- imperative-style code and tail-recursive functions tend to both generate much more efficient IL than Array.map and Array.sum. Here's a tail recursive version:

let distance (p1: int[]) (p2: int[]) =
    let rec iterate s = function
        | -1 -> sqrt (float s)
        | n -> let v = p1.[n] - p2.[n] in iterate (s + v * v) (n-1)
    iterate 0 (p1.Length-1)

$ time mono fsharp-k-nn.exe
> real 4.541s

Aha! We're in the ballpark of the naive golang implementation. Which makes sense, because the above code does basically the same sequence of operations as the golang version (except it should be slightly faster since it operates on int arrays instead of float arrays)

Now, lets see if we can apply the extra set of optimizations in the sped-up golang version to our F# code:

Goroutines

The optimized go version parallelizes the computation using goroutines, going from this:

total := 0
for _, test := range validationSample {
    if is_correct(test) {
        total++
    }
}

to this:

total := 0
channel := make(chan float32)

for _, test := range validationSample {
    go func(t) {
        if is_correct(t) {
            channel <- 1
        } else {
            channel <- 0
        }
    })(test)
}
for i := 0; i < len(validationSample); i++ {
    total += <- channel
}

In F#, To parallelize the classification of validationSample, we change:

let num_correct =
    validationsample
    |> Array.map (fun p -> if (classify p.Pixels ) = p.Label then 1 else 0)
    |> Array.sum

to this:

let num_correct =
    validationsample
    |> Array.Parallel.map (fun p -> if (classify p.Pixels ) = p.Label then 1 else 0)
    |> Array.sum

Notice the 10-character difference! If we change the distance function further to include a bailout parameter like the Go solution does, we can further cut the time down to 1.7 seconds.

Conclusion

At this point, we've improved the time from the inital F# version over ten-fold from 18s to 1.7s, by inlining the expensive parts of the distance function and using the same techniques that were applied to the golang implementation. To further approach the speed of C, work could be done by writing vectorized code using the new System.Numerics.Vector libraries.