From 63305962ee3d2a8b31306534b85f88fc8a46783b Mon Sep 17 00:00:00 2001 From: Eduard Valeyev Date: Sat, 23 May 2026 17:02:49 -0400 Subject: [PATCH] arena: support permuted Hadamard add/subt/mult on Tensor ToT The permuted, arena ToT x arena ToT overloads of add, subt, and mult (scaled and unscaled) previously threw "permuted ... of a tensor-of-tensors is not yet supported". This blocked CSV/PNO-based coupled-cluster, whose residual evaluates permuted ToT Hadamard products at the tile-op level (a binary Mult/Add op calling left.mult(right, perm) etc.). By the time a permuted product reaches a tile op, the expression engine has already brought both operands to a common (congruent) layout, so the elementwise product/sum is valid and perm is purely the result permutation. Compute the unpermuted result, then apply perm as a post-pass via permute(), which already handles arena ToT: a shallow outer-cell reindex (arena_permute_shallow) plus an inner-slab rewrite (arena_inner_permute) when the bipartite permutation's inner part is non-trivial. This mirrors the existing numeric x arena permuted-mult branches. --- src/TiledArray/tensor/tensor.h | 67 ++++++++++++++++++---------------- 1 file changed, 36 insertions(+), 31 deletions(-) diff --git a/src/TiledArray/tensor/tensor.h b/src/TiledArray/tensor/tensor.h index ca67641e83..8399238f5f 100644 --- a/src/TiledArray/tensor/tensor.h +++ b/src/TiledArray/tensor/tensor.h @@ -1953,33 +1953,28 @@ class Tensor { return !static_cast(perm) || perm.is_identity(); } - /// Permuted add for `Tensor` ToT operands. A non-trivial - /// permutation of arena ToT tiles is not yet supported; an identity (or - /// null) permutation falls through to the plain element-wise add. + /// Permuted add for `Tensor` ToT operands. The operands are + /// congruent by the time a permuted product reaches a tile op, so the + /// elementwise `add(right)` is valid and `perm` is the result permutation; + /// `permute` applies it (shallow outer reindex + inner-slab rewrite). template requires(is_arena_tensor_v && is_arena_tensor_v && detail::is_permutation_v) Tensor add(const Right& right, const Perm& perm) const { - if (!arena_perm_is_trivial(perm)) - TA_EXCEPTION( - "TA::Tensor::add: permuted add of a tensor-of-tensors " - "is not yet supported"); - return add(right); + auto result = add(right); + return arena_perm_is_trivial(perm) ? result : result.permute(perm); } /// Permuted scaled add for `Tensor` ToT operands; see the - /// permuted-add overload above for the permutation restriction. + /// permuted-add overload above for the congruent-operand rationale. template requires(is_arena_tensor_v && is_arena_tensor_v && detail::is_numeric_v && detail::is_permutation_v) Tensor add(const Right& right, const Scalar factor, const Perm& perm) const { - if (!arena_perm_is_trivial(perm)) - TA_EXCEPTION( - "TA::Tensor::add: permuted scaled add of a " - "tensor-of-tensors is not yet supported"); - return add(right, factor); + auto result = add(right, factor); + return arena_perm_is_trivial(perm) ? result : result.permute(perm); } /// Add this and \c other to construct a new tensor @@ -2382,8 +2377,15 @@ class Tensor { typename std::enable_if::value && detail::is_permutation_v>::type* = nullptr> Tensor subt(const Right& right, const Perm& perm) const { - if constexpr (is_tensor_view_v) { - // Permutation isn't supported for view inner cells (fixed storage + if constexpr (is_arena_tensor_v && + is_arena_tensor_v) { + // arena ToT x arena ToT: operands are congruent at tile-op time, so the + // elementwise `subt(right)` is valid; apply the result permutation as a + // post-pass (shallow outer reindex + inner-slab rewrite). + auto result = subt(right); + return arena_perm_is_trivial(perm) ? result : result.permute(perm); + } else if constexpr (is_tensor_view_v) { + // Permutation isn't supported for other view inner cells (fixed storage // layout). Subt+permute would require materialization. TA_EXCEPTION( "Tensor::subt(right, perm): permutation is not " @@ -2443,11 +2445,10 @@ class Tensor { Tensor subt(const Right& right, const Scalar factor, const Perm& perm) const { if constexpr (is_arena_tensor_v && is_arena_tensor_v) { - if (!arena_perm_is_trivial(perm)) - TA_EXCEPTION( - "TA::Tensor::subt: permuted scaled subt of a " - "tensor-of-tensors is not yet supported"); - return subt(right, factor); + // arena ToT x arena ToT scaled subtraction; see the unscaled permuted + // subt overload above for the congruent-operand rationale. + auto result = subt(right, factor); + return arena_perm_is_trivial(perm) ? result : result.permute(perm); } else { return binary( right, @@ -2622,11 +2623,15 @@ class Tensor { decltype(auto) mult(const Right& right, const Perm& perm) const { if constexpr (is_arena_tensor_v && is_arena_tensor_v) { - if (!arena_perm_is_trivial(perm)) - TA_EXCEPTION( - "TA::Tensor::mult: permuted mult of a " - "tensor-of-tensors is not yet supported"); - return mult(right); + // arena ToT x arena ToT Hadamard product. By the time a permuted product + // reaches a tile op, the engine has already brought both operands to a + // common (congruent) layout, so the elementwise `mult(right)` is valid; + // `perm` is the result permutation (common layout -> target). Apply it + // as a post-pass: `permute` reindexes the outer cells shallowly + // (arena_permute_shallow) and rewrites the inner slab if the inner part + // of the permutation is non-trivial (arena_inner_permute). + auto result = mult(right); + return arena_perm_is_trivial(perm) ? result : result.permute(perm); } else if constexpr (detail::is_numeric_v && is_arena_tensor_v) { // t x tot: a plain scalar tile times an arena ToT tile. The 2-arg @@ -2697,11 +2702,11 @@ class Tensor { const Perm& perm) const { if constexpr (is_arena_tensor_v && is_arena_tensor_v) { - if (!arena_perm_is_trivial(perm)) - TA_EXCEPTION( - "TA::Tensor::mult: permuted scaled mult of a " - "tensor-of-tensors is not yet supported"); - return mult(right, factor); + // arena ToT x arena ToT scaled Hadamard product; see the unscaled + // permuted mult overload above for the congruent-operand rationale. + // Scale during the elementwise product, then permute the result. + auto result = mult(right, factor); + return arena_perm_is_trivial(perm) ? result : result.permute(perm); } else { return binary( right,