Created
December 13, 2018 16:27
-
-
Save tesch1/1126425eb7cb1dfea35c9f0480111908 to your computer and use it in GitHub Desktop.
Eigen circShift and fftshift and ifftshift
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
// circ_shift.h | |
// https://stackoverflow.com/questions/46077242/eigen-modifyable-custom-expression/46301503#46301503 | |
// this file implements circShift, fftshift, and ifftshift for Eigen vectors/matrices. | |
// | |
#pragma once | |
#include <Eigen/Core> | |
template <bool B> using bool_constant = std::integral_constant<bool, B>; | |
namespace helper | |
{ | |
namespace detail | |
{ | |
template <typename T> | |
constexpr std::true_type is_matrix(Eigen::MatrixBase<T>); | |
std::false_type constexpr is_matrix(...); | |
template <typename T> | |
constexpr std::true_type is_array(Eigen::ArrayBase<T>); | |
std::false_type constexpr is_array(...); | |
} | |
template <typename T> | |
struct is_matrix : decltype(detail::is_matrix(std::declval<std::remove_cv_t<T>>())) | |
{ | |
}; | |
template <typename T> | |
struct is_array : decltype(detail::is_array(std::declval<std::remove_cv_t<T>>())) | |
{ | |
}; | |
template <typename T> | |
using is_matrix_or_array = bool_constant<is_array<T>::value || is_matrix<T>::value>; | |
/* | |
* Index something if it's not an scalar | |
*/ | |
template <typename T, typename std::enable_if<is_matrix_or_array<T>::value, int>::type = 0> | |
auto index_if_necessary(T&& thing, Eigen::Index idx) | |
{ | |
return thing(idx); | |
} | |
/* | |
* Overload for scalar. | |
*/ | |
template <typename T, typename std::enable_if<std::is_scalar<std::decay_t<T>>::value, int>::type = 0> | |
auto index_if_necessary(T&& thing, Eigen::Index) | |
{ | |
return thing; | |
} | |
} | |
namespace Eigen | |
{ | |
template <typename XprType, typename RowIndices, typename ColIndices> | |
class CircShiftedView; | |
namespace internal | |
{ | |
template <typename XprType, typename RowIndices, typename ColIndices> | |
struct traits<CircShiftedView<XprType, RowIndices, ColIndices>> | |
: traits<XprType> | |
{ | |
enum | |
{ | |
RowsAtCompileTime = traits<XprType>::RowsAtCompileTime, | |
ColsAtCompileTime = traits<XprType>::ColsAtCompileTime, | |
MaxRowsAtCompileTime = (RowsAtCompileTime != Dynamic | |
? int(RowsAtCompileTime) | |
: int(traits<XprType>::MaxRowsAtCompileTime)), | |
MaxColsAtCompileTime = (ColsAtCompileTime != Dynamic | |
? int(ColsAtCompileTime) | |
: int(traits<XprType>::MaxColsAtCompileTime)), | |
XprTypeIsRowMajor = (int(traits<XprType>::Flags) & RowMajorBit) != 0, | |
IsRowMajor = ((MaxRowsAtCompileTime == 1 && MaxColsAtCompileTime != 1) ? 1 | |
: (MaxColsAtCompileTime == 1 && MaxRowsAtCompileTime != 1) ? 0 | |
: XprTypeIsRowMajor), | |
FlagsRowMajorBit = IsRowMajor ? RowMajorBit : 0, | |
FlagsLvalueBit = is_lvalue<XprType>::value ? LvalueBit : 0, | |
Flags = (traits<XprType>::Flags & HereditaryBits) | FlagsLvalueBit | FlagsRowMajorBit | |
}; | |
}; | |
} | |
template <typename XprType, typename RowShift, typename ColShift, typename StorageKind> | |
class CircShiftedViewImpl; | |
template <typename XprType, typename RowShift, typename ColShift> | |
class CircShiftedView : public CircShiftedViewImpl<XprType, RowShift, ColShift, | |
typename internal::traits<XprType>::StorageKind> | |
{ | |
public: | |
typedef typename CircShiftedViewImpl<XprType, RowShift, ColShift, | |
typename internal::traits<XprType>::StorageKind>::Base Base; | |
EIGEN_GENERIC_PUBLIC_INTERFACE(CircShiftedView) | |
EIGEN_INHERIT_ASSIGNMENT_OPERATORS(CircShiftedView) | |
typedef typename internal::ref_selector<XprType>::non_const_type MatrixTypeNested; | |
typedef typename internal::remove_all<XprType>::type NestedExpression; | |
template <typename T0, typename T1> | |
CircShiftedView(XprType& xpr, const T0& rowShift, const T1& colShift) | |
: m_xpr(xpr), m_rowShift(rowShift), m_colShift(colShift) | |
{ | |
for (auto c = 0; c < xpr.cols(); ++c) | |
assert(std::abs(helper::index_if_necessary(m_rowShift, c)) < m_xpr.rows()); // row shift must be within +- rows()-1 | |
for (auto r = 0; r < xpr.rows(); ++r) | |
assert(std::abs(helper::index_if_necessary(m_colShift, r)) < m_xpr.cols()); // col shift must be within +- cols()-1 | |
} | |
/** \returns number of rows */ | |
Index rows() const { return m_xpr.rows(); } | |
/** \returns number of columns */ | |
Index cols() const { return m_xpr.cols(); } | |
/** \returns the nested expression */ | |
const typename internal::remove_all<XprType>::type& | |
nestedExpression() const { return m_xpr; } | |
/** \returns the nested expression */ | |
typename internal::remove_reference<XprType>::type& | |
nestedExpression() { return m_xpr.const_cast_derived(); } | |
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE | |
Index getRowIdx(Index row, Index col) const | |
{ | |
Index R = m_xpr.rows(); | |
assert(row >= 0 && row < R && col >= 0 && col < m_xpr.cols()); | |
Index r = row - helper::index_if_necessary(m_rowShift, col); | |
if (r >= R) | |
return r - R; | |
if (r < 0) | |
return r + R; | |
return r; | |
} | |
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE | |
Index getColIdx(Index row, Index col) const | |
{ | |
Index C = m_xpr.cols(); | |
assert(row >= 0 && row < m_xpr.rows() && col >= 0 && col < C); | |
Index c = col - helper::index_if_necessary(m_colShift, row); | |
if (c >= C) | |
return c - C; | |
if (c < 0) | |
return c + C; | |
return c; | |
} | |
protected: | |
MatrixTypeNested m_xpr; | |
RowShift m_rowShift; | |
ColShift m_colShift; | |
}; | |
// Generic API dispatcher | |
template <typename XprType, typename RowIndices, typename ColIndices, typename StorageKind> | |
class CircShiftedViewImpl | |
: public internal::generic_xpr_base<CircShiftedView<XprType, RowIndices, ColIndices>>::type | |
{ | |
public: | |
typedef typename internal::generic_xpr_base<CircShiftedView<XprType, RowIndices, ColIndices>>::type Base; | |
}; | |
namespace internal | |
{ | |
template <typename ArgType, typename RowIndices, typename ColIndices> | |
struct unary_evaluator<CircShiftedView<ArgType, RowIndices, ColIndices>, IndexBased> | |
: evaluator_base<CircShiftedView<ArgType, RowIndices, ColIndices>> | |
{ | |
typedef CircShiftedView<ArgType, RowIndices, ColIndices> XprType; | |
enum | |
{ | |
CoeffReadCost = (evaluator<ArgType>::CoeffReadCost | |
+ NumTraits<Index>::AddCost /* for comparison */ | |
+ NumTraits<Index>::AddCost) /* for addition */, | |
Flags = (evaluator<ArgType>::Flags & HereditaryBits), | |
Alignment = 0 | |
}; | |
EIGEN_DEVICE_FUNC explicit unary_evaluator(const XprType& xpr) : m_argImpl(xpr.nestedExpression()), m_xpr(xpr) | |
{ | |
EIGEN_INTERNAL_CHECK_COST_VALUE(CoeffReadCost); | |
} | |
typedef typename XprType::Scalar Scalar; | |
typedef typename XprType::CoeffReturnType CoeffReturnType; | |
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE | |
CoeffReturnType coeff(Index row, Index col) const | |
{ | |
return m_argImpl.coeff(m_xpr.getRowIdx(row, col), m_xpr.getColIdx(row, col)); | |
} | |
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE | |
CoeffReturnType coeff(Index idx) const | |
{ | |
if (m_xpr.cols() == 1) | |
return m_argImpl.coeff(m_xpr.getRowIdx(idx, 1), 1); | |
if (m_xpr.rows() == 1) | |
return m_argImpl.coeff(1, m_xpr.getColIdx(1, idx)); | |
assert(m_xpr.cols() == 1 || m_xpr.rows() == 1); | |
// default no-assert case - assume col vector | |
return m_argImpl.coeff(m_xpr.getRowIdx(idx, 1), 1); | |
} | |
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE | |
Scalar& coeffRef(Index row, Index col) | |
{ | |
assert(row >= 0 && row < m_xpr.rows() && col >= 0 && col < m_xpr.cols()); | |
return m_argImpl.coeffRef(m_xpr.getRowIdx(row, col), m_xpr.getColIdx(row, col)); | |
} | |
protected: | |
evaluator<ArgType> m_argImpl; | |
const XprType& m_xpr; | |
}; | |
} // end namespace internal | |
} // end namespace Eigen | |
template <typename XprType, typename RowShift, typename ColShift> | |
auto circShift(Eigen::DenseBase<XprType>& x, RowShift r, ColShift c) | |
{ | |
return Eigen::CircShiftedView<XprType, RowShift, ColShift>(x.derived(), r, c); | |
} | |
template <typename XprType> | |
auto fftshift(Eigen::DenseBase<XprType>& x) | |
{ | |
Eigen::Index rs = x.rows() / 2; | |
Eigen::Index cs = x.cols() / 2; | |
return Eigen::CircShiftedView<XprType, Eigen::Index, Eigen::Index>(x.derived(), rs, cs); | |
} | |
template <typename XprType> | |
auto ifftshift(Eigen::DenseBase<XprType>& x) | |
{ | |
Eigen::Index rs = (x.rows() + 1) / 2; | |
Eigen::Index cs = (x.cols() + 1) / 2; | |
return Eigen::CircShiftedView<XprType, Eigen::Index, Eigen::Index>(x.derived(), rs, cs); | |
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
// main.cpp | |
#include "circ_shift.hpp" | |
#include <iostream> | |
#include <Eigen/Core> | |
using namespace Eigen; | |
int main() | |
{ | |
ArrayXXf x(4, 2); | |
x.transpose() << 1, 2, 3, 4, 10, 20, 30, 40; | |
Vector2i rowShift; | |
rowShift << 3, -3; // rotate col 1 by 3 and col 2 by -3 | |
Index colShift = 1; // flip columns | |
std::cout << "original: " << std::endl << x << std::endl; | |
auto shifted = circShift(x, rowShift, colShift); | |
std::cout << "shifted: " << std::endl << shifted << std::endl; | |
shifted.block(2,0,2,1) << -1, -2; // will appear in row 3 and 0. | |
shifted.col(1) << 2,4,6,8; // shifted col 1 is col 0 of the original | |
std::cout << "modified original:" << std::endl << x << std::endl; | |
MatrixXf m(3,4); | |
m << 1,2,3,4, 5,6,7,8, 9,10,11,12; | |
std::cout << "m:" << std::endl << m << std::endl; | |
std::cout << "fftshift(m):" << std::endl << fftshift(m) << std::endl; | |
std::cout << "ifftshift(m):" << std::endl << ifftshift(m) << std::endl; | |
auto mm = fftshift(m); | |
std::cout << "ifftshift(fftshift(m)):" << std::endl << ifftshift(mm) << std::endl; | |
return 0; | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
nice!!!