Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

wgsl: support atomics on f32 (in workgroup and storage buffer) #4894

Open
dneto0 opened this issue Sep 24, 2024 · 9 comments
Open

wgsl: support atomics on f32 (in workgroup and storage buffer) #4894

dneto0 opened this issue Sep 24, 2024 · 9 comments
Labels
wgsl WebGPU Shading Language Issues
Milestone

Comments

@dneto0
Copy link
Contributor

dneto0 commented Sep 24, 2024

@jowens has a use case.

@dneto0 dneto0 added the wgsl WebGPU Shading Language Issues label Sep 24, 2024
@dneto0 dneto0 added this to the Milestone 2 milestone Sep 24, 2024
@petermcneeleychromium
Copy link

As discussed this can likely be worked around via bitcast and a uint32 CAS with possible exception of NaNs
https://www.w3.org/TR/WGSL/#bitcast-builtin

@jowens
Copy link

jowens commented Sep 24, 2024

Starting off with a description in English, then I'll move to pseudocode and actual code.

Consider a mesh of vertices and faces. Each face has computed a normal vector (this is called the "facet normal"). We wish to compute a normal vector for each vertex, which is the average of all of the normal vectors of the neighboring faces.

The straightforward data-parallel way to compute this is via scatter at cost O(V+E):

initialize each vertex's normal vector to (0,0,0)
parallel for each face in mesh:
  for each vertex in face:
    atomic-add face's normal to vertex's normal // this is scatter
parallel for each vertex in mesh:
  normalize(vertex normal)

@dneto0
Copy link
Contributor Author

dneto0 commented Sep 24, 2024

As discussed this can likely be worked around via bitcast and a uint32 CAS with possible exception of NaNs https://www.w3.org/TR/WGSL/#bitcast-builtin

Something like this. This uses a u32 cell in workgroup memory.

@group(0) @binding(0) var<storage> inputs: array<f32>;
@group(0) @binding(1) var<storage, read_write> output: f32;

// The cell holds u32 values.
var<workgroup> sum: atomic<u32>;

alias cell = ptr<workgroup,atomic<u32>>;

// Adds f32 'value' to the value in 'cell', atomically.
// Perform atomic compare-exchange with u32 type, and bitcast in and out of f32.
fn atomic_add(sum_cell: cell, value: f32) -> f32 {
  // Initializing to 0 forces second iteration in almost all cases.
  var old = 0u; //  alternately, atomicLoad(sum_cell);
  loop {
    let new_value = value + bitcast<f32>(old);
    let exchange_result = atomicCompareExchangeWeak(sum_cell, old, bitcast<u32>(new_value));
    if exchange_result.exchanged {
       return new_value;
    }
    old = exchange_result.old_value;
  }
}

@compute @workgroup_size(32)
fn main(@builtin(global_invocation_id) gid: vec3u) {
  atomic_add(&sum, inputs[gid.x]);
  workgroupBarrier();
  if gid.x == 0 {
    output = bitcast<f32>(atomicLoad(&sum));
  }
}

@jowens
Copy link

jowens commented Sep 24, 2024

If we don't have an atomic add, we can recast this as gather, but it requires cost O(VE):

parallel for each vertex in mesh:
  initialize each vertex's normal vector to (0,0,0)
  for each face in all faces:
    if vertex in face:
      add face.normal to my normal
  normalize(vertex normal)

We could probably build a data structure that mapped vertices to faces (key: vertex, value: list of faces) and that would significantly reduce complexity. Note it's a variable-length list of faces per vertex (a vertex's valence might potentially be a rather high number).

@jowens
Copy link

jowens commented Sep 24, 2024

Here's the code I wrote to do this:

      const facet_normals_module = device.createShaderModule({
        label: "compute facet normals module",
        code: /* wgsl */ `
                    /* output */
                    @group(0) @binding(0) var<storage, read_write> facet_normals: array<vec3f>;
                    /* input */
                    @group(0) @binding(1) var<storage, read> vertices: array<vec3f>;
                    @group(0) @binding(2) var<storage, read> triangle_indices: array<u32>;

                     /** Algorithm:
                      * For tri in all triangles:
                      *   Fetch all 3 vertices of tri
                      *   Compute normalize(cross(v1-v0, v2-v0))
                      *   For each vertex in tri:
                      *     Atomically add it to vertex_normals[vertex]
                      *     /* Can't do this! No f32 atomics */
                      * For vertex in all vertices:
                      *   Normalize vertex_normals[vertex]
                      *
                      * OK, so we can't do this approach w/o f32 atomics
                      * So we will instead convert this scatter to gather
                      * This is wasteful; every vertex will walk the entire
                      *   index array looking for matches.
                      * Could alternately build a mapping of {vtx->facet}
                      *
                      * (1) For tri in all triangles:
                      *   Fetch all 3 vertices of tri
                      *   Compute normalize(cross(v1-v0, v2-v0))
                      *   Store that vector as a facet normal
                      * (2) For vertex in all vertices:
                      *   normal[vertex] = (0,0,0)
                      *   For tri in all triangles:
                      *     // note expensive doubly-nested loop!
                      *     if my vertex is in that triangle:
                      *       normal[vertex] += facet_normal[tri]
                      *   normalize(normal[vertex])
                      */
                    @compute @workgroup_size(${WORKGROUP_SIZE}) fn facet_normals_kernel(
                      @builtin(global_invocation_id) id: vec3u) {
                        let tri = id.x;
                        if (tri < arrayLength(&facet_normals)) {
                          /* note triangle_indices is u32 not vec3, do math accordingly */
                          let v0: vec3f = vertices[triangle_indices[tri * 3]];
                          let v1: vec3f = vertices[triangle_indices[tri * 3 + 1]];
                          let v2: vec3f = vertices[triangle_indices[tri * 3 + 2]];
                          facet_normals[tri] = normalize(cross(v1-v0, v2-v0));
                        }
                      }
                  `,
      });

      const vertex_normals_module = device.createShaderModule({
        label: "compute vertex normals module",
        code: /* wgsl */ `
                    /* output */
                    @group(0) @binding(0) var<storage, read_write> vertex_normals: array<vec3f>;
                    /* input */
                    @group(0) @binding(1) var<storage, read> facet_normals: array<vec3f>;
                    @group(0) @binding(2) var<storage, read> triangle_indices: array<u32>;

                    /* see facet_normals_module for algorithm */

                    @compute @workgroup_size(${WORKGROUP_SIZE}) fn vertex_normals_kernel(
                      @builtin(global_invocation_id) id: vec3u) {
                        let vtx = id.x;
                        if (vtx < arrayLength(&vertex_normals)) {
                          vertex_normals[vtx] = vec3f(0, 0, 0);
                          /* note triangle_indices is u32 not vec3, do math accordingly */
                          for (var tri: u32 = 0; tri < arrayLength(&triangle_indices) / 3; tri++) {
                            for (var tri_vtx: u32 = 0; tri_vtx < 3; tri_vtx++) { /* unroll */
                              if (vtx == triangle_indices[tri * 3 + tri_vtx]) {
                                vertex_normals[vtx] += facet_normals[tri];
                              }
                            }
                          }
                          vertex_normals[vtx] = normalize(vertex_normals[vtx]);
                        }
                    }
                  `,
      });

I agree that compare-and-swap is a viable way to do this and I will implement it.

@petermcneeleychromium
Copy link

You can probably avoid

parallel for each vertex in mesh:
normalize(vertex normal)

If you just normalize in the CAS. I wonder if the is a ABA problem with this?

@jowens
Copy link

jowens commented Sep 24, 2024

If you just normalize in the CAS.

Not sure I understand? I have to wait for all adds to complete before I can normalize (I can't normalize halfway through). My mental picture is that I would need a global barrier to make sure all adds into a vertex normal are complete before I can normalize the resulting normal.

@petermcneeleychromium
Copy link

You are correct. You would need to store additional information (a magnitude) if you wanted to avoid this secondary pass

@jowens
Copy link

jowens commented Sep 25, 2024

I teach the ABA problem but have never encountered it in the wild so this will be fun for me to learn about!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
wgsl WebGPU Shading Language Issues
Projects
None yet
Development

No branches or pull requests

3 participants