Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
67 changes: 36 additions & 31 deletions src/TiledArray/tensor/tensor.h
Original file line number Diff line number Diff line change
Expand Up @@ -1953,33 +1953,28 @@ class Tensor {
return !static_cast<bool>(perm) || perm.is_identity();
}

/// Permuted add for `Tensor<ArenaTensor>` ToT operands. A non-trivial
/// permutation of arena ToT tiles is not yet supported; an identity (or
/// null) permutation falls through to the plain element-wise add.
/// Permuted add for `Tensor<ArenaTensor>` ToT operands. The operands are
/// congruent by the time a permuted product reaches a tile op, so the
/// elementwise `add(right)` is valid and `perm` is the result permutation;
/// `permute` applies it (shallow outer reindex + inner-slab rewrite).
template <typename Right, typename Perm>
requires(is_arena_tensor_v<value_type> &&
is_arena_tensor_v<typename Right::value_type> &&
detail::is_permutation_v<Perm>)
Tensor add(const Right& right, const Perm& perm) const {
if (!arena_perm_is_trivial(perm))
TA_EXCEPTION(
"TA::Tensor<ArenaTensor>::add: permuted add of a tensor-of-tensors "
"is not yet supported");
return add(right);
auto result = add(right);
return arena_perm_is_trivial(perm) ? result : result.permute(perm);
}

/// Permuted scaled add for `Tensor<ArenaTensor>` ToT operands; see the
/// permuted-add overload above for the permutation restriction.
/// permuted-add overload above for the congruent-operand rationale.
template <typename Right, typename Scalar, typename Perm>
requires(is_arena_tensor_v<value_type> &&
is_arena_tensor_v<typename Right::value_type> &&
detail::is_numeric_v<Scalar> && detail::is_permutation_v<Perm>)
Tensor add(const Right& right, const Scalar factor, const Perm& perm) const {
if (!arena_perm_is_trivial(perm))
TA_EXCEPTION(
"TA::Tensor<ArenaTensor>::add: permuted scaled add of a "
"tensor-of-tensors is not yet supported");
return add(right, factor);
auto result = add(right, factor);
return arena_perm_is_trivial(perm) ? result : result.permute(perm);
}

/// Add this and \c other to construct a new tensor
Expand Down Expand Up @@ -2382,8 +2377,15 @@ class Tensor {
typename std::enable_if<is_tensor<Right>::value &&
detail::is_permutation_v<Perm>>::type* = nullptr>
Tensor subt(const Right& right, const Perm& perm) const {
if constexpr (is_tensor_view_v<value_type>) {
// Permutation isn't supported for view inner cells (fixed storage
if constexpr (is_arena_tensor_v<value_type> &&
is_arena_tensor_v<typename Right::value_type>) {
// arena ToT x arena ToT: operands are congruent at tile-op time, so the
// elementwise `subt(right)` is valid; apply the result permutation as a
// post-pass (shallow outer reindex + inner-slab rewrite).
auto result = subt(right);
return arena_perm_is_trivial(perm) ? result : result.permute(perm);
} else if constexpr (is_tensor_view_v<value_type>) {
// Permutation isn't supported for other view inner cells (fixed storage
// layout). Subt+permute would require materialization.
TA_EXCEPTION(
"Tensor<View>::subt(right, perm): permutation is not "
Expand Down Expand Up @@ -2443,11 +2445,10 @@ class Tensor {
Tensor subt(const Right& right, const Scalar factor, const Perm& perm) const {
if constexpr (is_arena_tensor_v<value_type> &&
is_arena_tensor_v<typename Right::value_type>) {
if (!arena_perm_is_trivial(perm))
TA_EXCEPTION(
"TA::Tensor<ArenaTensor>::subt: permuted scaled subt of a "
"tensor-of-tensors is not yet supported");
return subt(right, factor);
// arena ToT x arena ToT scaled subtraction; see the unscaled permuted
// subt overload above for the congruent-operand rationale.
auto result = subt(right, factor);
return arena_perm_is_trivial(perm) ? result : result.permute(perm);
} else {
return binary(
right,
Expand Down Expand Up @@ -2622,11 +2623,15 @@ class Tensor {
decltype(auto) mult(const Right& right, const Perm& perm) const {
if constexpr (is_arena_tensor_v<value_type> &&
is_arena_tensor_v<typename Right::value_type>) {
if (!arena_perm_is_trivial(perm))
TA_EXCEPTION(
"TA::Tensor<ArenaTensor>::mult: permuted mult of a "
"tensor-of-tensors is not yet supported");
return mult(right);
// arena ToT x arena ToT Hadamard product. By the time a permuted product
// reaches a tile op, the engine has already brought both operands to a
// common (congruent) layout, so the elementwise `mult(right)` is valid;
// `perm` is the result permutation (common layout -> target). Apply it
// as a post-pass: `permute` reindexes the outer cells shallowly
// (arena_permute_shallow) and rewrites the inner slab if the inner part
// of the permutation is non-trivial (arena_inner_permute).
auto result = mult(right);
return arena_perm_is_trivial(perm) ? result : result.permute(perm);
} else if constexpr (detail::is_numeric_v<value_type> &&
is_arena_tensor_v<typename Right::value_type>) {
// t x tot: a plain scalar tile times an arena ToT tile. The 2-arg
Expand Down Expand Up @@ -2697,11 +2702,11 @@ class Tensor {
const Perm& perm) const {
if constexpr (is_arena_tensor_v<value_type> &&
is_arena_tensor_v<typename Right::value_type>) {
if (!arena_perm_is_trivial(perm))
TA_EXCEPTION(
"TA::Tensor<ArenaTensor>::mult: permuted scaled mult of a "
"tensor-of-tensors is not yet supported");
return mult(right, factor);
// arena ToT x arena ToT scaled Hadamard product; see the unscaled
// permuted mult overload above for the congruent-operand rationale.
// Scale during the elementwise product, then permute the result.
auto result = mult(right, factor);
return arena_perm_is_trivial(perm) ? result : result.permute(perm);
} else {
return binary(
right,
Expand Down
Loading