XLB: A Hardware-Accelerated Differentiable Lattice Boltzmann Simulation Framework based on JAX for Physics-based Machine Learning
XLB (Accelerated LB) is a fully differentiable 2D/3D Lattice Boltzmann Method (LBM) solver that leverages hardware acceleration. It's built on top of the JAX library and is specifically designed to solve fluid dynamics problems in a computationally efficient and differentiable manner. Its unique combination of features positions it as an exceptionally suitable tool for applications in physics-based machine learning.
- Integration with JAX Ecosystem: The solver can be easily integrated with JAX's robust ecosystem of machine learning libraries such as Flax, Haiku, Optax, and many more.
- Scalability: XLB is capable of scaling on distributed multi-GPU systems, enabling the execution of large-scale simulations with billions of voxels.
- Support for Various LBM Boundary Conditions and Kernels: XLB supports several LBM boundary conditions and collision kernels.
- User-Friendly Interface: Written entirely in Python, XLB emphasizes a highly accessible interface that allows users to extend the solver with ease and quickly set up and run new simulations.
- Leverages JAX Array and Shardmap: The solver incorporates the new JAX array unified array type and JAX shardmap, providing users with a numpy-like interface. This allows users to focus solely on the semantics, leaving performance optimizations to the compiler.
- Platform Versatility: The same XLB code can be executed on a variety of platforms including multi-core CPUs, single or multi-GPU systems, TPUs, and it also supports distributed runs on multi-GPU systems or TPU Pod slices.
The documentation can be found here (in preparation)
The following examples showcase the capabilities of XLB:
Lid-driven Cavity flow at Re=100,000 (~25 million voxels)
DrivAer model in a wind-tunnel using KBC Lattice Boltzmann Simulation with approx. 317 million voxels
Flow over a NACA airfoil using KBC Lattice Boltzmann Simulation with approx. 100 million voxels
- BGK collision model (Standard LBM collision model)
- KBC collision model (unconditionally stable for flows with high Reynolds number)
- D2Q9
- D3Q19
- D3Q27 (Must be used for KBC simulation runs)
- Binary and ASCII VTK output (based on PyVista library)
- Image Output
- 3D mesh voxelizer using trimesh
-
Equilibrium BC: In this boundary condition, the fluid populations are assumed to be in at equilibrium. Can be used to set prescribed velocity or pressure.
-
Full-Way Bounceback BC: In this boundary condition, the velocity of the fluid populations is reflected back to the fluid side of the boundary, resulting in zero fluid velocity at the boundary.
-
Half-Way Bounceback BC: Similar to the Full-Way Bounceback BC, in this boundary condition, the velocity of the fluid populations is partially reflected back to the fluid side of the boundary, resulting in a non-zero fluid velocity at the boundary.
-
Do Nothing BC: In this boundary condition, the fluid populations are allowed to pass through the boundary without any reflection or modification.
-
Zouhe BC: This boundary condition is used to impose a prescribed velocity or pressure profile at the boundary.
-
Regularized BC: This boundary condition is used to impose a prescribed velocity or pressure profile at the boundary. This BC is more stable than Zouhe BC, but computationally more expensive.
-
Extrapolation Outflow BC: A type of outflow boundary condition that uses extrapolation to avoid strong wave reflections.
- Distributed Multi-GPU support
- JAX shard-map and JAX Array support
- Mixed-Precision support (store vs compute)
To install XLB, you can run the following commands:
pip install --upgrade pip
# For CPU run
pip install --upgrade "jax[cpu]"
# For GPU run
# CUDA 12 and cuDNN 8.8 or newer.
pip install --upgrade "jax[cuda12_pip]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
# CUDA 11 and cuDNN 8.6 or newer.
pip install --upgrade "jax[cuda11_pip]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
# Please refer to https://github.com/google/jax for the latest installation documentation
# Run dependencies
pip install jmp pyvista numpy matplotlib Rtree trimesh jmp
Accompanying publication coming soon:
M. Ataei, H. Salehipour. XLB: Hardware-Accelerated, Scalable, and Differentiable Lattice Boltzmann Simulation Framework based on JAX. TBA