January 22, 2019

Knet v1.2.0: iterators, iterators, iterators...

The new Knet release is all about iterators: iterators for minibatching, iterators for training, iterators for monitoring, convergence etc. Why am I so excited about iterators all of a sudden? Allow me to explain:

Knet has used iterators for data generation since 2015. That was about it until recently when I was looking for a way to improve the training interface. See, at the core of every deep learning project there is a training loop that looks like this:

function train(model,data)
    for (x,y) in data
        # improve model parameters so model(x) approaches y
    end
end
And these things can run for hours or days. You want the user to have full control of this loop: how many iterations to go, how to detect convergence and quit, how to monitor progress, how to take model snapshots or measure dev accuracy every n iterations etc.

My original (non)solution was to write a new `train` function for every experiment. Why restrict the user with a bad interface when they can write their own 5 line loop? (of course then why write any package at all but that's another discussion).

My next (pseudo)solution was to provide a `train` function with lots of keyword arguments. I soon gave up on that idea when it became clear that I was on my way to implementing a Turing complete programming language using keyword arguments.

Then I thought I had a brilliant flash of insight based on callback functions. See if `train` just accepts a callback function that gets called inside the for loop, the user can implement any behavior:
function train(model,data,callback)
    for (x,y) in data
        callback() || break
        # improve model parameters so model(x) approaches y
    end
end
You want to display a progress bar, do something every n iterations, or quit after N iterations? Just implement some callback function with state and you are all set! Brilliant? Everybody hated it. Including me. It turns out callback functions are awkward to write and do not lead to very readable code.

Then finally I rediscovered iterators and iterators that wrap other iterators (inspired by Tqdm.jl). I knew iterators can be these lazy collections that produce their next element only when asked. (Here is a summary with doc links to refresh your memory). See, once you implement the training loop as an iterator you can pause, restart and terminate it whenever you want:
train(model,data) = ((update model and return loss) for (x,y) in data)
What I realized iterators also do is turn the for loop inside out! Make its guts visible so one has explicit control: You can monitor and display its progress, take snapshots or whatever all with very explicit and readable code. Here are some actual examples from Knet v1.2.0. (`sgd` is a train iterator, f is the model, d is the data):

* To display a progress bar use progress(sgd(f,d)).
* To run until convergence use converge(sgd(f,cycle(d))).
* To run multiple epochs use sgd(f,repeat(d,n)).
* To run a given number of iterations use sgd(f,take(cycle(d),n)).
* To do a task every n iterations use:
(task(x) for x in every(n, sgd(f,cycle(d)))).

Each of the functions like `progress`, `converge`, `sgd` etc. take and return iterators. So they can be composed like crazy. Here is how to (1) train a model on dtrn, (2) measuring loss on dtst every 100 iterations, (3) quitting when dtst performance converges, and (4) displaying a progress bar from the Knet tutorial:
a = adam(model,cycle(dtrn))
b = (model(dtst) for _ in every(100,a))
c = converge(b, alpha=0.1)
progress!(c, alpha=1)
The code reads like the English description! Imagine trying to implement this using keyword arguments or callback functions... and that is why I am excited about iterators.

Notes:
* the more nitpicky reader will probably point out that I should have called these things generators or coroutines or streams or something rather than iterators, but you get the idea.
* every(n,itr) = (x for (i,x) in enumerate(itr) if i%n == 0) should be a Julia primitive! (Thank you @CarloLucibello for pointing out that `IterTools.takenth` does the same thing.)
* @lostella has a wonderful post on iterators.
* Here are the relevant links in Julia docs: Interfaces, Collections, Iteration Utilities and Generator expressions.
* Here is a link to the discussion on Julia discourse.

Full post... Related link