Skip to content

JAX-based sampler #517

Open
Open
@JohnGoertz

Description

@JohnGoertz

Feature description
It would be great to have a sampler that's both compatible with JAX-jitted functions and leverages JAX's parallelization tools.

Motivation/Application
I have a slow objective function that is made significantly (orders of magnitude) faster using JAX tools, in particular jit and vmap. However, JAX's multithreading clashes with pyABC's multithreaded samplers, and pickling the jitted function doesn't behave either. Oddly, this isn't an issue with relatively simple versions of my objective function, those can use pyABC's default samplers, but more complex versions only work with pyABC's SingleCoreSampler.

I'd like to write an extension of the SingleCoreSampler that relies solely on JAX for vectorization/parallelization/multithreading. I have some ideas on how to get started but I'd like some pointers. This would work best as a batch-sampling system, where an array of samples are submitted and the model function is mapped across the array using vmap or pmap. This evaluation could itself be jitted as well. My questions are:

  • How could I get the sampler to create a batch of samples?
  • Would mapping submit_one across the batch work?
  • Do you know if there's anything that happens to the model function when it's assigned to submit_one that JAX might not like? (namely numpy operations).
  • How to return the samples after evaluation?

Metadata

Metadata

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions