In this blog post I want to talk about the research and development results for a library that I started working on more than a year ago - TensorFrost. Under the hood it’s a static optimizing tensor compiler with a focus on being able to do more “shader-like” things while still keeping the ability to do high level linear algebra for ML in Numpy-like syntax with automatic differentiation support. (Click on the example GIF’s for more details!)
For documentation on basic functionality, read the README file in the repo.
I started working on this library around 14 months ago, initially I didn’t really plan to do much more than a few matrix operations for an optimization algorithm I wanted to implement in Unity, but there were quite a few things that I wanted to have on top of all of this and it sidetracked me into a writing an entire compiler (hello scope creep 👋).
The thing is, it’s not the first time I tried to make a tensor library, the first time was a whole 5 years ago and used OpenCL, as I didn’t have an Nvidia GPU at the time. To be honest I’ve been completely unprepared to the magnitude of what it required, and while I did get it to a “somewhat” working state like having basic kernels and somewhat working autodiff using a tape, the lack of good debug tools and actual problems that I wanted to solve pretty much killed it. And for the things that I did want to work on, it was completely unsuited for, as I usually write simulations or graphics, and the overhead of doing a kernel per operation (especially unoptimized kernels) is just too bad to be useful.
Since that time I’ve had a lot of ideas of what I would like a library like that to even look like, and wanted to try working on it again. However I did know, that for this project to survive I would need to make it suitable for projects I usually do, like the ones I usually do on Shadertoy. It might seem weird to you as to why I would make a specifically “tensor” library for something that is basically equivalent to writing shaders. But to be honest, shaders are not actually a perfect place for what I do, and a lot of simulation/rendering algorithms can map quite well to high level “tensor-like” operations. While the limitations might force you to come up with creative solutions, for really large or complicated projects it just becomes more of a problem, as it’s very hard to iterate on quickly. This was also one of the main reasons I didn’t really touch ML too much for most of my pet projects (except stuff like Neural Implicit Representations), ML algorithms are usually quite orthogonal to the way you write shaders, usually being split into hundreds of kernel dispatches, while those shader algorithms are effectively just one megakernel most of the time. The only reasonable way to integrate neural networks into those is unrolling the network into a single scalar function, which can be quite unoptimal and limits their size. Not even talking about the fact that training them is completely out of the question. This brings up another problem, shaders don’t have automatic differentiation, which is surprisingly much more useful here than you might think. While its usually used for optimization algorithms like SGD, having the analytic gradient can also be useful for computing normals/curvature of analytic shapes, or forces from potentials in simulations.
So in this library, I hoped to somehow extend the applicability range of a Tensor library to more “shader-like” use cases like rendering and simulations.
And right now I can actually say that at least to some partial degree it did work out. Currently the library is something of a mix of slightly more low-level Numpy-like operations with shader-like control-flow and operations (most of the built-in scalar shader functions are passed through to Python). In terms of where it stands right now, its more low-level than something like JAX or PyTorch, but still not as low-level as Taichi as you still technically operate on something similar to tensors.
So why make a new library?
It will indeed take a inordinate amount of work to make a library from scratch and get it to a useful state, as I’ve already experienced. Why would I not just use an existing library, as there are seemingly thousands of them? There are a few reasons mostly applicable to my use cases which make both using pure ML libraries or pure shaders annoying.
1. Performance scales poorly for non-ML specific operations
... continue reading