Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Simplify distance/detail to make is easier to dispatch to different kernel implementations #1142

Merged
merged 67 commits into from
Mar 10, 2023

Conversation

ahendriksen
Copy link
Contributor

@ahendriksen ahendriksen commented Jan 13, 2023

The pairwise distance metrics are quite varied. The table below summarizes the differences, in terms of

  • Epilog : whether the metric has a non-empty epilog operation.
  • Uses norms: whether the metric requires precalculation of the norms of the vectors.
  • Has params: whether the norm has additional parameters. The L2 metric, for instance, has the sqrt boolean parameter that determines whether to calculate the squared or actual distance.
  • Pre- & post-processing: For some metrics, the norms have to be precalculated. For other metrics, the input matrices are transformed before the kernel launch, and "untransformed" after.
  • Expensive inner loop: some metrics use pow, log or other expensive functions in the inner loop.
  • Depends on row-major: the calculation of some metrics depend on whether the input is row-major.
  • CUTLASS: some metrics have an implementation using CUTLASS and tensor cores.
Metric Epilog Uses norms Has params Pre- & post-processing Expensive inner loop Depends on row-major CUTLASS
Canberra         x    
Chebyshev (Linf)              
Correlation x x (twice) x (many) compute norms   x  
Cosine x x   compute norms     x
Hamming x   x (k)        
Hellinger x     sqrt and square      
Jensen Shannon x       x    
KL divergence x   x (row major, x == y) yes x x  
L1              
L2 expanded x x x (sqrt) compute norms     x
L2 unexpanded x   x (sqrt)        
Minkowski (Lp) x   x (p)   x    
Russel-Rao x   x (k, 1/k)        

To keep the complexity that results from all these differences in check, there are several layers between the public API and the kernel launch, each with their own responsibility.

Before

  1. raft::distance::pairwise_distance takes distance type as a run-time argument and dispatches to raft::distance::detail::pairwise_distance_impl.
  2. raft::distance::detail::pairwise_distance_impl allocates workspace as necessary and calls raft::distance::detail::distance
  3. raft::distance::detail::distance defines a default final operation (the identity) and calls an overload of itself.
  4. raft::distance::detail::distance (with fin_op) initializes a DistanceImpl zero-sized struct with the correct template arguments and runs the .run() method of the struct.
  5. raft::distance::detail::DistanceImpl<DistanceType>.run() calls raft::distance::detail::XX_Impl.
  6. raft::distance::detail::XX_Impl has the following responsibilities:
    • Pre-compute norms if necessary
    • Transform input if necessary
    • If metric supports a CUTLASS operation, dispatch if necessary.
    • Swap inputs if column-major.
    • Based on runtime parameter row_major dispatch to function template raft::distance::detail::XX<bool row_major>
  7. raft::distance::detail::XX based on alignment of input data dispatch to function template raft::distance::detail::XX_Impl<int veclen> (different overload of previous raft::distance::detail::XX_Impl)
  8. raft::distance::detail::XX_Impl has the following responsibilities:
    • Define core_op and epilog_op
    • Define use_norms
    • Launch kernel pairwiseDistanceMatKernel with correct launch parameters

Observations:

  • Steps 6 and 7 both convert a runtime value to a compile time constant (row-major layout and alignment).
  • Step 7 is repeated (copy pasted) for each metric.
  • Steps 7 and 8 do a lot of different things and the steps in between do relatively little.
  • Steps 1-5 do fairly little (but require a lot of boilerplate)

Proposal:

  1. Collect as much of the runtime behavior of each metric in a distance_op that contains:
    • The core_op
    • The epilog_op
    • The required shared memory
    • Whether the inner loop is expensive (and thus loop unrolling should be curtailed)
  2. Collect the runtime -> compile-time dispatch in one location (dispatch.cuh)
  3. Collect kernel launching in one location
  4. Remove some of the boilerplate in steps 1-5.

After

  1. raft::distance::pairwise_distance takes distance type as a run-time argument, allocates workspace as necessary, and dispatches to raft::distance::detail::distance.
  2. raft::distance::detail::distance defines a default final operation (the identity) and calls an overload of itself.
  3. raft::distance::detail::distance (with fin_op) calls an overload of raft::distance::detail::distance_impl for the correct distance type.
  4. raft::distance::detail::distance_impl has the following responsibilities:
    • Pre-compute norms if necessary
    • Initialize distance op with parameters as necessary, see below for more information.
    • Transform input if necessary
    • If metric supports a CUTLASS operation, dispatch if necessary.
    • Dispatch to raft::distance::detail::distance_matrix_dispatch
  5. raft::distance::detail::distance_matrix_dispatch has the following responsibilities:
    • swap x, y matrices if column major
    • dispatch to correct kernel based on run-time parameters row_major and vec_len
    • Determine kernel policy based on parameters
    • Call raft::distance::detail::pairwise_matrix
  6. raft::distance::detail::pairwise_matrix launches the raft::distance::detail::pairwise_matrix_kernel with the correct launch parameters.

Distance_op
raft::distance::detail::ops::XX_distance_op [example] has the following responsibilities:

  • Take any parameters (sqrt, k, etc)
  • Define core_op and epilog_op
  • Define use_norms, expensive_inner_loop, and shared_mem_size().

Still TODO:

  • Rename Minkowski and Chebyshev to Lp and Linf.
  • Do something with this note in the comments: "if workspace is passed as nullptr, this will return in worksize, the number of bytes of workspace required", which is wrong.
  • Add a mechanism to limit duplicate compilation when a CUTLASS kernel is available. This is done in follow up PR Add dispatch based on compute architecture #1295.
  • Some distance_ops have additional template parameters. This must be cleared up.

@ahendriksen ahendriksen changed the title [WIP] Refactor distance/detail [WIP] Refactor distance/detail to make is easier to dispatch to different kernel implementations Jan 13, 2023
@ahendriksen ahendriksen force-pushed the wip-refactor-distance branch from 6d78314 to d58f0f3 Compare January 13, 2023 21:23
@github-actions github-actions bot removed the CMake label Jan 13, 2023
@ahendriksen ahendriksen force-pushed the wip-refactor-distance branch from d58f0f3 to a213ef3 Compare January 13, 2023 21:27
@ahendriksen ahendriksen changed the title [WIP] Refactor distance/detail to make is easier to dispatch to different kernel implementations [WIP] Simplify distance/detail to make is easier to dispatch to different kernel implementations Jan 13, 2023
@ahendriksen
Copy link
Contributor Author

Reminder/TODO

make grid stride loop variables local variables instead of member variables.

See: #838 (comment)

@ahendriksen
Copy link
Contributor Author

TODO: #838 (comment)

Also look at fused and maskedL2NN for shared memory calculations.

@ahendriksen ahendriksen mentioned this pull request Jan 25, 2023
@cjnolet cjnolet added improvement Improvement / enhancement to an existing function non-breaking Non-breaking change labels Jan 26, 2023
The calculation of the tile indices are now performed in ldgXY(). This
will make it possible to remove all state related to the tile index out
of the class in the next commit.

Note that the calculation of the tile index can depend on which
overloaded constructor is called(!)
This commit moves all grid and tile indexing logic into the caller.
Contractions_NT is now only responsible for *intra*-tile indexing.

Due to the complexity of the epilog function, the ldgNextGridStride
function is not yet called from within the main loop. That is the next
goal so that we have all the grid and tile indexing localized in the
loop.
This commit removes the epilog function and moves its functionality into
the run loop. The next step might be to see if the ldgNextGridStride()
method has to be called the current location, or if performance is the
same if its called at the start of the loop.
This results in subtle issues with non-square KernelPolicy, as found in
fusedL2KNN.
This is more general than just for L1. Making use of it more is work in
progress.
This did remove support for the CUTLASS kernels. Has to be put back.
I wasted a lot of time because I had not replaced the op::core() method
of the l2_exp_distance_op after I copied it from l2_unexp_distance_op...

If I copy something from the template and forget to fill it in, I get a
compile error.
I am testing on CUDA 12, where it does not seem to work. Prior to my
commits, the CUTLASS kernels were also not working. So not sure what's
up.

In any case: consider this untested.
@ahendriksen
Copy link
Contributor Author

The CI error does not seem related to the changes in the PR:

2023-02-22T19:37:49.5526382Z   Error compiling Cython file:
2023-02-22T19:37:49.5533415Z   ------------------------------------------------------------
2023-02-22T19:37:49.5539691Z   ...
2023-02-22T19:37:49.5546449Z   # cython: language_level = 3
2023-02-22T19:37:49.5552762Z 
2023-02-22T19:37:49.5559417Z 
2023-02-22T19:37:49.5566630Z   from libcpp.memory cimport shared_ptr, unique_ptr
2023-02-22T19:37:49.5573013Z 
2023-02-22T19:37:49.5580739Z   from rmm._lib.cuda_stream_pool cimport cuda_stream_pool

Is there a way to rerun the CI without pushing an empty commit?

@vyasr
Copy link
Contributor

vyasr commented Feb 23, 2023

@ahendriksen if you click on the "Details" button it will take you to the actions tab where you can use the "Re-run jobs" dropdown. See https://docs.github.com/en/actions/managing-workflow-runs/re-running-workflows-and-jobs for more info. I've queued up a rerun.

@ahendriksen
Copy link
Contributor Author

Thanks! And thanks for linking to the docs. I will have a look to see what else I am missing.

Copy link
Member

@cjnolet cjnolet left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I really like the idea of abstracting each unique distance formula into composable ops. In this case, not only does it allow us to specify and react differently to various performance conditions, but it allows us to capture the differences of the computations themselves in a single place. This is very much what we did w/ the sparse APIs as well- they are a series of composable operations that can be executed according to their needs (in the sparse case that's binary vs dot-product-based vs full-pairwise evaluation).

I think this looks great. I'd like to see the functions in the public API that accept explicit workspaces deprecated because the memory resource should be getting propagated through the raft::resources instance (and be controllable by the user). I think we have an opportunity to here to expand our developer guide as well. Otherwise, I'm completely on board w/ this change.

cpp/include/raft/distance/distance.cuh Show resolved Hide resolved
cpp/test/CMakeLists.txt Show resolved Hide resolved
cpp/include/raft/distance/distance.cuh Show resolved Hide resolved
Copy link
Contributor

@tfeher tfeher left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks Allard for this PR! This looks great! I am very happy to see the reduction of code duplication. I have just a few smaller questions.

Copy link
Contributor

@tfeher tfeher left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks Allard for the updates! The PR looks good to me!

@cjnolet
Copy link
Member

cjnolet commented Mar 9, 2023

Just a heads up- we need to merge this cuml PR before this RAFT PR is merged, otherwise cuml will break downstream.

@ahendriksen
Copy link
Contributor Author

Good catch. I should be more careful with the non-breaking label. Although that file has been deprecated for a couple of months now.

@cjnolet
Copy link
Member

cjnolet commented Mar 9, 2023

Good catch. I should be more careful with the non-breaking label. Although that file has been deprecated for a couple of months now.

No problem. To be fair, we had updated cuml in 23.02 to remove most of its uses of deprecated headers but this one slipped through the cracks. Before we merged this I just wanted to do a quick grep to be sure.

Copy link
Member

@cjnolet cjnolet left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM!

@cjnolet
Copy link
Member

cjnolet commented Mar 10, 2023

/merge

@rapids-bot rapids-bot bot merged commit e4aec7b into rapidsai:branch-23.04 Mar 10, 2023
lowener pushed a commit to lowener/raft that referenced this pull request Mar 15, 2023
…ernel implementations (rapidsai#1142)

The pairwise distance metrics are quite varied. The table below summarizes the differences, in terms of 

- Epilog : whether the metric has a non-empty epilog operation.
- Uses norms: whether the metric requires precalculation of the norms of the vectors.
- Has params: whether the norm has additional parameters. The L2 metric, for instance, has the `sqrt` boolean parameter that determines whether to calculate the squared or actual distance.
- Pre- & post-processing: For some metrics, the norms have to be precalculated. For other metrics, the input matrices are transformed before the kernel launch, and "untransformed" after.
- Expensive inner loop: some metrics use `pow`, `log` or other expensive functions in the inner loop. 
- Depends on row-major: the calculation of some metrics depend on whether the input is row-major. 
- CUTLASS: some metrics have an implementation using CUTLASS and tensor cores.


<table border="2" cellspacing="0" cellpadding="6" rules="groups" frame="hsides">


<colgroup>
<col  class="org-left" />

<col  class="org-left" />

<col  class="org-left" />

<col  class="org-left" />

<col  class="org-left" />

<col  class="org-left" />

<col  class="org-left" />

<col  class="org-left" />
</colgroup>
<thead>
<tr>
<th scope="col" class="org-left">Metric</th>
<th scope="col" class="org-left">Epilog</th>
<th scope="col" class="org-left">Uses norms</th>
<th scope="col" class="org-left">Has params</th>
<th scope="col" class="org-left">Pre- &amp; post-processing</th>
<th scope="col" class="org-left">Expensive inner loop</th>
<th scope="col" class="org-left">Depends on row-major</th>
<th scope="col" class="org-left">CUTLASS</th>
</tr>
</thead>

<tbody>
<tr>
<td class="org-left">Canberra</td>
<td class="org-left">&#xa0;</td>
<td class="org-left">&#xa0;</td>
<td class="org-left">&#xa0;</td>
<td class="org-left">&#xa0;</td>
<td class="org-left">x</td>
<td class="org-left">&#xa0;</td>
<td class="org-left">&#xa0;</td>
</tr>


<tr>
<td class="org-left">Chebyshev (Linf)</td>
<td class="org-left">&#xa0;</td>
<td class="org-left">&#xa0;</td>
<td class="org-left">&#xa0;</td>
<td class="org-left">&#xa0;</td>
<td class="org-left">&#xa0;</td>
<td class="org-left">&#xa0;</td>
<td class="org-left">&#xa0;</td>
</tr>


<tr>
<td class="org-left">Correlation</td>
<td class="org-left">x</td>
<td class="org-left">x (twice)</td>
<td class="org-left">x (many)</td>
<td class="org-left">compute norms</td>
<td class="org-left">&#xa0;</td>
<td class="org-left">x</td>
<td class="org-left">&#xa0;</td>
</tr>


<tr>
<td class="org-left">Cosine</td>
<td class="org-left">x</td>
<td class="org-left">x</td>
<td class="org-left">&#xa0;</td>
<td class="org-left">compute norms</td>
<td class="org-left">&#xa0;</td>
<td class="org-left">&#xa0;</td>
<td class="org-left">x</td>
</tr>


<tr>
<td class="org-left">Hamming</td>
<td class="org-left">x</td>
<td class="org-left">&#xa0;</td>
<td class="org-left">x (k)</td>
<td class="org-left">&#xa0;</td>
<td class="org-left">&#xa0;</td>
<td class="org-left">&#xa0;</td>
<td class="org-left">&#xa0;</td>
</tr>


<tr>
<td class="org-left">Hellinger</td>
<td class="org-left">x</td>
<td class="org-left">&#xa0;</td>
<td class="org-left">&#xa0;</td>
<td class="org-left">sqrt and square</td>
<td class="org-left">&#xa0;</td>
<td class="org-left">&#xa0;</td>
<td class="org-left">&#xa0;</td>
</tr>


<tr>
<td class="org-left">Jensen Shannon</td>
<td class="org-left">x</td>
<td class="org-left">&#xa0;</td>
<td class="org-left">&#xa0;</td>
<td class="org-left">&#xa0;</td>
<td class="org-left">x</td>
<td class="org-left">&#xa0;</td>
<td class="org-left">&#xa0;</td>
</tr>


<tr>
<td class="org-left">KL divergence</td>
<td class="org-left">x</td>
<td class="org-left">&#xa0;</td>
<td class="org-left">x (row major, x == y)</td>
<td class="org-left">yes</td>
<td class="org-left">x</td>
<td class="org-left">x</td>
<td class="org-left">&#xa0;</td>
</tr>


<tr>
<td class="org-left">L1</td>
<td class="org-left">&#xa0;</td>
<td class="org-left">&#xa0;</td>
<td class="org-left">&#xa0;</td>
<td class="org-left">&#xa0;</td>
<td class="org-left">&#xa0;</td>
<td class="org-left">&#xa0;</td>
<td class="org-left">&#xa0;</td>
</tr>


<tr>
<td class="org-left">L2 expanded</td>
<td class="org-left">x</td>
<td class="org-left">x</td>
<td class="org-left">x (sqrt)</td>
<td class="org-left">compute norms</td>
<td class="org-left">&#xa0;</td>
<td class="org-left">&#xa0;</td>
<td class="org-left">x</td>
</tr>


<tr>
<td class="org-left">L2 unexpanded</td>
<td class="org-left">x</td>
<td class="org-left">&#xa0;</td>
<td class="org-left">x (sqrt)</td>
<td class="org-left">&#xa0;</td>
<td class="org-left">&#xa0;</td>
<td class="org-left">&#xa0;</td>
<td class="org-left">&#xa0;</td>
</tr>


<tr>
<td class="org-left">Minkowski (Lp)</td>
<td class="org-left">x</td>
<td class="org-left">&#xa0;</td>
<td class="org-left">x (p)</td>
<td class="org-left">&#xa0;</td>
<td class="org-left">x</td>
<td class="org-left">&#xa0;</td>
<td class="org-left">&#xa0;</td>
</tr>


<tr>
<td class="org-left">Russel-Rao</td>
<td class="org-left">x</td>
<td class="org-left">&#xa0;</td>
<td class="org-left">x (k, 1/k)</td>
<td class="org-left">&#xa0;</td>
<td class="org-left">&#xa0;</td>
<td class="org-left">&#xa0;</td>
<td class="org-left">&#xa0;</td>
</tr>
</tbody>
</table>


To keep the complexity that results from all these differences in check, there are several layers between the public API and the kernel launch, each with their own responsibility. 

## Before
1. `raft::distance::pairwise_distance` takes distance type as a run-time argument and dispatches to `raft::distance::detail::pairwise_distance_impl`.
2. `raft::distance::detail::pairwise_distance_impl` allocates workspace as necessary and calls `raft::distance::detail::distance`
3. `raft::distance::detail::distance` defines a default final operation (the identity) and calls an overload of itself.
4. `raft::distance::detail::distance` (with `fin_op`) initializes a `DistanceImpl` zero-sized struct with the correct template arguments and runs the `.run()` method of the struct.
5. `raft::distance::detail::DistanceImpl<DistanceType>.run()` calls `raft::distance::detail::XX_Impl`.
6. `raft::distance::detail::XX_Impl` has the following responsibilities:
   - Pre-compute norms if necessary
   - Transform input if necessary
   - If metric supports a CUTLASS operation, dispatch if necessary.
   - Swap inputs if column-major.
   - Based on runtime parameter `row_major` dispatch to function template `raft::distance::detail::XX<bool row_major>`
7. `raft::distance::detail::XX` based on alignment of input data dispatch to function template `raft::distance::detail::XX_Impl<int veclen>` (different overload of previous `raft::distance::detail::XX_Impl`)
8. `raft::distance::detail::XX_Impl` has the following responsibilities:
   - Define `core_op` and `epilog_op`
   - Define `use_norms`
   - Launch kernel `pairwiseDistanceMatKernel` with correct launch parameters

**Observations**: 
- Steps 6 and 7 both convert a runtime value to a compile time constant (row-major layout and alignment). 
- Step 7 is repeated (copy pasted) for each metric.
- Steps 7 and 8 do a lot of different things and the steps in between do relatively little.
- Steps 1-5 do fairly little (but require a lot of boilerplate)

**Proposal**:

1. Collect as much of the runtime behavior of each metric in a `distance_op` that [contains](https://github.com/ahendriksen/raft/blob/wip-refactor-distance/cpp/include/raft/distance/detail/distance_ops/canberra.cuh):
    - The core_op
    - The epilog_op
    - The required shared memory
    - Whether the inner loop is expensive (and thus loop unrolling should be curtailed)
2. Collect the runtime -> compile-time dispatch in one location ([dispatch.cuh](https://github.com/ahendriksen/raft/blob/486393eff4e0cf1d45ab9d7990b64d607e835d70/cpp/include/raft/distance/detail/pairwise_matrix/dispatch.cuh#L70))
3. Collect kernel launching in one [location](https://github.com/ahendriksen/raft/blob/486393eff4e0cf1d45ab9d7990b64d607e835d70/cpp/include/raft/distance/detail/pairwise_matrix/kernel_sm60.cuh#L108)
4. Remove some of the boilerplate in steps 1-5.

## After
1. `raft::distance::pairwise_distance` takes distance type as a run-time argument, allocates workspace as necessary, and dispatches to `raft::distance::detail::distance`.
2. `raft::distance::detail::distance` defines a default final operation (the identity) and calls an overload of itself.
3. `raft::distance::detail::distance` (with `fin_op`) calls an overload of `raft::distance::detail::distance_impl` for the correct distance type.
4. `raft::distance::detail::distance_impl` has the following responsibilities:
   - Pre-compute norms if necessary
   - Initialize distance op with parameters as necessary, see below for more information.
   - Transform input if necessary
   - If metric supports a CUTLASS operation, dispatch if necessary.
   - Dispatch to `raft::distance::detail::distance_matrix_dispatch`
5. `raft::distance::detail::distance_matrix_dispatch` has the following responsibilities:
   - swap x, y matrices if column major
   - dispatch to correct kernel based on run-time parameters `row_major` and `vec_len`
   - Determine kernel policy based on parameters
   - Call `raft::distance::detail::pairwise_matrix`
6. `raft::distance::detail::pairwise_matrix` launches the `raft::distance::detail::pairwise_matrix_kernel` with the correct launch parameters.

**Distance_op**
 `raft::distance::detail::ops::XX_distance_op` [[example]](https://github.com/ahendriksen/raft/blob/wip-refactor-distance/cpp/include/raft/distance/detail/distance_ops/canberra.cuh) has the following responsibilities:
   - Take any parameters (sqrt, k, etc)
   - Define `core_op` and `epilog_op`
   - Define `use_norms`, `expensive_inner_loop`, and `shared_mem_size()`.


Still TODO:

- [x] Rename Minkowski and Chebyshev to Lp and Linf.
- [x] Do something with this note in the comments: "if workspace is passed as nullptr, this will return in worksize, the number of bytes of workspace required", which is wrong.
- [x] Add a mechanism to limit duplicate compilation when a CUTLASS kernel is available. This is done in follow up PR rapidsai#1295.
- [x] Some distance_ops have additional template parameters. This must be cleared up.

Authors:
  - Allard Hendriksen (https://github.com/ahendriksen)
  - Corey J. Nolet (https://github.com/cjnolet)

Approvers:
  - Tamas Bela Feher (https://github.com/tfeher)
  - Corey J. Nolet (https://github.com/cjnolet)

URL: rapidsai#1142
@ahendriksen ahendriksen deleted the wip-refactor-distance branch March 17, 2023 09:35
rapids-bot bot pushed a commit that referenced this pull request Mar 17, 2023
This PR improves the ability to do dispatch based on compute architecture. It is a follow up to #1142.

It has two goals: 

1. Make it easier to specify which compute architectures a kernel is compatible with / should be compiled for.
2. Make it easier to compile a kernel only for the architectures for which it is used (if it is unused, the kernel should be empty).

We have a specific use case in RAFT for this feature. For the L2 pairwise distance kernel we have a CUTLASS based implementation that works om SM80+ and a fallback kernel. Preferably, each kernel is only compiled for the architectures on which it is actually used.

Authors:
  - Allard Hendriksen (https://github.com/ahendriksen)
  - Corey J. Nolet (https://github.com/cjnolet)

Approvers:
  - Tamas Bela Feher (https://github.com/tfeher)
  - Corey J. Nolet (https://github.com/cjnolet)

URL: #1335
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
5 - Ready to Merge CMake cpp improvement Improvement / enhancement to an existing function non-breaking Non-breaking change
Projects
Development

Successfully merging this pull request may close these issues.

4 participants