Why JAX Could Be the Next Platform for HPC-AI

So many of Google’s technology developments have had major impacts in the commercial world but the company’s impact on scientific computing has been far less pronounced. Teams at the search engine giant are trying to bridge that gap by creating software tools that blend the best of two worlds — traditional numerical simulations and modern machine learning.

Stephan Hoyer has been working at this intersection of scientific computing and deep learning at Google since 2016, specifically trying to find ways to move ML and HPC closer together in climate and weather modeling. Before that, he worked on climate and risk models at the Climate Corporation following a Ph.D in physics at UC Berkeley. In short, he has a good sense of what works for traditional numerical simulations and where machine learning might fit.

Hoyer and team in the applied sciences group at Google are focusing on “differentiable programming” as the route to meshing the worlds of traditional HPC and ML. “Numerical methods are interpretable and generalizable and neural networks work well for approximation,” he says. “In physics, there is always some level of approximation. Our goal is how to best artful inject a bit of machine learning just where it’s needed to improve results.”

The balance between traditional and ML methods is trickier than mere speedups. As models grow more complex and resolution requires are higher, the real problem in weather and climate simulations is juggling accuracy and time to result. Traditional simulations are accurate but slow and ML-only models cannot compete with the detail of fluid flow simulations where factors like turbulence stretch the limits of even the largest supercomputers.

To blend both worlds, Hoyer and Google team built on top of JAX, a Python library designed for numerical computing and HPC/AI. Jax was built at Google and is used as the foundation for various projects, including one for weather and climate simulations, JAX-CFD, which is that HPC-AI hybrid framework Hoyer and colleagues built and used to make weather and climate calculations more efficient.

Before we get into the CFD extension, it’s important to note that JAX might be the key to more widespread adoption of AI in HPC and physics-dominated areas. First, there are enough familiar hooks in JAX to make it palatable for HPC developers. Since the JAX API for all numerical elements is based on the familiar NumPy functions, which have already been in use in HPC, JAX can be relatively easy to onboard with, allowing faster uptake of ML functions.

All of these capabilities are delivered as a system of what Google calls “composable function transformations” that can allow more seamless integration of ML elements in traditional HPC workflows (beyond CFD). These include differentiation (gradient-based optimization tools), vectorization and support for large-scale data parallelism, and compilation tools that allow scaling across accelerators (GPUs and TPUs, for example). The last one is a big deal since so many traditional numerical simulations are designed for CPU only and have not yet been able to take advantage of any acceleration.

  • Differentiation: Gradient-based optimisation is fundamental to ML. JAX natively supports both forward and reverse mode automatic differentiation of arbitrary numerical functions, via function transformations such as gradhessianjacfwd and jacrev.

  • Vectorisation: In ML research we often apply a single function to lots of data, e.g. calculating the loss across a batch or evaluating per-example gradients for differentially private learning. JAX provides automatic vectorisation via the vmap transformation that simplifies this form of programming. For example, researchers don’t need to reason about batching when implementing new algorithms. JAX also supports large scale data parallelism via the related pmap transformation, elegantly distributing data that is too large for the memory of a single accelerator.

  • JIT-compilation: XLA is used to just-in-time (JIT)-compile and execute JAX programs on GPU and Cloud TPU accelerators. JIT-compilation, together with JAX’s NumPy-consistent API, allows researchers with no previous experience in high-performance computing to easily scale to one or many accelerators.

Google has already shown off some of JAX and its extensions for HPC application domains. For instance, there is Haiku, which is based on JAX and has been tested for computational chemistry problems. Other tools that could have play in HPC/AI include Jraph, which is JAX-based and allows graph neural networks and physics to dance.

Hoyer says he and the Google team went through a number of steps in deciding where and how much neural networks could fit the CFD bill.

From models ranging from pure CNNs to learned corrections (a physics step followed by a CNN step), to ML inside physics, the goal was to design an accurate simulator on coarse grids (versus ultra-high resolution). They found that pure ML models don’t generalize well but a hybrid physics-ML approach worked best.

To do this, they established high-res ground truth and down-sampled to coarse grids to avoid the expensive, long runs of traditional simulations. After running this ground truth through an ML solver to work on coarse grids, they optimized parameters using stochastic gradients and “unrolled” the simulation through time steps to do an end-to-end training run that matched the traditional simulation’s dynamics with error correction built in. They were able to show that their method, even though it used quite a bit more computational power (150X more FLOPS) was only 12X slower at the same resolution but 80X faster for the same accuracy. Those are nuanced tradeoffs but the speed/accuracy boost is notable.

“This new toolkit lets us build better methods and holds great promise for fluids and more broadly in scientific computing for adding onto and adapting numerical methods,” Hoyer says.

While AI/ML is still slow in coming to HPC, it is the lack of tools that bridge the programmatic divide that is the most pressing bottleneck. The advantages of being able to augment traditional numerical methods as well as use acceleration more fully make JAX a framework to watch in both HPC and future AI.

Sign up to our Newsletter

Featuring highlights, analysis, and stories from the week directly from us to your inbox with nothing in between.
Subscribe now

Be the first to comment

Leave a Reply

Your email address will not be published.


This site uses Akismet to reduce spam. Learn how your comment data is processed.