Skip to content

Commit 35a30fe

Browse files
authored
Merge pull request #3 from pingelit/copilot/fix-nd-span-iterator-strides
Fix nd_span iterators to respect strides of non-contiguous views
2 parents d3e9064 + c5bf0a1 commit 35a30fe

3 files changed

Lines changed: 372 additions & 13 deletions

File tree

.gitignore

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,4 @@
11
.vscode/
2-
build/
2+
build/
3+
_codeql_build_dir/
4+
_codeql_detected_source_root

include/nd_array/nd_array.hpp

Lines changed: 212 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
#include <algorithm>
44
#include <array>
55
#include <initializer_list>
6+
#include <iterator>
67
#include <memory>
78
#include <numeric>
89
#include <stdexcept>
@@ -156,6 +157,202 @@ namespace cppa
156157
seen[axis] = true;
157158
}
158159
}
160+
/// \brief Stride-aware random access iterator for nd_span
161+
/// \tparam ElementType Element type (use const-qualified type for a const iterator)
162+
/// \tparam MaxRank Maximum number of dimensions
163+
template<typename ElementType, size_t MaxRank>
164+
class nd_iterator
165+
{
166+
public:
167+
using size_type = size_t;
168+
using pointer = ElementType*;
169+
using difference_type = std::ptrdiff_t;
170+
using value_type = std::remove_cv_t<ElementType>;
171+
using reference = ElementType&;
172+
using iterator_category = std::random_access_iterator_tag;
173+
174+
nd_iterator( ) = default;
175+
176+
/// \brief Constructs an iterator at a given flat position within the span
177+
/// \param t_data Pointer to the first element of the span
178+
/// \param t_extents Extent (size) of each dimension
179+
/// \param t_strides Stride of each dimension
180+
/// \param t_rank Number of active dimensions
181+
/// \param t_flat_start Starting flat index (0 = begin, total_size = end)
182+
nd_iterator( pointer t_data, const std::array<size_type, MaxRank>& t_extents, const std::array<size_type, MaxRank>& t_strides, size_type t_rank,
183+
size_type t_flat_start = 0 )
184+
: m_data( t_data )
185+
, m_extents( t_extents )
186+
, m_strides( t_strides )
187+
, m_rank( t_rank )
188+
, m_flat_size( compute_size<MaxRank>( t_extents, t_rank ) )
189+
, m_flat_index( t_flat_start )
190+
{
191+
if( m_flat_index > m_flat_size )
192+
{
193+
m_flat_index = m_flat_size;
194+
}
195+
update_from_flat_index( );
196+
}
197+
198+
/// \brief Allow implicit conversion from non-const to const iterator
199+
template<typename OtherTy, std::enable_if_t<std::is_const_v<ElementType> && !std::is_const_v<OtherTy>, int> = 0>
200+
nd_iterator( const nd_iterator<OtherTy, MaxRank>& t_other )
201+
: m_data( t_other.m_data )
202+
, m_ptr( t_other.m_ptr )
203+
, m_extents( t_other.m_extents )
204+
, m_strides( t_other.m_strides )
205+
, m_indices( t_other.m_indices )
206+
, m_rank( t_other.m_rank )
207+
, m_flat_size( t_other.m_flat_size )
208+
, m_flat_index( t_other.m_flat_index )
209+
{
210+
}
211+
212+
// --- Element access ---
213+
214+
[[nodiscard]] reference operator*( ) const { return *m_ptr; }
215+
[[nodiscard]] pointer operator->( ) const { return m_ptr; }
216+
217+
/// \brief Returns a reference to the element at offset n from this iterator
218+
[[nodiscard]] reference operator[]( difference_type t_n ) const
219+
{
220+
nd_iterator tmp = *this;
221+
tmp += t_n;
222+
return *tmp;
223+
}
224+
225+
// --- Increment / decrement ---
226+
227+
nd_iterator& operator++( )
228+
{
229+
advance( 1 );
230+
return *this;
231+
}
232+
233+
nd_iterator operator++( int )
234+
{
235+
nd_iterator tmp = *this;
236+
advance( 1 );
237+
return tmp;
238+
}
239+
240+
nd_iterator& operator--( )
241+
{
242+
advance( -1 );
243+
return *this;
244+
}
245+
246+
nd_iterator operator--( int )
247+
{
248+
nd_iterator tmp = *this;
249+
advance( -1 );
250+
return tmp;
251+
}
252+
253+
// --- Arithmetic ---
254+
255+
nd_iterator& operator+=( difference_type t_n )
256+
{
257+
advance( t_n );
258+
return *this;
259+
}
260+
261+
nd_iterator& operator-=( difference_type t_n )
262+
{
263+
advance( -t_n );
264+
return *this;
265+
}
266+
267+
[[nodiscard]] nd_iterator operator+( difference_type t_n ) const
268+
{
269+
nd_iterator tmp = *this;
270+
tmp += t_n;
271+
return tmp;
272+
}
273+
274+
[[nodiscard]] nd_iterator operator-( difference_type t_n ) const
275+
{
276+
nd_iterator tmp = *this;
277+
tmp -= t_n;
278+
return tmp;
279+
}
280+
281+
[[nodiscard]] friend nd_iterator operator+( difference_type t_n, const nd_iterator& t_it ) { return t_it + t_n; }
282+
283+
/// \brief Returns the signed distance between two iterators from the same span
284+
[[nodiscard]] difference_type operator-( const nd_iterator& t_other ) const
285+
{
286+
return static_cast<difference_type>( m_flat_index ) - static_cast<difference_type>( t_other.m_flat_index );
287+
}
288+
289+
// --- Comparison ---
290+
291+
[[nodiscard]] bool operator==( const nd_iterator& t_other ) const { return m_flat_index == t_other.m_flat_index; }
292+
[[nodiscard]] bool operator!=( const nd_iterator& t_other ) const { return !( *this == t_other ); }
293+
[[nodiscard]] bool operator<( const nd_iterator& t_other ) const { return m_flat_index < t_other.m_flat_index; }
294+
[[nodiscard]] bool operator>( const nd_iterator& t_other ) const { return t_other < *this; }
295+
[[nodiscard]] bool operator<=( const nd_iterator& t_other ) const { return !( *this > t_other ); }
296+
[[nodiscard]] bool operator>=( const nd_iterator& t_other ) const { return !( *this < t_other ); }
297+
298+
private:
299+
pointer m_data = nullptr;
300+
pointer m_ptr = nullptr;
301+
std::array<size_type, MaxRank> m_extents{ };
302+
std::array<size_type, MaxRank> m_strides{ };
303+
std::array<size_type, MaxRank> m_indices{ };
304+
size_type m_rank = 0;
305+
size_type m_flat_size = 0;
306+
size_type m_flat_index = 0;
307+
308+
// Grant access to the opposite const-ness specialisation for the converting constructor
309+
friend class nd_iterator<std::conditional_t<std::is_const_v<ElementType>, std::remove_const_t<ElementType>, const ElementType>, MaxRank>;
310+
311+
/// \brief Updates the multi-dimensional indices and element pointer from the current flat index
312+
void update_from_flat_index( )
313+
{
314+
if( m_flat_index >= m_flat_size )
315+
{
316+
m_ptr = nullptr;
317+
return;
318+
}
319+
// Decompose flat index into per-dimension indices (row-major)
320+
size_type f = m_flat_index;
321+
for( int i = static_cast<int>( m_rank ) - 1; i >= 0; --i )
322+
{
323+
m_indices[i] = f % m_extents[i];
324+
f /= m_extents[i];
325+
}
326+
// Compute element pointer using strides
327+
m_ptr = m_data;
328+
for( size_type i = 0; i < m_rank; ++i )
329+
{
330+
m_ptr += m_indices[i] * m_strides[i];
331+
}
332+
}
333+
334+
/// \brief Moves the iterator by t_n positions (positive = forward, negative = backward)
335+
void advance( difference_type t_n )
336+
{
337+
if( t_n == 0 )
338+
return;
339+
const auto signed_index = static_cast<difference_type>( m_flat_index ) + t_n;
340+
if( signed_index < 0 )
341+
{
342+
m_flat_index = 0;
343+
}
344+
else if( static_cast<size_type>( signed_index ) >= m_flat_size )
345+
{
346+
m_flat_index = m_flat_size;
347+
}
348+
else
349+
{
350+
m_flat_index = static_cast<size_type>( signed_index );
351+
}
352+
update_from_flat_index( );
353+
}
354+
};
355+
159356
} // namespace detail
160357

161358
/// \class nd_span
@@ -514,23 +711,26 @@ namespace cppa
514711
/// \return Const pointer to the first element
515712
[[nodiscard]] const_pointer data( ) const noexcept { return m_data; }
516713

517-
/// \brief Returns a pointer to the first element for flat iteration
518-
[[nodiscard]] pointer begin( ) noexcept { return m_data; }
714+
using iterator = detail::nd_iterator<Ty, MaxRank>; ///< Mutable stride-aware iterator
715+
using const_iterator = detail::nd_iterator<const Ty, MaxRank>; ///< Const stride-aware iterator
519716

520-
/// \brief Returns a pointer past the last element for flat iteration
521-
[[nodiscard]] pointer end( ) noexcept { return m_data + size( ); }
717+
/// \brief Returns a stride-aware iterator to the first element
718+
[[nodiscard]] iterator begin( ) noexcept { return iterator( m_data, m_extents, m_strides, m_rank ); }
522719

523-
/// \brief Returns a const pointer to the first element for flat iteration
524-
[[nodiscard]] const_pointer begin( ) const noexcept { return m_data; }
720+
/// \brief Returns a past-the-end iterator
721+
[[nodiscard]] iterator end( ) noexcept { return iterator( m_data, m_extents, m_strides, m_rank, size( ) ); }
525722

526-
/// \brief Returns a const pointer past the last element for flat iteration
527-
[[nodiscard]] const_pointer end( ) const noexcept { return m_data + size( ); }
723+
/// \brief Returns a stride-aware const iterator to the first element
724+
[[nodiscard]] const_iterator begin( ) const noexcept { return const_iterator( m_data, m_extents, m_strides, m_rank ); }
528725

529-
/// \brief Returns a const pointer to the first element for flat iteration
530-
[[nodiscard]] const_pointer cbegin( ) const noexcept { return m_data; }
726+
/// \brief Returns a past-the-end const iterator
727+
[[nodiscard]] const_iterator end( ) const noexcept { return const_iterator( m_data, m_extents, m_strides, m_rank, size( ) ); }
531728

532-
/// \brief Returns a const pointer past the last element for flat iteration
533-
[[nodiscard]] const_pointer cend( ) const noexcept { return m_data + size( ); }
729+
/// \brief Returns a stride-aware const iterator to the first element
730+
[[nodiscard]] const_iterator cbegin( ) const noexcept { return const_iterator( m_data, m_extents, m_strides, m_rank ); }
731+
732+
/// \brief Returns a past-the-end const iterator
733+
[[nodiscard]] const_iterator cend( ) const noexcept { return const_iterator( m_data, m_extents, m_strides, m_rank, size( ) ); }
534734

535735
private:
536736
pointer m_data; ///< Pointer to the first element

0 commit comments

Comments
 (0)