Tensors and Nx, are not just for machine learning

Image by Annie Ruygt

This post is about using NX with Elixir and how easily it can be done for every day math! If you want to deploy your Phoenix LiveView app right now, then check out how to get started. You could be up and running in minutes.

The Elixir community has a sleeping giant brewing with Nx, and not just for Machine Learning. Nx allows you to describe nearly any numerical operation, which can then be run in an optimized environment for such operations. If you read the Nx README it describes itself as:

… a multidimensional tensors library for Elixir with multistaged compilation to the CPU/GPU.

It continues with a very impressive list of features that for sure make sense to people coming from Numerical programming, but what about us regular Elixir developers? Typically, the BEAM is not a place to do efficient math programming, but Nx changes the landscape.

You may not be looking specifically to handle large or multidimensional tensors—but consider this: if you can express it in tensor form, you can take advantage of Nx for faster calculations.

In this post, we’re going to try it out and walk through what’s possible. The first thing I recommend you do is either open up an elixir script or a Livebook with the following line:

Mix.install([
    :nx
])

What is a Tensor?

A tensor is a Data Structure that is wildly flexible. It is used to wrap computations over matrixes or anything that can be described as a matrix. For example, this is a tensor:

Nx.tensor(42)
#Nx.Tensor<
s64
42
>

This simply means that we have a tensor of type, (64)bit (s)igned integer, with value 42. There is an implied matrix of 1x1 but that isn’t relevant here. And we can do math with this tensor!

Nx.tensor(42)
|> Nx.add(10)
#Nx.Tensor<
s64
52
>

You’ll notice that Nx was able to take our integer 10, coerce it to the correct shape, then do the addition, returning a new tensor with value 52. What happens if we try and add a float?

Nx.tensor(42)
|> Nx.add(10.5)
#Nx.Tensor<
f32
52.5
>

Here it coerced the tensor into a (32)bit (f)loat, then did the addition, returning a new tensor.

So far I hope you’re still with me, let’s take it up a notch, lists of integers!

Nx.tensor([1, 2, 3, 42, 5])
#Nx.Tensor<
s64[5]
[1, 2, 3, 42, 5]
>

Here we have a (64)bit (s)igned integer vector of 1x5, a vector is just a single row or column of a matrix. You can verify the matrix shape here using the Nx.shape/1 function. Let’s see what happens when we do some math on this tensor:

Nx.tensor([1, 2, 3, 42, 5])
|> Nx.add(10)
#Nx.Tensor<
s64[5]
[11, 12, 13, 52, 15]
>

Is it what we expected? Nx saw we had a tensor with shape {5} and a tensor of shape {} and automagically “broadcasted” the tensor into shape {5}. Which is equivalent to:

Nx.add([1,2,3,42,5], [10, 10, 10, 10, 10])

Most operations will do their best to coerce the types and shapes of tensors into the correct shape for you. Nx has functions for just about all basic element-wise math functions.

To try and relate this back to functional programming, we’re essentially doing an Enum.map/2 over the first vector and adding the value to each item:

  Enum.map([1,2,3,42,5], fn n -&gt; n + 10 end)

In fact, for all of our 1 dimensional vectors we could implement everything using Enum functions.

And it doesn’t stop there, we also have aggregate functions as well:

Nx.tensor([1, 2, 3, 42, 5])
|> Nx.mean()
#Nx.Tensor<
f32
10.600000381469727
>

If we just stopped now, we’d already have a powerful tool for working with simple lists of data. And thanks to Nx this could be optimized to run on a CPU or GPU and in parallel with essentially zero work from us. Just as an example, here is a flow of computation that you could write with Nx:

k = Nx.Random.key(1) # Random key seed

# roll 1000 dice with a shape of 1x1000
{dice_rolls, _} = Nx.Random.randint(k, 1, 6, shape: {1000})

Nx.divide(Nx.sum(dice_rolls), 1000) # 2.96
# or
Nx.mean(dice_rolls) # 2.96

Nx.multiply(dice_rolls, dice_rolls) # {1000} vector dice_roll*dice_roll
|> Nx.mean() # 10.85

Really, the sky is the limit here!

Multiple Dimensions

Here is where we finally catch up to the baseline description for Nx, it’s not limited to our basic 1 dimension, it can work with N dimensions, let’s start with 2:

Nx.tensor([
    [1, 2, 3, 4, 5], 
    [12, 22, 32, 42, 52]
])
#Nx.Tensor<
s64[2][5]
[
    [1, 2, 3, 4, 5],
    [12, 22, 32, 42, 52]
]
>

Here we have a (64)bit (s)igned integer with 2 rows and 5 columns, {2, 5} aka a 2x5 matrix. All the same functions apply here that we used above, and this is where we really start using the matrix optimizations of the GPU/CPU.

You can see how this might be useful with images, which are Width X Height sized matrixes of color values. This is what the Image library will help you with: give it an image and it will help you build a tensor that you can manipulate!

Conclusion

A tensor is a high level way of describing mathematical operations over N dimensions. Be it 0, 1, or 1000, we can operate over these tensors in a performant way with a single interface.

And that is truly amazing.

We are really only scratching the surface of what’s possible using Nx. The old adage of “Beam is not good at math” is no longer true, as the real possibilities are now endless. Just having this ability has already created an explosion of projects such as:

  • Bumblebee which builds a high level interface around pre-trained AI and ML models.
  • Scholar gives us a tool chest of classic machine learning/statistical tensor functions for doing high level math.
  • Explorer works with large datasets without exploding our memory usage
  • Image high level library for doing image manipulation, using Nx it enables classification and low level operations on raw image data.

Each of these is pretty math and science focused, but the only way to grow this list is for someone to take the first step!

Fly.io ❤️ Elixir

Fly.io is a great way to run your Phoenix LiveView app close to your users. It’s really easy to get started. You can be running in minutes.

Deploy a Phoenix app today!