Integration between torchnet and torch-dataframe – a closer look at the mnist example

It's all about the numbers and getting the tensors right. The image is cc by David Asch .

It’s all about the numbers and getting the tensors right. The image is cc by David Asch
.

In previous posts we’ve looked into the basic structure of the torch-dataframe package. In this post we’ll go through the [mnist example][mnist ex] that shows how to best integrate the dataframe with [torchnet](https://github.com/torchnet/torchnet).

# All posts in the *torch-dataframe* series

1. [Intro to the torch-dataframe][intro]
2. [Modifications][mods]
3. [Subsetting][subs]
4. [The mnist example][mnist ex]
5. [Multilabel classification][multilabel]

[intro]: http://gforge.se/2016/08/deep-learning-with-torch-dataframe-a-gentle-introduction-to-torch/
[mods]: http://gforge.se/2016/08/the-torch-dataframe-basics-on-modifications/
[subs]: http://gforge.se/2016/08/the-torch-dataframe-subsetting-and-sampling/
[mnist ex]: http://gforge.se/2016/08/integration-between-torchnet-and-torch-dataframe-a-closer-look-at-the-mnist-example/
[multilabel]: http://gforge.se/2016/08/setting-up-a-multilabel-classification-network-with-torch-dataframe/

# The `getIterator`

Everything interesting is located int the `getIterator` function. This functions purpose is to select *train*, *test*, *validate* datasets and return an instance of the `tnt.DatasetIterator` that returns a table with `input` and `target` tensors:

“`lua
{
input = …,
target = …
}
“`

# The mnist dataset

The mnist dataset is loaded via the mnist package:

“`lua
local mnist = require ‘mnist’
local mnist_dataset = mnist[mode .. ‘dataset’]()
“`

The labels from the dataset are then converted to a Dataframe:

“`lua
local df = Dataframe(
Df_Dict{
label = mnist_dataset.label:totable(),
row_id = torch.range(1, mnist_dataset.data:size(1)):totable()
})

“`

The image data is retrieved using an external resource just as you would for any external data storage:

“`lua
— Since the mnist package already has taken care of the data
— splitting we create a single subsetter
df:create_subsets{
subsets = Df_Dict{core = 1},
data_retriever = function(row)
return ext_resource[row.row_id]
end,
label_retriever = Df_Array(“label”)
}
local subset = df[“/core”]
“`

Note that we here create a single subset as the data is already split.

# The iterators

The dataframe has two specialized iterator classes for setting up the iterator, `Df_Iterator` and `Df_ParallelIterator`. The difference is that `Df_ParallelIterator` allows you to set up multiple threads (the `nthread` argument) and take care of the external data loading there. *Note*: this is different from torchnet’s own parallel iterator that exports the entire dataset to the threads and then does everything within that thread. The reason for our approach is that this won’t work with the samplers and we believe that the extra cost is negligible as long as you don’t have all your data in the csv-file.

## The plain iterator

Here we set up the external resource and then create the iterator that will have access to that resource:

“`lua
ext_resource = mnist_dataset.data:reshape(mnist_dataset.data:size(1),
mnist_dataset.data:size(2) * mnist_dataset.data:size(3)):double()

return Df_Iterator{
dataset = subset,
batch_size = 128,
target_transform = function(val)
return val + 1
end
}
“`

## The parallel iterator

This is similar to the above but we most in addition load the packages required within each thread and also the external resource inside each thread:

“`lua
return Df_ParallelIterator{
dataset = subset,
batch_size = 128,
init = function(idx)
— Load the libraries needed
require ‘torch’
require ‘Dataframe’

— Load the datasets external resource
local mnist = require ‘mnist’
local mnist_dataset = mnist[mode .. ‘dataset’]()
ext_resource = mnist_dataset.data:reshape(mnist_dataset.data:size(1),
mnist_dataset.data:size(2) * mnist_dataset.data:size(3)):double()
end,
nthread = 2,
target_transform = function(val)
return val + 1
end
}
“`

# The reset call

The torchnet engines don’t resample the dataset after each epoch and simply restart after completing `my_dataset:size()` number of times. We therefore need to add a hook so that the `reset_sampler` is envoked. This is only needed for those that require resetting (linear, ordered and permutation) but it is recommended to do this as a standard practice since it will make it easier to switch between the samplers. The hook belongs to the engine and is set a little further down the script:

“`lua
engine.hooks.onEndEpoch = function(state)
print(“End epoch no ” .. state.epoch)
state.iterator.dataset:reset_sampler()
end
“`
# A little about torchnet

As I’ve been adapting torch-dataframe to torchnet I’ve learned to appreciate its brilliant structure. The dataset layers let you build increasing complexity as needed and the possibilities are endless. The engine is elegant and understanding the hooks is trivial if you just look at the code (from the [SGDEngine](https://github.com/torchnet/torchnet/blob/master/engine/sgdengine.lua))

“`lua
self.hooks(“onStart”, state)
while state.epoch < state.maxepoch do state.network:training() self.hooks("onStartEpoch", state) for sample in state.iterator() do state.sample = sample self.hooks("onSample", state) state.network:forward(sample.input) self.hooks("onForward", state) state.criterion:forward(state.network.output, sample.target) self.hooks("onForwardCriterion", state) state.network:zeroGradParameters() if state.criterion.zeroGradParameters then state.criterion:zeroGradParameters() end state.criterion:backward(state.network.output, sample.target) self.hooks("onBackwardCriterion", state) state.network:backward(sample.input, state.criterion.gradInput) self.hooks("onBackward", state) assert(state.lrcriterion >= 0, ‘lrcriterion should be positive or zero’)
if state.lrcriterion > 0 and state.criterion.updateParameters then
state.criterion:updateParameters(state.lrcriterion)
end
assert(state.lr >= 0, ‘lr should be positive or zero’)
if state.lr > 0 then
state.network:updateParameters(state.lr)
end
state.t = state.t + 1
self.hooks(“onUpdate”, state)
end
state.epoch = state.epoch + 1
self.hooks(“onEndEpoch”, state)
end
self.hooks(“onEnd”, state)
“`

I believe that the torch-dataframe does have it’s place in the infrastructure as it will allow you to better visualize your data. It also brings to the table some basic data operations where the simple `as_categorical` will quickly allow you to understand the networks outputs.

# Summary

In this post we’ve looked closer at the [mnist example][mnist ex] and what components it uses. Hopefully you can use this as a template in your own research.

Leave a Reply

Your email address will not be published. Required fields are marked *

This site uses Akismet to reduce spam. Learn how your comment data is processed.