CUG Logo

Papers

Porting a large cosmology code to GPU, a case study examining JAX and OpenMP.

Authors: Nestor Demeure (Lawrence Berkeley National Laboratory/National Energy Research Scientific Computing Center), Theodore Kisner (Computational Cosmology Center, Lawrence Berkeley National Laboratory; Department of Physics, University of California Berkeley), Reijo Keskitalo (Computational Cosmology Center, Lawrence Berkeley National Laboratory; Department of Physics, University of California Berkeley), Rollin Thomas (Lawrence Berkeley National Laboratory/National Energy Research Scientific Computing Center), Julian Borrill (Computational Cosmology Center, Lawrence Berkeley National Laboratory; Space Sciences Laboratory, University of California Berkeley), Wahid Bhimji (Lawrence Berkeley National Laboratory/National Energy Research Scientific Computing Center)

Abstract: In recent years, a common pattern has emerged where numerical software is designed around a Python interface calling high-performance kernels written in a lower level language.

With the advent of general-purpose graphics processing units (GPUs), many of those kernels now need to be rewritten, a task which can seem daunting to those new to GPU programming. Furthermore, these developers also need to ensure that their code will be both portable to future GPU architectures and flexible enough to evolve with their needs.

In this paper, we explore the possibility of using a higher level framework, testing both JAX and OpenMP target offload, to produce straightforward, portable code while achieving good GPU performance. JAX is a Python library that allows us to write our kernels in pure Python, while OpenMP target offload is a directive-based strategy that integrates seamlessly with our already OpenMP accelerated C++ kernels.

Experimenting on TOAST, a cosmology software framework that was designed to take full advantage of a supercomputer, we ported a dozen kernels to both frameworks in order to compare development cost, run times and to study whether they can be used to port a given a complex numerical code.


Long Description: In recent years, a common pattern has emerged where numerical software is designed around a Python interface calling high-performance kernels written in a lower level language.

With the advent of general-purpose graphics processing units (GPUs), many of those kernels now need to be rewritten, a task which can seem daunting to those new to GPU programming. Furthermore, these developers also need to ensure that their code will be both portable to future GPU architectures and flexible enough to evolve with their needs.

In this paper, we explore the possibility of using a higher level framework, testing both JAX and OpenMP target offload, to produce straightforward, portable code while achieving good GPU performance. JAX is a Python library that allows us to write our kernels in pure Python, while OpenMP target offload is a directive-based strategy that integrates seamlessly with our already OpenMP accelerated C++ kernels.

Experimenting on TOAST, a cosmology software framework that was designed to take full advantage of a supercomputer, we ported a dozen kernels to both frameworks in order to compare development cost, run times and to study whether they can be used to port a given a complex numerical code.


Paper: PDF



Back to Papers Archive Listing