Top of the cake: a neural net for Titanic predictions in under 1 hour
🍰 Building a working neural network can be easy or hard, depending on where in the abstraction layer cake you are
Let’s try something simple and quick… how about a neural network? – said no one ever.
Neural networks are the arcane magic black box of artificial intelligence: they can do great things (ChatGPT, Midjourney, Google Translate are all based on neural networks), but they’re famously difficult to work with.
Let’s break some stereotypes today and see how we can build an end-to-end neural network system in under one hour.
Titanic survival predictions
We’re going to solve the Titanic survival prediction problem from Kaggle.
In this problem (which may look familiar), we’re given training data about 891 Titanic passengers. The information includes their name, sex, age, port of embarkation; and (very importantly) their survival outcome.
Our task is to build an ML model that learns from this data, and can later predict the survival outcome for further passengers (for which Kaggle knows whether they survived, but we don’t).
The layer cake
If you already know programming in Python, building a neural network can be fairly straightforward or very very hard… depending on where in the abstraction layer cake you want to build it.
Let me explain: whenever someone is using a neural network, it’s built in layers of abstraction that help to manage the complexity of the system. Starting from the ground up, it may look like this:
💻 Hardware: The processor of your computer, the graphical card (GPU)
🎨 GPU support/drivers: the software that helps to run computation (including neural network training) on the graphical card / GPU
🔥 PyTorch: one of the most popular libraries that help build neural networks and run them on GPUs
People often stop there, and learn how to build neural networks using PyTorch. That’s fine, but no wonder it’s hard: PyTorch is a low-level deep learning library. You need to know a lot about neural network design, and be ready for some serious trial and error, to build a neural network that’s practically useful at this abstraction level.
To go faster, we will add one more layer to the cake:
🤖 fastai: a high-level deep learning library providing ready-to-use neural network architectures
Tabular
fastai is what it sounds it is: a library that helps you build end-to-end AI systems fast. (There’s a lot of value in this proposition: fast prototyping allows us to quickly try a few approaches and see which one works best, rather than sinking into a rabbit hole of one specific technique with no promise of success.)
fastai comes with ready-to-go architectures for common classes of problems. One of these classes is, guess what, making predictions based on tabular data :).
To use the Tabular learner from fastai, we need to:
give it the passenger data
indicate which passenger attributes should be treated as continuous variables (numeric values such as age), which are categorical (such as the port of embarkation), and which is the variable that we want to learn to predict (the survival outcome)
What did a Deep Learning library ever do for us?
fastai the does a lot of work behind the scenes to get us a trainable neural network for our data table. This isn’t simply putting the data through some fixed template of computational graph: fastai creates the right design of the neural network for us.
To peek behind the scenes and see this in practice, we can print out the model architecture:
learn = tabular_learner(dls, metrics=accuracy)
print(learn.model)
The result is a PyTorch representation of the neural network (what we would have to build by hand if not for fastai):
TabularModel(
(embeds): ModuleList(
(0): Embedding(4, 3)
(1): Embedding(3, 3)
(2): Embedding(4, 3)
(3): Embedding(3, 3)
)
(emb_drop): Dropout(p=0.0, inplace=False)
(bn_cont): BatchNorm1d(4, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(layers): Sequential(
(0): LinBnDrop(
(0): Linear(in_features=16, out_features=200, bias=False)
(1): ReLU(inplace=True)
(2): BatchNorm1d(200, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
(1): LinBnDrop(
(0): Linear(in_features=200, out_features=100, bias=False)
(1): ReLU(inplace=True)
(2): BatchNorm1d(100, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
(2): LinBnDrop(
(0): Linear(in_features=100, out_features=2, bias=True)
)
)
)
This looks intimidating, so let’s go through the main pieces:
🖼️ The four “embedding” modules. These are needed for each “categorical” variable in our input: sex, point of embarkation, class, etc. Embeddings represent each category as an abstract vector of numbers, allowing the model to learn hidden relations between category elements. For example, if people who embarked in Southampton had similar survival outcomes to those who embarked in Cherbourg, but different than those in Queenstown, the model will be able to learn that.
📈 Linear modules. Linear(in_features=16, out_features=200) is the core of the network. Here the 16 input attributes of each passengers are connected to 200 artificial neurons: mathematical formulas that will try to learn relations between data points and their survival outcomes. We also have a second layer of these, connecting 200 neurons in layer 1 with 100 neurons in layer 2.
🎁 Output: at the end, the last linear layer connects the 100 neurons in layer 2 to just 2 output features, corresponding to two survival outcomes: survived or perished
With the neural network in place, we need just one more line of code to train it on the training data:
How does it do?
When submitted on Kaggle, the fastai neural network solution reaches the accuracy of 77% out of the box.
This is much better than the baseline 62% for a solution that simply predicts that everyone dies and a bit better than 76% I got a few months back when experimenting with decision forests. Not bad for something that we put together in <30 lines of code and under 1 hour (including reading the fastai documentation)!
And this is just the first shot :). In a future post we may try to get more points.
More on this
🎞️ What did Romans ever do for us? – Monty Python, reference for “What did a Deep Learning library ever do for us?”
In other news
⚙️ Tech companies are using the data faster than it is being produced. New York Times wrote about the scramble to get more training data for LLMs.
📈 Meanwhile, a new paper suggests that to sustain linear gains in model performance we will need exponentially more data: suggesting that all this scaling may quickly run into diminishing returns.
🗺️ Melting Arctic ice opens up a new potential route for an Internet cable connecting Europe to Japan: up and to the West around Greenland, Canada and Alaska. This would be an alternative to the East-bound route via the Red sea, where the current submarine cables are getting damaged.
Postcard from Paris
Evenings are getting warm and people are having impromptu picnics in front of the Pompidou Center, a sure sign that springtime is finally here 💫.
Have a great week,
– Przemek
I didn't understand a word from this text, but the cake looks tasty! 😸