diff --git a/crates/core_arch/missing-x86.md b/crates/core_arch/missing-x86.md index e9f68eb9e6..3a82f9761f 100644 --- a/crates/core_arch/missing-x86.md +++ b/crates/core_arch/missing-x86.md @@ -1,41 +1,4 @@ -
["AMX-BF16"]

- - * [ ] [`__tile_dpbf16ps`](https://software.intel.com/sites/landingpage/IntrinsicsGuide/#text=__tile_dpbf16ps) -

- - -
["AMX-COMPLEX"]

- - * [ ] [`__tile_cmmimfp16ps`](https://software.intel.com/sites/landingpage/IntrinsicsGuide/#text=__tile_cmmimfp16ps) - * [ ] [`__tile_cmmrlfp16ps`](https://software.intel.com/sites/landingpage/IntrinsicsGuide/#text=__tile_cmmrlfp16ps) -

- - -
["AMX-FP16"]

- - * [ ] [`__tile_dpfp16ps`](https://software.intel.com/sites/landingpage/IntrinsicsGuide/#text=__tile_dpfp16ps) -

- - -
["AMX-INT8"]

- - * [ ] [`__tile_dpbssd`](https://software.intel.com/sites/landingpage/IntrinsicsGuide/#text=__tile_dpbssd) - * [ ] [`__tile_dpbsud`](https://software.intel.com/sites/landingpage/IntrinsicsGuide/#text=__tile_dpbsud) - * [ ] [`__tile_dpbusd`](https://software.intel.com/sites/landingpage/IntrinsicsGuide/#text=__tile_dpbusd) - * [ ] [`__tile_dpbuud`](https://software.intel.com/sites/landingpage/IntrinsicsGuide/#text=__tile_dpbuud) -

- - -
["AMX-TILE"]

- - * [ ] [`__tile_loadd`](https://software.intel.com/sites/landingpage/IntrinsicsGuide/#text=__tile_loadd) - * [ ] [`__tile_stored`](https://software.intel.com/sites/landingpage/IntrinsicsGuide/#text=__tile_stored) - * [ ] [`__tile_stream_loadd`](https://software.intel.com/sites/landingpage/IntrinsicsGuide/#text=__tile_stream_loadd) - * [ ] [`__tile_zero`](https://software.intel.com/sites/landingpage/IntrinsicsGuide/#text=__tile_zero) -

- -
["AVX512_FP16"]

* [ ] [`_mm256_set1_pch`](https://software.intel.com/sites/landingpage/IntrinsicsGuide/#text=_mm256_set1_pch) diff --git a/crates/core_arch/src/x86_64/amx.rs b/crates/core_arch/src/x86_64/amx.rs index b3b3e86750..7888973fcc 100644 --- a/crates/core_arch/src/x86_64/amx.rs +++ b/crates/core_arch/src/x86_64/amx.rs @@ -1,3 +1,4 @@ +use crate::core_arch::x86_64::{__tile1024i, Tile}; use crate::core_arch::{simd::*, x86::*}; #[cfg(test)] @@ -44,6 +45,18 @@ pub unsafe fn _tile_loadd(base: *const u8, stride: usize) { tileloadd64(DST as i8, base, stride); } +/// Load tile rows from memory specified by base address and stride into destination tile dst. The shape +/// of the tile is specified in the struct of [`__tile1024i`]. The register of the tile is allocated by the compiler. +/// +/// [Intel's documentation](https://www.intel.com/content/www/us/en/docs/intrinsics-guide/index.html#text=__tile_loadd&ig_expand=6877) +#[inline] +#[target_feature(enable = "amx-tile")] +#[cfg_attr(test, assert_instr(tileloadd))] +#[unstable(feature = "x86_amx_intrinsics", issue = "126622")] +pub unsafe fn __tile_loadd(dst: *mut __tile1024i, base: *const u8, stride: usize) { + (*dst).tile = tileloadd64_internal((*dst).rows, (*dst).colsb, base, stride as u64); +} + /// Release the tile configuration to return to the init state, which releases all storage it currently holds. /// /// [Intel's documentation](https://www.intel.com/content/www/us/en/docs/intrinsics-guide/index.html#text=_tile_release&ig_expand=6878) @@ -68,6 +81,18 @@ pub unsafe fn _tile_stored(base: *mut u8, stride: usize) { tilestored64(DST as i8, base, stride); } +/// Store the tile specified by src to memory specified by base address and stride. The shape of the tile +/// is specified in the struct of [`__tile1024i`]. The register of the tile is allocated by the compiler. +/// +/// [Intel's documentation](https://www.intel.com/content/www/us/en/docs/intrinsics-guide/index.html#text=__tile_stored&ig_expand=6881) +#[inline] +#[target_feature(enable = "amx-tile")] +#[cfg_attr(test, assert_instr(tilestored))] +#[unstable(feature = "x86_amx_intrinsics", issue = "126622")] +pub unsafe fn __tile_stored(base: *mut u8, stride: usize, src: __tile1024i) { + tilestored64_internal(src.rows, src.colsb, base, stride as u64, src.tile); +} + /// Load tile rows from memory specifieid by base address and stride into destination tile dst using the tile configuration /// previously configured via _tile_loadconfig. This intrinsic provides a hint to the implementation that the data will /// likely not be reused in the near future and the data caching can be optimized accordingly. @@ -83,6 +108,20 @@ pub unsafe fn _tile_stream_loadd(base: *const u8, stride: usize) tileloaddt164(DST as i8, base, stride); } +/// Load tile rows from memory specified by base address and stride into destination tile dst. The shape +/// of the tile is specified in the struct of [`__tile1024i`]. The register of the tile is allocated by the compiler. +/// This intrinsic provides a hint to the implementation that the data will likely not be reused in the +/// near future and the data caching can be optimized accordingly. +/// +/// [Intel's documentation](https://www.intel.com/content/www/us/en/docs/intrinsics-guide/index.html#text=__tile_stream_loadd&ig_expand=6883) +#[inline] +#[target_feature(enable = "amx-tile")] +#[cfg_attr(test, assert_instr(tileloaddt1))] +#[unstable(feature = "x86_amx_intrinsics", issue = "126622")] +pub unsafe fn __tile_stream_loadd(dst: *mut __tile1024i, base: *const u8, stride: usize) { + (*dst).tile = tileloaddt164_internal((*dst).rows, (*dst).colsb, base, stride as u64); +} + /// Zero the tile specified by tdest. /// /// [Intel's documentation](https://www.intel.com/content/www/us/en/docs/intrinsics-guide/index.html#text=_tile_zero&ig_expand=6885) @@ -96,6 +135,18 @@ pub unsafe fn _tile_zero() { tilezero(DST as i8); } +/// Zero the tile specified by `dst`. The shape of the tile is specified in the struct of [`__tile1024i`]. +/// The register of the tile is allocated by the compiler. +/// +/// [Intel's documentation](https://www.intel.com/content/www/us/en/docs/intrinsics-guide/index.html#text=__tile_zero&ig_expand=6885) +#[inline] +#[target_feature(enable = "amx-tile")] +#[cfg_attr(test, assert_instr(tilezero))] +#[unstable(feature = "x86_amx_intrinsics", issue = "126622")] +pub unsafe fn __tile_zero(dst: *mut __tile1024i) { + (*dst).tile = tilezero_internal((*dst).rows, (*dst).colsb); +} + /// Compute dot-product of BF16 (16-bit) floating-point pairs in tiles a and b, /// accumulating the intermediate single-precision (32-bit) floating-point elements /// with elements in dst, and store the 32-bit result back to tile dst. @@ -113,6 +164,20 @@ pub unsafe fn _tile_dpbf16ps() { tdpbf16ps(DST as i8, A as i8, B as i8); } +/// Compute dot-product of FP16 (16-bit) floating-point pairs in tiles a and b, +/// accumulating the intermediate single-precision (32-bit) floating-point elements +/// with elements in dst, and store the 32-bit result back to tile dst. The shape of the tile +/// is specified in the struct of [`__tile1024i`]. The register of the tile is allocated by the compiler. +/// +/// [Intel's documentation](https://www.intel.com/content/www/us/en/docs/intrinsics-guide/index.html#text=__tile_dpbf16ps&ig_expand=6864) +#[inline] +#[target_feature(enable = "amx-bf16")] +#[cfg_attr(test, assert_instr(tdpbf16ps))] +#[unstable(feature = "x86_amx_intrinsics", issue = "126622")] +pub unsafe fn __tile_dpbf16ps(dst: *mut __tile1024i, a: __tile1024i, b: __tile1024i) { + (*dst).tile = tdpbf16ps_internal(a.rows, b.colsb, a.colsb, (*dst).tile, a.tile, b.tile); +} + /// Compute dot-product of bytes in tiles with a source/destination accumulator. /// Multiply groups of 4 adjacent pairs of signed 8-bit integers in a with corresponding /// signed 8-bit integers in b, producing 4 intermediate 32-bit results. @@ -131,6 +196,21 @@ pub unsafe fn _tile_dpbssd() { tdpbssd(DST as i8, A as i8, B as i8); } +/// Compute dot-product of bytes in tiles with a source/destination accumulator. +/// Multiply groups of 4 adjacent pairs of signed 8-bit integers in a with corresponding +/// signed 8-bit integers in b, producing 4 intermediate 32-bit results. +/// Sum these 4 results with the corresponding 32-bit integer in dst, and store the 32-bit result back to tile dst. +/// The shape of the tile is specified in the struct of [`__tile1024i`]. The register of the tile is allocated by the compiler. +/// +/// [Intel's documentation](https://www.intel.com/content/www/us/en/docs/intrinsics-guide/index.html#text=__tile_dpbssd&ig_expand=6866) +#[inline] +#[target_feature(enable = "amx-int8")] +#[cfg_attr(test, assert_instr(tdpbssd))] +#[unstable(feature = "x86_amx_intrinsics", issue = "126622")] +pub unsafe fn __tile_dpbssd(dst: *mut __tile1024i, a: __tile1024i, b: __tile1024i) { + (*dst).tile = tdpbssd_internal(a.rows, b.colsb, a.colsb, (*dst).tile, a.tile, b.tile); +} + /// Compute dot-product of bytes in tiles with a source/destination accumulator. /// Multiply groups of 4 adjacent pairs of signed 8-bit integers in a with corresponding /// unsigned 8-bit integers in b, producing 4 intermediate 32-bit results. @@ -149,6 +229,21 @@ pub unsafe fn _tile_dpbsud() { tdpbsud(DST as i8, A as i8, B as i8); } +/// Compute dot-product of bytes in tiles with a source/destination accumulator. +/// Multiply groups of 4 adjacent pairs of signed 8-bit integers in a with corresponding +/// unsigned 8-bit integers in b, producing 4 intermediate 32-bit results. +/// Sum these 4 results with the corresponding 32-bit integer in dst, and store the 32-bit result back to tile dst. +/// The shape of the tile is specified in the struct of [`__tile1024i`]. The register of the tile is allocated by the compiler. +/// +/// [Intel's documentation](https://www.intel.com/content/www/us/en/docs/intrinsics-guide/index.html#text=__tile_dpbsud&ig_expand=6868) +#[inline] +#[target_feature(enable = "amx-int8")] +#[cfg_attr(test, assert_instr(tdpbsud))] +#[unstable(feature = "x86_amx_intrinsics", issue = "126622")] +pub unsafe fn __tile_dpbsud(dst: *mut __tile1024i, a: __tile1024i, b: __tile1024i) { + (*dst).tile = tdpbsud_internal(a.rows, b.colsb, a.colsb, (*dst).tile, a.tile, b.tile); +} + /// Compute dot-product of bytes in tiles with a source/destination accumulator. /// Multiply groups of 4 adjacent pairs of unsigned 8-bit integers in a with corresponding /// signed 8-bit integers in b, producing 4 intermediate 32-bit results. @@ -167,6 +262,21 @@ pub unsafe fn _tile_dpbusd() { tdpbusd(DST as i8, A as i8, B as i8); } +/// Compute dot-product of bytes in tiles with a source/destination accumulator. +/// Multiply groups of 4 adjacent pairs of unsigned 8-bit integers in a with corresponding +/// signed 8-bit integers in b, producing 4 intermediate 32-bit results. +/// Sum these 4 results with the corresponding 32-bit integer in dst, and store the 32-bit result back to tile dst. +/// The shape of the tile is specified in the struct of [`__tile1024i`]. The register of the tile is allocated by the compiler. +/// +/// [Intel's documentation](https://www.intel.com/content/www/us/en/docs/intrinsics-guide/index.html#text=__tile_dpbusd&ig_expand=6870) +#[inline] +#[target_feature(enable = "amx-int8")] +#[cfg_attr(test, assert_instr(tdpbusd))] +#[unstable(feature = "x86_amx_intrinsics", issue = "126622")] +pub unsafe fn __tile_dpbusd(dst: *mut __tile1024i, a: __tile1024i, b: __tile1024i) { + (*dst).tile = tdpbusd_internal(a.rows, b.colsb, a.colsb, (*dst).tile, a.tile, b.tile); +} + /// Compute dot-product of bytes in tiles with a source/destination accumulator. /// Multiply groups of 4 adjacent pairs of unsigned 8-bit integers in a with corresponding /// unsigned 8-bit integers in b, producing 4 intermediate 32-bit results. @@ -185,6 +295,21 @@ pub unsafe fn _tile_dpbuud() { tdpbuud(DST as i8, A as i8, B as i8); } +/// Compute dot-product of bytes in tiles with a source/destination accumulator. +/// Multiply groups of 4 adjacent pairs of unsigned 8-bit integers in a with corresponding +/// unsigned 8-bit integers in b, producing 4 intermediate 32-bit results. +/// Sum these 4 results with the corresponding 32-bit integer in dst, and store the 32-bit result back to tile dst. +/// The shape of the tile is specified in the struct of [`__tile1024i`]. The register of the tile is allocated by the compiler. +/// +/// [Intel's documentation](https://www.intel.com/content/www/us/en/docs/intrinsics-guide/index.html#text=__tile_dpbuud&ig_expand=6872) +#[inline] +#[target_feature(enable = "amx-int8")] +#[cfg_attr(test, assert_instr(tdpbuud))] +#[unstable(feature = "x86_amx_intrinsics", issue = "126622")] +pub unsafe fn __tile_dpbuud(dst: *mut __tile1024i, a: __tile1024i, b: __tile1024i) { + (*dst).tile = tdpbuud_internal(a.rows, b.colsb, a.colsb, (*dst).tile, a.tile, b.tile); +} + /// Compute dot-product of FP16 (16-bit) floating-point pairs in tiles a and b, /// accumulating the intermediate single-precision (32-bit) floating-point elements /// with elements in dst, and store the 32-bit result back to tile dst. @@ -202,6 +327,20 @@ pub unsafe fn _tile_dpfp16ps() { tdpfp16ps(DST as i8, A as i8, B as i8); } +/// Compute dot-product of FP16 (16-bit) floating-point pairs in tiles a and b, +/// accumulating the intermediate single-precision (32-bit) floating-point elements +/// with elements in dst, and store the 32-bit result back to tile dst. +/// The shape of the tile is specified in the struct of [`__tile1024i`]. The register of the tile is allocated by the compiler. +/// +/// [Intel's documentation](https://www.intel.com/content/www/us/en/docs/intrinsics-guide/index.html#text=__tile_dpfp16ps&ig_expand=6874) +#[inline] +#[target_feature(enable = "amx-fp16")] +#[cfg_attr(test, assert_instr(tdpfp16ps))] +#[unstable(feature = "x86_amx_intrinsics", issue = "126622")] +pub unsafe fn __tile_dpfp16ps(dst: *mut __tile1024i, a: __tile1024i, b: __tile1024i) { + (*dst).tile = tdpfp16ps_internal(a.rows, b.colsb, a.colsb, (*dst).tile, a.tile, b.tile); +} + /// Perform matrix multiplication of two tiles containing complex elements and accumulate the results into a packed single precision tile. /// Each dword element in input tiles a and b is interpreted as a complex number with FP16 real part and FP16 imaginary part. /// Calculates the imaginary part of the result. For each possible combination of (row of a, column of b), @@ -223,6 +362,24 @@ pub unsafe fn _tile_cmmimfp16ps() { tcmmimfp16ps(DST as i8, A as i8, B as i8); } +/// Perform matrix multiplication of two tiles containing complex elements and accumulate the results into a packed single precision tile. +/// Each dword element in input tiles a and b is interpreted as a complex number with FP16 real part and FP16 imaginary part. +/// Calculates the imaginary part of the result. For each possible combination of (row of a, column of b), +/// it performs a set of multiplication and accumulations on all corresponding complex numbers (one from a and one from b). +/// The imaginary part of the a element is multiplied with the real part of the corresponding b element, and the real part of +/// the a element is multiplied with the imaginary part of the corresponding b elements. The two accumulated results are added, +/// and then accumulated into the corresponding row and column of dst. +/// The shape of the tile is specified in the struct of [`__tile1024i`]. The register of the tile is allocated by the compiler. +/// +/// [Intel's documentation](https://www.intel.com/content/www/us/en/docs/intrinsics-guide/index.html#text=__tile_cmmimfp16ps&ig_expand=6860) +#[inline] +#[target_feature(enable = "amx-complex")] +#[cfg_attr(test, assert_instr(tcmmimfp16ps))] +#[unstable(feature = "x86_amx_intrinsics", issue = "126622")] +pub unsafe fn __tile_cmmimfp16ps(dst: *mut __tile1024i, a: __tile1024i, b: __tile1024i) { + (*dst).tile = tcmmimfp16ps_internal(a.rows, b.colsb, a.colsb, (*dst).tile, a.tile, b.tile); +} + /// Perform matrix multiplication of two tiles containing complex elements and accumulate the results into a packed single precision tile. /// Each dword element in input tiles a and b is interpreted as a complex number with FP16 real part and FP16 imaginary part. /// Calculates the real part of the result. For each possible combination of (row of a, column of b), @@ -244,6 +401,24 @@ pub unsafe fn _tile_cmmrlfp16ps() { tcmmrlfp16ps(DST as i8, A as i8, B as i8); } +/// Perform matrix multiplication of two tiles containing complex elements and accumulate the results into a packed single precision tile. +/// Each dword element in input tiles a and b is interpreted as a complex number with FP16 real part and FP16 imaginary part. +/// Calculates the real part of the result. For each possible combination of (row of a, column of b), +/// it performs a set of multiplication and accumulations on all corresponding complex numbers (one from a and one from b). +/// The real part of the a element is multiplied with the real part of the corresponding b element, and the negated imaginary part of +/// the a element is multiplied with the imaginary part of the corresponding b elements. +/// The two accumulated results are added, and then accumulated into the corresponding row and column of dst. +/// The shape of the tile is specified in the struct of [`__tile1024i`]. The register of the tile is allocated by the compiler. +/// +/// [Intel's documentation](https://www.intel.com/content/www/us/en/docs/intrinsics-guide/index.html#text=__tile_cmmrlfp16ps&ig_expand=6862) +#[inline] +#[target_feature(enable = "amx-complex")] +#[cfg_attr(test, assert_instr(tcmmrlfp16ps))] +#[unstable(feature = "x86_amx_intrinsics", issue = "126622")] +pub unsafe fn __tile_cmmrlfp16ps(dst: *mut __tile1024i, a: __tile1024i, b: __tile1024i) { + (*dst).tile = tcmmrlfp16ps_internal(a.rows, b.colsb, a.colsb, (*dst).tile, a.tile, b.tile); +} + /// Compute dot-product of BF8 (8-bit E5M2) floating-point elements in tile a and BF8 (8-bit E5M2) /// floating-point elements in tile b, accumulating the intermediate single-precision /// (32-bit) floating-point elements with elements in dst, and store the 32-bit result @@ -263,6 +438,19 @@ pub unsafe fn _tile_dpbf8ps() { tdpbf8ps(DST as i8, A as i8, B as i8); } +/// Compute dot-product of BF8 (8-bit E5M2) floating-point elements in tile a and BF8 (8-bit E5M2) +/// floating-point elements in tile b, accumulating the intermediate single-precision +/// (32-bit) floating-point elements with elements in dst, and store the 32-bit result +/// back to tile dst. +/// The shape of the tile is specified in the struct of [`__tile1024i`]. The register of the tile is allocated by the compiler. +#[inline] +#[target_feature(enable = "amx-fp8")] +#[cfg_attr(all(test, not(target_vendor = "apple")), assert_instr(tdpbf8ps))] +#[unstable(feature = "x86_amx_intrinsics", issue = "126622")] +pub unsafe fn __tile_dpbf8ps(dst: *mut __tile1024i, a: __tile1024i, b: __tile1024i) { + (*dst).tile = tdpbf8ps_internal(a.rows, b.colsb, a.colsb, (*dst).tile, a.tile, b.tile); +} + /// Compute dot-product of BF8 (8-bit E5M2) floating-point elements in tile a and HF8 /// (8-bit E4M3) floating-point elements in tile b, accumulating the intermediate single-precision /// (32-bit) floating-point elements with elements in dst, and store the 32-bit result @@ -282,6 +470,19 @@ pub unsafe fn _tile_dpbhf8ps() { tdpbhf8ps(DST as i8, A as i8, B as i8); } +/// Compute dot-product of BF8 (8-bit E5M2) floating-point elements in tile a and HF8 +/// (8-bit E4M3) floating-point elements in tile b, accumulating the intermediate single-precision +/// (32-bit) floating-point elements with elements in dst, and store the 32-bit result +/// back to tile dst. +/// The shape of the tile is specified in the struct of [`__tile1024i`]. The register of the tile is allocated by the compiler. +#[inline] +#[target_feature(enable = "amx-fp8")] +#[cfg_attr(all(test, not(target_vendor = "apple")), assert_instr(tdpbhf8ps))] +#[unstable(feature = "x86_amx_intrinsics", issue = "126622")] +pub unsafe fn __tile_dpbhf8ps(dst: *mut __tile1024i, a: __tile1024i, b: __tile1024i) { + (*dst).tile = tdpbhf8ps_internal(a.rows, b.colsb, a.colsb, (*dst).tile, a.tile, b.tile); +} + /// Compute dot-product of HF8 (8-bit E4M3) floating-point elements in tile a and BF8 /// (8-bit E5M2) floating-point elements in tile b, accumulating the intermediate single-precision /// (32-bit) floating-point elements with elements in dst, and store the 32-bit result @@ -301,6 +502,19 @@ pub unsafe fn _tile_dphbf8ps() { tdphbf8ps(DST as i8, A as i8, B as i8); } +/// Compute dot-product of HF8 (8-bit E4M3) floating-point elements in tile a and BF8 +/// (8-bit E5M2) floating-point elements in tile b, accumulating the intermediate single-precision +/// (32-bit) floating-point elements with elements in dst, and store the 32-bit result +/// back to tile dst. +/// The shape of the tile is specified in the struct of [`__tile1024i`]. The register of the tile is allocated by the compiler. +#[inline] +#[target_feature(enable = "amx-fp8")] +#[cfg_attr(all(test, not(target_vendor = "apple")), assert_instr(tdphbf8ps))] +#[unstable(feature = "x86_amx_intrinsics", issue = "126622")] +pub unsafe fn __tile_dphbf8ps(dst: *mut __tile1024i, a: __tile1024i, b: __tile1024i) { + (*dst).tile = tdphbf8ps_internal(a.rows, b.colsb, a.colsb, (*dst).tile, a.tile, b.tile); +} + /// Compute dot-product of HF8 (8-bit E4M3) floating-point elements in tile a and HF8 (8-bit E4M3) /// floating-point elements in tile b, accumulating the intermediate single-precision /// (32-bit) floating-point elements with elements in dst, and store the 32-bit result @@ -320,6 +534,19 @@ pub unsafe fn _tile_dphf8ps() { tdphf8ps(DST as i8, A as i8, B as i8); } +/// Compute dot-product of HF8 (8-bit E4M3) floating-point elements in tile a and HF8 (8-bit E4M3) +/// floating-point elements in tile b, accumulating the intermediate single-precision +/// (32-bit) floating-point elements with elements in dst, and store the 32-bit result +/// back to tile dst. +/// The shape of the tile is specified in the struct of [`__tile1024i`]. The register of the tile is allocated by the compiler. +#[inline] +#[target_feature(enable = "amx-fp8")] +#[cfg_attr(all(test, not(target_vendor = "apple")), assert_instr(tdphf8ps))] +#[unstable(feature = "x86_amx_intrinsics", issue = "126622")] +pub unsafe fn __tile_dphf8ps(dst: *mut __tile1024i, a: __tile1024i, b: __tile1024i) { + (*dst).tile = tdphf8ps_internal(a.rows, b.colsb, a.colsb, (*dst).tile, a.tile, b.tile); +} + /// Load tile rows from memory specified by base address and stride into destination tile dst /// using the tile configuration previously configured via _tile_loadconfig. /// Additionally, this intrinsic indicates the source memory location is likely to become @@ -338,6 +565,19 @@ pub unsafe fn _tile_loaddrs(base: *const u8, stride: usize) { tileloaddrs64(DST as i8, base, stride); } +/// Load tile rows from memory specified by base address and stride into destination tile dst. +/// The shape of the tile is specified in the struct of [`__tile1024i`]. The register of the tile is allocated by the compiler. +/// Additionally, this intrinsic indicates the source memory location is likely to become +/// read-shared by multiple processors, i.e., read in the future by at least one other processor +/// before it is written, assuming it is ever written in the future. +#[inline] +#[target_feature(enable = "amx-movrs")] +#[cfg_attr(all(test, not(target_vendor = "apple")), assert_instr(tileloaddrs))] +#[unstable(feature = "x86_amx_intrinsics", issue = "126622")] +pub unsafe fn __tile_loaddrs(dst: *mut __tile1024i, base: *const u8, stride: usize) { + (*dst).tile = tileloaddrs64_internal((*dst).rows, (*dst).colsb, base, stride as u64); +} + /// Load tile rows from memory specified by base address and stride into destination tile dst /// using the tile configuration previously configured via _tile_loadconfig. /// Provides a hint to the implementation that the data would be reused but does not need @@ -358,6 +598,21 @@ pub unsafe fn _tile_stream_loaddrs(base: *const u8, stride: usiz tileloaddrst164(DST as i8, base, stride); } +/// Load tile rows from memory specified by base address and stride into destination tile dst. +/// The shape of the tile is specified in the struct of [`__tile1024i`]. The register of the tile is allocated by the compiler. +/// Provides a hint to the implementation that the data would be reused but does not need +/// to be resident in the nearest cache levels. +/// Additionally, this intrinsic indicates the source memory location is likely to become +/// read-shared by multiple processors, i.e., read in the future by at least one other processor +/// before it is written, assuming it is ever written in the future. +#[inline] +#[target_feature(enable = "amx-movrs")] +#[cfg_attr(all(test, not(target_vendor = "apple")), assert_instr(tileloaddrst1))] +#[unstable(feature = "x86_amx_intrinsics", issue = "126622")] +pub unsafe fn __tile_stream_loaddrs(dst: *mut __tile1024i, base: *const u8, stride: usize) { + (*dst).tile = tileloaddrst164_internal((*dst).rows, (*dst).colsb, base, stride as u64); +} + /// Perform matrix multiplication of two tiles a and b, containing packed single precision (32-bit) /// floating-point elements, which are converted to TF32 (tensor-float32) format, and accumulate the /// results into a packed single precision tile. @@ -383,6 +638,25 @@ pub unsafe fn _tile_mmultf32ps() { tmmultf32ps(DST as i8, A as i8, B as i8); } +/// Perform matrix multiplication of two tiles a and b, containing packed single precision (32-bit) +/// floating-point elements, which are converted to TF32 (tensor-float32) format, and accumulate the +/// results into a packed single precision tile. +/// For each possible combination of (row of a, column of b), it performs +/// - convert to TF32 +/// - multiply the corresponding elements of a and b +/// - accumulate the results into the corresponding row and column of dst using round-to-nearest-even +/// rounding mode. +/// Output FP32 denormals are always flushed to zero, input single precision denormals are always +/// handled and *not* treated as zero. +/// The shape of the tile is specified in the struct of [`__tile1024i`]. The register of the tile is allocated by the compiler. +#[inline] +#[target_feature(enable = "amx-tf32")] +#[cfg_attr(all(test, not(target_vendor = "apple")), assert_instr(tmmultf32ps))] +#[unstable(feature = "x86_amx_intrinsics", issue = "126622")] +pub unsafe fn __tile_mmultf32ps(dst: *mut __tile1024i, a: __tile1024i, b: __tile1024i) { + (*dst).tile = tmmultf32ps_internal(a.rows, b.colsb, a.colsb, (*dst).tile, a.tile, b.tile); +} + /// Moves a row from a tile register to a zmm register, converting the packed 32-bit signed integer /// elements to packed single-precision (32-bit) floating-point elements. #[inline] @@ -414,6 +688,17 @@ pub unsafe fn _tile_cvtrowd2psi() -> __m512 { tcvtrowd2psi(TILE as i8, ROW as u32).as_m512() } +/// Moves a row from a tile register to a zmm register, converting the packed 32-bit signed integer +/// elements to packed single-precision (32-bit) floating-point elements. +/// The shape of the tile is specified in the struct of [`__tile1024i`]. The register of the tile is allocated by the compiler. +#[inline] +#[target_feature(enable = "amx-avx512,avx10.2")] +#[cfg_attr(all(test, not(target_vendor = "apple")), assert_instr(tcvtrowd2ps))] +#[unstable(feature = "x86_amx_intrinsics", issue = "126622")] +pub unsafe fn __tile_cvtrowd2ps(src: __tile1024i, row: u32) -> __m512 { + tcvtrowd2ps_internal(src.rows, src.colsb, src.tile, row).as_m512() +} + /// Moves a row from a tile register to a zmm register, converting the packed single-precision (32-bit) /// floating-point elements to packed half-precision (16-bit) floating-point elements. The resulting /// 16-bit elements are placed in the high 16-bits within each 32-bit element of the returned vector. @@ -447,6 +732,18 @@ pub unsafe fn _tile_cvtrowps2phhi() -> __m512h tcvtrowps2phhi(TILE as i8, ROW as u32).as_m512h() } +/// Moves a row from a tile register to a zmm register, converting the packed single-precision (32-bit) +/// floating-point elements to packed half-precision (16-bit) floating-point elements. The resulting +/// 16-bit elements are placed in the high 16-bits within each 32-bit element of the returned vector. +/// The shape of the tile is specified in the struct of [`__tile1024i`]. The register of the tile is allocated by the compiler. +#[inline] +#[target_feature(enable = "amx-avx512,avx10.2")] +#[cfg_attr(all(test, not(target_vendor = "apple")), assert_instr(tcvtrowps2phh))] +#[unstable(feature = "x86_amx_intrinsics", issue = "126622")] +pub unsafe fn __tile_cvtrowps2phh(src: __tile1024i, row: u32) -> __m512h { + tcvtrowps2phh_internal(src.rows, src.colsb, src.tile, row).as_m512h() +} + /// Moves a row from a tile register to a zmm register, converting the packed single-precision (32-bit) /// floating-point elements to packed half-precision (16-bit) floating-point elements. The resulting /// 16-bit elements are placed in the low 16-bits within each 32-bit element of the returned vector. @@ -480,6 +777,18 @@ pub unsafe fn _tile_cvtrowps2phli() -> __m512h tcvtrowps2phli(TILE as i8, ROW as u32).as_m512h() } +/// Moves a row from a tile register to a zmm register, converting the packed single-precision (32-bit) +/// floating-point elements to packed half-precision (16-bit) floating-point elements. The resulting +/// 16-bit elements are placed in the low 16-bits within each 32-bit element of the returned vector. +/// The shape of the tile is specified in the struct of [`__tile1024i`]. The register of the tile is allocated by the compiler. +#[inline] +#[target_feature(enable = "amx-avx512,avx10.2")] +#[cfg_attr(all(test, not(target_vendor = "apple")), assert_instr(tcvtrowps2phl))] +#[unstable(feature = "x86_amx_intrinsics", issue = "126622")] +pub unsafe fn __tile_cvtrowps2phl(src: __tile1024i, row: u32) -> __m512h { + tcvtrowps2phl_internal(src.rows, src.colsb, src.tile, row).as_m512h() +} + /// Moves a row from a tile register to a zmm register, converting the packed single-precision (32-bit) /// floating-point elements to packed BF16 (16-bit) floating-point elements. The resulting /// 16-bit elements are placed in the high 16-bits within each 32-bit element of the returned vector. @@ -513,6 +822,18 @@ pub unsafe fn _tile_cvtrowps2bf16hi() -> __m512 tcvtrowps2bf16hi(TILE as i8, ROW as u32).as_m512bh() } +/// Moves a row from a tile register to a zmm register, converting the packed single-precision (32-bit) +/// floating-point elements to packed BF16 (16-bit) floating-point elements. The resulting +/// 16-bit elements are placed in the high 16-bits within each 32-bit element of the returned vector. +/// The shape of the tile is specified in the struct of [`__tile1024i`]. The register of the tile is allocated by the compiler. +#[inline] +#[target_feature(enable = "amx-avx512,avx10.2")] +#[cfg_attr(all(test, not(target_vendor = "apple")), assert_instr(tcvtrowps2bf16h))] +#[unstable(feature = "x86_amx_intrinsics", issue = "126622")] +pub unsafe fn __tile_cvtrowps2bf16h(src: __tile1024i, row: u32) -> __m512bh { + tcvtrowps2bf16h_internal(src.rows, src.colsb, src.tile, row).as_m512bh() +} + /// Moves a row from a tile register to a zmm register, converting the packed single-precision (32-bit) /// floating-point elements to packed BF16 (16-bit) floating-point elements. The resulting /// 16-bit elements are placed in the low 16-bits within each 32-bit element of the returned vector. @@ -546,6 +867,18 @@ pub unsafe fn _tile_cvtrowps2bf16li() -> __m512 tcvtrowps2bf16li(TILE as i8, ROW as u32).as_m512bh() } +/// Moves a row from a tile register to a zmm register, converting the packed single-precision (32-bit) +/// floating-point elements to packed BF16 (16-bit) floating-point elements. The resulting +/// 16-bit elements are placed in the low 16-bits within each 32-bit element of the returned vector. +/// The shape of the tile is specified in the struct of [`__tile1024i`]. The register of the tile is allocated by the compiler. +#[inline] +#[target_feature(enable = "amx-avx512,avx10.2")] +#[cfg_attr(all(test, not(target_vendor = "apple")), assert_instr(tcvtrowps2bf16l))] +#[unstable(feature = "x86_amx_intrinsics", issue = "126622")] +pub unsafe fn __tile_cvtrowps2bf16l(src: __tile1024i, row: u32) -> __m512bh { + tcvtrowps2bf16l_internal(src.rows, src.colsb, src.tile, row).as_m512bh() +} + /// Moves one row of tile data into a zmm vector register #[inline] #[rustc_legacy_const_generics(0)] @@ -575,82 +908,169 @@ pub unsafe fn _tile_movrowi() -> __m512i { tilemovrowi(TILE as i8, ROW as u32).as_m512i() } +/// Moves one row of tile data into a zmm vector register +/// The shape of the tile is specified in the struct of [`__tile1024i`]. The register of the tile is allocated by the compiler. +#[inline] +#[target_feature(enable = "amx-avx512,avx10.2")] +#[cfg_attr(all(test, not(target_vendor = "apple")), assert_instr(tilemovrow))] +#[unstable(feature = "x86_amx_intrinsics", issue = "126622")] +pub unsafe fn __tile_movrow(src: __tile1024i, row: u32) -> __m512i { + tilemovrow_internal(src.rows, src.colsb, src.tile, row).as_m512i() +} + #[allow(improper_ctypes)] -unsafe extern "C" { +unsafe extern "unadjusted" { #[link_name = "llvm.x86.ldtilecfg"] fn ldtilecfg(mem_addr: *const u8); #[link_name = "llvm.x86.sttilecfg"] fn sttilecfg(mem_addr: *mut u8); + #[link_name = "llvm.x86.tileloadd64"] fn tileloadd64(dst: i8, base: *const u8, stride: usize); + #[link_name = "llvm.x86.tileloadd64.internal"] + fn tileloadd64_internal(rows: u16, colsb: u16, base: *const u8, stride: u64) -> Tile; + #[link_name = "llvm.x86.tileloaddt164"] fn tileloaddt164(dst: i8, base: *const u8, stride: usize); + #[link_name = "llvm.x86.tileloaddt164.internal"] + fn tileloaddt164_internal(rows: u16, colsb: u16, base: *const u8, stride: u64) -> Tile; + #[link_name = "llvm.x86.tilerelease"] fn tilerelease(); + #[link_name = "llvm.x86.tilestored64"] fn tilestored64(dst: i8, base: *mut u8, stride: usize); + #[link_name = "llvm.x86.tilestored64.internal"] + fn tilestored64_internal(rows: u16, colsb: u16, base: *mut u8, stride: u64, src: Tile); + #[link_name = "llvm.x86.tilezero"] fn tilezero(dst: i8); + #[link_name = "llvm.x86.tilezero.internal"] + fn tilezero_internal(rows: u16, colsb: u16) -> Tile; + #[link_name = "llvm.x86.tdpbf16ps"] fn tdpbf16ps(dst: i8, a: i8, b: i8); + #[link_name = "llvm.x86.tdpbf16ps.internal"] + fn tdpbf16ps_internal(m: u16, n: u16, k: u16, dst: Tile, a: Tile, b: Tile) -> Tile; + #[link_name = "llvm.x86.tdpbuud"] fn tdpbuud(dst: i8, a: i8, b: i8); + #[link_name = "llvm.x86.tdpbuud.internal"] + fn tdpbuud_internal(m: u16, n: u16, k: u16, dst: Tile, a: Tile, b: Tile) -> Tile; + #[link_name = "llvm.x86.tdpbusd"] fn tdpbusd(dst: i8, a: i8, b: i8); + #[link_name = "llvm.x86.tdpbusd.internal"] + fn tdpbusd_internal(m: u16, n: u16, k: u16, dst: Tile, a: Tile, b: Tile) -> Tile; + #[link_name = "llvm.x86.tdpbsud"] fn tdpbsud(dst: i8, a: i8, b: i8); + #[link_name = "llvm.x86.tdpbsud.internal"] + fn tdpbsud_internal(m: u16, n: u16, k: u16, dst: Tile, a: Tile, b: Tile) -> Tile; + #[link_name = "llvm.x86.tdpbssd"] fn tdpbssd(dst: i8, a: i8, b: i8); + #[link_name = "llvm.x86.tdpbssd.internal"] + fn tdpbssd_internal(m: u16, n: u16, k: u16, dst: Tile, a: Tile, b: Tile) -> Tile; + #[link_name = "llvm.x86.tdpfp16ps"] fn tdpfp16ps(dst: i8, a: i8, b: i8); + #[link_name = "llvm.x86.tdpfp16ps.internal"] + fn tdpfp16ps_internal(m: u16, n: u16, k: u16, dst: Tile, a: Tile, b: Tile) -> Tile; + #[link_name = "llvm.x86.tcmmimfp16ps"] fn tcmmimfp16ps(dst: i8, a: i8, b: i8); + #[link_name = "llvm.x86.tcmmimfp16ps.internal"] + fn tcmmimfp16ps_internal(m: u16, n: u16, k: u16, dst: Tile, a: Tile, b: Tile) -> Tile; + #[link_name = "llvm.x86.tcmmrlfp16ps"] fn tcmmrlfp16ps(dst: i8, a: i8, b: i8); + #[link_name = "llvm.x86.tcmmrlfp16ps.internal"] + fn tcmmrlfp16ps_internal(m: u16, n: u16, k: u16, dst: Tile, a: Tile, b: Tile) -> Tile; + #[link_name = "llvm.x86.tdpbf8ps"] fn tdpbf8ps(dst: i8, a: i8, b: i8); + #[link_name = "llvm.x86.tdpbf8ps.internal"] + fn tdpbf8ps_internal(m: u16, n: u16, k: u16, dst: Tile, a: Tile, b: Tile) -> Tile; + #[link_name = "llvm.x86.tdpbhf8ps"] fn tdpbhf8ps(dst: i8, a: i8, b: i8); + #[link_name = "llvm.x86.tdpbhf8ps.internal"] + fn tdpbhf8ps_internal(m: u16, n: u16, k: u16, dst: Tile, a: Tile, b: Tile) -> Tile; + #[link_name = "llvm.x86.tdphbf8ps"] fn tdphbf8ps(dst: i8, a: i8, b: i8); + #[link_name = "llvm.x86.tdphbf8ps.internal"] + fn tdphbf8ps_internal(m: u16, n: u16, k: u16, dst: Tile, a: Tile, b: Tile) -> Tile; + #[link_name = "llvm.x86.tdphf8ps"] fn tdphf8ps(dst: i8, a: i8, b: i8); + #[link_name = "llvm.x86.tdphf8ps.internal"] + fn tdphf8ps_internal(m: u16, n: u16, k: u16, dst: Tile, a: Tile, b: Tile) -> Tile; + #[link_name = "llvm.x86.tileloaddrs64"] fn tileloaddrs64(dst: i8, base: *const u8, stride: usize); + #[link_name = "llvm.x86.tileloaddrs64.internal"] + fn tileloaddrs64_internal(rows: u16, colsb: u16, base: *const u8, stride: u64) -> Tile; + #[link_name = "llvm.x86.tileloaddrst164"] fn tileloaddrst164(dst: i8, base: *const u8, stride: usize); + #[link_name = "llvm.x86.tileloaddrst164.internal"] + fn tileloaddrst164_internal(rows: u16, colsb: u16, base: *const u8, stride: u64) -> Tile; + #[link_name = "llvm.x86.tmmultf32ps"] fn tmmultf32ps(dst: i8, a: i8, b: i8); + #[link_name = "llvm.x86.tmmultf32ps.internal"] + fn tmmultf32ps_internal(m: u16, n: u16, k: u16, dst: Tile, a: Tile, b: Tile) -> Tile; + #[link_name = "llvm.x86.tcvtrowd2ps"] fn tcvtrowd2ps(tile: i8, row: u32) -> f32x16; #[link_name = "llvm.x86.tcvtrowd2psi"] fn tcvtrowd2psi(tile: i8, row: u32) -> f32x16; + #[link_name = "llvm.x86.tcvtrowd2ps.internal"] + fn tcvtrowd2ps_internal(rows: u16, colsb: u16, src: Tile, row: u32) -> f32x16; + #[link_name = "llvm.x86.tcvtrowps2phh"] fn tcvtrowps2phh(tile: i8, row: u32) -> f16x32; #[link_name = "llvm.x86.tcvtrowps2phhi"] fn tcvtrowps2phhi(tile: i8, row: u32) -> f16x32; + #[link_name = "llvm.x86.tcvtrowps2phh.internal"] + fn tcvtrowps2phh_internal(rows: u16, colsb: u16, src: Tile, row: u32) -> f16x32; + #[link_name = "llvm.x86.tcvtrowps2phl"] fn tcvtrowps2phl(tile: i8, row: u32) -> f16x32; #[link_name = "llvm.x86.tcvtrowps2phli"] fn tcvtrowps2phli(tile: i8, row: u32) -> f16x32; + #[link_name = "llvm.x86.tcvtrowps2phl.internal"] + fn tcvtrowps2phl_internal(rows: u16, colsb: u16, src: Tile, row: u32) -> f16x32; + #[link_name = "llvm.x86.tcvtrowps2bf16h"] fn tcvtrowps2bf16h(tile: i8, row: u32) -> u16x32; #[link_name = "llvm.x86.tcvtrowps2bf16hi"] fn tcvtrowps2bf16hi(tile: i8, row: u32) -> u16x32; + #[link_name = "llvm.x86.tcvtrowps2bf16h.internal"] + fn tcvtrowps2bf16h_internal(rows: u16, colsb: u16, src: Tile, row: u32) -> u16x32; + #[link_name = "llvm.x86.tcvtrowps2bf16l"] fn tcvtrowps2bf16l(tile: i8, row: u32) -> u16x32; #[link_name = "llvm.x86.tcvtrowps2bf16li"] fn tcvtrowps2bf16li(tile: i8, row: u32) -> u16x32; + #[link_name = "llvm.x86.tcvtrowps2bf16l.internal"] + fn tcvtrowps2bf16l_internal(rows: u16, colsb: u16, src: Tile, row: u32) -> u16x32; + #[link_name = "llvm.x86.tilemovrow"] fn tilemovrow(tile: i8, row: u32) -> i32x16; #[link_name = "llvm.x86.tilemovrowi"] fn tilemovrowi(tile: i8, row: u32) -> i32x16; + #[link_name = "llvm.x86.tilemovrow.internal"] + fn tilemovrow_internal(rows: u16, colsb: u16, src: Tile, row: u32) -> i32x16; } #[cfg(test)] mod tests { use crate::core_arch::x86::_mm_cvtness_sbh; use crate::core_arch::x86_64::*; + use core::mem::MaybeUninit; use core::{array, mem::transmute}; use stdarch_test::simd_test; #[cfg(target_os = "linux")] @@ -727,6 +1147,18 @@ mod tests { } } + impl __tile1024i { + #[inline] + #[target_feature(enable = "amx-tile")] + fn zeroed(rows: u16, colsb: u16) -> Self { + Self { + rows, + colsb, + tile: unsafe { super::tilezero_internal(rows, colsb) }, + } + } + } + #[simd_test(enable = "amx-tile")] fn test_tile_loadconfig() { unsafe { @@ -766,10 +1198,24 @@ mod tests { } #[simd_test(enable = "amx-tile")] - fn test_tile_stored() { + fn test__tile_zero() { unsafe { _init_amx(); - let mut config = __tilecfg::default(); + + let tile = __tile1024i::zeroed(16, 64); + + let mut out = [[1_i8; 64]; 16]; + __tile_stored(out.as_mut_ptr().cast(), 64, tile); + + assert_eq!(out, [[0; 64]; 16]); + } + } + + #[simd_test(enable = "amx-tile")] + fn test_tile_stored() { + unsafe { + _init_amx(); + let mut config = __tilecfg::default(); config.palette = 1; config.colsb[0] = 64; config.rows[0] = 16; @@ -782,6 +1228,20 @@ mod tests { } } + #[simd_test(enable = "amx-tile")] + fn test__tile_stored() { + unsafe { + _init_amx(); + + let tile = __tile1024i::zeroed(16, 64); + + let mut out = [[1_i8; 64]; 16]; + __tile_stored(out.as_mut_ptr().cast(), 64, tile); + + assert_eq!(out, [[0; 64]; 16]); + } + } + #[simd_test(enable = "amx-tile")] fn test_tile_loadd() { unsafe { @@ -801,6 +1261,22 @@ mod tests { } } + #[simd_test(enable = "amx-tile")] + fn test__tile_loadd() { + unsafe { + _init_amx(); + + let mut tile = __tile1024i::zeroed(16, 64); + + let mat = [1_i8; 1024]; + __tile_loadd(&mut tile, mat.as_ptr().cast(), 64); + let mut out = [[0_i8; 64]; 16]; + __tile_stored(out.as_mut_ptr().cast(), 64, tile); + + assert_eq!(out, [[1; 64]; 16]); + } + } + #[simd_test(enable = "amx-tile")] fn test_tile_stream_loadd() { unsafe { @@ -820,6 +1296,22 @@ mod tests { } } + #[simd_test(enable = "amx-tile")] + fn test__tile_stream_loadd() { + unsafe { + _init_amx(); + + let mut tile = __tile1024i::zeroed(16, 64); + + let mat = [1_i8; 1024]; + __tile_stream_loadd(&mut tile, mat.as_ptr().cast(), 64); + let mut out = [[0_i8; 64]; 16]; + __tile_stored(out.as_mut_ptr().cast(), 64, tile); + + assert_eq!(out, [[1; 64]; 16]); + } + } + #[simd_test(enable = "amx-tile")] fn test_tile_release() { unsafe { @@ -827,14 +1319,15 @@ mod tests { } } - #[simd_test(enable = "amx-bf16,avx512f")] + const BF16_1: u16 = 0x3f80; + const BF16_2: u16 = 0x4000; + + #[simd_test(enable = "amx-bf16")] fn test_tile_dpbf16ps() { unsafe { _init_amx(); - let bf16_1: u16 = _mm_cvtness_sbh(1.0).to_bits(); - let bf16_2: u16 = _mm_cvtness_sbh(2.0).to_bits(); - let ones: [u8; 1024] = transmute([bf16_1; 512]); - let twos: [u8; 1024] = transmute([bf16_2; 512]); + let ones: [u8; 1024] = transmute([BF16_1; 512]); + let twos: [u8; 1024] = transmute([BF16_2; 512]); let mut res = [[0f32; 16]; 16]; let mut config = __tilecfg::default(); config.palette = 1; @@ -853,6 +1346,27 @@ mod tests { } } + #[simd_test(enable = "amx-bf16")] + fn test__tile_dpbf16ps() { + unsafe { + _init_amx(); + let ones: [u8; 1024] = transmute([BF16_1; 512]); + let twos: [u8; 1024] = transmute([BF16_2; 512]); + let mut res = [[0f32; 16]; 16]; + + let mut a = __tile1024i::zeroed(16, 64); + let mut b = __tile1024i::zeroed(16, 64); + let mut c = __tile1024i::zeroed(16, 64); + + __tile_loadd(&mut a, ones.as_ptr().cast(), 64); + __tile_loadd(&mut b, twos.as_ptr().cast(), 64); + __tile_dpbf16ps(&mut c, a, b); + __tile_stored(res.as_mut_ptr().cast(), 64, c); + + assert_eq!(res, [[64f32; 16]; 16]); + } + } + #[simd_test(enable = "amx-int8")] fn test_tile_dpbssd() { unsafe { @@ -877,6 +1391,27 @@ mod tests { } } + #[simd_test(enable = "amx-int8")] + fn test__tile_dpbssd() { + unsafe { + _init_amx(); + let ones = [-1_i8; 1024]; + let twos = [-2_i8; 1024]; + let mut res = [[0_i32; 16]; 16]; + + let mut a = __tile1024i::zeroed(16, 64); + let mut b = __tile1024i::zeroed(16, 64); + let mut c = __tile1024i::zeroed(16, 64); + + __tile_loadd(&mut a, ones.as_ptr().cast(), 64); + __tile_loadd(&mut b, twos.as_ptr().cast(), 64); + __tile_dpbssd(&mut c, a, b); + __tile_stored(res.as_mut_ptr().cast(), 64, c); + + assert_eq!(res, [[128_i32; 16]; 16]); + } + } + #[simd_test(enable = "amx-int8")] fn test_tile_dpbsud() { unsafe { @@ -901,6 +1436,27 @@ mod tests { } } + #[simd_test(enable = "amx-int8")] + fn test__tile_dpbsud() { + unsafe { + _init_amx(); + let ones = [-1_i8; 1024]; + let twos = [2_u8; 1024]; + let mut res = [[0_i32; 16]; 16]; + + let mut a = __tile1024i::zeroed(16, 64); + let mut b = __tile1024i::zeroed(16, 64); + let mut c = __tile1024i::zeroed(16, 64); + + __tile_loadd(&mut a, ones.as_ptr().cast(), 64); + __tile_loadd(&mut b, twos.as_ptr(), 64); + __tile_dpbsud(&mut c, a, b); + __tile_stored(res.as_mut_ptr().cast(), 64, c); + + assert_eq!(res, [[-128_i32; 16]; 16]); + } + } + #[simd_test(enable = "amx-int8")] fn test_tile_dpbusd() { unsafe { @@ -925,6 +1481,27 @@ mod tests { } } + #[simd_test(enable = "amx-int8")] + fn test__tile_dpbusd() { + unsafe { + _init_amx(); + let ones = [1_u8; 1024]; + let twos = [-2_i8; 1024]; + let mut res = [[0_i32; 16]; 16]; + + let mut a = __tile1024i::zeroed(16, 64); + let mut b = __tile1024i::zeroed(16, 64); + let mut c = __tile1024i::zeroed(16, 64); + + __tile_loadd(&mut a, ones.as_ptr(), 64); + __tile_loadd(&mut b, twos.as_ptr().cast(), 64); + __tile_dpbusd(&mut c, a, b); + __tile_stored(res.as_mut_ptr().cast(), 64, c); + + assert_eq!(res, [[-128_i32; 16]; 16]); + } + } + #[simd_test(enable = "amx-int8")] fn test_tile_dpbuud() { unsafe { @@ -949,6 +1526,27 @@ mod tests { } } + #[simd_test(enable = "amx-int8")] + fn test__tile_dpbuud() { + unsafe { + _init_amx(); + let ones = [1_u8; 1024]; + let twos = [2_u8; 1024]; + let mut res = [[0_i32; 16]; 16]; + + let mut a = __tile1024i::zeroed(16, 64); + let mut b = __tile1024i::zeroed(16, 64); + let mut c = __tile1024i::zeroed(16, 64); + + __tile_loadd(&mut a, ones.as_ptr(), 64); + __tile_loadd(&mut b, twos.as_ptr(), 64); + __tile_dpbuud(&mut c, a, b); + __tile_stored(res.as_mut_ptr().cast(), 64, c); + + assert_eq!(res, [[128_i32; 16]; 16]); + } + } + #[simd_test(enable = "amx-fp16")] fn test_tile_dpfp16ps() { unsafe { @@ -973,6 +1571,27 @@ mod tests { } } + #[simd_test(enable = "amx-fp16")] + fn test__tile_dpfp16ps() { + unsafe { + _init_amx(); + let ones = [1f16; 512]; + let twos = [2f16; 512]; + let mut res = [[0f32; 16]; 16]; + + let mut a = __tile1024i::zeroed(16, 64); + let mut b = __tile1024i::zeroed(16, 64); + let mut c = __tile1024i::zeroed(16, 64); + + __tile_loadd(&mut a, ones.as_ptr().cast(), 64); + __tile_loadd(&mut b, twos.as_ptr().cast(), 64); + __tile_dpfp16ps(&mut c, a, b); + __tile_stored(res.as_mut_ptr().cast(), 64, c); + + assert_eq!(res, [[64f32; 16]; 16]); + } + } + #[simd_test(enable = "amx-complex")] fn test_tile_cmmimfp16ps() { unsafe { @@ -997,6 +1616,27 @@ mod tests { } } + #[simd_test(enable = "amx-complex")] + fn test__tile_cmmimfp16ps() { + unsafe { + _init_amx(); + let ones = [1f16; 512]; + let twos = [2f16; 512]; + let mut res = [[0f32; 16]; 16]; + + let mut a = __tile1024i::zeroed(16, 64); + let mut b = __tile1024i::zeroed(16, 64); + let mut c = __tile1024i::zeroed(16, 64); + + __tile_loadd(&mut a, ones.as_ptr().cast(), 64); + __tile_loadd(&mut b, twos.as_ptr().cast(), 64); + __tile_cmmimfp16ps(&mut c, a, b); + __tile_stored(res.as_mut_ptr().cast(), 64, c); + + assert_eq!(res, [[64f32; 16]; 16]); + } + } + #[simd_test(enable = "amx-complex")] fn test_tile_cmmrlfp16ps() { unsafe { @@ -1021,6 +1661,27 @@ mod tests { } } + #[simd_test(enable = "amx-complex")] + fn test__tile_cmmrlfp16ps() { + unsafe { + _init_amx(); + let ones = [1f16; 512]; + let twos = [2f16; 512]; + let mut res = [[0f32; 16]; 16]; + + let mut a = __tile1024i::zeroed(16, 64); + let mut b = __tile1024i::zeroed(16, 64); + let mut c = __tile1024i::zeroed(16, 64); + + __tile_loadd(&mut a, ones.as_ptr().cast(), 64); + __tile_loadd(&mut b, twos.as_ptr().cast(), 64); + __tile_cmmrlfp16ps(&mut c, a, b); + __tile_stored(res.as_mut_ptr().cast(), 64, c); + + assert_eq!(res, [[0f32; 16]; 16]); + } + } + const BF8_ONE: u8 = 0x3c; const BF8_TWO: u8 = 0x40; const HF8_ONE: u8 = 0x38; @@ -1050,6 +1711,27 @@ mod tests { } } + #[simd_test(enable = "amx-fp8")] + fn test__tile_dpbf8ps() { + unsafe { + _init_amx(); + let ones = [BF8_ONE; 1024]; + let twos = [BF8_TWO; 1024]; + let mut res = [[0.0_f32; 16]; 16]; + + let mut a = __tile1024i::zeroed(16, 64); + let mut b = __tile1024i::zeroed(16, 64); + let mut c = __tile1024i::zeroed(16, 64); + + __tile_loadd(&mut a, ones.as_ptr(), 64); + __tile_loadd(&mut b, twos.as_ptr(), 64); + __tile_dpbf8ps(&mut c, a, b); + __tile_stored(res.as_mut_ptr().cast(), 64, c); + + assert_eq!(res, [[128.0_f32; 16]; 16]); + } + } + #[simd_test(enable = "amx-fp8")] fn test_tile_dpbhf8ps() { unsafe { @@ -1074,6 +1756,27 @@ mod tests { } } + #[simd_test(enable = "amx-fp8")] + fn test__tile_dpbhf8ps() { + unsafe { + _init_amx(); + let ones = [BF8_ONE; 1024]; + let twos = [HF8_TWO; 1024]; + let mut res = [[0.0_f32; 16]; 16]; + + let mut a = __tile1024i::zeroed(16, 64); + let mut b = __tile1024i::zeroed(16, 64); + let mut c = __tile1024i::zeroed(16, 64); + + __tile_loadd(&mut a, ones.as_ptr(), 64); + __tile_loadd(&mut b, twos.as_ptr(), 64); + __tile_dpbhf8ps(&mut c, a, b); + __tile_stored(res.as_mut_ptr().cast(), 64, c); + + assert_eq!(res, [[128.0_f32; 16]; 16]); + } + } + #[simd_test(enable = "amx-fp8")] fn test_tile_dphbf8ps() { unsafe { @@ -1098,6 +1801,27 @@ mod tests { } } + #[simd_test(enable = "amx-fp8")] + fn test__tile_dphbf8ps() { + unsafe { + _init_amx(); + let ones = [HF8_ONE; 1024]; + let twos = [BF8_TWO; 1024]; + let mut res = [[0.0_f32; 16]; 16]; + + let mut a = __tile1024i::zeroed(16, 64); + let mut b = __tile1024i::zeroed(16, 64); + let mut c = __tile1024i::zeroed(16, 64); + + __tile_loadd(&mut a, ones.as_ptr(), 64); + __tile_loadd(&mut b, twos.as_ptr(), 64); + __tile_dphbf8ps(&mut c, a, b); + __tile_stored(res.as_mut_ptr().cast(), 64, c); + + assert_eq!(res, [[128.0_f32; 16]; 16]); + } + } + #[simd_test(enable = "amx-fp8")] fn test_tile_dphf8ps() { unsafe { @@ -1122,6 +1846,27 @@ mod tests { } } + #[simd_test(enable = "amx-fp8")] + fn test__tile_dphf8ps() { + unsafe { + _init_amx(); + let ones = [HF8_ONE; 1024]; + let twos = [HF8_TWO; 1024]; + let mut res = [[0.0_f32; 16]; 16]; + + let mut a = __tile1024i::zeroed(16, 64); + let mut b = __tile1024i::zeroed(16, 64); + let mut c = __tile1024i::zeroed(16, 64); + + __tile_loadd(&mut a, ones.as_ptr(), 64); + __tile_loadd(&mut b, twos.as_ptr(), 64); + __tile_dphf8ps(&mut c, a, b); + __tile_stored(res.as_mut_ptr().cast(), 64, c); + + assert_eq!(res, [[128.0_f32; 16]; 16]); + } + } + #[simd_test(enable = "amx-movrs")] fn test_tile_loaddrs() { unsafe { @@ -1141,6 +1886,22 @@ mod tests { } } + #[simd_test(enable = "amx-movrs")] + fn test__tile_loaddrs() { + unsafe { + _init_amx(); + + let mut tile = __tile1024i::zeroed(16, 64); + + let mat = [1_i8; 1024]; + __tile_loaddrs(&mut tile, mat.as_ptr().cast(), 64); + let mut out = [[0_i8; 64]; 16]; + __tile_stored(out.as_mut_ptr().cast(), 64, tile); + + assert_eq!(out, [[1; 64]; 16]); + } + } + #[simd_test(enable = "amx-movrs")] fn test_tile_stream_loaddrs() { unsafe { @@ -1160,6 +1921,22 @@ mod tests { } } + #[simd_test(enable = "amx-movrs")] + fn test__tile_stream_loaddrs() { + unsafe { + _init_amx(); + + let mut tile = __tile1024i::zeroed(16, 64); + + let mat = [1_i8; 1024]; + __tile_stream_loaddrs(&mut tile, mat.as_ptr().cast(), 64); + let mut out = [[0_i8; 64]; 16]; + __tile_stored(out.as_mut_ptr().cast(), 64, tile); + + assert_eq!(out, [[1; 64]; 16]); + } + } + #[simd_test(enable = "amx-avx512,avx10.2")] fn test_tile_movrow() { unsafe { @@ -1223,6 +2000,22 @@ mod tests { } } + #[simd_test(enable = "amx-avx512,avx10.2")] + fn test__tile_movrow() { + unsafe { + _init_amx(); + let array: [[u8; 64]; 16] = array::from_fn(|i| [i as _; _]); + + let mut tile = __tile1024i::zeroed(16, 64); + __tile_loadd(&mut tile, array.as_ptr().cast(), 64); + + for i in 0..16 { + let row = __tile_movrow(tile, i); + assert_eq!(*row.as_u8x64().as_array(), [i as _; _]); + } + } + } + #[simd_test(enable = "amx-avx512,avx10.2")] fn test_tile_cvtrowd2ps() { unsafe { @@ -1262,6 +2055,22 @@ mod tests { } } + #[simd_test(enable = "amx-avx512,avx10.2")] + fn test__tile_cvtrowd2ps() { + unsafe { + _init_amx(); + let array: [[u32; 16]; 16] = array::from_fn(|i| [i as _; _]); + + let mut tile = __tile1024i::zeroed(16, 64); + __tile_loadd(&mut tile, array.as_ptr().cast(), 64); + + for i in 0..16 { + let row = __tile_cvtrowd2ps(tile, i); + assert_eq!(*row.as_f32x16().as_array(), [i as _; _]); + } + } + } + #[simd_test(enable = "amx-avx512,avx10.2")] fn test_tile_cvtrowps2phh() { unsafe { @@ -1306,6 +2115,25 @@ mod tests { } } + #[simd_test(enable = "amx-avx512,avx10.2")] + fn test__tile_cvtrowps2phh() { + unsafe { + _init_amx(); + let array: [[f32; 16]; 16] = array::from_fn(|i| [i as _; _]); + + let mut tile = __tile1024i::zeroed(16, 64); + __tile_loadd(&mut tile, array.as_ptr().cast(), 64); + + for i in 0..16 { + let row = __tile_cvtrowps2phh(tile, i); + assert_eq!( + *row.as_f16x32().as_array(), + array::from_fn(|j| if j & 1 == 0 { 0.0 } else { i as _ }) + ); + } + } + } + #[simd_test(enable = "amx-avx512,avx10.2")] fn test_tile_cvtrowps2phl() { unsafe { @@ -1350,6 +2178,25 @@ mod tests { } } + #[simd_test(enable = "amx-avx512,avx10.2")] + fn test__tile_cvtrowps2phl() { + unsafe { + _init_amx(); + let array: [[f32; 16]; 16] = array::from_fn(|i| [i as _; _]); + + let mut tile = __tile1024i::zeroed(16, 64); + __tile_loadd(&mut tile, array.as_ptr().cast(), 64); + + for i in 0..16 { + let row = __tile_cvtrowps2phl(tile, i); + assert_eq!( + *row.as_f16x32().as_array(), + array::from_fn(|j| if j & 1 == 0 { i as _ } else { 0.0 }) + ); + } + } + } + #[simd_test(enable = "amx-avx512,avx10.2")] fn test_tile_cvtrowps2bf16h() { unsafe { @@ -1402,6 +2249,29 @@ mod tests { } } + #[simd_test(enable = "amx-avx512,avx10.2")] + fn test__tile_cvtrowps2bf16h() { + unsafe { + _init_amx(); + let array: [[f32; 16]; 16] = array::from_fn(|i| [i as _; _]); + + let mut tile = __tile1024i::zeroed(16, 64); + __tile_loadd(&mut tile, array.as_ptr().cast(), 64); + + for i in 0..16 { + let row = __tile_cvtrowps2bf16h(tile, i); + assert_eq!( + *row.as_u16x32().as_array(), + array::from_fn(|j| if j & 1 == 0 { + 0 + } else { + _mm_cvtness_sbh(i as _).to_bits() + }) + ); + } + } + } + #[simd_test(enable = "amx-avx512,avx10.2")] fn test_tile_cvtrowps2bf16l() { unsafe { @@ -1454,6 +2324,29 @@ mod tests { } } + #[simd_test(enable = "amx-avx512,avx10.2")] + fn test__tile_cvtrowps2bf16l() { + unsafe { + _init_amx(); + let array: [[f32; 16]; 16] = array::from_fn(|i| [i as _; _]); + + let mut tile = __tile1024i::zeroed(16, 64); + __tile_loadd(&mut tile, array.as_ptr().cast(), 64); + + for i in 0..16 { + let row = __tile_cvtrowps2bf16l(tile, i); + assert_eq!( + *row.as_u16x32().as_array(), + array::from_fn(|j| if j & 1 == 0 { + _mm_cvtness_sbh(i as _).to_bits() + } else { + 0 + }) + ); + } + } + } + #[simd_test(enable = "amx-tf32")] fn test_tile_mmultf32ps() { unsafe { @@ -1480,4 +2373,26 @@ mod tests { assert_eq!(res, expected); } } + + #[simd_test(enable = "amx-tf32")] + fn test__tile_mmultf32ps() { + unsafe { + _init_amx(); + let a: [[f32; 16]; 16] = array::from_fn(|i| [i as _; _]); + let b: [[f32; 16]; 16] = [array::from_fn(|j| j as _); _]; + let mut res = [[0.0; 16]; 16]; + + let mut tile_a = __tile1024i::zeroed(16, 64); + let mut tile_b = __tile1024i::zeroed(16, 64); + let mut tile_c = __tile1024i::zeroed(16, 64); + + __tile_loadd(&mut tile_a, a.as_ptr().cast(), 64); + __tile_loadd(&mut tile_b, b.as_ptr().cast(), 64); + __tile_mmultf32ps(&mut tile_c, tile_a, tile_b); + __tile_stored(res.as_mut_ptr().cast(), 64, tile_c); + + let expected = array::from_fn(|i| array::from_fn(|j| 16.0 * i as f32 * j as f32)); + assert_eq!(res, expected); + } + } } diff --git a/crates/core_arch/src/x86_64/mod.rs b/crates/core_arch/src/x86_64/mod.rs index 46384176e0..6f1177ad94 100644 --- a/crates/core_arch/src/x86_64/mod.rs +++ b/crates/core_arch/src/x86_64/mod.rs @@ -3,6 +3,43 @@ #[macro_use] mod macros; +// Any 1024-byte vector should work +type Tile = crate::core_arch::simd::Simd; + +/// A tile register, used by AMX instructions. +/// +/// This type is the same as the `__tile1024i` type defined by Intel, representing a 1024-byte tile register. +/// Usage of this type typically corresponds to the `amx-tile` and up target features for x86_64. +/// +/// This struct contains the tile configuration information as well as the tile itself. +/// The tile configuration information consists of the row count and the size of each column in bytes, +/// with `row * colsb` never exceeding 1024. +/// +/// The typical usage pattern looks like +/// ```ignore +/// let tile = MaybeUninit::uninit(); +/// let tile_ptr = tile.as_mut_ptr(); +/// +/// (*tile_ptr).rows = rows; +/// (*tile_ptr).colsb = colsb; +/// __tile_zero(tile_ptr); +/// +/// let tile = tile.assume_init(); +/// ``` +/// Most intrinsics using `__tile1024i` (except for the store intrinsics) have a destination parameter +/// of type `*mut __tile1024i`, and it expects the `rows` and `colsb` fields of the destination +/// to be initialized. After the function call, the whole struct can be assumed to be initialized. +/// Moreover, for dot-product intrinsics, it is UB if the shape of two operands are not compatible +/// as a matrix product or if the shape of the destination doesn't match the expected shape. +#[derive(Copy, Clone, Debug)] +#[allow(non_camel_case_types)] +#[unstable(feature = "x86_amx_intrinsics", issue = "126622")] +pub struct __tile1024i { + pub rows: u16, + pub colsb: u16, + tile: Tile, +} + mod fxsr; #[stable(feature = "simd_x86", since = "1.27.0")] pub use self::fxsr::*; diff --git a/crates/stdarch-test/src/lib.rs b/crates/stdarch-test/src/lib.rs index ecaf95f617..c468ebd12b 100644 --- a/crates/stdarch-test/src/lib.rs +++ b/crates/stdarch-test/src/lib.rs @@ -172,6 +172,10 @@ pub fn assert(shim_addr: usize, fnname: &str, expected: &str) { // vst1q_p64_x4_nop : #instructions = 33 >= 22 (limit) "nop" if fnname.contains("vst1q_p64") => 34, + // AMX intrinsics generate a lot of move instructions to load/store the tile registers + // due to Rust ABI + _ if fnname.contains("___tile") => 165, + // Original limit was 20 instructions, but ARM DSP Intrinsics // are exactly 20 instructions long. So, bump the limit to 22 // instead of adding here a long list of exceptions. diff --git a/crates/stdarch-verify/src/lib.rs b/crates/stdarch-verify/src/lib.rs index f7304ab326..5412ab466a 100644 --- a/crates/stdarch-verify/src/lib.rs +++ b/crates/stdarch-verify/src/lib.rs @@ -202,6 +202,7 @@ fn to_type(t: &syn::Type) -> proc_macro2::TokenStream { "_MM_MANTISSA_NORM_ENUM" => quote! { &MM_MANTISSA_NORM_ENUM }, "_MM_MANTISSA_SIGN_ENUM" => quote! { &MM_MANTISSA_SIGN_ENUM }, "_MM_PERM_ENUM" => quote! { &MM_PERM_ENUM }, + "__tile1024i" => quote! { &TILE1024I }, "bool" => quote! { &BOOL }, "bf16" => quote! { &BF16 }, "f16" => quote! { &F16 }, diff --git a/crates/stdarch-verify/tests/x86-intel.rs b/crates/stdarch-verify/tests/x86-intel.rs index 024a873de1..aad19ca55a 100644 --- a/crates/stdarch-verify/tests/x86-intel.rs +++ b/crates/stdarch-verify/tests/x86-intel.rs @@ -62,6 +62,7 @@ static MM_CMPINT_ENUM: Type = Type::MM_CMPINT_ENUM; static MM_MANTISSA_NORM_ENUM: Type = Type::MM_MANTISSA_NORM_ENUM; static MM_MANTISSA_SIGN_ENUM: Type = Type::MM_MANTISSA_SIGN_ENUM; static MM_PERM_ENUM: Type = Type::MM_PERM_ENUM; +static TILE1024I: Type = Type::TILE1024I; static TUPLE: Type = Type::Tuple; static CPUID: Type = Type::CpuidResult; @@ -102,6 +103,7 @@ enum Type { CpuidResult, Never, Ordering, + TILE1024I, } stdarch_verify::x86_functions!(static FUNCTIONS); @@ -774,6 +776,7 @@ fn equate( (&Type::MMASK32, "__mmask32") => {} (&Type::MMASK16, "__mmask16") => {} (&Type::MMASK8, "__mmask8") => {} + (&Type::TILE1024I, "__tile1024i") => {} (&Type::MutPtr(_type), "void*") | (&Type::ConstPtr(_type), "void const*") => { let pointed_type = pointed_type(intrinsic)?; @@ -812,6 +815,7 @@ fn equate( (&Type::MutPtr(&Type::M512BH), "__m512bh*") => {} (&Type::MutPtr(&Type::M512I), "__m512i*") => {} (&Type::MutPtr(&Type::M512D), "__m512d*") => {} + (&Type::MutPtr(&Type::TILE1024I), "__tile1024i*") => {} (&Type::ConstPtr(&Type::PrimFloat(16)), "_Float16 const*") => {} (&Type::ConstPtr(&Type::PrimFloat(32)), "float const*") => {}