Skip to content

Commit 86edac4

Browse files
authored
Comm gemm fixes (#2818)
* Fix GemmRs B descriptor lld for transb=true With a row_major (1×P) grid, all rows are on a single process row, so the local leading dimension must be n (full row count), not block_size(n) which is n/P. Signed-off-by: Almog Segal <asegal@nvidia.com> * Set GemmRs communication type to output data type Match the UserBuffers behavior where the reduce-scatter operates in the output precision rather than FP32. Signed-off-by: Almog Segal <asegal@nvidia.com> --------- Signed-off-by: Almog Segal <asegal@nvidia.com>
1 parent ac96651 commit 86edac4

1 file changed

Lines changed: 8 additions & 3 deletions

File tree

transformer_engine/common/comm_gemm/comm_gemm.cpp

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -186,9 +186,9 @@ void GemmRsInitMatrices(NVTECommGemmCtx* ctx, int64_t* ldd, int64_t m, int64_t n
186186
}
187187
if (transb) {
188188
NVTE_CHECK(b1 == n, "Unsupported tensor dimension in B: expected ", n, ", got ", b1);
189-
NVTE_CHECK_CUBLASMP(cublasMpMatrixDescriptorInit(
190-
n, k, block_size(ctx, n), block_size(ctx, k), 0, 0, block_size(ctx, n),
191-
get_cuda_dtype(b->dtype()), ctx->grid_row_major.get(), ctx->b_desc.get()));
189+
NVTE_CHECK_CUBLASMP(cublasMpMatrixDescriptorInit(n, k, block_size(ctx, n), block_size(ctx, k),
190+
0, 0, n, get_cuda_dtype(b->dtype()),
191+
ctx->grid_row_major.get(), ctx->b_desc.get()));
192192
} else {
193193
NVTE_CHECK(b0 == n, "Unsupported tensor dimension in B: expected ", n, ", got ", b0);
194194
NVTE_CHECK_CUBLASMP(cublasMpMatrixDescriptorInit(
@@ -200,6 +200,11 @@ void GemmRsInitMatrices(NVTECommGemmCtx* ctx, int64_t* ldd, int64_t m, int64_t n
200200
NVTE_CHECK_CUBLASMP(cublasMpMatrixDescriptorInit(m, n, m, block_size(ctx, n), 0, 0, *ldd,
201201
get_cuda_dtype(d->dtype()),
202202
ctx->grid_row_major.get(), ctx->d_desc.get()));
203+
204+
const cudaDataType_t comm_type = get_cuda_dtype(d->dtype());
205+
NVTE_CHECK_CUBLASMP(cublasMpMatmulDescriptorSetAttribute(
206+
ctx->matmul_desc.get(), CUBLASMP_MATMUL_DESCRIPTOR_ATTRIBUTE_COMMUNICATION_TYPE, &comm_type,
207+
sizeof comm_type));
203208
}
204209

205210
void GemmArInitMatrices(NVTECommGemmCtx* ctx, int64_t* ldd, int64_t m, int64_t n, int64_t k,

0 commit comments

Comments
 (0)