| | |
| |
|
| | |
| |
|
| | |
| | |
| |
|
| | |
| |
|
| | void cutlass_scaled_mm_sm75(torch::Tensor& c, torch::Tensor const& a, |
| | torch::Tensor const& b, |
| | torch::Tensor const& a_scales, |
| | torch::Tensor const& b_scales, |
| | std::optional<torch::Tensor> const& bias); |
| |
|
| | void cutlass_scaled_mm_sm80(torch::Tensor& c, torch::Tensor const& a, |
| | torch::Tensor const& b, |
| | torch::Tensor const& a_scales, |
| | torch::Tensor const& b_scales, |
| | std::optional<torch::Tensor> const& bias); |
| |
|
| | void cutlass_scaled_mm_sm89(torch::Tensor& c, torch::Tensor const& a, |
| | torch::Tensor const& b, |
| | torch::Tensor const& a_scales, |
| | torch::Tensor const& b_scales, |
| | std::optional<torch::Tensor> const& bias); |
| |
|
| | |
| | void cutlass_scaled_mm_sm90(torch::Tensor& c, torch::Tensor const& a, |
| | torch::Tensor const& b, |
| | torch::Tensor const& a_scales, |
| | torch::Tensor const& b_scales, |
| | std::optional<torch::Tensor> const& bias); |
| | |
| |
|
| | |
| | void cutlass_scaled_mm_sm100(torch::Tensor& c, torch::Tensor const& a, |
| | torch::Tensor const& b, |
| | torch::Tensor const& a_scales, |
| | torch::Tensor const& b_scales, |
| | std::optional<torch::Tensor> const& bias); |
| | |
| |
|
| | void cutlass_scaled_mm_azp_sm75(torch::Tensor& c, torch::Tensor const& a, |
| | torch::Tensor const& b, |
| | torch::Tensor const& a_scales, |
| | torch::Tensor const& b_scales, |
| | torch::Tensor const& azp_adj, |
| | std::optional<torch::Tensor> const& azp, |
| | std::optional<torch::Tensor> const& bias); |
| |
|
| | void cutlass_scaled_mm_azp_sm80(torch::Tensor& c, torch::Tensor const& a, |
| | torch::Tensor const& b, |
| | torch::Tensor const& a_scales, |
| | torch::Tensor const& b_scales, |
| | torch::Tensor const& azp_adj, |
| | std::optional<torch::Tensor> const& azp, |
| | std::optional<torch::Tensor> const& bias); |
| |
|
| | void cutlass_scaled_mm_azp_sm89(torch::Tensor& c, torch::Tensor const& a, |
| | torch::Tensor const& b, |
| | torch::Tensor const& a_scales, |
| | torch::Tensor const& b_scales, |
| | torch::Tensor const& azp_adj, |
| | std::optional<torch::Tensor> const& azp, |
| | std::optional<torch::Tensor> const& bias); |
| |
|
| | |
| | void cutlass_scaled_mm_azp_sm90(torch::Tensor& c, torch::Tensor const& a, |
| | torch::Tensor const& b, |
| | torch::Tensor const& a_scales, |
| | torch::Tensor const& b_scales, |
| | torch::Tensor const& azp_adj, |
| | std::optional<torch::Tensor> const& azp, |
| | std::optional<torch::Tensor> const& bias); |
| | |
| |
|
| | bool cutlass_scaled_mm_supports_fp8(int64_t cuda_device_capability) { |
| | |
| | |
| | |
| |
|
| | |
| | if (cuda_device_capability >= 90) { |
| | return CUDA_VERSION >= 12000; |
| | } else if (cuda_device_capability >= 89) { |
| | return CUDA_VERSION >= 12040; |
| | } |
| | |
| |
|
| | return false; |
| | } |
| |
|
| | bool cutlass_scaled_mm_supports_block_fp8(int64_t cuda_device_capability) { |
| | |
| | |
| |
|
| | |
| | if (cuda_device_capability >= 90 && cuda_device_capability < 100) { |
| | return CUDA_VERSION >= 12000; |
| | } else if (cuda_device_capability >= 100) { |
| | return CUDA_VERSION >= 12080; |
| | } |
| | |
| |
|
| | return false; |
| | } |
| |
|
| | bool cutlass_group_gemm_supported(int64_t cuda_device_capability) { |
| | |
| | |
| |
|
| | |
| | if (cuda_device_capability == 90) { |
| | return CUDA_VERSION >= 12030; |
| | } |
| | |
| |
|
| | return false; |
| | } |
| |
|
| | void cutlass_scaled_mm(torch::Tensor& c, torch::Tensor const& a, |
| | torch::Tensor const& b, torch::Tensor const& a_scales, |
| | torch::Tensor const& b_scales, |
| | std::optional<torch::Tensor> const& bias) { |
| | |
| | TORCH_CHECK(a.dim() == 2 && b.dim() == 2 && c.dim() == 2); |
| | TORCH_CHECK(c.size(0) == a.size(0) && a.size(1) == b.size(0) && |
| | b.size(1) == c.size(1)); |
| |
|
| | |
| | TORCH_CHECK(a.stride(1) == 1 && c.stride(1) == 1); |
| | TORCH_CHECK(b.stride(0) == 1); |
| | TORCH_CHECK(c.stride(0) % 16 == 0 && |
| | b.stride(1) % 16 == 0); |
| |
|
| | if (bias) { |
| | TORCH_CHECK(bias->numel() == b.size(1) && bias->is_contiguous() && |
| | bias->dim() == 1); |
| | } |
| |
|
| | at::cuda::OptionalCUDAGuard const device_guard(device_of(a)); |
| | int32_t version_num = get_sm_version_num(); |
| |
|
| | |
| | if (version_num >= 100) { |
| | cutlass_scaled_mm_sm100(c, a, b, a_scales, b_scales, bias); |
| | return; |
| | } |
| | |
| |
|
| | |
| | |
| | if (version_num >= 90 && version_num < 100) { |
| | |
| | cutlass_scaled_mm_sm90(c, a, b, a_scales, b_scales, bias); |
| | return; |
| | } |
| | |
| |
|
| | if (version_num == 89) { |
| | |
| | cutlass_scaled_mm_sm89(c, a, b, a_scales, b_scales, bias); |
| | return; |
| | } |
| |
|
| | if (version_num >= 80) { |
| | |
| | cutlass_scaled_mm_sm80(c, a, b, a_scales, b_scales, bias); |
| | return; |
| | } |
| |
|
| | if (version_num >= 75) { |
| | |
| | cutlass_scaled_mm_sm75(c, a, b, a_scales, b_scales, bias); |
| | return; |
| | } |
| |
|
| | TORCH_CHECK_NOT_IMPLEMENTED( |
| | false, |
| | "No compiled cutlass_scaled_mm for a compute capability less than " |
| | "CUDA device capability: ", |
| | std::to_string(version_num)); |
| | } |
| |
|
| | void cutlass_scaled_mm_azp(torch::Tensor& c, torch::Tensor const& a, |
| | torch::Tensor const& b, |
| | torch::Tensor const& a_scales, |
| | torch::Tensor const& b_scales, |
| | torch::Tensor const& azp_adj, |
| | std::optional<torch::Tensor> const& azp, |
| | std::optional<torch::Tensor> const& bias) { |
| | |
| | TORCH_CHECK(a.dim() == 2 && b.dim() == 2 && c.dim() == 2); |
| | TORCH_CHECK(c.size(0) == a.size(0) && a.size(1) == b.size(0) && |
| | b.size(1) == c.size(1)); |
| | TORCH_CHECK(a_scales.numel() == 1 || a_scales.numel() == a.size(0)); |
| | TORCH_CHECK(b_scales.numel() == 1 || b_scales.numel() == b.size(1)); |
| |
|
| | |
| | TORCH_CHECK(a.stride(1) == 1 && c.stride(1) == 1); |
| | TORCH_CHECK(b.stride(0) == 1); |
| | TORCH_CHECK(c.stride(0) % 16 == 0 && |
| | b.stride(1) % 16 == 0); |
| | TORCH_CHECK(a_scales.is_contiguous() && b_scales.is_contiguous()); |
| |
|
| | |
| | |
| | if (bias) { |
| | TORCH_CHECK(bias->numel() == b.size(1) && bias->is_contiguous()); |
| | } |
| | if (azp) { |
| | TORCH_CHECK(azp->numel() == a.size(0) && azp->is_contiguous()); |
| | } |
| | TORCH_CHECK(azp_adj.numel() == b.size(1) && azp_adj.is_contiguous()); |
| |
|
| | |
| | TORCH_CHECK(azp_adj.dtype() == torch::kInt32); |
| | TORCH_CHECK(!azp || azp->dtype() == torch::kInt32); |
| | TORCH_CHECK(!bias || bias->dtype() == c.dtype(), |
| | "currently bias dtype must match output dtype ", c.dtype()); |
| |
|
| | at::cuda::OptionalCUDAGuard const device_guard(device_of(a)); |
| |
|
| | int32_t version_num = get_sm_version_num(); |
| |
|
| | |
| | if (version_num >= 90) { |
| | cutlass_scaled_mm_azp_sm90(c, a, b, a_scales, b_scales, azp_adj, azp, bias); |
| | return; |
| | } |
| | |
| |
|
| | if (version_num == 89) { |
| | |
| | cutlass_scaled_mm_azp_sm89(c, a, b, a_scales, b_scales, azp_adj, azp, bias); |
| | return; |
| | } |
| |
|
| | if (version_num >= 80) { |
| | |
| | cutlass_scaled_mm_azp_sm80(c, a, b, a_scales, b_scales, azp_adj, azp, bias); |
| | return; |
| | } |
| |
|
| | |
| | TORCH_CHECK(version_num >= 75); |
| | cutlass_scaled_mm_azp_sm75(c, a, b, a_scales, b_scales, azp_adj, azp, bias); |
| | return; |
| |
|
| | TORCH_CHECK_NOT_IMPLEMENTED( |
| | false, |
| | "No compiled cutlass_scaled_mm_azp for a compute capability less than " |
| | "CUDA device capability: ", |
| | std::to_string(version_num)); |
| | } |
| |
|