Skip to content

Instantly share code, notes, and snippets.

@drisspg
Last active January 13, 2025 09:25
Show Gist options
  • Save drisspg/783616821043ab4594b9784f556c6714 to your computer and use it in GitHub Desktop.
Save drisspg/783616821043ab4594b9784f556c6714 to your computer and use it in GitHub Desktop.
Scaled MM API

Summary

This doc servers as a quick reference for the _scaled_mm API and how it has changed overtime for each major version of PyTorch.


NOTE The leading underscore is intended here and we make no current FC/BC guarantees on this API. That being said it is currently the only OP that has native support for FP8 matmuls within the PyTorch Libary. We are planning to make an official Public api for this. Until then this is subject to change but you can use this doc as a reference.


torch=2.1.0

This is the first version that has any support for the op. Signature _scaled_mm(Tensor self, Tensor mat2, *, Tensor? bias=None, ScalarType? out_dtype=None, Tensor? scale_a=None, Tensor? scale_b=None, Tensor? scale_result=None) -> (Tensor, Tensor)

def _scaled_mm(
    self: torch.Tensor,
    mat2: torch.Tensor,
    *,
    bias: Optional[torch.Tensor] = None,
    out_dtype: Optional[torch.dtype] = None,
    scale_a: Optional[torch.Tensor] = None,
    scale_b: Optional[torch.Tensor] = None,
    scale_result: Optional[torch.Tensor] = None
) -> Tuple[torch.Tensor, torch.Tensor]:

Notable change from later verisons is that scale_a and scale_b are optional and when they are None this implicitly uses a scale tensor w/ value 1 for both self and mat2. This will return a tuple whose first value is the result of the matmul and second value is the global abs_max of the resulting tensor. Scale_resut will be used to scale the accumulator. scale_result will only have an effect when the output dtype is an fp8 type.

torch==2.2-2.4

Signature _scaled_mm(Tensor self, Tensor mat2, *, Tensor? bias=None, ScalarType? out_dtype=None, Tensor? scale_a=None, Tensor? scale_b=None, Tensor? scale_result=None, bool use_fast_accum=False) -> (Tensor, Tensor))

def _scaled_mm(
    self: torch.Tensor,
    mat2: torch.Tensor,
    *,
    bias: Optional[torch.Tensor] = None,
    out_dtype: Optional[torch.dtype] = None,
    scale_a: Optional[torch.Tensor] = None,
    scale_b: Optional[torch.Tensor] = None,
    scale_result: Optional[torch.Tensor] = None,
    use_fast_accum: bool = False
) -> Tuple[torch.Tensor, torch.Tensor]:

Added the use_fast_accum parameter with type bool and default value False. For more details: fast_accum_section. TLDR is that tensorcore precision is not the same as ieee754. While reducing along the K dim the tensorcore accumulator will peridocially sync its running sum to a full precision accumulator only with this flag set otherwise it does the entire reduction in tensorcore precision.

torch==2.5

_scaled_mm(Tensor self, Tensor mat2, Tensor scale_a, Tensor scale_b, Tensor? bias=None, Tensor? scale_result=None, ScalarType? out_dtype=None, bool use_fast_accum=False) -> Tensor

def _scaled_mm(
    self: torch.Tensor,
    mat2: torch.Tensor,
    scale_a: torch.Tensor,
    scale_b: torch.Tensor,
    bias: Optional[torch.Tensor] = None,
    scale_result: Optional[torch.Tensor] = None,
    out_dtype: Optional[torch.dtype] = None,
    use_fast_accum: bool = False
) -> torch.Tensor:

The key differences from the previous versions are:

  • scale_a and scale_b are now required parameters (not optional)
  • The return type is a single torch.Tensor instead of a tuple of two tensors
  • The parameter order has changed

Why did we do this? There are few reasons the main one is that we did not find much usage of retruning the output in fp8. The need for amax (second value in the tuple) is primarily for delayed scaling. Another practical, albeit not great, is that it made it much easier to add Inductor lowerings with the above signature.

Up until this point the only type of scales we supported were PerTensor - scale tensor's shape == 1. This is the first version we added support for PerRow scaling.

Matrix Layout and Shape Constraints for scaled_mm

The function performs a scaled matrix multiplication with specific requirements for matrix layouts and shapes:

For input matrices:

  • First matrix (self): Shape (M, K), must be in row-major layout
  • Second matrix (mat2): Shape (K, N), must be in column-major layout

For scaling factors:

  • Per-tensor scaling: scale_a and scale_b should be shape (1)
  • Per-row scaling: scale_a can be shape (M, 1) and scale_b can be shape (1, N)

Note: We refer to this as "per-row" scaling because even though scale_b is logically PerColumn, the second matrix is laid it in memory so that each row of memory gets 1 scale.

The operation will produce an output matrix of shape (M, N).

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment