From 7a8d0d8bc2572707c9d35006f30ea835c86954b0 Mon Sep 17 00:00:00 2001 From: sotech117 Date: Tue, 9 Apr 2024 03:14:17 -0400 Subject: first draft to generate waves --- Eigen/src/Core/arch/AVX/Complex.h | 372 ++ Eigen/src/Core/arch/AVX/MathFunctions.h | 228 + Eigen/src/Core/arch/AVX/PacketMath.h | 1574 +++++++ Eigen/src/Core/arch/AVX/TypeCasting.h | 115 + Eigen/src/Core/arch/AVX512/Complex.h | 422 ++ Eigen/src/Core/arch/AVX512/MathFunctions.h | 362 ++ Eigen/src/Core/arch/AVX512/PacketMath.h | 2303 ++++++++++ Eigen/src/Core/arch/AVX512/TypeCasting.h | 89 + Eigen/src/Core/arch/AltiVec/Complex.h | 417 ++ Eigen/src/Core/arch/AltiVec/MathFunctions.h | 90 + Eigen/src/Core/arch/AltiVec/MatrixProduct.h | 2937 +++++++++++++ Eigen/src/Core/arch/AltiVec/MatrixProductCommon.h | 221 + Eigen/src/Core/arch/AltiVec/MatrixProductMMA.h | 629 +++ Eigen/src/Core/arch/AltiVec/PacketMath.h | 2711 ++++++++++++ Eigen/src/Core/arch/CUDA/Complex.h | 258 ++ Eigen/src/Core/arch/Default/BFloat16.h | 700 +++ Eigen/src/Core/arch/Default/ConjHelper.h | 117 + .../Core/arch/Default/GenericPacketMathFunctions.h | 1649 +++++++ .../arch/Default/GenericPacketMathFunctionsFwd.h | 110 + Eigen/src/Core/arch/Default/Half.h | 942 ++++ Eigen/src/Core/arch/Default/Settings.h | 49 + Eigen/src/Core/arch/Default/TypeCasting.h | 120 + Eigen/src/Core/arch/GPU/MathFunctions.h | 103 + Eigen/src/Core/arch/GPU/PacketMath.h | 1685 +++++++ Eigen/src/Core/arch/GPU/TypeCasting.h | 80 + Eigen/src/Core/arch/HIP/hcc/math_constants.h | 23 + Eigen/src/Core/arch/MSA/Complex.h | 648 +++ Eigen/src/Core/arch/MSA/MathFunctions.h | 387 ++ Eigen/src/Core/arch/MSA/PacketMath.h | 1233 ++++++ Eigen/src/Core/arch/NEON/Complex.h | 584 +++ Eigen/src/Core/arch/NEON/GeneralBlockPanelKernel.h | 183 + Eigen/src/Core/arch/NEON/MathFunctions.h | 75 + Eigen/src/Core/arch/NEON/PacketMath.h | 4587 ++++++++++++++++++++ Eigen/src/Core/arch/NEON/TypeCasting.h | 1419 ++++++ Eigen/src/Core/arch/SSE/Complex.h | 351 ++ Eigen/src/Core/arch/SSE/MathFunctions.h | 199 + Eigen/src/Core/arch/SSE/PacketMath.h | 1505 +++++++ Eigen/src/Core/arch/SSE/TypeCasting.h | 142 + Eigen/src/Core/arch/SVE/MathFunctions.h | 44 + Eigen/src/Core/arch/SVE/PacketMath.h | 752 ++++ Eigen/src/Core/arch/SVE/TypeCasting.h | 49 + Eigen/src/Core/arch/SYCL/InteropHeaders.h | 232 + Eigen/src/Core/arch/SYCL/MathFunctions.h | 301 ++ Eigen/src/Core/arch/SYCL/PacketMath.h | 670 +++ Eigen/src/Core/arch/SYCL/SyclMemoryModel.h | 694 +++ Eigen/src/Core/arch/SYCL/TypeCasting.h | 85 + Eigen/src/Core/arch/ZVector/Complex.h | 426 ++ Eigen/src/Core/arch/ZVector/MathFunctions.h | 233 + Eigen/src/Core/arch/ZVector/PacketMath.h | 1060 +++++ 49 files changed, 34165 insertions(+) create mode 100644 Eigen/src/Core/arch/AVX/Complex.h create mode 100644 Eigen/src/Core/arch/AVX/MathFunctions.h create mode 100644 Eigen/src/Core/arch/AVX/PacketMath.h create mode 100644 Eigen/src/Core/arch/AVX/TypeCasting.h create mode 100644 Eigen/src/Core/arch/AVX512/Complex.h create mode 100644 Eigen/src/Core/arch/AVX512/MathFunctions.h create mode 100644 Eigen/src/Core/arch/AVX512/PacketMath.h create mode 100644 Eigen/src/Core/arch/AVX512/TypeCasting.h create mode 100644 Eigen/src/Core/arch/AltiVec/Complex.h create mode 100644 Eigen/src/Core/arch/AltiVec/MathFunctions.h create mode 100644 Eigen/src/Core/arch/AltiVec/MatrixProduct.h create mode 100644 Eigen/src/Core/arch/AltiVec/MatrixProductCommon.h create mode 100644 Eigen/src/Core/arch/AltiVec/MatrixProductMMA.h create mode 100755 Eigen/src/Core/arch/AltiVec/PacketMath.h create mode 100644 Eigen/src/Core/arch/CUDA/Complex.h create mode 100644 Eigen/src/Core/arch/Default/BFloat16.h create mode 100644 Eigen/src/Core/arch/Default/ConjHelper.h create mode 100644 Eigen/src/Core/arch/Default/GenericPacketMathFunctions.h create mode 100644 Eigen/src/Core/arch/Default/GenericPacketMathFunctionsFwd.h create mode 100644 Eigen/src/Core/arch/Default/Half.h create mode 100644 Eigen/src/Core/arch/Default/Settings.h create mode 100644 Eigen/src/Core/arch/Default/TypeCasting.h create mode 100644 Eigen/src/Core/arch/GPU/MathFunctions.h create mode 100644 Eigen/src/Core/arch/GPU/PacketMath.h create mode 100644 Eigen/src/Core/arch/GPU/TypeCasting.h create mode 100644 Eigen/src/Core/arch/HIP/hcc/math_constants.h create mode 100644 Eigen/src/Core/arch/MSA/Complex.h create mode 100644 Eigen/src/Core/arch/MSA/MathFunctions.h create mode 100644 Eigen/src/Core/arch/MSA/PacketMath.h create mode 100644 Eigen/src/Core/arch/NEON/Complex.h create mode 100644 Eigen/src/Core/arch/NEON/GeneralBlockPanelKernel.h create mode 100644 Eigen/src/Core/arch/NEON/MathFunctions.h create mode 100644 Eigen/src/Core/arch/NEON/PacketMath.h create mode 100644 Eigen/src/Core/arch/NEON/TypeCasting.h create mode 100644 Eigen/src/Core/arch/SSE/Complex.h create mode 100644 Eigen/src/Core/arch/SSE/MathFunctions.h create mode 100755 Eigen/src/Core/arch/SSE/PacketMath.h create mode 100644 Eigen/src/Core/arch/SSE/TypeCasting.h create mode 100644 Eigen/src/Core/arch/SVE/MathFunctions.h create mode 100644 Eigen/src/Core/arch/SVE/PacketMath.h create mode 100644 Eigen/src/Core/arch/SVE/TypeCasting.h create mode 100644 Eigen/src/Core/arch/SYCL/InteropHeaders.h create mode 100644 Eigen/src/Core/arch/SYCL/MathFunctions.h create mode 100644 Eigen/src/Core/arch/SYCL/PacketMath.h create mode 100644 Eigen/src/Core/arch/SYCL/SyclMemoryModel.h create mode 100644 Eigen/src/Core/arch/SYCL/TypeCasting.h create mode 100644 Eigen/src/Core/arch/ZVector/Complex.h create mode 100644 Eigen/src/Core/arch/ZVector/MathFunctions.h create mode 100755 Eigen/src/Core/arch/ZVector/PacketMath.h (limited to 'Eigen/src/Core/arch') diff --git a/Eigen/src/Core/arch/AVX/Complex.h b/Eigen/src/Core/arch/AVX/Complex.h new file mode 100644 index 0000000..ab7bd6c --- /dev/null +++ b/Eigen/src/Core/arch/AVX/Complex.h @@ -0,0 +1,372 @@ +// This file is part of Eigen, a lightweight C++ template library +// for linear algebra. +// +// Copyright (C) 2014 Benoit Steiner (benoit.steiner.goog@gmail.com) +// +// This Source Code Form is subject to the terms of the Mozilla +// Public License v. 2.0. If a copy of the MPL was not distributed +// with this file, You can obtain one at http://mozilla.org/MPL/2.0/. + +#ifndef EIGEN_COMPLEX_AVX_H +#define EIGEN_COMPLEX_AVX_H + +namespace Eigen { + +namespace internal { + +//---------- float ---------- +struct Packet4cf +{ + EIGEN_STRONG_INLINE Packet4cf() {} + EIGEN_STRONG_INLINE explicit Packet4cf(const __m256& a) : v(a) {} + __m256 v; +}; + +#ifndef EIGEN_VECTORIZE_AVX512 +template<> struct packet_traits > : default_packet_traits +{ + typedef Packet4cf type; + typedef Packet2cf half; + enum { + Vectorizable = 1, + AlignedOnScalar = 1, + size = 4, + HasHalfPacket = 1, + + HasAdd = 1, + HasSub = 1, + HasMul = 1, + HasDiv = 1, + HasNegate = 1, + HasSqrt = 1, + HasAbs = 0, + HasAbs2 = 0, + HasMin = 0, + HasMax = 0, + HasSetLinear = 0 + }; +}; +#endif + +template<> struct unpacket_traits { + typedef std::complex type; + typedef Packet2cf half; + typedef Packet8f as_real; + enum { + size=4, + alignment=Aligned32, + vectorizable=true, + masked_load_available=false, + masked_store_available=false + }; +}; + +template<> EIGEN_STRONG_INLINE Packet4cf padd(const Packet4cf& a, const Packet4cf& b) { return Packet4cf(_mm256_add_ps(a.v,b.v)); } +template<> EIGEN_STRONG_INLINE Packet4cf psub(const Packet4cf& a, const Packet4cf& b) { return Packet4cf(_mm256_sub_ps(a.v,b.v)); } +template<> EIGEN_STRONG_INLINE Packet4cf pnegate(const Packet4cf& a) +{ + return Packet4cf(pnegate(a.v)); +} +template<> EIGEN_STRONG_INLINE Packet4cf pconj(const Packet4cf& a) +{ + const __m256 mask = _mm256_castsi256_ps(_mm256_setr_epi32(0x00000000,0x80000000,0x00000000,0x80000000,0x00000000,0x80000000,0x00000000,0x80000000)); + return Packet4cf(_mm256_xor_ps(a.v,mask)); +} + +template<> EIGEN_STRONG_INLINE Packet4cf pmul(const Packet4cf& a, const Packet4cf& b) +{ + __m256 tmp1 = _mm256_mul_ps(_mm256_moveldup_ps(a.v), b.v); + __m256 tmp2 = _mm256_mul_ps(_mm256_movehdup_ps(a.v), _mm256_permute_ps(b.v, _MM_SHUFFLE(2,3,0,1))); + __m256 result = _mm256_addsub_ps(tmp1, tmp2); + return Packet4cf(result); +} + +template <> +EIGEN_STRONG_INLINE Packet4cf pcmp_eq(const Packet4cf& a, const Packet4cf& b) { + __m256 eq = _mm256_cmp_ps(a.v, b.v, _CMP_EQ_OQ); + return Packet4cf(_mm256_and_ps(eq, _mm256_permute_ps(eq, 0xb1))); +} + +template<> EIGEN_STRONG_INLINE Packet4cf ptrue(const Packet4cf& a) { return Packet4cf(ptrue(Packet8f(a.v))); } +template<> EIGEN_STRONG_INLINE Packet4cf pand (const Packet4cf& a, const Packet4cf& b) { return Packet4cf(_mm256_and_ps(a.v,b.v)); } +template<> EIGEN_STRONG_INLINE Packet4cf por (const Packet4cf& a, const Packet4cf& b) { return Packet4cf(_mm256_or_ps(a.v,b.v)); } +template<> EIGEN_STRONG_INLINE Packet4cf pxor (const Packet4cf& a, const Packet4cf& b) { return Packet4cf(_mm256_xor_ps(a.v,b.v)); } +template<> EIGEN_STRONG_INLINE Packet4cf pandnot(const Packet4cf& a, const Packet4cf& b) { return Packet4cf(_mm256_andnot_ps(b.v,a.v)); } + +template<> EIGEN_STRONG_INLINE Packet4cf pload (const std::complex* from) { EIGEN_DEBUG_ALIGNED_LOAD return Packet4cf(pload(&numext::real_ref(*from))); } +template<> EIGEN_STRONG_INLINE Packet4cf ploadu(const std::complex* from) { EIGEN_DEBUG_UNALIGNED_LOAD return Packet4cf(ploadu(&numext::real_ref(*from))); } + + +template<> EIGEN_STRONG_INLINE Packet4cf pset1(const std::complex& from) +{ + return Packet4cf(_mm256_castpd_ps(_mm256_broadcast_sd((const double*)(const void*)&from))); +} + +template<> EIGEN_STRONG_INLINE Packet4cf ploaddup(const std::complex* from) +{ + // FIXME The following might be optimized using _mm256_movedup_pd + Packet2cf a = ploaddup(from); + Packet2cf b = ploaddup(from+1); + return Packet4cf(_mm256_insertf128_ps(_mm256_castps128_ps256(a.v), b.v, 1)); +} + +template<> EIGEN_STRONG_INLINE void pstore >(std::complex* to, const Packet4cf& from) { EIGEN_DEBUG_ALIGNED_STORE pstore(&numext::real_ref(*to), from.v); } +template<> EIGEN_STRONG_INLINE void pstoreu >(std::complex* to, const Packet4cf& from) { EIGEN_DEBUG_UNALIGNED_STORE pstoreu(&numext::real_ref(*to), from.v); } + +template<> EIGEN_DEVICE_FUNC inline Packet4cf pgather, Packet4cf>(const std::complex* from, Index stride) +{ + return Packet4cf(_mm256_set_ps(std::imag(from[3*stride]), std::real(from[3*stride]), + std::imag(from[2*stride]), std::real(from[2*stride]), + std::imag(from[1*stride]), std::real(from[1*stride]), + std::imag(from[0*stride]), std::real(from[0*stride]))); +} + +template<> EIGEN_DEVICE_FUNC inline void pscatter, Packet4cf>(std::complex* to, const Packet4cf& from, Index stride) +{ + __m128 low = _mm256_extractf128_ps(from.v, 0); + to[stride*0] = std::complex(_mm_cvtss_f32(_mm_shuffle_ps(low, low, 0)), + _mm_cvtss_f32(_mm_shuffle_ps(low, low, 1))); + to[stride*1] = std::complex(_mm_cvtss_f32(_mm_shuffle_ps(low, low, 2)), + _mm_cvtss_f32(_mm_shuffle_ps(low, low, 3))); + + __m128 high = _mm256_extractf128_ps(from.v, 1); + to[stride*2] = std::complex(_mm_cvtss_f32(_mm_shuffle_ps(high, high, 0)), + _mm_cvtss_f32(_mm_shuffle_ps(high, high, 1))); + to[stride*3] = std::complex(_mm_cvtss_f32(_mm_shuffle_ps(high, high, 2)), + _mm_cvtss_f32(_mm_shuffle_ps(high, high, 3))); + +} + +template<> EIGEN_STRONG_INLINE std::complex pfirst(const Packet4cf& a) +{ + return pfirst(Packet2cf(_mm256_castps256_ps128(a.v))); +} + +template<> EIGEN_STRONG_INLINE Packet4cf preverse(const Packet4cf& a) { + __m128 low = _mm256_extractf128_ps(a.v, 0); + __m128 high = _mm256_extractf128_ps(a.v, 1); + __m128d lowd = _mm_castps_pd(low); + __m128d highd = _mm_castps_pd(high); + low = _mm_castpd_ps(_mm_shuffle_pd(lowd,lowd,0x1)); + high = _mm_castpd_ps(_mm_shuffle_pd(highd,highd,0x1)); + __m256 result = _mm256_setzero_ps(); + result = _mm256_insertf128_ps(result, low, 1); + result = _mm256_insertf128_ps(result, high, 0); + return Packet4cf(result); +} + +template<> EIGEN_STRONG_INLINE std::complex predux(const Packet4cf& a) +{ + return predux(padd(Packet2cf(_mm256_extractf128_ps(a.v,0)), + Packet2cf(_mm256_extractf128_ps(a.v,1)))); +} + +template<> EIGEN_STRONG_INLINE std::complex predux_mul(const Packet4cf& a) +{ + return predux_mul(pmul(Packet2cf(_mm256_extractf128_ps(a.v, 0)), + Packet2cf(_mm256_extractf128_ps(a.v, 1)))); +} + +EIGEN_MAKE_CONJ_HELPER_CPLX_REAL(Packet4cf,Packet8f) + +template<> EIGEN_STRONG_INLINE Packet4cf pdiv(const Packet4cf& a, const Packet4cf& b) +{ + Packet4cf num = pmul(a, pconj(b)); + __m256 tmp = _mm256_mul_ps(b.v, b.v); + __m256 tmp2 = _mm256_shuffle_ps(tmp,tmp,0xB1); + __m256 denom = _mm256_add_ps(tmp, tmp2); + return Packet4cf(_mm256_div_ps(num.v, denom)); +} + +template<> EIGEN_STRONG_INLINE Packet4cf pcplxflip(const Packet4cf& x) +{ + return Packet4cf(_mm256_shuffle_ps(x.v, x.v, _MM_SHUFFLE(2, 3, 0 ,1))); +} + +//---------- double ---------- +struct Packet2cd +{ + EIGEN_STRONG_INLINE Packet2cd() {} + EIGEN_STRONG_INLINE explicit Packet2cd(const __m256d& a) : v(a) {} + __m256d v; +}; + +#ifndef EIGEN_VECTORIZE_AVX512 +template<> struct packet_traits > : default_packet_traits +{ + typedef Packet2cd type; + typedef Packet1cd half; + enum { + Vectorizable = 1, + AlignedOnScalar = 0, + size = 2, + HasHalfPacket = 1, + + HasAdd = 1, + HasSub = 1, + HasMul = 1, + HasDiv = 1, + HasNegate = 1, + HasSqrt = 1, + HasAbs = 0, + HasAbs2 = 0, + HasMin = 0, + HasMax = 0, + HasSetLinear = 0 + }; +}; +#endif + +template<> struct unpacket_traits { + typedef std::complex type; + typedef Packet1cd half; + typedef Packet4d as_real; + enum { + size=2, + alignment=Aligned32, + vectorizable=true, + masked_load_available=false, + masked_store_available=false + }; +}; + +template<> EIGEN_STRONG_INLINE Packet2cd padd(const Packet2cd& a, const Packet2cd& b) { return Packet2cd(_mm256_add_pd(a.v,b.v)); } +template<> EIGEN_STRONG_INLINE Packet2cd psub(const Packet2cd& a, const Packet2cd& b) { return Packet2cd(_mm256_sub_pd(a.v,b.v)); } +template<> EIGEN_STRONG_INLINE Packet2cd pnegate(const Packet2cd& a) { return Packet2cd(pnegate(a.v)); } +template<> EIGEN_STRONG_INLINE Packet2cd pconj(const Packet2cd& a) +{ + const __m256d mask = _mm256_castsi256_pd(_mm256_set_epi32(0x80000000,0x0,0x0,0x0,0x80000000,0x0,0x0,0x0)); + return Packet2cd(_mm256_xor_pd(a.v,mask)); +} + +template<> EIGEN_STRONG_INLINE Packet2cd pmul(const Packet2cd& a, const Packet2cd& b) +{ + __m256d tmp1 = _mm256_shuffle_pd(a.v,a.v,0x0); + __m256d even = _mm256_mul_pd(tmp1, b.v); + __m256d tmp2 = _mm256_shuffle_pd(a.v,a.v,0xF); + __m256d tmp3 = _mm256_shuffle_pd(b.v,b.v,0x5); + __m256d odd = _mm256_mul_pd(tmp2, tmp3); + return Packet2cd(_mm256_addsub_pd(even, odd)); +} + +template <> +EIGEN_STRONG_INLINE Packet2cd pcmp_eq(const Packet2cd& a, const Packet2cd& b) { + __m256d eq = _mm256_cmp_pd(a.v, b.v, _CMP_EQ_OQ); + return Packet2cd(pand(eq, _mm256_permute_pd(eq, 0x5))); +} + +template<> EIGEN_STRONG_INLINE Packet2cd ptrue(const Packet2cd& a) { return Packet2cd(ptrue(Packet4d(a.v))); } +template<> EIGEN_STRONG_INLINE Packet2cd pand (const Packet2cd& a, const Packet2cd& b) { return Packet2cd(_mm256_and_pd(a.v,b.v)); } +template<> EIGEN_STRONG_INLINE Packet2cd por (const Packet2cd& a, const Packet2cd& b) { return Packet2cd(_mm256_or_pd(a.v,b.v)); } +template<> EIGEN_STRONG_INLINE Packet2cd pxor (const Packet2cd& a, const Packet2cd& b) { return Packet2cd(_mm256_xor_pd(a.v,b.v)); } +template<> EIGEN_STRONG_INLINE Packet2cd pandnot(const Packet2cd& a, const Packet2cd& b) { return Packet2cd(_mm256_andnot_pd(b.v,a.v)); } + +template<> EIGEN_STRONG_INLINE Packet2cd pload (const std::complex* from) +{ EIGEN_DEBUG_ALIGNED_LOAD return Packet2cd(pload((const double*)from)); } +template<> EIGEN_STRONG_INLINE Packet2cd ploadu(const std::complex* from) +{ EIGEN_DEBUG_UNALIGNED_LOAD return Packet2cd(ploadu((const double*)from)); } + +template<> EIGEN_STRONG_INLINE Packet2cd pset1(const std::complex& from) +{ + // in case casting to a __m128d* is really not safe, then we can still fallback to this version: (much slower though) +// return Packet2cd(_mm256_loadu2_m128d((const double*)&from,(const double*)&from)); + return Packet2cd(_mm256_broadcast_pd((const __m128d*)(const void*)&from)); +} + +template<> EIGEN_STRONG_INLINE Packet2cd ploaddup(const std::complex* from) { return pset1(*from); } + +template<> EIGEN_STRONG_INLINE void pstore >(std::complex * to, const Packet2cd& from) { EIGEN_DEBUG_ALIGNED_STORE pstore((double*)to, from.v); } +template<> EIGEN_STRONG_INLINE void pstoreu >(std::complex * to, const Packet2cd& from) { EIGEN_DEBUG_UNALIGNED_STORE pstoreu((double*)to, from.v); } + +template<> EIGEN_DEVICE_FUNC inline Packet2cd pgather, Packet2cd>(const std::complex* from, Index stride) +{ + return Packet2cd(_mm256_set_pd(std::imag(from[1*stride]), std::real(from[1*stride]), + std::imag(from[0*stride]), std::real(from[0*stride]))); +} + +template<> EIGEN_DEVICE_FUNC inline void pscatter, Packet2cd>(std::complex* to, const Packet2cd& from, Index stride) +{ + __m128d low = _mm256_extractf128_pd(from.v, 0); + to[stride*0] = std::complex(_mm_cvtsd_f64(low), _mm_cvtsd_f64(_mm_shuffle_pd(low, low, 1))); + __m128d high = _mm256_extractf128_pd(from.v, 1); + to[stride*1] = std::complex(_mm_cvtsd_f64(high), _mm_cvtsd_f64(_mm_shuffle_pd(high, high, 1))); +} + +template<> EIGEN_STRONG_INLINE std::complex pfirst(const Packet2cd& a) +{ + __m128d low = _mm256_extractf128_pd(a.v, 0); + EIGEN_ALIGN16 double res[2]; + _mm_store_pd(res, low); + return std::complex(res[0],res[1]); +} + +template<> EIGEN_STRONG_INLINE Packet2cd preverse(const Packet2cd& a) { + __m256d result = _mm256_permute2f128_pd(a.v, a.v, 1); + return Packet2cd(result); +} + +template<> EIGEN_STRONG_INLINE std::complex predux(const Packet2cd& a) +{ + return predux(padd(Packet1cd(_mm256_extractf128_pd(a.v,0)), + Packet1cd(_mm256_extractf128_pd(a.v,1)))); +} + +template<> EIGEN_STRONG_INLINE std::complex predux_mul(const Packet2cd& a) +{ + return predux(pmul(Packet1cd(_mm256_extractf128_pd(a.v,0)), + Packet1cd(_mm256_extractf128_pd(a.v,1)))); +} + +EIGEN_MAKE_CONJ_HELPER_CPLX_REAL(Packet2cd,Packet4d) + +template<> EIGEN_STRONG_INLINE Packet2cd pdiv(const Packet2cd& a, const Packet2cd& b) +{ + Packet2cd num = pmul(a, pconj(b)); + __m256d tmp = _mm256_mul_pd(b.v, b.v); + __m256d denom = _mm256_hadd_pd(tmp, tmp); + return Packet2cd(_mm256_div_pd(num.v, denom)); +} + +template<> EIGEN_STRONG_INLINE Packet2cd pcplxflip(const Packet2cd& x) +{ + return Packet2cd(_mm256_shuffle_pd(x.v, x.v, 0x5)); +} + +EIGEN_DEVICE_FUNC inline void +ptranspose(PacketBlock& kernel) { + __m256d P0 = _mm256_castps_pd(kernel.packet[0].v); + __m256d P1 = _mm256_castps_pd(kernel.packet[1].v); + __m256d P2 = _mm256_castps_pd(kernel.packet[2].v); + __m256d P3 = _mm256_castps_pd(kernel.packet[3].v); + + __m256d T0 = _mm256_shuffle_pd(P0, P1, 15); + __m256d T1 = _mm256_shuffle_pd(P0, P1, 0); + __m256d T2 = _mm256_shuffle_pd(P2, P3, 15); + __m256d T3 = _mm256_shuffle_pd(P2, P3, 0); + + kernel.packet[1].v = _mm256_castpd_ps(_mm256_permute2f128_pd(T0, T2, 32)); + kernel.packet[3].v = _mm256_castpd_ps(_mm256_permute2f128_pd(T0, T2, 49)); + kernel.packet[0].v = _mm256_castpd_ps(_mm256_permute2f128_pd(T1, T3, 32)); + kernel.packet[2].v = _mm256_castpd_ps(_mm256_permute2f128_pd(T1, T3, 49)); +} + +EIGEN_DEVICE_FUNC inline void +ptranspose(PacketBlock& kernel) { + __m256d tmp = _mm256_permute2f128_pd(kernel.packet[0].v, kernel.packet[1].v, 0+(2<<4)); + kernel.packet[1].v = _mm256_permute2f128_pd(kernel.packet[0].v, kernel.packet[1].v, 1+(3<<4)); + kernel.packet[0].v = tmp; +} + +template<> EIGEN_STRONG_INLINE Packet2cd psqrt(const Packet2cd& a) { + return psqrt_complex(a); +} + +template<> EIGEN_STRONG_INLINE Packet4cf psqrt(const Packet4cf& a) { + return psqrt_complex(a); +} + +} // end namespace internal + +} // end namespace Eigen + +#endif // EIGEN_COMPLEX_AVX_H diff --git a/Eigen/src/Core/arch/AVX/MathFunctions.h b/Eigen/src/Core/arch/AVX/MathFunctions.h new file mode 100644 index 0000000..67041c8 --- /dev/null +++ b/Eigen/src/Core/arch/AVX/MathFunctions.h @@ -0,0 +1,228 @@ +// This file is part of Eigen, a lightweight C++ template library +// for linear algebra. +// +// Copyright (C) 2014 Pedro Gonnet (pedro.gonnet@gmail.com) +// +// This Source Code Form is subject to the terms of the Mozilla +// Public License v. 2.0. If a copy of the MPL was not distributed +// with this file, You can obtain one at http://mozilla.org/MPL/2.0/. + +#ifndef EIGEN_MATH_FUNCTIONS_AVX_H +#define EIGEN_MATH_FUNCTIONS_AVX_H + +/* The sin and cos functions of this file are loosely derived from + * Julien Pommier's sse math library: http://gruntthepeon.free.fr/ssemath/ + */ + +namespace Eigen { + +namespace internal { + +template <> +EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS EIGEN_UNUSED Packet8f +psin(const Packet8f& _x) { + return psin_float(_x); +} + +template <> +EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS EIGEN_UNUSED Packet8f +pcos(const Packet8f& _x) { + return pcos_float(_x); +} + +template <> +EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS EIGEN_UNUSED Packet8f +plog(const Packet8f& _x) { + return plog_float(_x); +} + +template <> +EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS EIGEN_UNUSED Packet4d +plog(const Packet4d& _x) { + return plog_double(_x); +} + +template <> +EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS EIGEN_UNUSED Packet8f +plog2(const Packet8f& _x) { + return plog2_float(_x); +} + +template <> +EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS EIGEN_UNUSED Packet4d +plog2(const Packet4d& _x) { + return plog2_double(_x); +} + +template<> EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS EIGEN_UNUSED +Packet8f plog1p(const Packet8f& _x) { + return generic_plog1p(_x); +} + +template<> EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS EIGEN_UNUSED +Packet8f pexpm1(const Packet8f& _x) { + return generic_expm1(_x); +} + +// Exponential function. Works by writing "x = m*log(2) + r" where +// "m = floor(x/log(2)+1/2)" and "r" is the remainder. The result is then +// "exp(x) = 2^m*exp(r)" where exp(r) is in the range [-1,1). +template <> +EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS EIGEN_UNUSED Packet8f +pexp(const Packet8f& _x) { + return pexp_float(_x); +} + +// Hyperbolic Tangent function. +template <> +EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS EIGEN_UNUSED Packet8f +ptanh(const Packet8f& _x) { + return internal::generic_fast_tanh_float(_x); +} + +// Exponential function for doubles. +template <> +EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS EIGEN_UNUSED Packet4d +pexp(const Packet4d& _x) { + return pexp_double(_x); +} + +// Functions for sqrt. +// The EIGEN_FAST_MATH version uses the _mm_rsqrt_ps approximation and one step +// of Newton's method, at a cost of 1-2 bits of precision as opposed to the +// exact solution. It does not handle +inf, or denormalized numbers correctly. +// The main advantage of this approach is not just speed, but also the fact that +// it can be inlined and pipelined with other computations, further reducing its +// effective latency. This is similar to Quake3's fast inverse square root. +// For detail see here: http://www.beyond3d.com/content/articles/8/ +#if EIGEN_FAST_MATH +template <> +EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS EIGEN_UNUSED +Packet8f psqrt(const Packet8f& _x) { + Packet8f minus_half_x = pmul(_x, pset1(-0.5f)); + Packet8f denormal_mask = pandnot( + pcmp_lt(_x, pset1((std::numeric_limits::min)())), + pcmp_lt(_x, pzero(_x))); + + // Compute approximate reciprocal sqrt. + Packet8f x = _mm256_rsqrt_ps(_x); + // Do a single step of Newton's iteration. + x = pmul(x, pmadd(minus_half_x, pmul(x,x), pset1(1.5f))); + // Flush results for denormals to zero. + return pandnot(pmul(_x,x), denormal_mask); +} + +#else + +template <> EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS EIGEN_UNUSED +Packet8f psqrt(const Packet8f& _x) { + return _mm256_sqrt_ps(_x); +} + +#endif + +template <> EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS EIGEN_UNUSED +Packet4d psqrt(const Packet4d& _x) { + return _mm256_sqrt_pd(_x); +} + +#if EIGEN_FAST_MATH +template<> EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS EIGEN_UNUSED +Packet8f prsqrt(const Packet8f& _x) { + _EIGEN_DECLARE_CONST_Packet8f_FROM_INT(inf, 0x7f800000); + _EIGEN_DECLARE_CONST_Packet8f(one_point_five, 1.5f); + _EIGEN_DECLARE_CONST_Packet8f(minus_half, -0.5f); + _EIGEN_DECLARE_CONST_Packet8f_FROM_INT(flt_min, 0x00800000); + + Packet8f neg_half = pmul(_x, p8f_minus_half); + + // select only the inverse sqrt of positive normal inputs (denormals are + // flushed to zero and cause infs as well). + Packet8f lt_min_mask = _mm256_cmp_ps(_x, p8f_flt_min, _CMP_LT_OQ); + Packet8f inf_mask = _mm256_cmp_ps(_x, p8f_inf, _CMP_EQ_OQ); + Packet8f not_normal_finite_mask = _mm256_or_ps(lt_min_mask, inf_mask); + + // Compute an approximate result using the rsqrt intrinsic. + Packet8f y_approx = _mm256_rsqrt_ps(_x); + + // Do a single step of Newton-Raphson iteration to improve the approximation. + // This uses the formula y_{n+1} = y_n * (1.5 - y_n * (0.5 * x) * y_n). + // It is essential to evaluate the inner term like this because forming + // y_n^2 may over- or underflow. + Packet8f y_newton = pmul(y_approx, pmadd(y_approx, pmul(neg_half, y_approx), p8f_one_point_five)); + + // Select the result of the Newton-Raphson step for positive normal arguments. + // For other arguments, choose the output of the intrinsic. This will + // return rsqrt(+inf) = 0, rsqrt(x) = NaN if x < 0, and rsqrt(x) = +inf if + // x is zero or a positive denormalized float (equivalent to flushing positive + // denormalized inputs to zero). + return pselect(not_normal_finite_mask, y_approx, y_newton); +} + +#else +template <> EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS EIGEN_UNUSED +Packet8f prsqrt(const Packet8f& _x) { + _EIGEN_DECLARE_CONST_Packet8f(one, 1.0f); + return _mm256_div_ps(p8f_one, _mm256_sqrt_ps(_x)); +} +#endif + +template <> EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS EIGEN_UNUSED +Packet4d prsqrt(const Packet4d& _x) { + _EIGEN_DECLARE_CONST_Packet4d(one, 1.0); + return _mm256_div_pd(p4d_one, _mm256_sqrt_pd(_x)); +} + +F16_PACKET_FUNCTION(Packet8f, Packet8h, psin) +F16_PACKET_FUNCTION(Packet8f, Packet8h, pcos) +F16_PACKET_FUNCTION(Packet8f, Packet8h, plog) +F16_PACKET_FUNCTION(Packet8f, Packet8h, plog2) +F16_PACKET_FUNCTION(Packet8f, Packet8h, plog1p) +F16_PACKET_FUNCTION(Packet8f, Packet8h, pexpm1) +F16_PACKET_FUNCTION(Packet8f, Packet8h, pexp) +F16_PACKET_FUNCTION(Packet8f, Packet8h, ptanh) +F16_PACKET_FUNCTION(Packet8f, Packet8h, psqrt) +F16_PACKET_FUNCTION(Packet8f, Packet8h, prsqrt) + +template <> +EIGEN_STRONG_INLINE Packet8h pfrexp(const Packet8h& a, Packet8h& exponent) { + Packet8f fexponent; + const Packet8h out = float2half(pfrexp(half2float(a), fexponent)); + exponent = float2half(fexponent); + return out; +} + +template <> +EIGEN_STRONG_INLINE Packet8h pldexp(const Packet8h& a, const Packet8h& exponent) { + return float2half(pldexp(half2float(a), half2float(exponent))); +} + +BF16_PACKET_FUNCTION(Packet8f, Packet8bf, psin) +BF16_PACKET_FUNCTION(Packet8f, Packet8bf, pcos) +BF16_PACKET_FUNCTION(Packet8f, Packet8bf, plog) +BF16_PACKET_FUNCTION(Packet8f, Packet8bf, plog2) +BF16_PACKET_FUNCTION(Packet8f, Packet8bf, plog1p) +BF16_PACKET_FUNCTION(Packet8f, Packet8bf, pexpm1) +BF16_PACKET_FUNCTION(Packet8f, Packet8bf, pexp) +BF16_PACKET_FUNCTION(Packet8f, Packet8bf, ptanh) +BF16_PACKET_FUNCTION(Packet8f, Packet8bf, psqrt) +BF16_PACKET_FUNCTION(Packet8f, Packet8bf, prsqrt) + +template <> +EIGEN_STRONG_INLINE Packet8bf pfrexp(const Packet8bf& a, Packet8bf& exponent) { + Packet8f fexponent; + const Packet8bf out = F32ToBf16(pfrexp(Bf16ToF32(a), fexponent)); + exponent = F32ToBf16(fexponent); + return out; +} + +template <> +EIGEN_STRONG_INLINE Packet8bf pldexp(const Packet8bf& a, const Packet8bf& exponent) { + return F32ToBf16(pldexp(Bf16ToF32(a), Bf16ToF32(exponent))); +} + +} // end namespace internal + +} // end namespace Eigen + +#endif // EIGEN_MATH_FUNCTIONS_AVX_H diff --git a/Eigen/src/Core/arch/AVX/PacketMath.h b/Eigen/src/Core/arch/AVX/PacketMath.h new file mode 100644 index 0000000..7fc32fd --- /dev/null +++ b/Eigen/src/Core/arch/AVX/PacketMath.h @@ -0,0 +1,1574 @@ +// This file is part of Eigen, a lightweight C++ template library +// for linear algebra. +// +// Copyright (C) 2014 Benoit Steiner (benoit.steiner.goog@gmail.com) +// +// This Source Code Form is subject to the terms of the Mozilla +// Public License v. 2.0. If a copy of the MPL was not distributed +// with this file, You can obtain one at http://mozilla.org/MPL/2.0/. + +#ifndef EIGEN_PACKET_MATH_AVX_H +#define EIGEN_PACKET_MATH_AVX_H + +namespace Eigen { + +namespace internal { + +#ifndef EIGEN_CACHEFRIENDLY_PRODUCT_THRESHOLD +#define EIGEN_CACHEFRIENDLY_PRODUCT_THRESHOLD 8 +#endif + +#if !defined(EIGEN_VECTORIZE_AVX512) && !defined(EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS) +#define EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS 16 +#endif + +#ifdef EIGEN_VECTORIZE_FMA +#ifndef EIGEN_HAS_SINGLE_INSTRUCTION_MADD +#define EIGEN_HAS_SINGLE_INSTRUCTION_MADD +#endif +#endif + +typedef __m256 Packet8f; +typedef __m256i Packet8i; +typedef __m256d Packet4d; +typedef eigen_packet_wrapper<__m128i, 2> Packet8h; +typedef eigen_packet_wrapper<__m128i, 3> Packet8bf; + +template<> struct is_arithmetic<__m256> { enum { value = true }; }; +template<> struct is_arithmetic<__m256i> { enum { value = true }; }; +template<> struct is_arithmetic<__m256d> { enum { value = true }; }; +template<> struct is_arithmetic { enum { value = true }; }; +template<> struct is_arithmetic { enum { value = true }; }; + +#define _EIGEN_DECLARE_CONST_Packet8f(NAME,X) \ + const Packet8f p8f_##NAME = pset1(X) + +#define _EIGEN_DECLARE_CONST_Packet4d(NAME,X) \ + const Packet4d p4d_##NAME = pset1(X) + +#define _EIGEN_DECLARE_CONST_Packet8f_FROM_INT(NAME,X) \ + const Packet8f p8f_##NAME = _mm256_castsi256_ps(pset1(X)) + +#define _EIGEN_DECLARE_CONST_Packet8i(NAME,X) \ + const Packet8i p8i_##NAME = pset1(X) + +// Use the packet_traits defined in AVX512/PacketMath.h instead if we're going +// to leverage AVX512 instructions. +#ifndef EIGEN_VECTORIZE_AVX512 +template<> struct packet_traits : default_packet_traits +{ + typedef Packet8f type; + typedef Packet4f half; + enum { + Vectorizable = 1, + AlignedOnScalar = 1, + size = 8, + HasHalfPacket = 1, + + HasCmp = 1, + HasDiv = 1, + HasSin = EIGEN_FAST_MATH, + HasCos = EIGEN_FAST_MATH, + HasLog = 1, + HasLog1p = 1, + HasExpm1 = 1, + HasExp = 1, + HasNdtri = 1, + HasBessel = 1, + HasSqrt = 1, + HasRsqrt = 1, + HasTanh = EIGEN_FAST_MATH, + HasErf = EIGEN_FAST_MATH, + HasBlend = 1, + HasRound = 1, + HasFloor = 1, + HasCeil = 1, + HasRint = 1 + }; +}; +template<> struct packet_traits : default_packet_traits +{ + typedef Packet4d type; + typedef Packet2d half; + enum { + Vectorizable = 1, + AlignedOnScalar = 1, + size=4, + HasHalfPacket = 1, + + HasCmp = 1, + HasDiv = 1, + HasLog = 1, + HasExp = 1, + HasSqrt = 1, + HasRsqrt = 1, + HasBlend = 1, + HasRound = 1, + HasFloor = 1, + HasCeil = 1, + HasRint = 1 + }; +}; + +template <> +struct packet_traits : default_packet_traits { + typedef Packet8h type; + // There is no half-size packet for Packet8h. + typedef Packet8h half; + enum { + Vectorizable = 1, + AlignedOnScalar = 1, + size = 8, + HasHalfPacket = 0, + + HasCmp = 1, + HasAdd = 1, + HasSub = 1, + HasMul = 1, + HasDiv = 1, + HasSin = EIGEN_FAST_MATH, + HasCos = EIGEN_FAST_MATH, + HasNegate = 1, + HasAbs = 1, + HasAbs2 = 0, + HasMin = 1, + HasMax = 1, + HasConj = 1, + HasSetLinear = 0, + HasLog = 1, + HasLog1p = 1, + HasExpm1 = 1, + HasExp = 1, + HasSqrt = 1, + HasRsqrt = 1, + HasTanh = EIGEN_FAST_MATH, + HasErf = EIGEN_FAST_MATH, + HasBlend = 0, + HasRound = 1, + HasFloor = 1, + HasCeil = 1, + HasRint = 1, + HasBessel = 1, + HasNdtri = 1 + }; +}; + +template <> +struct packet_traits : default_packet_traits { + typedef Packet8bf type; + // There is no half-size packet for current Packet8bf. + // TODO: support as SSE path. + typedef Packet8bf half; + enum { + Vectorizable = 1, + AlignedOnScalar = 1, + size = 8, + HasHalfPacket = 0, + + HasCmp = 1, + HasAdd = 1, + HasSub = 1, + HasMul = 1, + HasDiv = 1, + HasSin = EIGEN_FAST_MATH, + HasCos = EIGEN_FAST_MATH, + HasNegate = 1, + HasAbs = 1, + HasAbs2 = 0, + HasMin = 1, + HasMax = 1, + HasConj = 1, + HasSetLinear = 0, + HasLog = 1, + HasLog1p = 1, + HasExpm1 = 1, + HasExp = 1, + HasSqrt = 1, + HasRsqrt = 1, + HasTanh = EIGEN_FAST_MATH, + HasErf = EIGEN_FAST_MATH, + HasBlend = 0, + HasRound = 1, + HasFloor = 1, + HasCeil = 1, + HasRint = 1, + HasBessel = 1, + HasNdtri = 1 + }; +}; +#endif + +template<> struct scalar_div_cost { enum { value = 14 }; }; +template<> struct scalar_div_cost { enum { value = 16 }; }; + +/* Proper support for integers is only provided by AVX2. In the meantime, we'll + use SSE instructions and packets to deal with integers. +template<> struct packet_traits : default_packet_traits +{ + typedef Packet8i type; + enum { + Vectorizable = 1, + AlignedOnScalar = 1, + size=8 + }; +}; +*/ + +template<> struct unpacket_traits { + typedef float type; + typedef Packet4f half; + typedef Packet8i integer_packet; + typedef uint8_t mask_t; + enum {size=8, alignment=Aligned32, vectorizable=true, masked_load_available=true, masked_store_available=true}; +}; +template<> struct unpacket_traits { + typedef double type; + typedef Packet2d half; + enum {size=4, alignment=Aligned32, vectorizable=true, masked_load_available=false, masked_store_available=false}; +}; +template<> struct unpacket_traits { typedef int type; typedef Packet4i half; enum {size=8, alignment=Aligned32, vectorizable=false, masked_load_available=false, masked_store_available=false}; }; +template<> struct unpacket_traits { typedef bfloat16 type; typedef Packet8bf half; enum {size=8, alignment=Aligned16, vectorizable=true, masked_load_available=false, masked_store_available=false}; }; + +// Helper function for bit packing snippet of low precision comparison. +// It packs the flags from 16x16 to 8x16. +EIGEN_STRONG_INLINE __m128i Pack16To8(Packet8f rf) { + return _mm_packs_epi32(_mm256_extractf128_si256(_mm256_castps_si256(rf), 0), + _mm256_extractf128_si256(_mm256_castps_si256(rf), 1)); +} + + +template<> EIGEN_STRONG_INLINE Packet8f pset1(const float& from) { return _mm256_set1_ps(from); } +template<> EIGEN_STRONG_INLINE Packet4d pset1(const double& from) { return _mm256_set1_pd(from); } +template<> EIGEN_STRONG_INLINE Packet8i pset1(const int& from) { return _mm256_set1_epi32(from); } + +template<> EIGEN_STRONG_INLINE Packet8f pset1frombits(unsigned int from) { return _mm256_castsi256_ps(pset1(from)); } +template<> EIGEN_STRONG_INLINE Packet4d pset1frombits(uint64_t from) { return _mm256_castsi256_pd(_mm256_set1_epi64x(from)); } + +template<> EIGEN_STRONG_INLINE Packet8f pzero(const Packet8f& /*a*/) { return _mm256_setzero_ps(); } +template<> EIGEN_STRONG_INLINE Packet4d pzero(const Packet4d& /*a*/) { return _mm256_setzero_pd(); } +template<> EIGEN_STRONG_INLINE Packet8i pzero(const Packet8i& /*a*/) { return _mm256_setzero_si256(); } + + +template<> EIGEN_STRONG_INLINE Packet8f peven_mask(const Packet8f& /*a*/) { return _mm256_castsi256_ps(_mm256_set_epi32(0, -1, 0, -1, 0, -1, 0, -1)); } +template<> EIGEN_STRONG_INLINE Packet8i peven_mask(const Packet8i& /*a*/) { return _mm256_set_epi32(0, -1, 0, -1, 0, -1, 0, -1); } +template<> EIGEN_STRONG_INLINE Packet4d peven_mask(const Packet4d& /*a*/) { return _mm256_castsi256_pd(_mm256_set_epi32(0, 0, -1, -1, 0, 0, -1, -1)); } + +template<> EIGEN_STRONG_INLINE Packet8f pload1(const float* from) { return _mm256_broadcast_ss(from); } +template<> EIGEN_STRONG_INLINE Packet4d pload1(const double* from) { return _mm256_broadcast_sd(from); } + +template<> EIGEN_STRONG_INLINE Packet8f plset(const float& a) { return _mm256_add_ps(_mm256_set1_ps(a), _mm256_set_ps(7.0,6.0,5.0,4.0,3.0,2.0,1.0,0.0)); } +template<> EIGEN_STRONG_INLINE Packet4d plset(const double& a) { return _mm256_add_pd(_mm256_set1_pd(a), _mm256_set_pd(3.0,2.0,1.0,0.0)); } + +template<> EIGEN_STRONG_INLINE Packet8f padd(const Packet8f& a, const Packet8f& b) { return _mm256_add_ps(a,b); } +template<> EIGEN_STRONG_INLINE Packet4d padd(const Packet4d& a, const Packet4d& b) { return _mm256_add_pd(a,b); } +template<> EIGEN_STRONG_INLINE Packet8i padd(const Packet8i& a, const Packet8i& b) { +#ifdef EIGEN_VECTORIZE_AVX2 + return _mm256_add_epi32(a,b); +#else + __m128i lo = _mm_add_epi32(_mm256_extractf128_si256(a, 0), _mm256_extractf128_si256(b, 0)); + __m128i hi = _mm_add_epi32(_mm256_extractf128_si256(a, 1), _mm256_extractf128_si256(b, 1)); + return _mm256_insertf128_si256(_mm256_castsi128_si256(lo), (hi), 1); +#endif +} + +template<> EIGEN_STRONG_INLINE Packet8f psub(const Packet8f& a, const Packet8f& b) { return _mm256_sub_ps(a,b); } +template<> EIGEN_STRONG_INLINE Packet4d psub(const Packet4d& a, const Packet4d& b) { return _mm256_sub_pd(a,b); } +template<> EIGEN_STRONG_INLINE Packet8i psub(const Packet8i& a, const Packet8i& b) { +#ifdef EIGEN_VECTORIZE_AVX2 + return _mm256_sub_epi32(a,b); +#else + __m128i lo = _mm_sub_epi32(_mm256_extractf128_si256(a, 0), _mm256_extractf128_si256(b, 0)); + __m128i hi = _mm_sub_epi32(_mm256_extractf128_si256(a, 1), _mm256_extractf128_si256(b, 1)); + return _mm256_insertf128_si256(_mm256_castsi128_si256(lo), (hi), 1); +#endif +} + +template<> EIGEN_STRONG_INLINE Packet8f pnegate(const Packet8f& a) +{ + return _mm256_sub_ps(_mm256_set1_ps(0.0),a); +} +template<> EIGEN_STRONG_INLINE Packet4d pnegate(const Packet4d& a) +{ + return _mm256_sub_pd(_mm256_set1_pd(0.0),a); +} + +template<> EIGEN_STRONG_INLINE Packet8f pconj(const Packet8f& a) { return a; } +template<> EIGEN_STRONG_INLINE Packet4d pconj(const Packet4d& a) { return a; } +template<> EIGEN_STRONG_INLINE Packet8i pconj(const Packet8i& a) { return a; } + +template<> EIGEN_STRONG_INLINE Packet8f pmul(const Packet8f& a, const Packet8f& b) { return _mm256_mul_ps(a,b); } +template<> EIGEN_STRONG_INLINE Packet4d pmul(const Packet4d& a, const Packet4d& b) { return _mm256_mul_pd(a,b); } +template<> EIGEN_STRONG_INLINE Packet8i pmul(const Packet8i& a, const Packet8i& b) { +#ifdef EIGEN_VECTORIZE_AVX2 + return _mm256_mullo_epi32(a,b); +#else + const __m128i lo = _mm_mullo_epi32(_mm256_extractf128_si256(a, 0), _mm256_extractf128_si256(b, 0)); + const __m128i hi = _mm_mullo_epi32(_mm256_extractf128_si256(a, 1), _mm256_extractf128_si256(b, 1)); + return _mm256_insertf128_si256(_mm256_castsi128_si256(lo), (hi), 1); +#endif +} + +template<> EIGEN_STRONG_INLINE Packet8f pdiv(const Packet8f& a, const Packet8f& b) { return _mm256_div_ps(a,b); } +template<> EIGEN_STRONG_INLINE Packet4d pdiv(const Packet4d& a, const Packet4d& b) { return _mm256_div_pd(a,b); } +template<> EIGEN_STRONG_INLINE Packet8i pdiv(const Packet8i& /*a*/, const Packet8i& /*b*/) +{ eigen_assert(false && "packet integer division are not supported by AVX"); + return pset1(0); +} + +#ifdef EIGEN_VECTORIZE_FMA +template<> EIGEN_STRONG_INLINE Packet8f pmadd(const Packet8f& a, const Packet8f& b, const Packet8f& c) { +#if ( (EIGEN_COMP_GNUC_STRICT && EIGEN_COMP_GNUC<80) || (EIGEN_COMP_CLANG) ) + // Clang stupidly generates a vfmadd213ps instruction plus some vmovaps on registers, + // and even register spilling with clang>=6.0 (bug 1637). + // Gcc stupidly generates a vfmadd132ps instruction. + // So let's enforce it to generate a vfmadd231ps instruction since the most common use + // case is to accumulate the result of the product. + Packet8f res = c; + __asm__("vfmadd231ps %[a], %[b], %[c]" : [c] "+x" (res) : [a] "x" (a), [b] "x" (b)); + return res; +#else + return _mm256_fmadd_ps(a,b,c); +#endif +} +template<> EIGEN_STRONG_INLINE Packet4d pmadd(const Packet4d& a, const Packet4d& b, const Packet4d& c) { +#if ( (EIGEN_COMP_GNUC_STRICT && EIGEN_COMP_GNUC<80) || (EIGEN_COMP_CLANG) ) + // see above + Packet4d res = c; + __asm__("vfmadd231pd %[a], %[b], %[c]" : [c] "+x" (res) : [a] "x" (a), [b] "x" (b)); + return res; +#else + return _mm256_fmadd_pd(a,b,c); +#endif +} +#endif + +template<> EIGEN_STRONG_INLINE Packet8f pcmp_le(const Packet8f& a, const Packet8f& b) { return _mm256_cmp_ps(a,b,_CMP_LE_OQ); } +template<> EIGEN_STRONG_INLINE Packet8f pcmp_lt(const Packet8f& a, const Packet8f& b) { return _mm256_cmp_ps(a,b,_CMP_LT_OQ); } +template<> EIGEN_STRONG_INLINE Packet8f pcmp_lt_or_nan(const Packet8f& a, const Packet8f& b) { return _mm256_cmp_ps(a, b, _CMP_NGE_UQ); } +template<> EIGEN_STRONG_INLINE Packet8f pcmp_eq(const Packet8f& a, const Packet8f& b) { return _mm256_cmp_ps(a,b,_CMP_EQ_OQ); } + +template<> EIGEN_STRONG_INLINE Packet4d pcmp_le(const Packet4d& a, const Packet4d& b) { return _mm256_cmp_pd(a,b,_CMP_LE_OQ); } +template<> EIGEN_STRONG_INLINE Packet4d pcmp_lt(const Packet4d& a, const Packet4d& b) { return _mm256_cmp_pd(a,b,_CMP_LT_OQ); } +template<> EIGEN_STRONG_INLINE Packet4d pcmp_lt_or_nan(const Packet4d& a, const Packet4d& b) { return _mm256_cmp_pd(a, b, _CMP_NGE_UQ); } +template<> EIGEN_STRONG_INLINE Packet4d pcmp_eq(const Packet4d& a, const Packet4d& b) { return _mm256_cmp_pd(a,b,_CMP_EQ_OQ); } + + +template<> EIGEN_STRONG_INLINE Packet8i pcmp_eq(const Packet8i& a, const Packet8i& b) { +#ifdef EIGEN_VECTORIZE_AVX2 + return _mm256_cmpeq_epi32(a,b); +#else + __m128i lo = _mm_cmpeq_epi32(_mm256_extractf128_si256(a, 0), _mm256_extractf128_si256(b, 0)); + __m128i hi = _mm_cmpeq_epi32(_mm256_extractf128_si256(a, 1), _mm256_extractf128_si256(b, 1)); + return _mm256_insertf128_si256(_mm256_castsi128_si256(lo), (hi), 1); +#endif +} + +template<> EIGEN_STRONG_INLINE Packet8f pmin(const Packet8f& a, const Packet8f& b) { +#if EIGEN_COMP_GNUC && EIGEN_COMP_GNUC < 63 + // There appears to be a bug in GCC, by which the optimizer may flip + // the argument order in calls to _mm_min_ps/_mm_max_ps, so we have to + // resort to inline ASM here. This is supposed to be fixed in gcc6.3, + // see also: https://gcc.gnu.org/bugzilla/show_bug.cgi?id=72867 + Packet8f res; + asm("vminps %[a], %[b], %[res]" : [res] "=x" (res) : [a] "x" (a), [b] "x" (b)); + return res; +#else + // Arguments are swapped to match NaN propagation behavior of std::min. + return _mm256_min_ps(b,a); +#endif +} +template<> EIGEN_STRONG_INLINE Packet4d pmin(const Packet4d& a, const Packet4d& b) { +#if EIGEN_COMP_GNUC && EIGEN_COMP_GNUC < 63 + // See pmin above + Packet4d res; + asm("vminpd %[a], %[b], %[res]" : [res] "=x" (res) : [a] "x" (a), [b] "x" (b)); + return res; +#else + // Arguments are swapped to match NaN propagation behavior of std::min. + return _mm256_min_pd(b,a); +#endif +} + +template<> EIGEN_STRONG_INLINE Packet8f pmax(const Packet8f& a, const Packet8f& b) { +#if EIGEN_COMP_GNUC && EIGEN_COMP_GNUC < 63 + // See pmin above + Packet8f res; + asm("vmaxps %[a], %[b], %[res]" : [res] "=x" (res) : [a] "x" (a), [b] "x" (b)); + return res; +#else + // Arguments are swapped to match NaN propagation behavior of std::max. + return _mm256_max_ps(b,a); +#endif +} +template<> EIGEN_STRONG_INLINE Packet4d pmax(const Packet4d& a, const Packet4d& b) { +#if EIGEN_COMP_GNUC && EIGEN_COMP_GNUC < 63 + // See pmin above + Packet4d res; + asm("vmaxpd %[a], %[b], %[res]" : [res] "=x" (res) : [a] "x" (a), [b] "x" (b)); + return res; +#else + // Arguments are swapped to match NaN propagation behavior of std::max. + return _mm256_max_pd(b,a); +#endif +} + +// Add specializations for min/max with prescribed NaN progation. +template<> +EIGEN_STRONG_INLINE Packet8f pmin(const Packet8f& a, const Packet8f& b) { + return pminmax_propagate_numbers(a, b, pmin); +} +template<> +EIGEN_STRONG_INLINE Packet4d pmin(const Packet4d& a, const Packet4d& b) { + return pminmax_propagate_numbers(a, b, pmin); +} +template<> +EIGEN_STRONG_INLINE Packet8f pmax(const Packet8f& a, const Packet8f& b) { + return pminmax_propagate_numbers(a, b, pmax); +} +template<> +EIGEN_STRONG_INLINE Packet4d pmax(const Packet4d& a, const Packet4d& b) { + return pminmax_propagate_numbers(a, b, pmax); +} +template<> +EIGEN_STRONG_INLINE Packet8f pmin(const Packet8f& a, const Packet8f& b) { + return pminmax_propagate_nan(a, b, pmin); +} +template<> +EIGEN_STRONG_INLINE Packet4d pmin(const Packet4d& a, const Packet4d& b) { + return pminmax_propagate_nan(a, b, pmin); +} +template<> +EIGEN_STRONG_INLINE Packet8f pmax(const Packet8f& a, const Packet8f& b) { + return pminmax_propagate_nan(a, b, pmax); +} +template<> +EIGEN_STRONG_INLINE Packet4d pmax(const Packet4d& a, const Packet4d& b) { + return pminmax_propagate_nan(a, b, pmax); +} + +template<> EIGEN_STRONG_INLINE Packet8f print(const Packet8f& a) { return _mm256_round_ps(a, _MM_FROUND_CUR_DIRECTION); } +template<> EIGEN_STRONG_INLINE Packet4d print(const Packet4d& a) { return _mm256_round_pd(a, _MM_FROUND_CUR_DIRECTION); } + +template<> EIGEN_STRONG_INLINE Packet8f pceil(const Packet8f& a) { return _mm256_ceil_ps(a); } +template<> EIGEN_STRONG_INLINE Packet4d pceil(const Packet4d& a) { return _mm256_ceil_pd(a); } + +template<> EIGEN_STRONG_INLINE Packet8f pfloor(const Packet8f& a) { return _mm256_floor_ps(a); } +template<> EIGEN_STRONG_INLINE Packet4d pfloor(const Packet4d& a) { return _mm256_floor_pd(a); } + + +template<> EIGEN_STRONG_INLINE Packet8i ptrue(const Packet8i& a) { +#ifdef EIGEN_VECTORIZE_AVX2 + // vpcmpeqd has lower latency than the more general vcmpps + return _mm256_cmpeq_epi32(a,a); +#else + const __m256 b = _mm256_castsi256_ps(a); + return _mm256_castps_si256(_mm256_cmp_ps(b,b,_CMP_TRUE_UQ)); +#endif +} + +template<> EIGEN_STRONG_INLINE Packet8f ptrue(const Packet8f& a) { +#ifdef EIGEN_VECTORIZE_AVX2 + // vpcmpeqd has lower latency than the more general vcmpps + const __m256i b = _mm256_castps_si256(a); + return _mm256_castsi256_ps(_mm256_cmpeq_epi32(b,b)); +#else + return _mm256_cmp_ps(a,a,_CMP_TRUE_UQ); +#endif +} + +template<> EIGEN_STRONG_INLINE Packet4d ptrue(const Packet4d& a) { +#ifdef EIGEN_VECTORIZE_AVX2 + // vpcmpeqq has lower latency than the more general vcmppd + const __m256i b = _mm256_castpd_si256(a); + return _mm256_castsi256_pd(_mm256_cmpeq_epi64(b,b)); +#else + return _mm256_cmp_pd(a,a,_CMP_TRUE_UQ); +#endif +} + +template<> EIGEN_STRONG_INLINE Packet8f pand(const Packet8f& a, const Packet8f& b) { return _mm256_and_ps(a,b); } +template<> EIGEN_STRONG_INLINE Packet4d pand(const Packet4d& a, const Packet4d& b) { return _mm256_and_pd(a,b); } +template<> EIGEN_STRONG_INLINE Packet8i pand(const Packet8i& a, const Packet8i& b) { +#ifdef EIGEN_VECTORIZE_AVX2 + return _mm256_and_si256(a,b); +#else + return _mm256_castps_si256(_mm256_and_ps(_mm256_castsi256_ps(a),_mm256_castsi256_ps(b))); +#endif +} + +template<> EIGEN_STRONG_INLINE Packet8f por(const Packet8f& a, const Packet8f& b) { return _mm256_or_ps(a,b); } +template<> EIGEN_STRONG_INLINE Packet4d por(const Packet4d& a, const Packet4d& b) { return _mm256_or_pd(a,b); } +template<> EIGEN_STRONG_INLINE Packet8i por(const Packet8i& a, const Packet8i& b) { +#ifdef EIGEN_VECTORIZE_AVX2 + return _mm256_or_si256(a,b); +#else + return _mm256_castps_si256(_mm256_or_ps(_mm256_castsi256_ps(a),_mm256_castsi256_ps(b))); +#endif +} + +template<> EIGEN_STRONG_INLINE Packet8f pxor(const Packet8f& a, const Packet8f& b) { return _mm256_xor_ps(a,b); } +template<> EIGEN_STRONG_INLINE Packet4d pxor(const Packet4d& a, const Packet4d& b) { return _mm256_xor_pd(a,b); } +template<> EIGEN_STRONG_INLINE Packet8i pxor(const Packet8i& a, const Packet8i& b) { +#ifdef EIGEN_VECTORIZE_AVX2 + return _mm256_xor_si256(a,b); +#else + return _mm256_castps_si256(_mm256_xor_ps(_mm256_castsi256_ps(a),_mm256_castsi256_ps(b))); +#endif +} + +template<> EIGEN_STRONG_INLINE Packet8f pandnot(const Packet8f& a, const Packet8f& b) { return _mm256_andnot_ps(b,a); } +template<> EIGEN_STRONG_INLINE Packet4d pandnot(const Packet4d& a, const Packet4d& b) { return _mm256_andnot_pd(b,a); } +template<> EIGEN_STRONG_INLINE Packet8i pandnot(const Packet8i& a, const Packet8i& b) { +#ifdef EIGEN_VECTORIZE_AVX2 + return _mm256_andnot_si256(b,a); +#else + return _mm256_castps_si256(_mm256_andnot_ps(_mm256_castsi256_ps(b),_mm256_castsi256_ps(a))); +#endif +} + +template<> EIGEN_STRONG_INLINE Packet8f pround(const Packet8f& a) +{ + const Packet8f mask = pset1frombits(static_cast(0x80000000u)); + const Packet8f prev0dot5 = pset1frombits(static_cast(0x3EFFFFFFu)); + return _mm256_round_ps(padd(por(pand(a, mask), prev0dot5), a), _MM_FROUND_TO_ZERO); +} +template<> EIGEN_STRONG_INLINE Packet4d pround(const Packet4d& a) +{ + const Packet4d mask = pset1frombits(static_cast(0x8000000000000000ull)); + const Packet4d prev0dot5 = pset1frombits(static_cast(0x3FDFFFFFFFFFFFFFull)); + return _mm256_round_pd(padd(por(pand(a, mask), prev0dot5), a), _MM_FROUND_TO_ZERO); +} + +template<> EIGEN_STRONG_INLINE Packet8f pselect(const Packet8f& mask, const Packet8f& a, const Packet8f& b) +{ return _mm256_blendv_ps(b,a,mask); } +template<> EIGEN_STRONG_INLINE Packet4d pselect(const Packet4d& mask, const Packet4d& a, const Packet4d& b) +{ return _mm256_blendv_pd(b,a,mask); } + +template EIGEN_STRONG_INLINE Packet8i parithmetic_shift_right(Packet8i a) { +#ifdef EIGEN_VECTORIZE_AVX2 + return _mm256_srai_epi32(a, N); +#else + __m128i lo = _mm_srai_epi32(_mm256_extractf128_si256(a, 0), N); + __m128i hi = _mm_srai_epi32(_mm256_extractf128_si256(a, 1), N); + return _mm256_insertf128_si256(_mm256_castsi128_si256(lo), (hi), 1); +#endif +} + +template EIGEN_STRONG_INLINE Packet8i plogical_shift_right(Packet8i a) { +#ifdef EIGEN_VECTORIZE_AVX2 + return _mm256_srli_epi32(a, N); +#else + __m128i lo = _mm_srli_epi32(_mm256_extractf128_si256(a, 0), N); + __m128i hi = _mm_srli_epi32(_mm256_extractf128_si256(a, 1), N); + return _mm256_insertf128_si256(_mm256_castsi128_si256(lo), (hi), 1); +#endif +} + +template EIGEN_STRONG_INLINE Packet8i plogical_shift_left(Packet8i a) { +#ifdef EIGEN_VECTORIZE_AVX2 + return _mm256_slli_epi32(a, N); +#else + __m128i lo = _mm_slli_epi32(_mm256_extractf128_si256(a, 0), N); + __m128i hi = _mm_slli_epi32(_mm256_extractf128_si256(a, 1), N); + return _mm256_insertf128_si256(_mm256_castsi128_si256(lo), (hi), 1); +#endif +} + +template<> EIGEN_STRONG_INLINE Packet8f pload(const float* from) { EIGEN_DEBUG_ALIGNED_LOAD return _mm256_load_ps(from); } +template<> EIGEN_STRONG_INLINE Packet4d pload(const double* from) { EIGEN_DEBUG_ALIGNED_LOAD return _mm256_load_pd(from); } +template<> EIGEN_STRONG_INLINE Packet8i pload(const int* from) { EIGEN_DEBUG_ALIGNED_LOAD return _mm256_load_si256(reinterpret_cast(from)); } + +template<> EIGEN_STRONG_INLINE Packet8f ploadu(const float* from) { EIGEN_DEBUG_UNALIGNED_LOAD return _mm256_loadu_ps(from); } +template<> EIGEN_STRONG_INLINE Packet4d ploadu(const double* from) { EIGEN_DEBUG_UNALIGNED_LOAD return _mm256_loadu_pd(from); } +template<> EIGEN_STRONG_INLINE Packet8i ploadu(const int* from) { EIGEN_DEBUG_UNALIGNED_LOAD return _mm256_loadu_si256(reinterpret_cast(from)); } + +template<> EIGEN_STRONG_INLINE Packet8f ploadu(const float* from, uint8_t umask) { + Packet8i mask = _mm256_set1_epi8(static_cast(umask)); + const Packet8i bit_mask = _mm256_set_epi32(0xffffff7f, 0xffffffbf, 0xffffffdf, 0xffffffef, 0xfffffff7, 0xfffffffb, 0xfffffffd, 0xfffffffe); + mask = por(mask, bit_mask); + mask = pcmp_eq(mask, _mm256_set1_epi32(0xffffffff)); + EIGEN_DEBUG_UNALIGNED_LOAD return _mm256_maskload_ps(from, mask); +} + +// Loads 4 floats from memory a returns the packet {a0, a0 a1, a1, a2, a2, a3, a3} +template<> EIGEN_STRONG_INLINE Packet8f ploaddup(const float* from) +{ + // TODO try to find a way to avoid the need of a temporary register +// Packet8f tmp = _mm256_castps128_ps256(_mm_loadu_ps(from)); +// tmp = _mm256_insertf128_ps(tmp, _mm_movehl_ps(_mm256_castps256_ps128(tmp),_mm256_castps256_ps128(tmp)), 1); +// return _mm256_unpacklo_ps(tmp,tmp); + + // _mm256_insertf128_ps is very slow on Haswell, thus: + Packet8f tmp = _mm256_broadcast_ps((const __m128*)(const void*)from); + // mimic an "inplace" permutation of the lower 128bits using a blend + tmp = _mm256_blend_ps(tmp,_mm256_castps128_ps256(_mm_permute_ps( _mm256_castps256_ps128(tmp), _MM_SHUFFLE(1,0,1,0))), 15); + // then we can perform a consistent permutation on the global register to get everything in shape: + return _mm256_permute_ps(tmp, _MM_SHUFFLE(3,3,2,2)); +} +// Loads 2 doubles from memory a returns the packet {a0, a0 a1, a1} +template<> EIGEN_STRONG_INLINE Packet4d ploaddup(const double* from) +{ + Packet4d tmp = _mm256_broadcast_pd((const __m128d*)(const void*)from); + return _mm256_permute_pd(tmp, 3<<2); +} + +// Loads 2 floats from memory a returns the packet {a0, a0 a0, a0, a1, a1, a1, a1} +template<> EIGEN_STRONG_INLINE Packet8f ploadquad(const float* from) +{ + Packet8f tmp = _mm256_castps128_ps256(_mm_broadcast_ss(from)); + return _mm256_insertf128_ps(tmp, _mm_broadcast_ss(from+1), 1); +} + +template<> EIGEN_STRONG_INLINE void pstore(float* to, const Packet8f& from) { EIGEN_DEBUG_ALIGNED_STORE _mm256_store_ps(to, from); } +template<> EIGEN_STRONG_INLINE void pstore(double* to, const Packet4d& from) { EIGEN_DEBUG_ALIGNED_STORE _mm256_store_pd(to, from); } +template<> EIGEN_STRONG_INLINE void pstore(int* to, const Packet8i& from) { EIGEN_DEBUG_ALIGNED_STORE _mm256_storeu_si256(reinterpret_cast<__m256i*>(to), from); } + +template<> EIGEN_STRONG_INLINE void pstoreu(float* to, const Packet8f& from) { EIGEN_DEBUG_UNALIGNED_STORE _mm256_storeu_ps(to, from); } +template<> EIGEN_STRONG_INLINE void pstoreu(double* to, const Packet4d& from) { EIGEN_DEBUG_UNALIGNED_STORE _mm256_storeu_pd(to, from); } +template<> EIGEN_STRONG_INLINE void pstoreu(int* to, const Packet8i& from) { EIGEN_DEBUG_UNALIGNED_STORE _mm256_storeu_si256(reinterpret_cast<__m256i*>(to), from); } + +template<> EIGEN_STRONG_INLINE void pstoreu(float* to, const Packet8f& from, uint8_t umask) { + Packet8i mask = _mm256_set1_epi8(static_cast(umask)); + const Packet8i bit_mask = _mm256_set_epi32(0xffffff7f, 0xffffffbf, 0xffffffdf, 0xffffffef, 0xfffffff7, 0xfffffffb, 0xfffffffd, 0xfffffffe); + mask = por(mask, bit_mask); + mask = pcmp_eq(mask, _mm256_set1_epi32(0xffffffff)); + EIGEN_DEBUG_UNALIGNED_STORE return _mm256_maskstore_ps(to, mask, from); +} + +// NOTE: leverage _mm256_i32gather_ps and _mm256_i32gather_pd if AVX2 instructions are available +// NOTE: for the record the following seems to be slower: return _mm256_i32gather_ps(from, _mm256_set1_epi32(stride), 4); +template<> EIGEN_DEVICE_FUNC inline Packet8f pgather(const float* from, Index stride) +{ + return _mm256_set_ps(from[7*stride], from[6*stride], from[5*stride], from[4*stride], + from[3*stride], from[2*stride], from[1*stride], from[0*stride]); +} +template<> EIGEN_DEVICE_FUNC inline Packet4d pgather(const double* from, Index stride) +{ + return _mm256_set_pd(from[3*stride], from[2*stride], from[1*stride], from[0*stride]); +} + +template<> EIGEN_DEVICE_FUNC inline void pscatter(float* to, const Packet8f& from, Index stride) +{ + __m128 low = _mm256_extractf128_ps(from, 0); + to[stride*0] = _mm_cvtss_f32(low); + to[stride*1] = _mm_cvtss_f32(_mm_shuffle_ps(low, low, 1)); + to[stride*2] = _mm_cvtss_f32(_mm_shuffle_ps(low, low, 2)); + to[stride*3] = _mm_cvtss_f32(_mm_shuffle_ps(low, low, 3)); + + __m128 high = _mm256_extractf128_ps(from, 1); + to[stride*4] = _mm_cvtss_f32(high); + to[stride*5] = _mm_cvtss_f32(_mm_shuffle_ps(high, high, 1)); + to[stride*6] = _mm_cvtss_f32(_mm_shuffle_ps(high, high, 2)); + to[stride*7] = _mm_cvtss_f32(_mm_shuffle_ps(high, high, 3)); +} +template<> EIGEN_DEVICE_FUNC inline void pscatter(double* to, const Packet4d& from, Index stride) +{ + __m128d low = _mm256_extractf128_pd(from, 0); + to[stride*0] = _mm_cvtsd_f64(low); + to[stride*1] = _mm_cvtsd_f64(_mm_shuffle_pd(low, low, 1)); + __m128d high = _mm256_extractf128_pd(from, 1); + to[stride*2] = _mm_cvtsd_f64(high); + to[stride*3] = _mm_cvtsd_f64(_mm_shuffle_pd(high, high, 1)); +} + +template<> EIGEN_STRONG_INLINE void pstore1(float* to, const float& a) +{ + Packet8f pa = pset1(a); + pstore(to, pa); +} +template<> EIGEN_STRONG_INLINE void pstore1(double* to, const double& a) +{ + Packet4d pa = pset1(a); + pstore(to, pa); +} +template<> EIGEN_STRONG_INLINE void pstore1(int* to, const int& a) +{ + Packet8i pa = pset1(a); + pstore(to, pa); +} + +#ifndef EIGEN_VECTORIZE_AVX512 +template<> EIGEN_STRONG_INLINE void prefetch(const float* addr) { _mm_prefetch((SsePrefetchPtrType)(addr), _MM_HINT_T0); } +template<> EIGEN_STRONG_INLINE void prefetch(const double* addr) { _mm_prefetch((SsePrefetchPtrType)(addr), _MM_HINT_T0); } +template<> EIGEN_STRONG_INLINE void prefetch(const int* addr) { _mm_prefetch((SsePrefetchPtrType)(addr), _MM_HINT_T0); } +#endif + +template<> EIGEN_STRONG_INLINE float pfirst(const Packet8f& a) { + return _mm_cvtss_f32(_mm256_castps256_ps128(a)); +} +template<> EIGEN_STRONG_INLINE double pfirst(const Packet4d& a) { + return _mm_cvtsd_f64(_mm256_castpd256_pd128(a)); +} +template<> EIGEN_STRONG_INLINE int pfirst(const Packet8i& a) { + return _mm_cvtsi128_si32(_mm256_castsi256_si128(a)); +} + + +template<> EIGEN_STRONG_INLINE Packet8f preverse(const Packet8f& a) +{ + __m256 tmp = _mm256_shuffle_ps(a,a,0x1b); + return _mm256_permute2f128_ps(tmp, tmp, 1); +} +template<> EIGEN_STRONG_INLINE Packet4d preverse(const Packet4d& a) +{ + __m256d tmp = _mm256_shuffle_pd(a,a,5); + return _mm256_permute2f128_pd(tmp, tmp, 1); + #if 0 + // This version is unlikely to be faster as _mm256_shuffle_ps and _mm256_permute_pd + // exhibit the same latency/throughput, but it is here for future reference/benchmarking... + __m256d swap_halves = _mm256_permute2f128_pd(a,a,1); + return _mm256_permute_pd(swap_halves,5); + #endif +} + +// pabs should be ok +template<> EIGEN_STRONG_INLINE Packet8f pabs(const Packet8f& a) +{ + const Packet8f mask = _mm256_castsi256_ps(_mm256_setr_epi32(0x7FFFFFFF,0x7FFFFFFF,0x7FFFFFFF,0x7FFFFFFF,0x7FFFFFFF,0x7FFFFFFF,0x7FFFFFFF,0x7FFFFFFF)); + return _mm256_and_ps(a,mask); +} +template<> EIGEN_STRONG_INLINE Packet4d pabs(const Packet4d& a) +{ + const Packet4d mask = _mm256_castsi256_pd(_mm256_setr_epi32(0xFFFFFFFF,0x7FFFFFFF,0xFFFFFFFF,0x7FFFFFFF,0xFFFFFFFF,0x7FFFFFFF,0xFFFFFFFF,0x7FFFFFFF)); + return _mm256_and_pd(a,mask); +} + +template<> EIGEN_STRONG_INLINE Packet8f pfrexp(const Packet8f& a, Packet8f& exponent) { + return pfrexp_generic(a,exponent); +} + +// Extract exponent without existence of Packet4l. +template<> +EIGEN_STRONG_INLINE +Packet4d pfrexp_generic_get_biased_exponent(const Packet4d& a) { + const Packet4d cst_exp_mask = pset1frombits(static_cast(0x7ff0000000000000ull)); + __m256i a_expo = _mm256_castpd_si256(pand(a, cst_exp_mask)); +#ifdef EIGEN_VECTORIZE_AVX2 + a_expo = _mm256_srli_epi64(a_expo, 52); + __m128i lo = _mm256_extractf128_si256(a_expo, 0); + __m128i hi = _mm256_extractf128_si256(a_expo, 1); +#else + __m128i lo = _mm256_extractf128_si256(a_expo, 0); + __m128i hi = _mm256_extractf128_si256(a_expo, 1); + lo = _mm_srli_epi64(lo, 52); + hi = _mm_srli_epi64(hi, 52); +#endif + Packet2d exponent_lo = _mm_cvtepi32_pd(vec4i_swizzle1(lo, 0, 2, 1, 3)); + Packet2d exponent_hi = _mm_cvtepi32_pd(vec4i_swizzle1(hi, 0, 2, 1, 3)); + Packet4d exponent = _mm256_insertf128_pd(_mm256_setzero_pd(), exponent_lo, 0); + exponent = _mm256_insertf128_pd(exponent, exponent_hi, 1); + return exponent; +} + + +template<> EIGEN_STRONG_INLINE Packet4d pfrexp(const Packet4d& a, Packet4d& exponent) { + return pfrexp_generic(a, exponent); +} + +template<> EIGEN_STRONG_INLINE Packet8f pldexp(const Packet8f& a, const Packet8f& exponent) { + return pldexp_generic(a, exponent); +} + +template<> EIGEN_STRONG_INLINE Packet4d pldexp(const Packet4d& a, const Packet4d& exponent) { + // Clamp exponent to [-2099, 2099] + const Packet4d max_exponent = pset1(2099.0); + const Packet4i e = _mm256_cvtpd_epi32(pmin(pmax(exponent, pnegate(max_exponent)), max_exponent)); + + // Split 2^e into four factors and multiply. + const Packet4i bias = pset1(1023); + Packet4i b = parithmetic_shift_right<2>(e); // floor(e/4) + + // 2^b + Packet4i hi = vec4i_swizzle1(padd(b, bias), 0, 2, 1, 3); + Packet4i lo = _mm_slli_epi64(hi, 52); + hi = _mm_slli_epi64(_mm_srli_epi64(hi, 32), 52); + Packet4d c = _mm256_castsi256_pd(_mm256_insertf128_si256(_mm256_castsi128_si256(lo), hi, 1)); + Packet4d out = pmul(pmul(pmul(a, c), c), c); // a * 2^(3b) + + // 2^(e - 3b) + b = psub(psub(psub(e, b), b), b); // e - 3b + hi = vec4i_swizzle1(padd(b, bias), 0, 2, 1, 3); + lo = _mm_slli_epi64(hi, 52); + hi = _mm_slli_epi64(_mm_srli_epi64(hi, 32), 52); + c = _mm256_castsi256_pd(_mm256_insertf128_si256(_mm256_castsi128_si256(lo), hi, 1)); + out = pmul(out, c); // a * 2^e + return out; +} + +template<> EIGEN_STRONG_INLINE float predux(const Packet8f& a) +{ + return predux(Packet4f(_mm_add_ps(_mm256_castps256_ps128(a),_mm256_extractf128_ps(a,1)))); +} +template<> EIGEN_STRONG_INLINE double predux(const Packet4d& a) +{ + return predux(Packet2d(_mm_add_pd(_mm256_castpd256_pd128(a),_mm256_extractf128_pd(a,1)))); +} + +template<> EIGEN_STRONG_INLINE Packet4f predux_half_dowto4(const Packet8f& a) +{ + return _mm_add_ps(_mm256_castps256_ps128(a),_mm256_extractf128_ps(a,1)); +} + +template<> EIGEN_STRONG_INLINE float predux_mul(const Packet8f& a) +{ + Packet8f tmp; + tmp = _mm256_mul_ps(a, _mm256_permute2f128_ps(a,a,1)); + tmp = _mm256_mul_ps(tmp, _mm256_shuffle_ps(tmp,tmp,_MM_SHUFFLE(1,0,3,2))); + return pfirst(_mm256_mul_ps(tmp, _mm256_shuffle_ps(tmp,tmp,1))); +} +template<> EIGEN_STRONG_INLINE double predux_mul(const Packet4d& a) +{ + Packet4d tmp; + tmp = _mm256_mul_pd(a, _mm256_permute2f128_pd(a,a,1)); + return pfirst(_mm256_mul_pd(tmp, _mm256_shuffle_pd(tmp,tmp,1))); +} + +template<> EIGEN_STRONG_INLINE float predux_min(const Packet8f& a) +{ + Packet8f tmp = _mm256_min_ps(a, _mm256_permute2f128_ps(a,a,1)); + tmp = _mm256_min_ps(tmp, _mm256_shuffle_ps(tmp,tmp,_MM_SHUFFLE(1,0,3,2))); + return pfirst(_mm256_min_ps(tmp, _mm256_shuffle_ps(tmp,tmp,1))); +} +template<> EIGEN_STRONG_INLINE double predux_min(const Packet4d& a) +{ + Packet4d tmp = _mm256_min_pd(a, _mm256_permute2f128_pd(a,a,1)); + return pfirst(_mm256_min_pd(tmp, _mm256_shuffle_pd(tmp, tmp, 1))); +} + +template<> EIGEN_STRONG_INLINE float predux_max(const Packet8f& a) +{ + Packet8f tmp = _mm256_max_ps(a, _mm256_permute2f128_ps(a,a,1)); + tmp = _mm256_max_ps(tmp, _mm256_shuffle_ps(tmp,tmp,_MM_SHUFFLE(1,0,3,2))); + return pfirst(_mm256_max_ps(tmp, _mm256_shuffle_ps(tmp,tmp,1))); +} + +template<> EIGEN_STRONG_INLINE double predux_max(const Packet4d& a) +{ + Packet4d tmp = _mm256_max_pd(a, _mm256_permute2f128_pd(a,a,1)); + return pfirst(_mm256_max_pd(tmp, _mm256_shuffle_pd(tmp, tmp, 1))); +} + +// not needed yet +// template<> EIGEN_STRONG_INLINE bool predux_all(const Packet8f& x) +// { +// return _mm256_movemask_ps(x)==0xFF; +// } + +template<> EIGEN_STRONG_INLINE bool predux_any(const Packet8f& x) +{ + return _mm256_movemask_ps(x)!=0; +} + +EIGEN_DEVICE_FUNC inline void +ptranspose(PacketBlock& kernel) { + __m256 T0 = _mm256_unpacklo_ps(kernel.packet[0], kernel.packet[1]); + __m256 T1 = _mm256_unpackhi_ps(kernel.packet[0], kernel.packet[1]); + __m256 T2 = _mm256_unpacklo_ps(kernel.packet[2], kernel.packet[3]); + __m256 T3 = _mm256_unpackhi_ps(kernel.packet[2], kernel.packet[3]); + __m256 T4 = _mm256_unpacklo_ps(kernel.packet[4], kernel.packet[5]); + __m256 T5 = _mm256_unpackhi_ps(kernel.packet[4], kernel.packet[5]); + __m256 T6 = _mm256_unpacklo_ps(kernel.packet[6], kernel.packet[7]); + __m256 T7 = _mm256_unpackhi_ps(kernel.packet[6], kernel.packet[7]); + __m256 S0 = _mm256_shuffle_ps(T0,T2,_MM_SHUFFLE(1,0,1,0)); + __m256 S1 = _mm256_shuffle_ps(T0,T2,_MM_SHUFFLE(3,2,3,2)); + __m256 S2 = _mm256_shuffle_ps(T1,T3,_MM_SHUFFLE(1,0,1,0)); + __m256 S3 = _mm256_shuffle_ps(T1,T3,_MM_SHUFFLE(3,2,3,2)); + __m256 S4 = _mm256_shuffle_ps(T4,T6,_MM_SHUFFLE(1,0,1,0)); + __m256 S5 = _mm256_shuffle_ps(T4,T6,_MM_SHUFFLE(3,2,3,2)); + __m256 S6 = _mm256_shuffle_ps(T5,T7,_MM_SHUFFLE(1,0,1,0)); + __m256 S7 = _mm256_shuffle_ps(T5,T7,_MM_SHUFFLE(3,2,3,2)); + kernel.packet[0] = _mm256_permute2f128_ps(S0, S4, 0x20); + kernel.packet[1] = _mm256_permute2f128_ps(S1, S5, 0x20); + kernel.packet[2] = _mm256_permute2f128_ps(S2, S6, 0x20); + kernel.packet[3] = _mm256_permute2f128_ps(S3, S7, 0x20); + kernel.packet[4] = _mm256_permute2f128_ps(S0, S4, 0x31); + kernel.packet[5] = _mm256_permute2f128_ps(S1, S5, 0x31); + kernel.packet[6] = _mm256_permute2f128_ps(S2, S6, 0x31); + kernel.packet[7] = _mm256_permute2f128_ps(S3, S7, 0x31); +} + +EIGEN_DEVICE_FUNC inline void +ptranspose(PacketBlock& kernel) { + __m256 T0 = _mm256_unpacklo_ps(kernel.packet[0], kernel.packet[1]); + __m256 T1 = _mm256_unpackhi_ps(kernel.packet[0], kernel.packet[1]); + __m256 T2 = _mm256_unpacklo_ps(kernel.packet[2], kernel.packet[3]); + __m256 T3 = _mm256_unpackhi_ps(kernel.packet[2], kernel.packet[3]); + + __m256 S0 = _mm256_shuffle_ps(T0,T2,_MM_SHUFFLE(1,0,1,0)); + __m256 S1 = _mm256_shuffle_ps(T0,T2,_MM_SHUFFLE(3,2,3,2)); + __m256 S2 = _mm256_shuffle_ps(T1,T3,_MM_SHUFFLE(1,0,1,0)); + __m256 S3 = _mm256_shuffle_ps(T1,T3,_MM_SHUFFLE(3,2,3,2)); + + kernel.packet[0] = _mm256_permute2f128_ps(S0, S1, 0x20); + kernel.packet[1] = _mm256_permute2f128_ps(S2, S3, 0x20); + kernel.packet[2] = _mm256_permute2f128_ps(S0, S1, 0x31); + kernel.packet[3] = _mm256_permute2f128_ps(S2, S3, 0x31); +} + +EIGEN_DEVICE_FUNC inline void +ptranspose(PacketBlock& kernel) { + __m256d T0 = _mm256_shuffle_pd(kernel.packet[0], kernel.packet[1], 15); + __m256d T1 = _mm256_shuffle_pd(kernel.packet[0], kernel.packet[1], 0); + __m256d T2 = _mm256_shuffle_pd(kernel.packet[2], kernel.packet[3], 15); + __m256d T3 = _mm256_shuffle_pd(kernel.packet[2], kernel.packet[3], 0); + + kernel.packet[1] = _mm256_permute2f128_pd(T0, T2, 32); + kernel.packet[3] = _mm256_permute2f128_pd(T0, T2, 49); + kernel.packet[0] = _mm256_permute2f128_pd(T1, T3, 32); + kernel.packet[2] = _mm256_permute2f128_pd(T1, T3, 49); +} + +template<> EIGEN_STRONG_INLINE Packet8f pblend(const Selector<8>& ifPacket, const Packet8f& thenPacket, const Packet8f& elsePacket) { + const __m256 zero = _mm256_setzero_ps(); + const __m256 select = _mm256_set_ps(ifPacket.select[7], ifPacket.select[6], ifPacket.select[5], ifPacket.select[4], ifPacket.select[3], ifPacket.select[2], ifPacket.select[1], ifPacket.select[0]); + __m256 false_mask = _mm256_cmp_ps(select, zero, _CMP_EQ_UQ); + return _mm256_blendv_ps(thenPacket, elsePacket, false_mask); +} +template<> EIGEN_STRONG_INLINE Packet4d pblend(const Selector<4>& ifPacket, const Packet4d& thenPacket, const Packet4d& elsePacket) { + const __m256d zero = _mm256_setzero_pd(); + const __m256d select = _mm256_set_pd(ifPacket.select[3], ifPacket.select[2], ifPacket.select[1], ifPacket.select[0]); + __m256d false_mask = _mm256_cmp_pd(select, zero, _CMP_EQ_UQ); + return _mm256_blendv_pd(thenPacket, elsePacket, false_mask); +} + +// Packet math for Eigen::half + +template<> struct unpacket_traits { typedef Eigen::half type; enum {size=8, alignment=Aligned16, vectorizable=true, masked_load_available=false, masked_store_available=false}; typedef Packet8h half; }; + +template<> EIGEN_STRONG_INLINE Packet8h pset1(const Eigen::half& from) { + return _mm_set1_epi16(numext::bit_cast(from)); +} + +template<> EIGEN_STRONG_INLINE Eigen::half pfirst(const Packet8h& from) { + return numext::bit_cast(static_cast(_mm_extract_epi16(from, 0))); +} + +template<> EIGEN_STRONG_INLINE Packet8h pload(const Eigen::half* from) { + return _mm_load_si128(reinterpret_cast(from)); +} + +template<> EIGEN_STRONG_INLINE Packet8h ploadu(const Eigen::half* from) { + return _mm_loadu_si128(reinterpret_cast(from)); +} + +template<> EIGEN_STRONG_INLINE void pstore(Eigen::half* to, const Packet8h& from) { + _mm_store_si128(reinterpret_cast<__m128i*>(to), from); +} + +template<> EIGEN_STRONG_INLINE void pstoreu(Eigen::half* to, const Packet8h& from) { + _mm_storeu_si128(reinterpret_cast<__m128i*>(to), from); +} + +template<> EIGEN_STRONG_INLINE Packet8h +ploaddup(const Eigen::half* from) { + const numext::uint16_t a = numext::bit_cast(from[0]); + const numext::uint16_t b = numext::bit_cast(from[1]); + const numext::uint16_t c = numext::bit_cast(from[2]); + const numext::uint16_t d = numext::bit_cast(from[3]); + return _mm_set_epi16(d, d, c, c, b, b, a, a); +} + +template<> EIGEN_STRONG_INLINE Packet8h +ploadquad(const Eigen::half* from) { + const numext::uint16_t a = numext::bit_cast(from[0]); + const numext::uint16_t b = numext::bit_cast(from[1]); + return _mm_set_epi16(b, b, b, b, a, a, a, a); +} + +template<> EIGEN_STRONG_INLINE Packet8h ptrue(const Packet8h& a) { + return _mm_cmpeq_epi32(a, a); +} + +template <> +EIGEN_STRONG_INLINE Packet8h pabs(const Packet8h& a) { + const __m128i sign_mask = _mm_set1_epi16(static_cast(0x8000)); + return _mm_andnot_si128(sign_mask, a); +} + +EIGEN_STRONG_INLINE Packet8f half2float(const Packet8h& a) { +#ifdef EIGEN_HAS_FP16_C + return _mm256_cvtph_ps(a); +#else + EIGEN_ALIGN32 Eigen::half aux[8]; + pstore(aux, a); + float f0(aux[0]); + float f1(aux[1]); + float f2(aux[2]); + float f3(aux[3]); + float f4(aux[4]); + float f5(aux[5]); + float f6(aux[6]); + float f7(aux[7]); + + return _mm256_set_ps(f7, f6, f5, f4, f3, f2, f1, f0); +#endif +} + +EIGEN_STRONG_INLINE Packet8h float2half(const Packet8f& a) { +#ifdef EIGEN_HAS_FP16_C + return _mm256_cvtps_ph(a, _MM_FROUND_TO_NEAREST_INT|_MM_FROUND_NO_EXC); +#else + EIGEN_ALIGN32 float aux[8]; + pstore(aux, a); + const numext::uint16_t s0 = numext::bit_cast(Eigen::half(aux[0])); + const numext::uint16_t s1 = numext::bit_cast(Eigen::half(aux[1])); + const numext::uint16_t s2 = numext::bit_cast(Eigen::half(aux[2])); + const numext::uint16_t s3 = numext::bit_cast(Eigen::half(aux[3])); + const numext::uint16_t s4 = numext::bit_cast(Eigen::half(aux[4])); + const numext::uint16_t s5 = numext::bit_cast(Eigen::half(aux[5])); + const numext::uint16_t s6 = numext::bit_cast(Eigen::half(aux[6])); + const numext::uint16_t s7 = numext::bit_cast(Eigen::half(aux[7])); + return _mm_set_epi16(s7, s6, s5, s4, s3, s2, s1, s0); +#endif +} + +template <> +EIGEN_STRONG_INLINE Packet8h pmin(const Packet8h& a, + const Packet8h& b) { + return float2half(pmin(half2float(a), half2float(b))); +} + +template <> +EIGEN_STRONG_INLINE Packet8h pmax(const Packet8h& a, + const Packet8h& b) { + return float2half(pmax(half2float(a), half2float(b))); +} + +template <> +EIGEN_STRONG_INLINE Packet8h plset(const half& a) { + return float2half(plset(static_cast(a))); +} + +template<> EIGEN_STRONG_INLINE Packet8h por(const Packet8h& a,const Packet8h& b) { + // in some cases Packet4i is a wrapper around __m128i, so we either need to + // cast to Packet4i to directly call the intrinsics as below: + return _mm_or_si128(a,b); +} +template<> EIGEN_STRONG_INLINE Packet8h pxor(const Packet8h& a,const Packet8h& b) { + return _mm_xor_si128(a,b); +} +template<> EIGEN_STRONG_INLINE Packet8h pand(const Packet8h& a,const Packet8h& b) { + return _mm_and_si128(a,b); +} +template<> EIGEN_STRONG_INLINE Packet8h pandnot(const Packet8h& a,const Packet8h& b) { + return _mm_andnot_si128(b,a); +} + +template<> EIGEN_STRONG_INLINE Packet8h pselect(const Packet8h& mask, const Packet8h& a, const Packet8h& b) { + return _mm_blendv_epi8(b, a, mask); +} + +template<> EIGEN_STRONG_INLINE Packet8h pround(const Packet8h& a) { + return float2half(pround(half2float(a))); +} + +template<> EIGEN_STRONG_INLINE Packet8h print(const Packet8h& a) { + return float2half(print(half2float(a))); +} + +template<> EIGEN_STRONG_INLINE Packet8h pceil(const Packet8h& a) { + return float2half(pceil(half2float(a))); +} + +template<> EIGEN_STRONG_INLINE Packet8h pfloor(const Packet8h& a) { + return float2half(pfloor(half2float(a))); +} + +template<> EIGEN_STRONG_INLINE Packet8h pcmp_eq(const Packet8h& a,const Packet8h& b) { + return Pack16To8(pcmp_eq(half2float(a), half2float(b))); +} + +template<> EIGEN_STRONG_INLINE Packet8h pcmp_le(const Packet8h& a,const Packet8h& b) { + return Pack16To8(pcmp_le(half2float(a), half2float(b))); +} + +template<> EIGEN_STRONG_INLINE Packet8h pcmp_lt(const Packet8h& a,const Packet8h& b) { + return Pack16To8(pcmp_lt(half2float(a), half2float(b))); +} + +template<> EIGEN_STRONG_INLINE Packet8h pcmp_lt_or_nan(const Packet8h& a,const Packet8h& b) { + return Pack16To8(pcmp_lt_or_nan(half2float(a), half2float(b))); +} + +template<> EIGEN_STRONG_INLINE Packet8h pconj(const Packet8h& a) { return a; } + +template<> EIGEN_STRONG_INLINE Packet8h pnegate(const Packet8h& a) { + Packet8h sign_mask = _mm_set1_epi16(static_cast(0x8000)); + return _mm_xor_si128(a, sign_mask); +} + +template<> EIGEN_STRONG_INLINE Packet8h padd(const Packet8h& a, const Packet8h& b) { + Packet8f af = half2float(a); + Packet8f bf = half2float(b); + Packet8f rf = padd(af, bf); + return float2half(rf); +} + +template<> EIGEN_STRONG_INLINE Packet8h psub(const Packet8h& a, const Packet8h& b) { + Packet8f af = half2float(a); + Packet8f bf = half2float(b); + Packet8f rf = psub(af, bf); + return float2half(rf); +} + +template<> EIGEN_STRONG_INLINE Packet8h pmul(const Packet8h& a, const Packet8h& b) { + Packet8f af = half2float(a); + Packet8f bf = half2float(b); + Packet8f rf = pmul(af, bf); + return float2half(rf); +} + +template<> EIGEN_STRONG_INLINE Packet8h pdiv(const Packet8h& a, const Packet8h& b) { + Packet8f af = half2float(a); + Packet8f bf = half2float(b); + Packet8f rf = pdiv(af, bf); + return float2half(rf); +} + +template<> EIGEN_STRONG_INLINE Packet8h pgather(const Eigen::half* from, Index stride) +{ + const numext::uint16_t s0 = numext::bit_cast(from[0*stride]); + const numext::uint16_t s1 = numext::bit_cast(from[1*stride]); + const numext::uint16_t s2 = numext::bit_cast(from[2*stride]); + const numext::uint16_t s3 = numext::bit_cast(from[3*stride]); + const numext::uint16_t s4 = numext::bit_cast(from[4*stride]); + const numext::uint16_t s5 = numext::bit_cast(from[5*stride]); + const numext::uint16_t s6 = numext::bit_cast(from[6*stride]); + const numext::uint16_t s7 = numext::bit_cast(from[7*stride]); + return _mm_set_epi16(s7, s6, s5, s4, s3, s2, s1, s0); +} + +template<> EIGEN_STRONG_INLINE void pscatter(Eigen::half* to, const Packet8h& from, Index stride) +{ + EIGEN_ALIGN32 Eigen::half aux[8]; + pstore(aux, from); + to[stride*0] = aux[0]; + to[stride*1] = aux[1]; + to[stride*2] = aux[2]; + to[stride*3] = aux[3]; + to[stride*4] = aux[4]; + to[stride*5] = aux[5]; + to[stride*6] = aux[6]; + to[stride*7] = aux[7]; +} + +template<> EIGEN_STRONG_INLINE Eigen::half predux(const Packet8h& a) { + Packet8f af = half2float(a); + float reduced = predux(af); + return Eigen::half(reduced); +} + +template<> EIGEN_STRONG_INLINE Eigen::half predux_max(const Packet8h& a) { + Packet8f af = half2float(a); + float reduced = predux_max(af); + return Eigen::half(reduced); +} + +template<> EIGEN_STRONG_INLINE Eigen::half predux_min(const Packet8h& a) { + Packet8f af = half2float(a); + float reduced = predux_min(af); + return Eigen::half(reduced); +} + +template<> EIGEN_STRONG_INLINE Eigen::half predux_mul(const Packet8h& a) { + Packet8f af = half2float(a); + float reduced = predux_mul(af); + return Eigen::half(reduced); +} + +template<> EIGEN_STRONG_INLINE Packet8h preverse(const Packet8h& a) +{ + __m128i m = _mm_setr_epi8(14,15,12,13,10,11,8,9,6,7,4,5,2,3,0,1); + return _mm_shuffle_epi8(a,m); +} + +EIGEN_STRONG_INLINE void +ptranspose(PacketBlock& kernel) { + __m128i a = kernel.packet[0]; + __m128i b = kernel.packet[1]; + __m128i c = kernel.packet[2]; + __m128i d = kernel.packet[3]; + __m128i e = kernel.packet[4]; + __m128i f = kernel.packet[5]; + __m128i g = kernel.packet[6]; + __m128i h = kernel.packet[7]; + + __m128i a03b03 = _mm_unpacklo_epi16(a, b); + __m128i c03d03 = _mm_unpacklo_epi16(c, d); + __m128i e03f03 = _mm_unpacklo_epi16(e, f); + __m128i g03h03 = _mm_unpacklo_epi16(g, h); + __m128i a47b47 = _mm_unpackhi_epi16(a, b); + __m128i c47d47 = _mm_unpackhi_epi16(c, d); + __m128i e47f47 = _mm_unpackhi_epi16(e, f); + __m128i g47h47 = _mm_unpackhi_epi16(g, h); + + __m128i a01b01c01d01 = _mm_unpacklo_epi32(a03b03, c03d03); + __m128i a23b23c23d23 = _mm_unpackhi_epi32(a03b03, c03d03); + __m128i e01f01g01h01 = _mm_unpacklo_epi32(e03f03, g03h03); + __m128i e23f23g23h23 = _mm_unpackhi_epi32(e03f03, g03h03); + __m128i a45b45c45d45 = _mm_unpacklo_epi32(a47b47, c47d47); + __m128i a67b67c67d67 = _mm_unpackhi_epi32(a47b47, c47d47); + __m128i e45f45g45h45 = _mm_unpacklo_epi32(e47f47, g47h47); + __m128i e67f67g67h67 = _mm_unpackhi_epi32(e47f47, g47h47); + + __m128i a0b0c0d0e0f0g0h0 = _mm_unpacklo_epi64(a01b01c01d01, e01f01g01h01); + __m128i a1b1c1d1e1f1g1h1 = _mm_unpackhi_epi64(a01b01c01d01, e01f01g01h01); + __m128i a2b2c2d2e2f2g2h2 = _mm_unpacklo_epi64(a23b23c23d23, e23f23g23h23); + __m128i a3b3c3d3e3f3g3h3 = _mm_unpackhi_epi64(a23b23c23d23, e23f23g23h23); + __m128i a4b4c4d4e4f4g4h4 = _mm_unpacklo_epi64(a45b45c45d45, e45f45g45h45); + __m128i a5b5c5d5e5f5g5h5 = _mm_unpackhi_epi64(a45b45c45d45, e45f45g45h45); + __m128i a6b6c6d6e6f6g6h6 = _mm_unpacklo_epi64(a67b67c67d67, e67f67g67h67); + __m128i a7b7c7d7e7f7g7h7 = _mm_unpackhi_epi64(a67b67c67d67, e67f67g67h67); + + kernel.packet[0] = a0b0c0d0e0f0g0h0; + kernel.packet[1] = a1b1c1d1e1f1g1h1; + kernel.packet[2] = a2b2c2d2e2f2g2h2; + kernel.packet[3] = a3b3c3d3e3f3g3h3; + kernel.packet[4] = a4b4c4d4e4f4g4h4; + kernel.packet[5] = a5b5c5d5e5f5g5h5; + kernel.packet[6] = a6b6c6d6e6f6g6h6; + kernel.packet[7] = a7b7c7d7e7f7g7h7; +} + +EIGEN_STRONG_INLINE void +ptranspose(PacketBlock& kernel) { + EIGEN_ALIGN32 Eigen::half in[4][8]; + pstore(in[0], kernel.packet[0]); + pstore(in[1], kernel.packet[1]); + pstore(in[2], kernel.packet[2]); + pstore(in[3], kernel.packet[3]); + + EIGEN_ALIGN32 Eigen::half out[4][8]; + + for (int i = 0; i < 4; ++i) { + for (int j = 0; j < 4; ++j) { + out[i][j] = in[j][2*i]; + } + for (int j = 0; j < 4; ++j) { + out[i][j+4] = in[j][2*i+1]; + } + } + + kernel.packet[0] = pload(out[0]); + kernel.packet[1] = pload(out[1]); + kernel.packet[2] = pload(out[2]); + kernel.packet[3] = pload(out[3]); +} + +// BFloat16 implementation. + +EIGEN_STRONG_INLINE Packet8f Bf16ToF32(const Packet8bf& a) { +#ifdef EIGEN_VECTORIZE_AVX2 + __m256i extend = _mm256_cvtepu16_epi32(a); + return _mm256_castsi256_ps(_mm256_slli_epi32(extend, 16)); +#else + __m128i lo = _mm_cvtepu16_epi32(a); + __m128i hi = _mm_cvtepu16_epi32(_mm_srli_si128(a, 8)); + __m128i lo_shift = _mm_slli_epi32(lo, 16); + __m128i hi_shift = _mm_slli_epi32(hi, 16); + return _mm256_castsi256_ps(_mm256_insertf128_si256(_mm256_castsi128_si256(lo_shift), hi_shift, 1)); +#endif +} + +// Convert float to bfloat16 according to round-to-nearest-even/denormals algorithm. +EIGEN_STRONG_INLINE Packet8bf F32ToBf16(const Packet8f& a) { + Packet8bf r; + + __m256i input = _mm256_castps_si256(a); + +#ifdef EIGEN_VECTORIZE_AVX2 + // uint32_t lsb = (input >> 16); + __m256i t = _mm256_srli_epi32(input, 16); + // uint32_t lsb = lsb & 1; + t = _mm256_and_si256(t, _mm256_set1_epi32(1)); + // uint32_t rounding_bias = 0x7fff + lsb; + t = _mm256_add_epi32(t, _mm256_set1_epi32(0x7fff)); + // input += rounding_bias; + t = _mm256_add_epi32(t, input); + // input = input >> 16; + t = _mm256_srli_epi32(t, 16); + // Check NaN before converting back to bf16 + __m256 mask = _mm256_cmp_ps(a, a, _CMP_ORD_Q); + __m256i nan = _mm256_set1_epi32(0x7fc0); + t = _mm256_blendv_epi8(nan, t, _mm256_castps_si256(mask)); + // output = numext::bit_cast(input); + return _mm_packus_epi32(_mm256_extractf128_si256(t, 0), + _mm256_extractf128_si256(t, 1)); +#else + // uint32_t lsb = (input >> 16); + __m128i lo = _mm_srli_epi32(_mm256_extractf128_si256(input, 0), 16); + __m128i hi = _mm_srli_epi32(_mm256_extractf128_si256(input, 1), 16); + // uint32_t lsb = lsb & 1; + lo = _mm_and_si128(lo, _mm_set1_epi32(1)); + hi = _mm_and_si128(hi, _mm_set1_epi32(1)); + // uint32_t rounding_bias = 0x7fff + lsb; + lo = _mm_add_epi32(lo, _mm_set1_epi32(0x7fff)); + hi = _mm_add_epi32(hi, _mm_set1_epi32(0x7fff)); + // input += rounding_bias; + lo = _mm_add_epi32(lo, _mm256_extractf128_si256(input, 0)); + hi = _mm_add_epi32(hi, _mm256_extractf128_si256(input, 1)); + // input = input >> 16; + lo = _mm_srli_epi32(lo, 16); + hi = _mm_srli_epi32(hi, 16); + // Check NaN before converting back to bf16 + __m256 mask = _mm256_cmp_ps(a, a, _CMP_ORD_Q); + __m128i nan = _mm_set1_epi32(0x7fc0); + lo = _mm_blendv_epi8(nan, lo, _mm_castps_si128(_mm256_castps256_ps128(mask))); + hi = _mm_blendv_epi8(nan, hi, _mm_castps_si128(_mm256_extractf128_ps(mask, 1))); + // output = numext::bit_cast(input); + return _mm_packus_epi32(lo, hi); +#endif +} + +template<> EIGEN_STRONG_INLINE Packet8bf pset1(const bfloat16& from) { + return _mm_set1_epi16(numext::bit_cast(from)); +} + +template<> EIGEN_STRONG_INLINE bfloat16 pfirst(const Packet8bf& from) { + return numext::bit_cast(static_cast(_mm_extract_epi16(from, 0))); +} + +template<> EIGEN_STRONG_INLINE Packet8bf pload(const bfloat16* from) { + return _mm_load_si128(reinterpret_cast(from)); +} + +template<> EIGEN_STRONG_INLINE Packet8bf ploadu(const bfloat16* from) { + return _mm_loadu_si128(reinterpret_cast(from)); +} + +template<> EIGEN_STRONG_INLINE void pstore(bfloat16* to, const Packet8bf& from) { + _mm_store_si128(reinterpret_cast<__m128i*>(to), from); +} + +template<> EIGEN_STRONG_INLINE void pstoreu(bfloat16* to, const Packet8bf& from) { + _mm_storeu_si128(reinterpret_cast<__m128i*>(to), from); +} + +template<> EIGEN_STRONG_INLINE Packet8bf +ploaddup(const bfloat16* from) { + const numext::uint16_t a = numext::bit_cast(from[0]); + const numext::uint16_t b = numext::bit_cast(from[1]); + const numext::uint16_t c = numext::bit_cast(from[2]); + const numext::uint16_t d = numext::bit_cast(from[3]); + return _mm_set_epi16(d, d, c, c, b, b, a, a); +} + +template<> EIGEN_STRONG_INLINE Packet8bf +ploadquad(const bfloat16* from) { + const numext::uint16_t a = numext::bit_cast(from[0]); + const numext::uint16_t b = numext::bit_cast(from[1]); + return _mm_set_epi16(b, b, b, b, a, a, a, a); +} + +template<> EIGEN_STRONG_INLINE Packet8bf ptrue(const Packet8bf& a) { + return _mm_cmpeq_epi32(a, a); +} + +template <> +EIGEN_STRONG_INLINE Packet8bf pabs(const Packet8bf& a) { + const __m128i sign_mask = _mm_set1_epi16(static_cast(0x8000)); + return _mm_andnot_si128(sign_mask, a); +} + +template <> +EIGEN_STRONG_INLINE Packet8bf pmin(const Packet8bf& a, + const Packet8bf& b) { + return F32ToBf16(pmin(Bf16ToF32(a), Bf16ToF32(b))); +} + +template <> +EIGEN_STRONG_INLINE Packet8bf pmax(const Packet8bf& a, + const Packet8bf& b) { + return F32ToBf16(pmax(Bf16ToF32(a), Bf16ToF32(b))); +} + +template <> +EIGEN_STRONG_INLINE Packet8bf plset(const bfloat16& a) { + return F32ToBf16(plset(static_cast(a))); +} + +template<> EIGEN_STRONG_INLINE Packet8bf por(const Packet8bf& a,const Packet8bf& b) { + return _mm_or_si128(a,b); +} +template<> EIGEN_STRONG_INLINE Packet8bf pxor(const Packet8bf& a,const Packet8bf& b) { + return _mm_xor_si128(a,b); +} +template<> EIGEN_STRONG_INLINE Packet8bf pand(const Packet8bf& a,const Packet8bf& b) { + return _mm_and_si128(a,b); +} +template<> EIGEN_STRONG_INLINE Packet8bf pandnot(const Packet8bf& a,const Packet8bf& b) { + return _mm_andnot_si128(b,a); +} + +template<> EIGEN_STRONG_INLINE Packet8bf pselect(const Packet8bf& mask, const Packet8bf& a, const Packet8bf& b) { + return _mm_blendv_epi8(b, a, mask); +} + +template<> EIGEN_STRONG_INLINE Packet8bf pround(const Packet8bf& a) +{ + return F32ToBf16(pround(Bf16ToF32(a))); +} + +template<> EIGEN_STRONG_INLINE Packet8bf print(const Packet8bf& a) { + return F32ToBf16(print(Bf16ToF32(a))); +} + +template<> EIGEN_STRONG_INLINE Packet8bf pceil(const Packet8bf& a) { + return F32ToBf16(pceil(Bf16ToF32(a))); +} + +template<> EIGEN_STRONG_INLINE Packet8bf pfloor(const Packet8bf& a) { + return F32ToBf16(pfloor(Bf16ToF32(a))); +} + +template<> EIGEN_STRONG_INLINE Packet8bf pcmp_eq(const Packet8bf& a,const Packet8bf& b) { + return Pack16To8(pcmp_eq(Bf16ToF32(a), Bf16ToF32(b))); +} + +template<> EIGEN_STRONG_INLINE Packet8bf pcmp_le(const Packet8bf& a,const Packet8bf& b) { + return Pack16To8(pcmp_le(Bf16ToF32(a), Bf16ToF32(b))); +} + +template<> EIGEN_STRONG_INLINE Packet8bf pcmp_lt(const Packet8bf& a,const Packet8bf& b) { + return Pack16To8(pcmp_lt(Bf16ToF32(a), Bf16ToF32(b))); +} + +template<> EIGEN_STRONG_INLINE Packet8bf pcmp_lt_or_nan(const Packet8bf& a,const Packet8bf& b) { + return Pack16To8(pcmp_lt_or_nan(Bf16ToF32(a), Bf16ToF32(b))); +} + +template<> EIGEN_STRONG_INLINE Packet8bf pconj(const Packet8bf& a) { return a; } + +template<> EIGEN_STRONG_INLINE Packet8bf pnegate(const Packet8bf& a) { + Packet8bf sign_mask = _mm_set1_epi16(static_cast(0x8000)); + return _mm_xor_si128(a, sign_mask); +} + +template<> EIGEN_STRONG_INLINE Packet8bf padd(const Packet8bf& a, const Packet8bf& b) { + return F32ToBf16(padd(Bf16ToF32(a), Bf16ToF32(b))); +} + +template<> EIGEN_STRONG_INLINE Packet8bf psub(const Packet8bf& a, const Packet8bf& b) { + return F32ToBf16(psub(Bf16ToF32(a), Bf16ToF32(b))); +} + +template<> EIGEN_STRONG_INLINE Packet8bf pmul(const Packet8bf& a, const Packet8bf& b) { + return F32ToBf16(pmul(Bf16ToF32(a), Bf16ToF32(b))); +} + +template<> EIGEN_STRONG_INLINE Packet8bf pdiv(const Packet8bf& a, const Packet8bf& b) { + return F32ToBf16(pdiv(Bf16ToF32(a), Bf16ToF32(b))); +} + + +template<> EIGEN_STRONG_INLINE Packet8bf pgather(const bfloat16* from, Index stride) +{ + const numext::uint16_t s0 = numext::bit_cast(from[0*stride]); + const numext::uint16_t s1 = numext::bit_cast(from[1*stride]); + const numext::uint16_t s2 = numext::bit_cast(from[2*stride]); + const numext::uint16_t s3 = numext::bit_cast(from[3*stride]); + const numext::uint16_t s4 = numext::bit_cast(from[4*stride]); + const numext::uint16_t s5 = numext::bit_cast(from[5*stride]); + const numext::uint16_t s6 = numext::bit_cast(from[6*stride]); + const numext::uint16_t s7 = numext::bit_cast(from[7*stride]); + return _mm_set_epi16(s7, s6, s5, s4, s3, s2, s1, s0); +} + +template<> EIGEN_STRONG_INLINE void pscatter(bfloat16* to, const Packet8bf& from, Index stride) +{ + EIGEN_ALIGN32 bfloat16 aux[8]; + pstore(aux, from); + to[stride*0] = aux[0]; + to[stride*1] = aux[1]; + to[stride*2] = aux[2]; + to[stride*3] = aux[3]; + to[stride*4] = aux[4]; + to[stride*5] = aux[5]; + to[stride*6] = aux[6]; + to[stride*7] = aux[7]; +} + +template<> EIGEN_STRONG_INLINE bfloat16 predux(const Packet8bf& a) { + return static_cast(predux(Bf16ToF32(a))); +} + +template<> EIGEN_STRONG_INLINE bfloat16 predux_max(const Packet8bf& a) { + return static_cast(predux_max(Bf16ToF32(a))); +} + +template<> EIGEN_STRONG_INLINE bfloat16 predux_min(const Packet8bf& a) { + return static_cast(predux_min(Bf16ToF32(a))); +} + +template<> EIGEN_STRONG_INLINE bfloat16 predux_mul(const Packet8bf& a) { + return static_cast(predux_mul(Bf16ToF32(a))); +} + +template<> EIGEN_STRONG_INLINE Packet8bf preverse(const Packet8bf& a) +{ + __m128i m = _mm_setr_epi8(14,15,12,13,10,11,8,9,6,7,4,5,2,3,0,1); + return _mm_shuffle_epi8(a,m); +} + +EIGEN_STRONG_INLINE void +ptranspose(PacketBlock& kernel) { + __m128i a = kernel.packet[0]; + __m128i b = kernel.packet[1]; + __m128i c = kernel.packet[2]; + __m128i d = kernel.packet[3]; + __m128i e = kernel.packet[4]; + __m128i f = kernel.packet[5]; + __m128i g = kernel.packet[6]; + __m128i h = kernel.packet[7]; + + __m128i a03b03 = _mm_unpacklo_epi16(a, b); + __m128i c03d03 = _mm_unpacklo_epi16(c, d); + __m128i e03f03 = _mm_unpacklo_epi16(e, f); + __m128i g03h03 = _mm_unpacklo_epi16(g, h); + __m128i a47b47 = _mm_unpackhi_epi16(a, b); + __m128i c47d47 = _mm_unpackhi_epi16(c, d); + __m128i e47f47 = _mm_unpackhi_epi16(e, f); + __m128i g47h47 = _mm_unpackhi_epi16(g, h); + + __m128i a01b01c01d01 = _mm_unpacklo_epi32(a03b03, c03d03); + __m128i a23b23c23d23 = _mm_unpackhi_epi32(a03b03, c03d03); + __m128i e01f01g01h01 = _mm_unpacklo_epi32(e03f03, g03h03); + __m128i e23f23g23h23 = _mm_unpackhi_epi32(e03f03, g03h03); + __m128i a45b45c45d45 = _mm_unpacklo_epi32(a47b47, c47d47); + __m128i a67b67c67d67 = _mm_unpackhi_epi32(a47b47, c47d47); + __m128i e45f45g45h45 = _mm_unpacklo_epi32(e47f47, g47h47); + __m128i e67f67g67h67 = _mm_unpackhi_epi32(e47f47, g47h47); + + kernel.packet[0] = _mm_unpacklo_epi64(a01b01c01d01, e01f01g01h01); + kernel.packet[1] = _mm_unpackhi_epi64(a01b01c01d01, e01f01g01h01); + kernel.packet[2] = _mm_unpacklo_epi64(a23b23c23d23, e23f23g23h23); + kernel.packet[3] = _mm_unpackhi_epi64(a23b23c23d23, e23f23g23h23); + kernel.packet[4] = _mm_unpacklo_epi64(a45b45c45d45, e45f45g45h45); + kernel.packet[5] = _mm_unpackhi_epi64(a45b45c45d45, e45f45g45h45); + kernel.packet[6] = _mm_unpacklo_epi64(a67b67c67d67, e67f67g67h67); + kernel.packet[7] = _mm_unpackhi_epi64(a67b67c67d67, e67f67g67h67); +} + +EIGEN_STRONG_INLINE void +ptranspose(PacketBlock& kernel) { + __m128i a = kernel.packet[0]; + __m128i b = kernel.packet[1]; + __m128i c = kernel.packet[2]; + __m128i d = kernel.packet[3]; + + __m128i ab_03 = _mm_unpacklo_epi16(a, b); + __m128i cd_03 = _mm_unpacklo_epi16(c, d); + __m128i ab_47 = _mm_unpackhi_epi16(a, b); + __m128i cd_47 = _mm_unpackhi_epi16(c, d); + + kernel.packet[0] = _mm_unpacklo_epi32(ab_03, cd_03); + kernel.packet[1] = _mm_unpackhi_epi32(ab_03, cd_03); + kernel.packet[2] = _mm_unpacklo_epi32(ab_47, cd_47); + kernel.packet[3] = _mm_unpackhi_epi32(ab_47, cd_47); +} + +} // end namespace internal + +} // end namespace Eigen + +#endif // EIGEN_PACKET_MATH_AVX_H diff --git a/Eigen/src/Core/arch/AVX/TypeCasting.h b/Eigen/src/Core/arch/AVX/TypeCasting.h new file mode 100644 index 0000000..d507fb6 --- /dev/null +++ b/Eigen/src/Core/arch/AVX/TypeCasting.h @@ -0,0 +1,115 @@ +// This file is part of Eigen, a lightweight C++ template library +// for linear algebra. +// +// Copyright (C) 2015 Benoit Steiner +// +// This Source Code Form is subject to the terms of the Mozilla +// Public License v. 2.0. If a copy of the MPL was not distributed +// with this file, You can obtain one at http://mozilla.org/MPL/2.0/. + +#ifndef EIGEN_TYPE_CASTING_AVX_H +#define EIGEN_TYPE_CASTING_AVX_H + +namespace Eigen { + +namespace internal { + +// For now we use SSE to handle integers, so we can't use AVX instructions to cast +// from int to float +template <> +struct type_casting_traits { + enum { + VectorizedCast = 0, + SrcCoeffRatio = 1, + TgtCoeffRatio = 1 + }; +}; + +template <> +struct type_casting_traits { + enum { + VectorizedCast = 0, + SrcCoeffRatio = 1, + TgtCoeffRatio = 1 + }; +}; + + +#ifndef EIGEN_VECTORIZE_AVX512 + +template <> +struct type_casting_traits { + enum { + VectorizedCast = 1, + SrcCoeffRatio = 1, + TgtCoeffRatio = 1 + }; +}; + + +template <> +struct type_casting_traits { + enum { + VectorizedCast = 1, + SrcCoeffRatio = 1, + TgtCoeffRatio = 1 + }; +}; + +template <> +struct type_casting_traits { + enum { + VectorizedCast = 1, + SrcCoeffRatio = 1, + TgtCoeffRatio = 1 + }; +}; + +template <> +struct type_casting_traits { + enum { + VectorizedCast = 1, + SrcCoeffRatio = 1, + TgtCoeffRatio = 1 + }; +}; + +#endif // EIGEN_VECTORIZE_AVX512 + +template<> EIGEN_STRONG_INLINE Packet8i pcast(const Packet8f& a) { + return _mm256_cvttps_epi32(a); +} + +template<> EIGEN_STRONG_INLINE Packet8f pcast(const Packet8i& a) { + return _mm256_cvtepi32_ps(a); +} + +template<> EIGEN_STRONG_INLINE Packet8i preinterpret(const Packet8f& a) { + return _mm256_castps_si256(a); +} + +template<> EIGEN_STRONG_INLINE Packet8f preinterpret(const Packet8i& a) { + return _mm256_castsi256_ps(a); +} + +template<> EIGEN_STRONG_INLINE Packet8f pcast(const Packet8h& a) { + return half2float(a); +} + +template<> EIGEN_STRONG_INLINE Packet8f pcast(const Packet8bf& a) { + return Bf16ToF32(a); +} + +template<> EIGEN_STRONG_INLINE Packet8h pcast(const Packet8f& a) { + return float2half(a); +} + +template<> EIGEN_STRONG_INLINE Packet8bf pcast(const Packet8f& a) { + return F32ToBf16(a); +} + +} // end namespace internal + +} // end namespace Eigen + +#endif // EIGEN_TYPE_CASTING_AVX_H diff --git a/Eigen/src/Core/arch/AVX512/Complex.h b/Eigen/src/Core/arch/AVX512/Complex.h new file mode 100644 index 0000000..49c72b3 --- /dev/null +++ b/Eigen/src/Core/arch/AVX512/Complex.h @@ -0,0 +1,422 @@ +// This file is part of Eigen, a lightweight C++ template library +// for linear algebra. +// +// Copyright (C) 2018 Gael Guennebaud +// +// This Source Code Form is subject to the terms of the Mozilla +// Public License v. 2.0. If a copy of the MPL was not distributed +// with this file, You can obtain one at http://mozilla.org/MPL/2.0/. + +#ifndef EIGEN_COMPLEX_AVX512_H +#define EIGEN_COMPLEX_AVX512_H + +namespace Eigen { + +namespace internal { + +//---------- float ---------- +struct Packet8cf +{ + EIGEN_STRONG_INLINE Packet8cf() {} + EIGEN_STRONG_INLINE explicit Packet8cf(const __m512& a) : v(a) {} + __m512 v; +}; + +template<> struct packet_traits > : default_packet_traits +{ + typedef Packet8cf type; + typedef Packet4cf half; + enum { + Vectorizable = 1, + AlignedOnScalar = 1, + size = 8, + HasHalfPacket = 1, + + HasAdd = 1, + HasSub = 1, + HasMul = 1, + HasDiv = 1, + HasNegate = 1, + HasSqrt = 1, + HasAbs = 0, + HasAbs2 = 0, + HasMin = 0, + HasMax = 0, + HasSetLinear = 0 + }; +}; + +template<> struct unpacket_traits { + typedef std::complex type; + typedef Packet4cf half; + typedef Packet16f as_real; + enum { + size = 8, + alignment=unpacket_traits::alignment, + vectorizable=true, + masked_load_available=false, + masked_store_available=false + }; +}; + +template<> EIGEN_STRONG_INLINE Packet8cf ptrue(const Packet8cf& a) { return Packet8cf(ptrue(Packet16f(a.v))); } +template<> EIGEN_STRONG_INLINE Packet8cf padd(const Packet8cf& a, const Packet8cf& b) { return Packet8cf(_mm512_add_ps(a.v,b.v)); } +template<> EIGEN_STRONG_INLINE Packet8cf psub(const Packet8cf& a, const Packet8cf& b) { return Packet8cf(_mm512_sub_ps(a.v,b.v)); } +template<> EIGEN_STRONG_INLINE Packet8cf pnegate(const Packet8cf& a) +{ + return Packet8cf(pnegate(a.v)); +} +template<> EIGEN_STRONG_INLINE Packet8cf pconj(const Packet8cf& a) +{ + const __m512 mask = _mm512_castsi512_ps(_mm512_setr_epi32( + 0x00000000,0x80000000,0x00000000,0x80000000,0x00000000,0x80000000,0x00000000,0x80000000, + 0x00000000,0x80000000,0x00000000,0x80000000,0x00000000,0x80000000,0x00000000,0x80000000)); + return Packet8cf(pxor(a.v,mask)); +} + +template<> EIGEN_STRONG_INLINE Packet8cf pmul(const Packet8cf& a, const Packet8cf& b) +{ + __m512 tmp2 = _mm512_mul_ps(_mm512_movehdup_ps(a.v), _mm512_permute_ps(b.v, _MM_SHUFFLE(2,3,0,1))); + return Packet8cf(_mm512_fmaddsub_ps(_mm512_moveldup_ps(a.v), b.v, tmp2)); +} + +template<> EIGEN_STRONG_INLINE Packet8cf pand (const Packet8cf& a, const Packet8cf& b) { return Packet8cf(pand(a.v,b.v)); } +template<> EIGEN_STRONG_INLINE Packet8cf por (const Packet8cf& a, const Packet8cf& b) { return Packet8cf(por(a.v,b.v)); } +template<> EIGEN_STRONG_INLINE Packet8cf pxor (const Packet8cf& a, const Packet8cf& b) { return Packet8cf(pxor(a.v,b.v)); } +template<> EIGEN_STRONG_INLINE Packet8cf pandnot(const Packet8cf& a, const Packet8cf& b) { return Packet8cf(pandnot(a.v,b.v)); } + +template <> +EIGEN_STRONG_INLINE Packet8cf pcmp_eq(const Packet8cf& a, const Packet8cf& b) { + __m512 eq = pcmp_eq(a.v, b.v); + return Packet8cf(pand(eq, _mm512_permute_ps(eq, 0xB1))); +} + +template<> EIGEN_STRONG_INLINE Packet8cf pload (const std::complex* from) { EIGEN_DEBUG_ALIGNED_LOAD return Packet8cf(pload(&numext::real_ref(*from))); } +template<> EIGEN_STRONG_INLINE Packet8cf ploadu(const std::complex* from) { EIGEN_DEBUG_UNALIGNED_LOAD return Packet8cf(ploadu(&numext::real_ref(*from))); } + + +template<> EIGEN_STRONG_INLINE Packet8cf pset1(const std::complex& from) +{ + return Packet8cf(_mm512_castpd_ps(pload1((const double*)(const void*)&from))); +} + +template<> EIGEN_STRONG_INLINE Packet8cf ploaddup(const std::complex* from) +{ + return Packet8cf( _mm512_castpd_ps( ploaddup((const double*)(const void*)from )) ); +} +template<> EIGEN_STRONG_INLINE Packet8cf ploadquad(const std::complex* from) +{ + return Packet8cf( _mm512_castpd_ps( ploadquad((const double*)(const void*)from )) ); +} + +template<> EIGEN_STRONG_INLINE void pstore >(std::complex* to, const Packet8cf& from) { EIGEN_DEBUG_ALIGNED_STORE pstore(&numext::real_ref(*to), from.v); } +template<> EIGEN_STRONG_INLINE void pstoreu >(std::complex* to, const Packet8cf& from) { EIGEN_DEBUG_UNALIGNED_STORE pstoreu(&numext::real_ref(*to), from.v); } + +template<> EIGEN_DEVICE_FUNC inline Packet8cf pgather, Packet8cf>(const std::complex* from, Index stride) +{ + return Packet8cf(_mm512_castpd_ps(pgather((const double*)(const void*)from, stride))); +} + +template<> EIGEN_DEVICE_FUNC inline void pscatter, Packet8cf>(std::complex* to, const Packet8cf& from, Index stride) +{ + pscatter((double*)(void*)to, _mm512_castps_pd(from.v), stride); +} + +template<> EIGEN_STRONG_INLINE std::complex pfirst(const Packet8cf& a) +{ + return pfirst(Packet2cf(_mm512_castps512_ps128(a.v))); +} + +template<> EIGEN_STRONG_INLINE Packet8cf preverse(const Packet8cf& a) { + return Packet8cf(_mm512_castsi512_ps( + _mm512_permutexvar_epi64( _mm512_set_epi32(0, 0, 0, 1, 0, 2, 0, 3, 0, 4, 0, 5, 0, 6, 0, 7), + _mm512_castps_si512(a.v)))); +} + +template<> EIGEN_STRONG_INLINE std::complex predux(const Packet8cf& a) +{ + return predux(padd(Packet4cf(extract256<0>(a.v)), + Packet4cf(extract256<1>(a.v)))); +} + +template<> EIGEN_STRONG_INLINE std::complex predux_mul(const Packet8cf& a) +{ + return predux_mul(pmul(Packet4cf(extract256<0>(a.v)), + Packet4cf(extract256<1>(a.v)))); +} + +template <> +EIGEN_STRONG_INLINE Packet4cf predux_half_dowto4(const Packet8cf& a) { + __m256 lane0 = extract256<0>(a.v); + __m256 lane1 = extract256<1>(a.v); + __m256 res = _mm256_add_ps(lane0, lane1); + return Packet4cf(res); +} + +EIGEN_MAKE_CONJ_HELPER_CPLX_REAL(Packet8cf,Packet16f) + +template<> EIGEN_STRONG_INLINE Packet8cf pdiv(const Packet8cf& a, const Packet8cf& b) +{ + Packet8cf num = pmul(a, pconj(b)); + __m512 tmp = _mm512_mul_ps(b.v, b.v); + __m512 tmp2 = _mm512_shuffle_ps(tmp,tmp,0xB1); + __m512 denom = _mm512_add_ps(tmp, tmp2); + return Packet8cf(_mm512_div_ps(num.v, denom)); +} + +template<> EIGEN_STRONG_INLINE Packet8cf pcplxflip(const Packet8cf& x) +{ + return Packet8cf(_mm512_shuffle_ps(x.v, x.v, _MM_SHUFFLE(2, 3, 0 ,1))); +} + +//---------- double ---------- +struct Packet4cd +{ + EIGEN_STRONG_INLINE Packet4cd() {} + EIGEN_STRONG_INLINE explicit Packet4cd(const __m512d& a) : v(a) {} + __m512d v; +}; + +template<> struct packet_traits > : default_packet_traits +{ + typedef Packet4cd type; + typedef Packet2cd half; + enum { + Vectorizable = 1, + AlignedOnScalar = 0, + size = 4, + HasHalfPacket = 1, + + HasAdd = 1, + HasSub = 1, + HasMul = 1, + HasDiv = 1, + HasNegate = 1, + HasSqrt = 1, + HasAbs = 0, + HasAbs2 = 0, + HasMin = 0, + HasMax = 0, + HasSetLinear = 0 + }; +}; + +template<> struct unpacket_traits { + typedef std::complex type; + typedef Packet2cd half; + typedef Packet8d as_real; + enum { + size = 4, + alignment = unpacket_traits::alignment, + vectorizable=true, + masked_load_available=false, + masked_store_available=false + }; +}; + +template<> EIGEN_STRONG_INLINE Packet4cd padd(const Packet4cd& a, const Packet4cd& b) { return Packet4cd(_mm512_add_pd(a.v,b.v)); } +template<> EIGEN_STRONG_INLINE Packet4cd psub(const Packet4cd& a, const Packet4cd& b) { return Packet4cd(_mm512_sub_pd(a.v,b.v)); } +template<> EIGEN_STRONG_INLINE Packet4cd pnegate(const Packet4cd& a) { return Packet4cd(pnegate(a.v)); } +template<> EIGEN_STRONG_INLINE Packet4cd pconj(const Packet4cd& a) +{ + const __m512d mask = _mm512_castsi512_pd( + _mm512_set_epi32(0x80000000,0x0,0x0,0x0,0x80000000,0x0,0x0,0x0, + 0x80000000,0x0,0x0,0x0,0x80000000,0x0,0x0,0x0)); + return Packet4cd(pxor(a.v,mask)); +} + +template<> EIGEN_STRONG_INLINE Packet4cd pmul(const Packet4cd& a, const Packet4cd& b) +{ + __m512d tmp1 = _mm512_shuffle_pd(a.v,a.v,0x0); + __m512d tmp2 = _mm512_shuffle_pd(a.v,a.v,0xFF); + __m512d tmp3 = _mm512_shuffle_pd(b.v,b.v,0x55); + __m512d odd = _mm512_mul_pd(tmp2, tmp3); + return Packet4cd(_mm512_fmaddsub_pd(tmp1, b.v, odd)); +} + +template<> EIGEN_STRONG_INLINE Packet4cd ptrue(const Packet4cd& a) { return Packet4cd(ptrue(Packet8d(a.v))); } +template<> EIGEN_STRONG_INLINE Packet4cd pand (const Packet4cd& a, const Packet4cd& b) { return Packet4cd(pand(a.v,b.v)); } +template<> EIGEN_STRONG_INLINE Packet4cd por (const Packet4cd& a, const Packet4cd& b) { return Packet4cd(por(a.v,b.v)); } +template<> EIGEN_STRONG_INLINE Packet4cd pxor (const Packet4cd& a, const Packet4cd& b) { return Packet4cd(pxor(a.v,b.v)); } +template<> EIGEN_STRONG_INLINE Packet4cd pandnot(const Packet4cd& a, const Packet4cd& b) { return Packet4cd(pandnot(a.v,b.v)); } + +template <> +EIGEN_STRONG_INLINE Packet4cd pcmp_eq(const Packet4cd& a, const Packet4cd& b) { + __m512d eq = pcmp_eq(a.v, b.v); + return Packet4cd(pand(eq, _mm512_permute_pd(eq, 0x55))); +} + +template<> EIGEN_STRONG_INLINE Packet4cd pload (const std::complex* from) +{ EIGEN_DEBUG_ALIGNED_LOAD return Packet4cd(pload((const double*)from)); } +template<> EIGEN_STRONG_INLINE Packet4cd ploadu(const std::complex* from) +{ EIGEN_DEBUG_UNALIGNED_LOAD return Packet4cd(ploadu((const double*)from)); } + +template<> EIGEN_STRONG_INLINE Packet4cd pset1(const std::complex& from) +{ + #ifdef EIGEN_VECTORIZE_AVX512DQ + return Packet4cd(_mm512_broadcast_f64x2(pset1(from).v)); + #else + return Packet4cd(_mm512_castps_pd(_mm512_broadcast_f32x4( _mm_castpd_ps(pset1(from).v)))); + #endif +} + +template<> EIGEN_STRONG_INLINE Packet4cd ploaddup(const std::complex* from) { + return Packet4cd(_mm512_insertf64x4( + _mm512_castpd256_pd512(ploaddup(from).v), ploaddup(from+1).v, 1)); +} + +template<> EIGEN_STRONG_INLINE void pstore >(std::complex * to, const Packet4cd& from) { EIGEN_DEBUG_ALIGNED_STORE pstore((double*)to, from.v); } +template<> EIGEN_STRONG_INLINE void pstoreu >(std::complex * to, const Packet4cd& from) { EIGEN_DEBUG_UNALIGNED_STORE pstoreu((double*)to, from.v); } + +template<> EIGEN_DEVICE_FUNC inline Packet4cd pgather, Packet4cd>(const std::complex* from, Index stride) +{ + return Packet4cd(_mm512_insertf64x4(_mm512_castpd256_pd512( + _mm256_insertf128_pd(_mm256_castpd128_pd256(ploadu(from+0*stride).v), ploadu(from+1*stride).v,1)), + _mm256_insertf128_pd(_mm256_castpd128_pd256(ploadu(from+2*stride).v), ploadu(from+3*stride).v,1), 1)); +} + +template<> EIGEN_DEVICE_FUNC inline void pscatter, Packet4cd>(std::complex* to, const Packet4cd& from, Index stride) +{ + __m512i fromi = _mm512_castpd_si512(from.v); + double* tod = (double*)(void*)to; + _mm_storeu_pd(tod+0*stride, _mm_castsi128_pd(_mm512_extracti32x4_epi32(fromi,0)) ); + _mm_storeu_pd(tod+2*stride, _mm_castsi128_pd(_mm512_extracti32x4_epi32(fromi,1)) ); + _mm_storeu_pd(tod+4*stride, _mm_castsi128_pd(_mm512_extracti32x4_epi32(fromi,2)) ); + _mm_storeu_pd(tod+6*stride, _mm_castsi128_pd(_mm512_extracti32x4_epi32(fromi,3)) ); +} + +template<> EIGEN_STRONG_INLINE std::complex pfirst(const Packet4cd& a) +{ + __m128d low = extract128<0>(a.v); + EIGEN_ALIGN16 double res[2]; + _mm_store_pd(res, low); + return std::complex(res[0],res[1]); +} + +template<> EIGEN_STRONG_INLINE Packet4cd preverse(const Packet4cd& a) { + return Packet4cd(_mm512_shuffle_f64x2(a.v, a.v, (shuffle_mask<3,2,1,0>::mask))); +} + +template<> EIGEN_STRONG_INLINE std::complex predux(const Packet4cd& a) +{ + return predux(padd(Packet2cd(_mm512_extractf64x4_pd(a.v,0)), + Packet2cd(_mm512_extractf64x4_pd(a.v,1)))); +} + +template<> EIGEN_STRONG_INLINE std::complex predux_mul(const Packet4cd& a) +{ + return predux_mul(pmul(Packet2cd(_mm512_extractf64x4_pd(a.v,0)), + Packet2cd(_mm512_extractf64x4_pd(a.v,1)))); +} + +template<> struct conj_helper +{ + EIGEN_STRONG_INLINE Packet4cd pmadd(const Packet4cd& x, const Packet4cd& y, const Packet4cd& c) const + { return padd(pmul(x,y),c); } + + EIGEN_STRONG_INLINE Packet4cd pmul(const Packet4cd& a, const Packet4cd& b) const + { + return internal::pmul(a, pconj(b)); + } +}; + +template<> struct conj_helper +{ + EIGEN_STRONG_INLINE Packet4cd pmadd(const Packet4cd& x, const Packet4cd& y, const Packet4cd& c) const + { return padd(pmul(x,y),c); } + + EIGEN_STRONG_INLINE Packet4cd pmul(const Packet4cd& a, const Packet4cd& b) const + { + return internal::pmul(pconj(a), b); + } +}; + +template<> struct conj_helper +{ + EIGEN_STRONG_INLINE Packet4cd pmadd(const Packet4cd& x, const Packet4cd& y, const Packet4cd& c) const + { return padd(pmul(x,y),c); } + + EIGEN_STRONG_INLINE Packet4cd pmul(const Packet4cd& a, const Packet4cd& b) const + { + return pconj(internal::pmul(a, b)); + } +}; + +EIGEN_MAKE_CONJ_HELPER_CPLX_REAL(Packet4cd,Packet8d) + +template<> EIGEN_STRONG_INLINE Packet4cd pdiv(const Packet4cd& a, const Packet4cd& b) +{ + Packet4cd num = pmul(a, pconj(b)); + __m512d tmp = _mm512_mul_pd(b.v, b.v); + __m512d denom = padd(_mm512_permute_pd(tmp,0x55), tmp); + return Packet4cd(_mm512_div_pd(num.v, denom)); +} + +template<> EIGEN_STRONG_INLINE Packet4cd pcplxflip(const Packet4cd& x) +{ + return Packet4cd(_mm512_permute_pd(x.v,0x55)); +} + +EIGEN_DEVICE_FUNC inline void +ptranspose(PacketBlock& kernel) { + PacketBlock pb; + + pb.packet[0] = _mm512_castps_pd(kernel.packet[0].v); + pb.packet[1] = _mm512_castps_pd(kernel.packet[1].v); + pb.packet[2] = _mm512_castps_pd(kernel.packet[2].v); + pb.packet[3] = _mm512_castps_pd(kernel.packet[3].v); + ptranspose(pb); + kernel.packet[0].v = _mm512_castpd_ps(pb.packet[0]); + kernel.packet[1].v = _mm512_castpd_ps(pb.packet[1]); + kernel.packet[2].v = _mm512_castpd_ps(pb.packet[2]); + kernel.packet[3].v = _mm512_castpd_ps(pb.packet[3]); +} + +EIGEN_DEVICE_FUNC inline void +ptranspose(PacketBlock& kernel) { + PacketBlock pb; + + pb.packet[0] = _mm512_castps_pd(kernel.packet[0].v); + pb.packet[1] = _mm512_castps_pd(kernel.packet[1].v); + pb.packet[2] = _mm512_castps_pd(kernel.packet[2].v); + pb.packet[3] = _mm512_castps_pd(kernel.packet[3].v); + pb.packet[4] = _mm512_castps_pd(kernel.packet[4].v); + pb.packet[5] = _mm512_castps_pd(kernel.packet[5].v); + pb.packet[6] = _mm512_castps_pd(kernel.packet[6].v); + pb.packet[7] = _mm512_castps_pd(kernel.packet[7].v); + ptranspose(pb); + kernel.packet[0].v = _mm512_castpd_ps(pb.packet[0]); + kernel.packet[1].v = _mm512_castpd_ps(pb.packet[1]); + kernel.packet[2].v = _mm512_castpd_ps(pb.packet[2]); + kernel.packet[3].v = _mm512_castpd_ps(pb.packet[3]); + kernel.packet[4].v = _mm512_castpd_ps(pb.packet[4]); + kernel.packet[5].v = _mm512_castpd_ps(pb.packet[5]); + kernel.packet[6].v = _mm512_castpd_ps(pb.packet[6]); + kernel.packet[7].v = _mm512_castpd_ps(pb.packet[7]); +} + +EIGEN_DEVICE_FUNC inline void +ptranspose(PacketBlock& kernel) { + __m512d T0 = _mm512_shuffle_f64x2(kernel.packet[0].v, kernel.packet[1].v, (shuffle_mask<0,1,0,1>::mask)); // [a0 a1 b0 b1] + __m512d T1 = _mm512_shuffle_f64x2(kernel.packet[0].v, kernel.packet[1].v, (shuffle_mask<2,3,2,3>::mask)); // [a2 a3 b2 b3] + __m512d T2 = _mm512_shuffle_f64x2(kernel.packet[2].v, kernel.packet[3].v, (shuffle_mask<0,1,0,1>::mask)); // [c0 c1 d0 d1] + __m512d T3 = _mm512_shuffle_f64x2(kernel.packet[2].v, kernel.packet[3].v, (shuffle_mask<2,3,2,3>::mask)); // [c2 c3 d2 d3] + + kernel.packet[3] = Packet4cd(_mm512_shuffle_f64x2(T1, T3, (shuffle_mask<1,3,1,3>::mask))); // [a3 b3 c3 d3] + kernel.packet[2] = Packet4cd(_mm512_shuffle_f64x2(T1, T3, (shuffle_mask<0,2,0,2>::mask))); // [a2 b2 c2 d2] + kernel.packet[1] = Packet4cd(_mm512_shuffle_f64x2(T0, T2, (shuffle_mask<1,3,1,3>::mask))); // [a1 b1 c1 d1] + kernel.packet[0] = Packet4cd(_mm512_shuffle_f64x2(T0, T2, (shuffle_mask<0,2,0,2>::mask))); // [a0 b0 c0 d0] +} + +template<> EIGEN_STRONG_INLINE Packet4cd psqrt(const Packet4cd& a) { + return psqrt_complex(a); +} + +template<> EIGEN_STRONG_INLINE Packet8cf psqrt(const Packet8cf& a) { + return psqrt_complex(a); +} + +} // end namespace internal +} // end namespace Eigen + +#endif // EIGEN_COMPLEX_AVX512_H diff --git a/Eigen/src/Core/arch/AVX512/MathFunctions.h b/Eigen/src/Core/arch/AVX512/MathFunctions.h new file mode 100644 index 0000000..6fd726d --- /dev/null +++ b/Eigen/src/Core/arch/AVX512/MathFunctions.h @@ -0,0 +1,362 @@ +// This file is part of Eigen, a lightweight C++ template library +// for linear algebra. +// +// Copyright (C) 2016 Pedro Gonnet (pedro.gonnet@gmail.com) +// +// This Source Code Form is subject to the terms of the Mozilla +// Public License v. 2.0. If a copy of the MPL was not distributed +// with this file, You can obtain one at http://mozilla.org/MPL/2.0/. + +#ifndef THIRD_PARTY_EIGEN3_EIGEN_SRC_CORE_ARCH_AVX512_MATHFUNCTIONS_H_ +#define THIRD_PARTY_EIGEN3_EIGEN_SRC_CORE_ARCH_AVX512_MATHFUNCTIONS_H_ + +namespace Eigen { + +namespace internal { + +// Disable the code for older versions of gcc that don't support many of the required avx512 instrinsics. +#if EIGEN_GNUC_AT_LEAST(5, 3) || EIGEN_COMP_CLANG || EIGEN_COMP_MSVC >= 1923 + +#define _EIGEN_DECLARE_CONST_Packet16f(NAME, X) \ + const Packet16f p16f_##NAME = pset1(X) + +#define _EIGEN_DECLARE_CONST_Packet16f_FROM_INT(NAME, X) \ + const Packet16f p16f_##NAME = preinterpret(pset1(X)) + +#define _EIGEN_DECLARE_CONST_Packet8d(NAME, X) \ + const Packet8d p8d_##NAME = pset1(X) + +#define _EIGEN_DECLARE_CONST_Packet8d_FROM_INT64(NAME, X) \ + const Packet8d p8d_##NAME = _mm512_castsi512_pd(_mm512_set1_epi64(X)) + +#define _EIGEN_DECLARE_CONST_Packet16bf(NAME, X) \ + const Packet16bf p16bf_##NAME = pset1(X) + +#define _EIGEN_DECLARE_CONST_Packet16bf_FROM_INT(NAME, X) \ + const Packet16bf p16bf_##NAME = preinterpret(pset1(X)) + +template <> +EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS EIGEN_UNUSED Packet16f +plog(const Packet16f& _x) { + return plog_float(_x); +} + +template <> +EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS EIGEN_UNUSED Packet8d +plog(const Packet8d& _x) { + return plog_double(_x); +} + +F16_PACKET_FUNCTION(Packet16f, Packet16h, plog) +BF16_PACKET_FUNCTION(Packet16f, Packet16bf, plog) + +template <> +EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS EIGEN_UNUSED Packet16f +plog2(const Packet16f& _x) { + return plog2_float(_x); +} + +template <> +EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS EIGEN_UNUSED Packet8d +plog2(const Packet8d& _x) { + return plog2_double(_x); +} + +F16_PACKET_FUNCTION(Packet16f, Packet16h, plog2) +BF16_PACKET_FUNCTION(Packet16f, Packet16bf, plog2) + +// Exponential function. Works by writing "x = m*log(2) + r" where +// "m = floor(x/log(2)+1/2)" and "r" is the remainder. The result is then +// "exp(x) = 2^m*exp(r)" where exp(r) is in the range [-1,1). +template <> +EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS EIGEN_UNUSED Packet16f +pexp(const Packet16f& _x) { + _EIGEN_DECLARE_CONST_Packet16f(1, 1.0f); + _EIGEN_DECLARE_CONST_Packet16f(half, 0.5f); + _EIGEN_DECLARE_CONST_Packet16f(127, 127.0f); + + _EIGEN_DECLARE_CONST_Packet16f(exp_hi, 88.3762626647950f); + _EIGEN_DECLARE_CONST_Packet16f(exp_lo, -88.3762626647949f); + + _EIGEN_DECLARE_CONST_Packet16f(cephes_LOG2EF, 1.44269504088896341f); + + _EIGEN_DECLARE_CONST_Packet16f(cephes_exp_p0, 1.9875691500E-4f); + _EIGEN_DECLARE_CONST_Packet16f(cephes_exp_p1, 1.3981999507E-3f); + _EIGEN_DECLARE_CONST_Packet16f(cephes_exp_p2, 8.3334519073E-3f); + _EIGEN_DECLARE_CONST_Packet16f(cephes_exp_p3, 4.1665795894E-2f); + _EIGEN_DECLARE_CONST_Packet16f(cephes_exp_p4, 1.6666665459E-1f); + _EIGEN_DECLARE_CONST_Packet16f(cephes_exp_p5, 5.0000001201E-1f); + + // Clamp x. + Packet16f x = pmax(pmin(_x, p16f_exp_hi), p16f_exp_lo); + + // Express exp(x) as exp(m*ln(2) + r), start by extracting + // m = floor(x/ln(2) + 0.5). + Packet16f m = _mm512_floor_ps(pmadd(x, p16f_cephes_LOG2EF, p16f_half)); + + // Get r = x - m*ln(2). Note that we can do this without losing more than one + // ulp precision due to the FMA instruction. + _EIGEN_DECLARE_CONST_Packet16f(nln2, -0.6931471805599453f); + Packet16f r = _mm512_fmadd_ps(m, p16f_nln2, x); + Packet16f r2 = pmul(r, r); + Packet16f r3 = pmul(r2, r); + + // Evaluate the polynomial approximant,improved by instruction-level parallelism. + Packet16f y, y1, y2; + y = pmadd(p16f_cephes_exp_p0, r, p16f_cephes_exp_p1); + y1 = pmadd(p16f_cephes_exp_p3, r, p16f_cephes_exp_p4); + y2 = padd(r, p16f_1); + y = pmadd(y, r, p16f_cephes_exp_p2); + y1 = pmadd(y1, r, p16f_cephes_exp_p5); + y = pmadd(y, r3, y1); + y = pmadd(y, r2, y2); + + // Build emm0 = 2^m. + Packet16i emm0 = _mm512_cvttps_epi32(padd(m, p16f_127)); + emm0 = _mm512_slli_epi32(emm0, 23); + + // Return 2^m * exp(r). + return pmax(pmul(y, _mm512_castsi512_ps(emm0)), _x); +} + +template <> +EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS EIGEN_UNUSED Packet8d +pexp(const Packet8d& _x) { + return pexp_double(_x); +} + +F16_PACKET_FUNCTION(Packet16f, Packet16h, pexp) +BF16_PACKET_FUNCTION(Packet16f, Packet16bf, pexp) + +template <> +EIGEN_STRONG_INLINE Packet16h pfrexp(const Packet16h& a, Packet16h& exponent) { + Packet16f fexponent; + const Packet16h out = float2half(pfrexp(half2float(a), fexponent)); + exponent = float2half(fexponent); + return out; +} + +template <> +EIGEN_STRONG_INLINE Packet16h pldexp(const Packet16h& a, const Packet16h& exponent) { + return float2half(pldexp(half2float(a), half2float(exponent))); +} + +template <> +EIGEN_STRONG_INLINE Packet16bf pfrexp(const Packet16bf& a, Packet16bf& exponent) { + Packet16f fexponent; + const Packet16bf out = F32ToBf16(pfrexp(Bf16ToF32(a), fexponent)); + exponent = F32ToBf16(fexponent); + return out; +} + +template <> +EIGEN_STRONG_INLINE Packet16bf pldexp(const Packet16bf& a, const Packet16bf& exponent) { + return F32ToBf16(pldexp(Bf16ToF32(a), Bf16ToF32(exponent))); +} + +// Functions for sqrt. +// The EIGEN_FAST_MATH version uses the _mm_rsqrt_ps approximation and one step +// of Newton's method, at a cost of 1-2 bits of precision as opposed to the +// exact solution. The main advantage of this approach is not just speed, but +// also the fact that it can be inlined and pipelined with other computations, +// further reducing its effective latency. +#if EIGEN_FAST_MATH +template <> +EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS EIGEN_UNUSED Packet16f +psqrt(const Packet16f& _x) { + Packet16f neg_half = pmul(_x, pset1(-.5f)); + __mmask16 denormal_mask = _mm512_kand( + _mm512_cmp_ps_mask(_x, pset1((std::numeric_limits::min)()), + _CMP_LT_OQ), + _mm512_cmp_ps_mask(_x, _mm512_setzero_ps(), _CMP_GE_OQ)); + + Packet16f x = _mm512_rsqrt14_ps(_x); + + // Do a single step of Newton's iteration. + x = pmul(x, pmadd(neg_half, pmul(x, x), pset1(1.5f))); + + // Flush results for denormals to zero. + return _mm512_mask_blend_ps(denormal_mask, pmul(_x,x), _mm512_setzero_ps()); +} + +template <> +EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS EIGEN_UNUSED Packet8d +psqrt(const Packet8d& _x) { + Packet8d neg_half = pmul(_x, pset1(-.5)); + __mmask16 denormal_mask = _mm512_kand( + _mm512_cmp_pd_mask(_x, pset1((std::numeric_limits::min)()), + _CMP_LT_OQ), + _mm512_cmp_pd_mask(_x, _mm512_setzero_pd(), _CMP_GE_OQ)); + + Packet8d x = _mm512_rsqrt14_pd(_x); + + // Do a single step of Newton's iteration. + x = pmul(x, pmadd(neg_half, pmul(x, x), pset1(1.5))); + + // Do a second step of Newton's iteration. + x = pmul(x, pmadd(neg_half, pmul(x, x), pset1(1.5))); + + return _mm512_mask_blend_pd(denormal_mask, pmul(_x,x), _mm512_setzero_pd()); +} +#else +template <> +EIGEN_STRONG_INLINE Packet16f psqrt(const Packet16f& x) { + return _mm512_sqrt_ps(x); +} + +template <> +EIGEN_STRONG_INLINE Packet8d psqrt(const Packet8d& x) { + return _mm512_sqrt_pd(x); +} +#endif + +F16_PACKET_FUNCTION(Packet16f, Packet16h, psqrt) +BF16_PACKET_FUNCTION(Packet16f, Packet16bf, psqrt) + +// prsqrt for float. +#if defined(EIGEN_VECTORIZE_AVX512ER) + +template <> +EIGEN_STRONG_INLINE Packet16f prsqrt(const Packet16f& x) { + return _mm512_rsqrt28_ps(x); +} +#elif EIGEN_FAST_MATH + +template <> +EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS EIGEN_UNUSED Packet16f +prsqrt(const Packet16f& _x) { + _EIGEN_DECLARE_CONST_Packet16f_FROM_INT(inf, 0x7f800000); + _EIGEN_DECLARE_CONST_Packet16f(one_point_five, 1.5f); + _EIGEN_DECLARE_CONST_Packet16f(minus_half, -0.5f); + + Packet16f neg_half = pmul(_x, p16f_minus_half); + + // Identity infinite, negative and denormal arguments. + __mmask16 inf_mask = _mm512_cmp_ps_mask(_x, p16f_inf, _CMP_EQ_OQ); + __mmask16 not_pos_mask = _mm512_cmp_ps_mask(_x, _mm512_setzero_ps(), _CMP_LE_OQ); + __mmask16 not_finite_pos_mask = not_pos_mask | inf_mask; + + // Compute an approximate result using the rsqrt intrinsic, forcing +inf + // for denormals for consistency with AVX and SSE implementations. + Packet16f y_approx = _mm512_rsqrt14_ps(_x); + + // Do a single step of Newton-Raphson iteration to improve the approximation. + // This uses the formula y_{n+1} = y_n * (1.5 - y_n * (0.5 * x) * y_n). + // It is essential to evaluate the inner term like this because forming + // y_n^2 may over- or underflow. + Packet16f y_newton = pmul(y_approx, pmadd(y_approx, pmul(neg_half, y_approx), p16f_one_point_five)); + + // Select the result of the Newton-Raphson step for positive finite arguments. + // For other arguments, choose the output of the intrinsic. This will + // return rsqrt(+inf) = 0, rsqrt(x) = NaN if x < 0, and rsqrt(0) = +inf. + return _mm512_mask_blend_ps(not_finite_pos_mask, y_newton, y_approx); +} +#else + +template <> +EIGEN_STRONG_INLINE Packet16f prsqrt(const Packet16f& x) { + _EIGEN_DECLARE_CONST_Packet16f(one, 1.0f); + return _mm512_div_ps(p16f_one, _mm512_sqrt_ps(x)); +} +#endif + +F16_PACKET_FUNCTION(Packet16f, Packet16h, prsqrt) +BF16_PACKET_FUNCTION(Packet16f, Packet16bf, prsqrt) + +// prsqrt for double. +#if EIGEN_FAST_MATH +template <> +EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS EIGEN_UNUSED Packet8d +prsqrt(const Packet8d& _x) { + _EIGEN_DECLARE_CONST_Packet8d(one_point_five, 1.5); + _EIGEN_DECLARE_CONST_Packet8d(minus_half, -0.5); + _EIGEN_DECLARE_CONST_Packet8d_FROM_INT64(inf, 0x7ff0000000000000LL); + + Packet8d neg_half = pmul(_x, p8d_minus_half); + + // Identity infinite, negative and denormal arguments. + __mmask8 inf_mask = _mm512_cmp_pd_mask(_x, p8d_inf, _CMP_EQ_OQ); + __mmask8 not_pos_mask = _mm512_cmp_pd_mask(_x, _mm512_setzero_pd(), _CMP_LE_OQ); + __mmask8 not_finite_pos_mask = not_pos_mask | inf_mask; + + // Compute an approximate result using the rsqrt intrinsic, forcing +inf + // for denormals for consistency with AVX and SSE implementations. +#if defined(EIGEN_VECTORIZE_AVX512ER) + Packet8d y_approx = _mm512_rsqrt28_pd(_x); +#else + Packet8d y_approx = _mm512_rsqrt14_pd(_x); +#endif + // Do one or two steps of Newton-Raphson's to improve the approximation, depending on the + // starting accuracy (either 2^-14 or 2^-28, depending on whether AVX512ER is available). + // The Newton-Raphson algorithm has quadratic convergence and roughly doubles the number + // of correct digits for each step. + // This uses the formula y_{n+1} = y_n * (1.5 - y_n * (0.5 * x) * y_n). + // It is essential to evaluate the inner term like this because forming + // y_n^2 may over- or underflow. + Packet8d y_newton = pmul(y_approx, pmadd(neg_half, pmul(y_approx, y_approx), p8d_one_point_five)); +#if !defined(EIGEN_VECTORIZE_AVX512ER) + y_newton = pmul(y_newton, pmadd(y_newton, pmul(neg_half, y_newton), p8d_one_point_five)); +#endif + // Select the result of the Newton-Raphson step for positive finite arguments. + // For other arguments, choose the output of the intrinsic. This will + // return rsqrt(+inf) = 0, rsqrt(x) = NaN if x < 0, and rsqrt(0) = +inf. + return _mm512_mask_blend_pd(not_finite_pos_mask, y_newton, y_approx); +} +#else +template <> +EIGEN_STRONG_INLINE Packet8d prsqrt(const Packet8d& x) { + _EIGEN_DECLARE_CONST_Packet8d(one, 1.0f); + return _mm512_div_pd(p8d_one, _mm512_sqrt_pd(x)); +} +#endif + +template<> EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS EIGEN_UNUSED +Packet16f plog1p(const Packet16f& _x) { + return generic_plog1p(_x); +} + +F16_PACKET_FUNCTION(Packet16f, Packet16h, plog1p) +BF16_PACKET_FUNCTION(Packet16f, Packet16bf, plog1p) + +template<> EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS EIGEN_UNUSED +Packet16f pexpm1(const Packet16f& _x) { + return generic_expm1(_x); +} + +F16_PACKET_FUNCTION(Packet16f, Packet16h, pexpm1) +BF16_PACKET_FUNCTION(Packet16f, Packet16bf, pexpm1) + +#endif + + +template <> +EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS EIGEN_UNUSED Packet16f +psin(const Packet16f& _x) { + return psin_float(_x); +} + +template <> +EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS EIGEN_UNUSED Packet16f +pcos(const Packet16f& _x) { + return pcos_float(_x); +} + +template <> +EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS EIGEN_UNUSED Packet16f +ptanh(const Packet16f& _x) { + return internal::generic_fast_tanh_float(_x); +} + +F16_PACKET_FUNCTION(Packet16f, Packet16h, psin) +F16_PACKET_FUNCTION(Packet16f, Packet16h, pcos) +F16_PACKET_FUNCTION(Packet16f, Packet16h, ptanh) + +BF16_PACKET_FUNCTION(Packet16f, Packet16bf, psin) +BF16_PACKET_FUNCTION(Packet16f, Packet16bf, pcos) +BF16_PACKET_FUNCTION(Packet16f, Packet16bf, ptanh) + +} // end namespace internal + +} // end namespace Eigen + +#endif // THIRD_PARTY_EIGEN3_EIGEN_SRC_CORE_ARCH_AVX512_MATHFUNCTIONS_H_ diff --git a/Eigen/src/Core/arch/AVX512/PacketMath.h b/Eigen/src/Core/arch/AVX512/PacketMath.h new file mode 100644 index 0000000..34d49ab --- /dev/null +++ b/Eigen/src/Core/arch/AVX512/PacketMath.h @@ -0,0 +1,2303 @@ +// This file is part of Eigen, a lightweight C++ template library +// for linear algebra. +// +// Copyright (C) 2016 Benoit Steiner (benoit.steiner.goog@gmail.com) +// +// This Source Code Form is subject to the terms of the Mozilla +// Public License v. 2.0. If a copy of the MPL was not distributed +// with this file, You can obtain one at http://mozilla.org/MPL/2.0/. + +#ifndef EIGEN_PACKET_MATH_AVX512_H +#define EIGEN_PACKET_MATH_AVX512_H + +namespace Eigen { + +namespace internal { + +#ifndef EIGEN_CACHEFRIENDLY_PRODUCT_THRESHOLD +#define EIGEN_CACHEFRIENDLY_PRODUCT_THRESHOLD 8 +#endif + +#ifndef EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS +#define EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS 32 +#endif + +#ifdef EIGEN_VECTORIZE_FMA +#ifndef EIGEN_HAS_SINGLE_INSTRUCTION_MADD +#define EIGEN_HAS_SINGLE_INSTRUCTION_MADD +#endif +#endif + +typedef __m512 Packet16f; +typedef __m512i Packet16i; +typedef __m512d Packet8d; +typedef eigen_packet_wrapper<__m256i, 1> Packet16h; +typedef eigen_packet_wrapper<__m256i, 2> Packet16bf; + +template <> +struct is_arithmetic<__m512> { + enum { value = true }; +}; +template <> +struct is_arithmetic<__m512i> { + enum { value = true }; +}; +template <> +struct is_arithmetic<__m512d> { + enum { value = true }; +}; + +template<> struct is_arithmetic { enum { value = true }; }; + +template <> +struct packet_traits : default_packet_traits { + typedef Packet16h type; + // There is no half-size packet for Packet16h. + typedef Packet16h half; + enum { + Vectorizable = 1, + AlignedOnScalar = 1, + size = 16, + HasHalfPacket = 1, + + HasCmp = 1, + HasAdd = 1, + HasSub = 1, + HasMul = 1, + HasDiv = 1, + HasNegate = 1, + HasAbs = 1, + HasAbs2 = 0, + HasMin = 1, + HasMax = 1, + HasConj = 1, + HasSetLinear = 0, + HasLog = 1, + HasLog1p = 1, + HasExpm1 = 1, + HasExp = 1, + HasSqrt = 1, + HasRsqrt = 1, + HasSin = EIGEN_FAST_MATH, + HasCos = EIGEN_FAST_MATH, + HasTanh = EIGEN_FAST_MATH, + HasErf = EIGEN_FAST_MATH, + HasBlend = 0, + HasRound = 1, + HasFloor = 1, + HasCeil = 1, + HasRint = 1, + HasBessel = 1, + HasNdtri = 1 + }; +}; + +template<> struct packet_traits : default_packet_traits +{ + typedef Packet16f type; + typedef Packet8f half; + enum { + Vectorizable = 1, + AlignedOnScalar = 1, + size = 16, + HasHalfPacket = 1, + + HasAbs = 1, + HasMin = 1, + HasMax = 1, + HasConj = 1, + HasBlend = 0, + HasSin = EIGEN_FAST_MATH, + HasCos = EIGEN_FAST_MATH, +#if EIGEN_GNUC_AT_LEAST(5, 3) || (!EIGEN_COMP_GNUC_STRICT) + HasLog = 1, + HasLog1p = 1, + HasExpm1 = 1, + HasNdtri = 1, + HasBessel = 1, + HasExp = 1, + HasSqrt = EIGEN_FAST_MATH, + HasRsqrt = EIGEN_FAST_MATH, + HasTanh = EIGEN_FAST_MATH, + HasErf = EIGEN_FAST_MATH, +#endif + HasCmp = 1, + HasDiv = 1, + HasRound = 1, + HasFloor = 1, + HasCeil = 1, + HasRint = 1 + }; + }; +template<> struct packet_traits : default_packet_traits +{ + typedef Packet8d type; + typedef Packet4d half; + enum { + Vectorizable = 1, + AlignedOnScalar = 1, + size = 8, + HasHalfPacket = 1, +#if EIGEN_GNUC_AT_LEAST(5, 3) || (!EIGEN_COMP_GNUC_STRICT) + HasLog = 1, + HasExp = 1, + HasSqrt = EIGEN_FAST_MATH, + HasRsqrt = EIGEN_FAST_MATH, +#endif + HasCmp = 1, + HasDiv = 1, + HasRound = 1, + HasFloor = 1, + HasCeil = 1, + HasRint = 1 + }; +}; + +/* TODO Implement AVX512 for integers +template<> struct packet_traits : default_packet_traits +{ + typedef Packet16i type; + enum { + Vectorizable = 1, + AlignedOnScalar = 1, + size=8 + }; +}; +*/ + +template <> +struct unpacket_traits { + typedef float type; + typedef Packet8f half; + typedef Packet16i integer_packet; + typedef uint16_t mask_t; + enum { size = 16, alignment=Aligned64, vectorizable=true, masked_load_available=true, masked_store_available=true }; +}; +template <> +struct unpacket_traits { + typedef double type; + typedef Packet4d half; + enum { size = 8, alignment=Aligned64, vectorizable=true, masked_load_available=false, masked_store_available=false }; +}; +template <> +struct unpacket_traits { + typedef int type; + typedef Packet8i half; + enum { size = 16, alignment=Aligned64, vectorizable=false, masked_load_available=false, masked_store_available=false }; +}; + +template<> +struct unpacket_traits { + typedef Eigen::half type; + typedef Packet8h half; + enum {size=16, alignment=Aligned32, vectorizable=true, masked_load_available=false, masked_store_available=false}; +}; + +template <> +EIGEN_STRONG_INLINE Packet16f pset1(const float& from) { + return _mm512_set1_ps(from); +} +template <> +EIGEN_STRONG_INLINE Packet8d pset1(const double& from) { + return _mm512_set1_pd(from); +} +template <> +EIGEN_STRONG_INLINE Packet16i pset1(const int& from) { + return _mm512_set1_epi32(from); +} + +template <> +EIGEN_STRONG_INLINE Packet16f pset1frombits(unsigned int from) { + return _mm512_castsi512_ps(_mm512_set1_epi32(from)); +} + +template <> +EIGEN_STRONG_INLINE Packet8d pset1frombits(const numext::uint64_t from) { + return _mm512_castsi512_pd(_mm512_set1_epi64(from)); +} + +template<> EIGEN_STRONG_INLINE Packet16f pzero(const Packet16f& /*a*/) { return _mm512_setzero_ps(); } +template<> EIGEN_STRONG_INLINE Packet8d pzero(const Packet8d& /*a*/) { return _mm512_setzero_pd(); } +template<> EIGEN_STRONG_INLINE Packet16i pzero(const Packet16i& /*a*/) { return _mm512_setzero_si512(); } + +template<> EIGEN_STRONG_INLINE Packet16f peven_mask(const Packet16f& /*a*/) { + return _mm512_castsi512_ps(_mm512_set_epi32(0, -1, 0, -1, 0, -1, 0, -1, + 0, -1, 0, -1, 0, -1, 0, -1)); +} +template<> EIGEN_STRONG_INLINE Packet16i peven_mask(const Packet16i& /*a*/) { + return _mm512_set_epi32(0, -1, 0, -1, 0, -1, 0, -1, + 0, -1, 0, -1, 0, -1, 0, -1); +} +template<> EIGEN_STRONG_INLINE Packet8d peven_mask(const Packet8d& /*a*/) { + return _mm512_castsi512_pd(_mm512_set_epi32(0, 0, -1, -1, 0, 0, -1, -1, + 0, 0, -1, -1, 0, 0, -1, -1)); +} + +template <> +EIGEN_STRONG_INLINE Packet16f pload1(const float* from) { + return _mm512_broadcastss_ps(_mm_load_ps1(from)); +} +template <> +EIGEN_STRONG_INLINE Packet8d pload1(const double* from) { + return _mm512_set1_pd(*from); +} + +template <> +EIGEN_STRONG_INLINE Packet16f plset(const float& a) { + return _mm512_add_ps( + _mm512_set1_ps(a), + _mm512_set_ps(15.0f, 14.0f, 13.0f, 12.0f, 11.0f, 10.0f, 9.0f, 8.0f, 7.0f, 6.0f, 5.0f, + 4.0f, 3.0f, 2.0f, 1.0f, 0.0f)); +} +template <> +EIGEN_STRONG_INLINE Packet8d plset(const double& a) { + return _mm512_add_pd(_mm512_set1_pd(a), + _mm512_set_pd(7.0, 6.0, 5.0, 4.0, 3.0, 2.0, 1.0, 0.0)); +} + +template <> +EIGEN_STRONG_INLINE Packet16f padd(const Packet16f& a, + const Packet16f& b) { + return _mm512_add_ps(a, b); +} +template <> +EIGEN_STRONG_INLINE Packet8d padd(const Packet8d& a, + const Packet8d& b) { + return _mm512_add_pd(a, b); +} +template <> +EIGEN_STRONG_INLINE Packet16i padd(const Packet16i& a, + const Packet16i& b) { + return _mm512_add_epi32(a, b); +} + +template <> +EIGEN_STRONG_INLINE Packet16f psub(const Packet16f& a, + const Packet16f& b) { + return _mm512_sub_ps(a, b); +} +template <> +EIGEN_STRONG_INLINE Packet8d psub(const Packet8d& a, + const Packet8d& b) { + return _mm512_sub_pd(a, b); +} +template <> +EIGEN_STRONG_INLINE Packet16i psub(const Packet16i& a, + const Packet16i& b) { + return _mm512_sub_epi32(a, b); +} + +template <> +EIGEN_STRONG_INLINE Packet16f pnegate(const Packet16f& a) { + return _mm512_sub_ps(_mm512_set1_ps(0.0), a); +} +template <> +EIGEN_STRONG_INLINE Packet8d pnegate(const Packet8d& a) { + return _mm512_sub_pd(_mm512_set1_pd(0.0), a); +} + +template <> +EIGEN_STRONG_INLINE Packet16f pconj(const Packet16f& a) { + return a; +} +template <> +EIGEN_STRONG_INLINE Packet8d pconj(const Packet8d& a) { + return a; +} +template <> +EIGEN_STRONG_INLINE Packet16i pconj(const Packet16i& a) { + return a; +} + +template <> +EIGEN_STRONG_INLINE Packet16f pmul(const Packet16f& a, + const Packet16f& b) { + return _mm512_mul_ps(a, b); +} +template <> +EIGEN_STRONG_INLINE Packet8d pmul(const Packet8d& a, + const Packet8d& b) { + return _mm512_mul_pd(a, b); +} +template <> +EIGEN_STRONG_INLINE Packet16i pmul(const Packet16i& a, + const Packet16i& b) { + return _mm512_mullo_epi32(a, b); +} + +template <> +EIGEN_STRONG_INLINE Packet16f pdiv(const Packet16f& a, + const Packet16f& b) { + return _mm512_div_ps(a, b); +} +template <> +EIGEN_STRONG_INLINE Packet8d pdiv(const Packet8d& a, + const Packet8d& b) { + return _mm512_div_pd(a, b); +} + +#ifdef EIGEN_VECTORIZE_FMA +template <> +EIGEN_STRONG_INLINE Packet16f pmadd(const Packet16f& a, const Packet16f& b, + const Packet16f& c) { + return _mm512_fmadd_ps(a, b, c); +} +template <> +EIGEN_STRONG_INLINE Packet8d pmadd(const Packet8d& a, const Packet8d& b, + const Packet8d& c) { + return _mm512_fmadd_pd(a, b, c); +} +#endif + +template <> +EIGEN_DEVICE_FUNC inline Packet16f pselect(const Packet16f& mask, + const Packet16f& a, + const Packet16f& b) { + __mmask16 mask16 = _mm512_cmp_epi32_mask( + _mm512_castps_si512(mask), _mm512_setzero_epi32(), _MM_CMPINT_EQ); + return _mm512_mask_blend_ps(mask16, a, b); +} + +template <> +EIGEN_DEVICE_FUNC inline Packet8d pselect(const Packet8d& mask, + const Packet8d& a, + const Packet8d& b) { + __mmask8 mask8 = _mm512_cmp_epi64_mask(_mm512_castpd_si512(mask), + _mm512_setzero_epi32(), _MM_CMPINT_EQ); + return _mm512_mask_blend_pd(mask8, a, b); +} + +template <> +EIGEN_STRONG_INLINE Packet16f pmin(const Packet16f& a, + const Packet16f& b) { + // Arguments are reversed to match NaN propagation behavior of std::min. + return _mm512_min_ps(b, a); +} +template <> +EIGEN_STRONG_INLINE Packet8d pmin(const Packet8d& a, + const Packet8d& b) { + // Arguments are reversed to match NaN propagation behavior of std::min. + return _mm512_min_pd(b, a); +} + +template <> +EIGEN_STRONG_INLINE Packet16f pmax(const Packet16f& a, + const Packet16f& b) { + // Arguments are reversed to match NaN propagation behavior of std::max. + return _mm512_max_ps(b, a); +} +template <> +EIGEN_STRONG_INLINE Packet8d pmax(const Packet8d& a, + const Packet8d& b) { + // Arguments are reversed to match NaN propagation behavior of std::max. + return _mm512_max_pd(b, a); +} + +// Add specializations for min/max with prescribed NaN progation. +template<> +EIGEN_STRONG_INLINE Packet16f pmin(const Packet16f& a, const Packet16f& b) { + return pminmax_propagate_numbers(a, b, pmin); +} +template<> +EIGEN_STRONG_INLINE Packet8d pmin(const Packet8d& a, const Packet8d& b) { + return pminmax_propagate_numbers(a, b, pmin); +} +template<> +EIGEN_STRONG_INLINE Packet16f pmax(const Packet16f& a, const Packet16f& b) { + return pminmax_propagate_numbers(a, b, pmax); +} +template<> +EIGEN_STRONG_INLINE Packet8d pmax(const Packet8d& a, const Packet8d& b) { + return pminmax_propagate_numbers(a, b, pmax); +} +template<> +EIGEN_STRONG_INLINE Packet16f pmin(const Packet16f& a, const Packet16f& b) { + return pminmax_propagate_nan(a, b, pmin); +} +template<> +EIGEN_STRONG_INLINE Packet8d pmin(const Packet8d& a, const Packet8d& b) { + return pminmax_propagate_nan(a, b, pmin); +} +template<> +EIGEN_STRONG_INLINE Packet16f pmax(const Packet16f& a, const Packet16f& b) { + return pminmax_propagate_nan(a, b, pmax); +} +template<> +EIGEN_STRONG_INLINE Packet8d pmax(const Packet8d& a, const Packet8d& b) { + return pminmax_propagate_nan(a, b, pmax); +} + + +#ifdef EIGEN_VECTORIZE_AVX512DQ +template EIGEN_STRONG_INLINE Packet8f extract256(Packet16f x) { return _mm512_extractf32x8_ps(x,I_); } +template EIGEN_STRONG_INLINE Packet2d extract128(Packet8d x) { return _mm512_extractf64x2_pd(x,I_); } +EIGEN_STRONG_INLINE Packet16f cat256(Packet8f a, Packet8f b) { return _mm512_insertf32x8(_mm512_castps256_ps512(a),b,1); } +#else +// AVX512F does not define _mm512_extractf32x8_ps to extract _m256 from _m512 +template EIGEN_STRONG_INLINE Packet8f extract256(Packet16f x) { + return _mm256_castsi256_ps(_mm512_extracti64x4_epi64( _mm512_castps_si512(x),I_)); +} + +// AVX512F does not define _mm512_extractf64x2_pd to extract _m128 from _m512 +template EIGEN_STRONG_INLINE Packet2d extract128(Packet8d x) { + return _mm_castsi128_pd(_mm512_extracti32x4_epi32( _mm512_castpd_si512(x),I_)); +} + +EIGEN_STRONG_INLINE Packet16f cat256(Packet8f a, Packet8f b) { + return _mm512_castsi512_ps(_mm512_inserti64x4(_mm512_castsi256_si512(_mm256_castps_si256(a)), + _mm256_castps_si256(b),1)); +} +#endif + +// Helper function for bit packing snippet of low precision comparison. +// It packs the flags from 32x16 to 16x16. +EIGEN_STRONG_INLINE __m256i Pack32To16(Packet16f rf) { + // Split data into small pieces and handle with AVX instructions + // to guarantee internal order of vector. + // Operation: + // dst[15:0] := Saturate16(rf[31:0]) + // dst[31:16] := Saturate16(rf[63:32]) + // ... + // dst[255:240] := Saturate16(rf[255:224]) + __m256i lo = _mm256_castps_si256(extract256<0>(rf)); + __m256i hi = _mm256_castps_si256(extract256<1>(rf)); + __m128i result_lo = _mm_packs_epi32(_mm256_extractf128_si256(lo, 0), + _mm256_extractf128_si256(lo, 1)); + __m128i result_hi = _mm_packs_epi32(_mm256_extractf128_si256(hi, 0), + _mm256_extractf128_si256(hi, 1)); + return _mm256_insertf128_si256(_mm256_castsi128_si256(result_lo), result_hi, 1); +} + +template <> +EIGEN_STRONG_INLINE Packet16f pcmp_eq(const Packet16f& a, const Packet16f& b) { + __mmask16 mask = _mm512_cmp_ps_mask(a, b, _CMP_EQ_OQ); + return _mm512_castsi512_ps( + _mm512_mask_set1_epi32(_mm512_set1_epi32(0), mask, 0xffffffffu)); +} +template<> EIGEN_STRONG_INLINE Packet16f pcmp_le(const Packet16f& a, const Packet16f& b) { + __mmask16 mask = _mm512_cmp_ps_mask(a, b, _CMP_LE_OQ); + return _mm512_castsi512_ps( + _mm512_mask_set1_epi32(_mm512_set1_epi32(0), mask, 0xffffffffu)); +} + +template<> EIGEN_STRONG_INLINE Packet16f pcmp_lt(const Packet16f& a, const Packet16f& b) { + __mmask16 mask = _mm512_cmp_ps_mask(a, b, _CMP_LT_OQ); + return _mm512_castsi512_ps( + _mm512_mask_set1_epi32(_mm512_set1_epi32(0), mask, 0xffffffffu)); +} + +template<> EIGEN_STRONG_INLINE Packet16f pcmp_lt_or_nan(const Packet16f& a, const Packet16f& b) { + __mmask16 mask = _mm512_cmp_ps_mask(a, b, _CMP_NGE_UQ); + return _mm512_castsi512_ps( + _mm512_mask_set1_epi32(_mm512_set1_epi32(0), mask, 0xffffffffu)); +} + +template<> EIGEN_STRONG_INLINE Packet16i pcmp_eq(const Packet16i& a, const Packet16i& b) { + __mmask16 mask = _mm512_cmp_epi32_mask(a, b, _CMP_EQ_OQ); + return _mm512_mask_set1_epi32(_mm512_set1_epi32(0), mask, 0xffffffffu); +} + + +template <> +EIGEN_STRONG_INLINE Packet8d pcmp_eq(const Packet8d& a, const Packet8d& b) { + __mmask8 mask = _mm512_cmp_pd_mask(a, b, _CMP_EQ_OQ); + return _mm512_castsi512_pd( + _mm512_mask_set1_epi64(_mm512_set1_epi64(0), mask, 0xffffffffffffffffu)); +} +template <> +EIGEN_STRONG_INLINE Packet8d pcmp_le(const Packet8d& a, const Packet8d& b) { + __mmask8 mask = _mm512_cmp_pd_mask(a, b, _CMP_LE_OQ); + return _mm512_castsi512_pd( + _mm512_mask_set1_epi64(_mm512_set1_epi64(0), mask, 0xffffffffffffffffu)); +} +template <> +EIGEN_STRONG_INLINE Packet8d pcmp_lt(const Packet8d& a, const Packet8d& b) { + __mmask8 mask = _mm512_cmp_pd_mask(a, b, _CMP_LT_OQ); + return _mm512_castsi512_pd( + _mm512_mask_set1_epi64(_mm512_set1_epi64(0), mask, 0xffffffffffffffffu)); +} +template <> +EIGEN_STRONG_INLINE Packet8d pcmp_lt_or_nan(const Packet8d& a, const Packet8d& b) { + __mmask8 mask = _mm512_cmp_pd_mask(a, b, _CMP_NGE_UQ); + return _mm512_castsi512_pd( + _mm512_mask_set1_epi64(_mm512_set1_epi64(0), mask, 0xffffffffffffffffu)); +} + +template<> EIGEN_STRONG_INLINE Packet16f print(const Packet16f& a) { return _mm512_roundscale_ps(a, _MM_FROUND_CUR_DIRECTION); } +template<> EIGEN_STRONG_INLINE Packet8d print(const Packet8d& a) { return _mm512_roundscale_pd(a, _MM_FROUND_CUR_DIRECTION); } + +template<> EIGEN_STRONG_INLINE Packet16f pceil(const Packet16f& a) { return _mm512_roundscale_ps(a, _MM_FROUND_TO_POS_INF); } +template<> EIGEN_STRONG_INLINE Packet8d pceil(const Packet8d& a) { return _mm512_roundscale_pd(a, _MM_FROUND_TO_POS_INF); } + +template<> EIGEN_STRONG_INLINE Packet16f pfloor(const Packet16f& a) { return _mm512_roundscale_ps(a, _MM_FROUND_TO_NEG_INF); } +template<> EIGEN_STRONG_INLINE Packet8d pfloor(const Packet8d& a) { return _mm512_roundscale_pd(a, _MM_FROUND_TO_NEG_INF); } + +template <> +EIGEN_STRONG_INLINE Packet16i ptrue(const Packet16i& /*a*/) { + return _mm512_set1_epi32(0xffffffffu); +} + +template <> +EIGEN_STRONG_INLINE Packet16f ptrue(const Packet16f& a) { + return _mm512_castsi512_ps(ptrue(_mm512_castps_si512(a))); +} + +template <> +EIGEN_STRONG_INLINE Packet8d ptrue(const Packet8d& a) { + return _mm512_castsi512_pd(ptrue(_mm512_castpd_si512(a))); +} + +template <> +EIGEN_STRONG_INLINE Packet16i pand(const Packet16i& a, + const Packet16i& b) { + return _mm512_and_si512(a,b); +} + +template <> +EIGEN_STRONG_INLINE Packet16f pand(const Packet16f& a, + const Packet16f& b) { +#ifdef EIGEN_VECTORIZE_AVX512DQ + return _mm512_and_ps(a, b); +#else + return _mm512_castsi512_ps(pand(_mm512_castps_si512(a),_mm512_castps_si512(b))); +#endif +} +template <> +EIGEN_STRONG_INLINE Packet8d pand(const Packet8d& a, + const Packet8d& b) { +#ifdef EIGEN_VECTORIZE_AVX512DQ + return _mm512_and_pd(a, b); +#else + Packet8d res = _mm512_undefined_pd(); + Packet4d lane0_a = _mm512_extractf64x4_pd(a, 0); + Packet4d lane0_b = _mm512_extractf64x4_pd(b, 0); + res = _mm512_insertf64x4(res, _mm256_and_pd(lane0_a, lane0_b), 0); + + Packet4d lane1_a = _mm512_extractf64x4_pd(a, 1); + Packet4d lane1_b = _mm512_extractf64x4_pd(b, 1); + return _mm512_insertf64x4(res, _mm256_and_pd(lane1_a, lane1_b), 1); +#endif +} + +template <> +EIGEN_STRONG_INLINE Packet16i por(const Packet16i& a, const Packet16i& b) { + return _mm512_or_si512(a, b); +} + +template <> +EIGEN_STRONG_INLINE Packet16f por(const Packet16f& a, const Packet16f& b) { +#ifdef EIGEN_VECTORIZE_AVX512DQ + return _mm512_or_ps(a, b); +#else + return _mm512_castsi512_ps(por(_mm512_castps_si512(a),_mm512_castps_si512(b))); +#endif +} + +template <> +EIGEN_STRONG_INLINE Packet8d por(const Packet8d& a, + const Packet8d& b) { +#ifdef EIGEN_VECTORIZE_AVX512DQ + return _mm512_or_pd(a, b); +#else + return _mm512_castsi512_pd(por(_mm512_castpd_si512(a),_mm512_castpd_si512(b))); +#endif +} + +template <> +EIGEN_STRONG_INLINE Packet16i pxor(const Packet16i& a, const Packet16i& b) { + return _mm512_xor_si512(a, b); +} + +template <> +EIGEN_STRONG_INLINE Packet16f pxor(const Packet16f& a, const Packet16f& b) { +#ifdef EIGEN_VECTORIZE_AVX512DQ + return _mm512_xor_ps(a, b); +#else + return _mm512_castsi512_ps(pxor(_mm512_castps_si512(a),_mm512_castps_si512(b))); +#endif +} + +template <> +EIGEN_STRONG_INLINE Packet8d pxor(const Packet8d& a, const Packet8d& b) { +#ifdef EIGEN_VECTORIZE_AVX512DQ + return _mm512_xor_pd(a, b); +#else + return _mm512_castsi512_pd(pxor(_mm512_castpd_si512(a),_mm512_castpd_si512(b))); +#endif +} + +template <> +EIGEN_STRONG_INLINE Packet16i pandnot(const Packet16i& a, const Packet16i& b) { + return _mm512_andnot_si512(b, a); +} + +template <> +EIGEN_STRONG_INLINE Packet16f pandnot(const Packet16f& a, const Packet16f& b) { +#ifdef EIGEN_VECTORIZE_AVX512DQ + return _mm512_andnot_ps(b, a); +#else + return _mm512_castsi512_ps(pandnot(_mm512_castps_si512(a),_mm512_castps_si512(b))); +#endif +} +template <> +EIGEN_STRONG_INLINE Packet8d pandnot(const Packet8d& a,const Packet8d& b) { +#ifdef EIGEN_VECTORIZE_AVX512DQ + return _mm512_andnot_pd(b, a); +#else + return _mm512_castsi512_pd(pandnot(_mm512_castpd_si512(a),_mm512_castpd_si512(b))); +#endif +} + +template<> EIGEN_STRONG_INLINE Packet16f pround(const Packet16f& a) +{ + // Work-around for default std::round rounding mode. + const Packet16f mask = pset1frombits(static_cast(0x80000000u)); + const Packet16f prev0dot5 = pset1frombits(static_cast(0x3EFFFFFFu)); + return _mm512_roundscale_ps(padd(por(pand(a, mask), prev0dot5), a), _MM_FROUND_TO_ZERO); +} +template<> EIGEN_STRONG_INLINE Packet8d pround(const Packet8d& a) +{ + // Work-around for default std::round rounding mode. + const Packet8d mask = pset1frombits(static_cast(0x8000000000000000ull)); + const Packet8d prev0dot5 = pset1frombits(static_cast(0x3FDFFFFFFFFFFFFFull)); + return _mm512_roundscale_pd(padd(por(pand(a, mask), prev0dot5), a), _MM_FROUND_TO_ZERO); +} + +template EIGEN_STRONG_INLINE Packet16i parithmetic_shift_right(Packet16i a) { + return _mm512_srai_epi32(a, N); +} + +template EIGEN_STRONG_INLINE Packet16i plogical_shift_right(Packet16i a) { + return _mm512_srli_epi32(a, N); +} + +template EIGEN_STRONG_INLINE Packet16i plogical_shift_left(Packet16i a) { + return _mm512_slli_epi32(a, N); +} + +template <> +EIGEN_STRONG_INLINE Packet16f pload(const float* from) { + EIGEN_DEBUG_ALIGNED_LOAD return _mm512_load_ps(from); +} +template <> +EIGEN_STRONG_INLINE Packet8d pload(const double* from) { + EIGEN_DEBUG_ALIGNED_LOAD return _mm512_load_pd(from); +} +template <> +EIGEN_STRONG_INLINE Packet16i pload(const int* from) { + EIGEN_DEBUG_ALIGNED_LOAD return _mm512_load_si512( + reinterpret_cast(from)); +} + +template <> +EIGEN_STRONG_INLINE Packet16f ploadu(const float* from) { + EIGEN_DEBUG_UNALIGNED_LOAD return _mm512_loadu_ps(from); +} +template <> +EIGEN_STRONG_INLINE Packet8d ploadu(const double* from) { + EIGEN_DEBUG_UNALIGNED_LOAD return _mm512_loadu_pd(from); +} +template <> +EIGEN_STRONG_INLINE Packet16i ploadu(const int* from) { + EIGEN_DEBUG_UNALIGNED_LOAD return _mm512_loadu_si512( + reinterpret_cast(from)); +} + +template <> +EIGEN_STRONG_INLINE Packet16f ploadu(const float* from, uint16_t umask) { + __mmask16 mask = static_cast<__mmask16>(umask); + EIGEN_DEBUG_UNALIGNED_LOAD return _mm512_maskz_loadu_ps(mask, from); +} + +// Loads 8 floats from memory a returns the packet +// {a0, a0 a1, a1, a2, a2, a3, a3, a4, a4, a5, a5, a6, a6, a7, a7} +template <> +EIGEN_STRONG_INLINE Packet16f ploaddup(const float* from) { + // an unaligned load is required here as there is no requirement + // on the alignment of input pointer 'from' + __m256i low_half = _mm256_loadu_si256(reinterpret_cast(from)); + __m512 even_elements = _mm512_castsi512_ps(_mm512_cvtepu32_epi64(low_half)); + __m512 pairs = _mm512_permute_ps(even_elements, _MM_SHUFFLE(2, 2, 0, 0)); + return pairs; +} + +#ifdef EIGEN_VECTORIZE_AVX512DQ +// FIXME: this does not look optimal, better load a Packet4d and shuffle... +// Loads 4 doubles from memory a returns the packet {a0, a0 a1, a1, a2, a2, a3, +// a3} +template <> +EIGEN_STRONG_INLINE Packet8d ploaddup(const double* from) { + __m512d x = _mm512_setzero_pd(); + x = _mm512_insertf64x2(x, _mm_loaddup_pd(&from[0]), 0); + x = _mm512_insertf64x2(x, _mm_loaddup_pd(&from[1]), 1); + x = _mm512_insertf64x2(x, _mm_loaddup_pd(&from[2]), 2); + x = _mm512_insertf64x2(x, _mm_loaddup_pd(&from[3]), 3); + return x; +} +#else +template <> +EIGEN_STRONG_INLINE Packet8d ploaddup(const double* from) { + __m512d x = _mm512_setzero_pd(); + x = _mm512_mask_broadcastsd_pd(x, 0x3<<0, _mm_load_sd(from+0)); + x = _mm512_mask_broadcastsd_pd(x, 0x3<<2, _mm_load_sd(from+1)); + x = _mm512_mask_broadcastsd_pd(x, 0x3<<4, _mm_load_sd(from+2)); + x = _mm512_mask_broadcastsd_pd(x, 0x3<<6, _mm_load_sd(from+3)); + return x; +} +#endif + +// Loads 4 floats from memory a returns the packet +// {a0, a0 a0, a0, a1, a1, a1, a1, a2, a2, a2, a2, a3, a3, a3, a3} +template <> +EIGEN_STRONG_INLINE Packet16f ploadquad(const float* from) { + Packet16f tmp = _mm512_castps128_ps512(ploadu(from)); + const Packet16i scatter_mask = _mm512_set_epi32(3,3,3,3, 2,2,2,2, 1,1,1,1, 0,0,0,0); + return _mm512_permutexvar_ps(scatter_mask, tmp); +} + +// Loads 2 doubles from memory a returns the packet +// {a0, a0 a0, a0, a1, a1, a1, a1} +template <> +EIGEN_STRONG_INLINE Packet8d ploadquad(const double* from) { + __m256d lane0 = _mm256_set1_pd(*from); + __m256d lane1 = _mm256_set1_pd(*(from+1)); + __m512d tmp = _mm512_undefined_pd(); + tmp = _mm512_insertf64x4(tmp, lane0, 0); + return _mm512_insertf64x4(tmp, lane1, 1); +} + +template <> +EIGEN_STRONG_INLINE void pstore(float* to, const Packet16f& from) { + EIGEN_DEBUG_ALIGNED_STORE _mm512_store_ps(to, from); +} +template <> +EIGEN_STRONG_INLINE void pstore(double* to, const Packet8d& from) { + EIGEN_DEBUG_ALIGNED_STORE _mm512_store_pd(to, from); +} +template <> +EIGEN_STRONG_INLINE void pstore(int* to, const Packet16i& from) { + EIGEN_DEBUG_ALIGNED_STORE _mm512_storeu_si512(reinterpret_cast<__m512i*>(to), + from); +} + +template <> +EIGEN_STRONG_INLINE void pstoreu(float* to, const Packet16f& from) { + EIGEN_DEBUG_UNALIGNED_STORE _mm512_storeu_ps(to, from); +} +template <> +EIGEN_STRONG_INLINE void pstoreu(double* to, const Packet8d& from) { + EIGEN_DEBUG_UNALIGNED_STORE _mm512_storeu_pd(to, from); +} +template <> +EIGEN_STRONG_INLINE void pstoreu(int* to, const Packet16i& from) { + EIGEN_DEBUG_UNALIGNED_STORE _mm512_storeu_si512( + reinterpret_cast<__m512i*>(to), from); +} +template <> +EIGEN_STRONG_INLINE void pstoreu(float* to, const Packet16f& from, uint16_t umask) { + __mmask16 mask = static_cast<__mmask16>(umask); + EIGEN_DEBUG_UNALIGNED_STORE return _mm512_mask_storeu_ps(to, mask, from); +} + +template <> +EIGEN_DEVICE_FUNC inline Packet16f pgather(const float* from, + Index stride) { + Packet16i stride_vector = _mm512_set1_epi32(convert_index(stride)); + Packet16i stride_multiplier = + _mm512_set_epi32(15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0); + Packet16i indices = _mm512_mullo_epi32(stride_vector, stride_multiplier); + + return _mm512_i32gather_ps(indices, from, 4); +} +template <> +EIGEN_DEVICE_FUNC inline Packet8d pgather(const double* from, + Index stride) { + Packet8i stride_vector = _mm256_set1_epi32(convert_index(stride)); + Packet8i stride_multiplier = _mm256_set_epi32(7, 6, 5, 4, 3, 2, 1, 0); + Packet8i indices = _mm256_mullo_epi32(stride_vector, stride_multiplier); + + return _mm512_i32gather_pd(indices, from, 8); +} + +template <> +EIGEN_DEVICE_FUNC inline void pscatter(float* to, + const Packet16f& from, + Index stride) { + Packet16i stride_vector = _mm512_set1_epi32(convert_index(stride)); + Packet16i stride_multiplier = + _mm512_set_epi32(15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0); + Packet16i indices = _mm512_mullo_epi32(stride_vector, stride_multiplier); + _mm512_i32scatter_ps(to, indices, from, 4); +} +template <> +EIGEN_DEVICE_FUNC inline void pscatter(double* to, + const Packet8d& from, + Index stride) { + Packet8i stride_vector = _mm256_set1_epi32(convert_index(stride)); + Packet8i stride_multiplier = _mm256_set_epi32(7, 6, 5, 4, 3, 2, 1, 0); + Packet8i indices = _mm256_mullo_epi32(stride_vector, stride_multiplier); + _mm512_i32scatter_pd(to, indices, from, 8); +} + +template <> +EIGEN_STRONG_INLINE void pstore1(float* to, const float& a) { + Packet16f pa = pset1(a); + pstore(to, pa); +} +template <> +EIGEN_STRONG_INLINE void pstore1(double* to, const double& a) { + Packet8d pa = pset1(a); + pstore(to, pa); +} +template <> +EIGEN_STRONG_INLINE void pstore1(int* to, const int& a) { + Packet16i pa = pset1(a); + pstore(to, pa); +} + +template<> EIGEN_STRONG_INLINE void prefetch(const float* addr) { _mm_prefetch((SsePrefetchPtrType)(addr), _MM_HINT_T0); } +template<> EIGEN_STRONG_INLINE void prefetch(const double* addr) { _mm_prefetch((SsePrefetchPtrType)(addr), _MM_HINT_T0); } +template<> EIGEN_STRONG_INLINE void prefetch(const int* addr) { _mm_prefetch((SsePrefetchPtrType)(addr), _MM_HINT_T0); } + +template <> +EIGEN_STRONG_INLINE float pfirst(const Packet16f& a) { + return _mm_cvtss_f32(_mm512_extractf32x4_ps(a, 0)); +} +template <> +EIGEN_STRONG_INLINE double pfirst(const Packet8d& a) { + return _mm_cvtsd_f64(_mm256_extractf128_pd(_mm512_extractf64x4_pd(a, 0), 0)); +} +template <> +EIGEN_STRONG_INLINE int pfirst(const Packet16i& a) { + return _mm_extract_epi32(_mm512_extracti32x4_epi32(a, 0), 0); +} + +template<> EIGEN_STRONG_INLINE Packet16f preverse(const Packet16f& a) +{ + return _mm512_permutexvar_ps(_mm512_set_epi32(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15), a); +} + +template<> EIGEN_STRONG_INLINE Packet8d preverse(const Packet8d& a) +{ + return _mm512_permutexvar_pd(_mm512_set_epi32(0, 0, 0, 1, 0, 2, 0, 3, 0, 4, 0, 5, 0, 6, 0, 7), a); +} + +template<> EIGEN_STRONG_INLINE Packet16f pabs(const Packet16f& a) +{ + // _mm512_abs_ps intrinsic not found, so hack around it + return _mm512_castsi512_ps(_mm512_and_si512(_mm512_castps_si512(a), _mm512_set1_epi32(0x7fffffff))); +} +template <> +EIGEN_STRONG_INLINE Packet8d pabs(const Packet8d& a) { + // _mm512_abs_ps intrinsic not found, so hack around it + return _mm512_castsi512_pd(_mm512_and_si512(_mm512_castpd_si512(a), + _mm512_set1_epi64(0x7fffffffffffffff))); +} + +template<> +EIGEN_STRONG_INLINE Packet16f pfrexp(const Packet16f& a, Packet16f& exponent){ + return pfrexp_generic(a, exponent); +} + +// Extract exponent without existence of Packet8l. +template<> +EIGEN_STRONG_INLINE +Packet8d pfrexp_generic_get_biased_exponent(const Packet8d& a) { + const Packet8d cst_exp_mask = pset1frombits(static_cast(0x7ff0000000000000ull)); + #ifdef EIGEN_VECTORIZE_AVX512DQ + return _mm512_cvtepi64_pd(_mm512_srli_epi64(_mm512_castpd_si512(pand(a, cst_exp_mask)), 52)); + #else + return _mm512_cvtepi32_pd(_mm512_cvtepi64_epi32(_mm512_srli_epi64(_mm512_castpd_si512(pand(a, cst_exp_mask)), 52))); + #endif +} + +template<> +EIGEN_STRONG_INLINE Packet8d pfrexp(const Packet8d& a, Packet8d& exponent) { + return pfrexp_generic(a, exponent); +} + +template<> EIGEN_STRONG_INLINE Packet16f pldexp(const Packet16f& a, const Packet16f& exponent) { + return pldexp_generic(a, exponent); +} + +template<> EIGEN_STRONG_INLINE Packet8d pldexp(const Packet8d& a, const Packet8d& exponent) { + // Clamp exponent to [-2099, 2099] + const Packet8d max_exponent = pset1(2099.0); + const Packet8i e = _mm512_cvtpd_epi32(pmin(pmax(exponent, pnegate(max_exponent)), max_exponent)); + + // Split 2^e into four factors and multiply. + const Packet8i bias = pset1(1023); + Packet8i b = parithmetic_shift_right<2>(e); // floor(e/4) + + // 2^b + const Packet8i permute_idx = _mm256_setr_epi32(0, 4, 1, 5, 2, 6, 3, 7); + Packet8i hi = _mm256_permutevar8x32_epi32(padd(b, bias), permute_idx); + Packet8i lo = _mm256_slli_epi64(hi, 52); + hi = _mm256_slli_epi64(_mm256_srli_epi64(hi, 32), 52); + Packet8d c = _mm512_castsi512_pd(_mm512_inserti64x4(_mm512_castsi256_si512(lo), hi, 1)); + Packet8d out = pmul(pmul(pmul(a, c), c), c); // a * 2^(3b) + + // 2^(e - 3b) + b = psub(psub(psub(e, b), b), b); // e - 3b + hi = _mm256_permutevar8x32_epi32(padd(b, bias), permute_idx); + lo = _mm256_slli_epi64(hi, 52); + hi = _mm256_slli_epi64(_mm256_srli_epi64(hi, 32), 52); + c = _mm512_castsi512_pd(_mm512_inserti64x4(_mm512_castsi256_si512(lo), hi, 1)); + out = pmul(out, c); // a * 2^e + return out; +} + +#ifdef EIGEN_VECTORIZE_AVX512DQ +// AVX512F does not define _mm512_extractf32x8_ps to extract _m256 from _m512 +#define EIGEN_EXTRACT_8f_FROM_16f(INPUT, OUTPUT) \ + __m256 OUTPUT##_0 = _mm512_extractf32x8_ps(INPUT, 0); \ + __m256 OUTPUT##_1 = _mm512_extractf32x8_ps(INPUT, 1) +#else +#define EIGEN_EXTRACT_8f_FROM_16f(INPUT, OUTPUT) \ + __m256 OUTPUT##_0 = _mm256_insertf128_ps( \ + _mm256_castps128_ps256(_mm512_extractf32x4_ps(INPUT, 0)), \ + _mm512_extractf32x4_ps(INPUT, 1), 1); \ + __m256 OUTPUT##_1 = _mm256_insertf128_ps( \ + _mm256_castps128_ps256(_mm512_extractf32x4_ps(INPUT, 2)), \ + _mm512_extractf32x4_ps(INPUT, 3), 1); +#endif + +#ifdef EIGEN_VECTORIZE_AVX512DQ +#define EIGEN_INSERT_8f_INTO_16f(OUTPUT, INPUTA, INPUTB) \ + OUTPUT = _mm512_insertf32x8(_mm512_castps256_ps512(INPUTA), INPUTB, 1); +#else +#define EIGEN_INSERT_8f_INTO_16f(OUTPUT, INPUTA, INPUTB) \ + OUTPUT = _mm512_undefined_ps(); \ + OUTPUT = _mm512_insertf32x4(OUTPUT, _mm256_extractf128_ps(INPUTA, 0), 0); \ + OUTPUT = _mm512_insertf32x4(OUTPUT, _mm256_extractf128_ps(INPUTA, 1), 1); \ + OUTPUT = _mm512_insertf32x4(OUTPUT, _mm256_extractf128_ps(INPUTB, 0), 2); \ + OUTPUT = _mm512_insertf32x4(OUTPUT, _mm256_extractf128_ps(INPUTB, 1), 3); +#endif + +template <> +EIGEN_STRONG_INLINE float predux(const Packet16f& a) { +#ifdef EIGEN_VECTORIZE_AVX512DQ + __m256 lane0 = _mm512_extractf32x8_ps(a, 0); + __m256 lane1 = _mm512_extractf32x8_ps(a, 1); + Packet8f x = _mm256_add_ps(lane0, lane1); + return predux(x); +#else + __m128 lane0 = _mm512_extractf32x4_ps(a, 0); + __m128 lane1 = _mm512_extractf32x4_ps(a, 1); + __m128 lane2 = _mm512_extractf32x4_ps(a, 2); + __m128 lane3 = _mm512_extractf32x4_ps(a, 3); + __m128 sum = _mm_add_ps(_mm_add_ps(lane0, lane1), _mm_add_ps(lane2, lane3)); + sum = _mm_hadd_ps(sum, sum); + sum = _mm_hadd_ps(sum, _mm_permute_ps(sum, 1)); + return _mm_cvtss_f32(sum); +#endif +} +template <> +EIGEN_STRONG_INLINE double predux(const Packet8d& a) { + __m256d lane0 = _mm512_extractf64x4_pd(a, 0); + __m256d lane1 = _mm512_extractf64x4_pd(a, 1); + __m256d sum = _mm256_add_pd(lane0, lane1); + __m256d tmp0 = _mm256_hadd_pd(sum, _mm256_permute2f128_pd(sum, sum, 1)); + return _mm_cvtsd_f64(_mm256_castpd256_pd128(_mm256_hadd_pd(tmp0, tmp0))); +} + +template <> +EIGEN_STRONG_INLINE Packet8f predux_half_dowto4(const Packet16f& a) { +#ifdef EIGEN_VECTORIZE_AVX512DQ + __m256 lane0 = _mm512_extractf32x8_ps(a, 0); + __m256 lane1 = _mm512_extractf32x8_ps(a, 1); + return _mm256_add_ps(lane0, lane1); +#else + __m128 lane0 = _mm512_extractf32x4_ps(a, 0); + __m128 lane1 = _mm512_extractf32x4_ps(a, 1); + __m128 lane2 = _mm512_extractf32x4_ps(a, 2); + __m128 lane3 = _mm512_extractf32x4_ps(a, 3); + __m128 sum0 = _mm_add_ps(lane0, lane2); + __m128 sum1 = _mm_add_ps(lane1, lane3); + return _mm256_insertf128_ps(_mm256_castps128_ps256(sum0), sum1, 1); +#endif +} +template <> +EIGEN_STRONG_INLINE Packet4d predux_half_dowto4(const Packet8d& a) { + __m256d lane0 = _mm512_extractf64x4_pd(a, 0); + __m256d lane1 = _mm512_extractf64x4_pd(a, 1); + return _mm256_add_pd(lane0, lane1); +} + +template <> +EIGEN_STRONG_INLINE float predux_mul(const Packet16f& a) { +//#ifdef EIGEN_VECTORIZE_AVX512DQ +#if 0 + Packet8f lane0 = _mm512_extractf32x8_ps(a, 0); + Packet8f lane1 = _mm512_extractf32x8_ps(a, 1); + Packet8f res = pmul(lane0, lane1); + res = pmul(res, _mm256_permute2f128_ps(res, res, 1)); + res = pmul(res, _mm_permute_ps(res, _MM_SHUFFLE(0, 0, 3, 2))); + return pfirst(pmul(res, _mm_permute_ps(res, _MM_SHUFFLE(0, 0, 0, 1)))); +#else + __m128 lane0 = _mm512_extractf32x4_ps(a, 0); + __m128 lane1 = _mm512_extractf32x4_ps(a, 1); + __m128 lane2 = _mm512_extractf32x4_ps(a, 2); + __m128 lane3 = _mm512_extractf32x4_ps(a, 3); + __m128 res = pmul(pmul(lane0, lane1), pmul(lane2, lane3)); + res = pmul(res, _mm_permute_ps(res, _MM_SHUFFLE(0, 0, 3, 2))); + return pfirst(pmul(res, _mm_permute_ps(res, _MM_SHUFFLE(0, 0, 0, 1)))); +#endif +} +template <> +EIGEN_STRONG_INLINE double predux_mul(const Packet8d& a) { + __m256d lane0 = _mm512_extractf64x4_pd(a, 0); + __m256d lane1 = _mm512_extractf64x4_pd(a, 1); + __m256d res = pmul(lane0, lane1); + res = pmul(res, _mm256_permute2f128_pd(res, res, 1)); + return pfirst(pmul(res, _mm256_shuffle_pd(res, res, 1))); +} + +template <> +EIGEN_STRONG_INLINE float predux_min(const Packet16f& a) { + __m128 lane0 = _mm512_extractf32x4_ps(a, 0); + __m128 lane1 = _mm512_extractf32x4_ps(a, 1); + __m128 lane2 = _mm512_extractf32x4_ps(a, 2); + __m128 lane3 = _mm512_extractf32x4_ps(a, 3); + __m128 res = _mm_min_ps(_mm_min_ps(lane0, lane1), _mm_min_ps(lane2, lane3)); + res = _mm_min_ps(res, _mm_permute_ps(res, _MM_SHUFFLE(0, 0, 3, 2))); + return pfirst(_mm_min_ps(res, _mm_permute_ps(res, _MM_SHUFFLE(0, 0, 0, 1)))); +} +template <> +EIGEN_STRONG_INLINE double predux_min(const Packet8d& a) { + __m256d lane0 = _mm512_extractf64x4_pd(a, 0); + __m256d lane1 = _mm512_extractf64x4_pd(a, 1); + __m256d res = _mm256_min_pd(lane0, lane1); + res = _mm256_min_pd(res, _mm256_permute2f128_pd(res, res, 1)); + return pfirst(_mm256_min_pd(res, _mm256_shuffle_pd(res, res, 1))); +} + +template <> +EIGEN_STRONG_INLINE float predux_max(const Packet16f& a) { + __m128 lane0 = _mm512_extractf32x4_ps(a, 0); + __m128 lane1 = _mm512_extractf32x4_ps(a, 1); + __m128 lane2 = _mm512_extractf32x4_ps(a, 2); + __m128 lane3 = _mm512_extractf32x4_ps(a, 3); + __m128 res = _mm_max_ps(_mm_max_ps(lane0, lane1), _mm_max_ps(lane2, lane3)); + res = _mm_max_ps(res, _mm_permute_ps(res, _MM_SHUFFLE(0, 0, 3, 2))); + return pfirst(_mm_max_ps(res, _mm_permute_ps(res, _MM_SHUFFLE(0, 0, 0, 1)))); +} + +template <> +EIGEN_STRONG_INLINE double predux_max(const Packet8d& a) { + __m256d lane0 = _mm512_extractf64x4_pd(a, 0); + __m256d lane1 = _mm512_extractf64x4_pd(a, 1); + __m256d res = _mm256_max_pd(lane0, lane1); + res = _mm256_max_pd(res, _mm256_permute2f128_pd(res, res, 1)); + return pfirst(_mm256_max_pd(res, _mm256_shuffle_pd(res, res, 1))); +} + +template<> EIGEN_STRONG_INLINE bool predux_any(const Packet16f& x) +{ + Packet16i xi = _mm512_castps_si512(x); + __mmask16 tmp = _mm512_test_epi32_mask(xi,xi); + return !_mm512_kortestz(tmp,tmp); +} + + + +#define PACK_OUTPUT(OUTPUT, INPUT, INDEX, STRIDE) \ + EIGEN_INSERT_8f_INTO_16f(OUTPUT[INDEX], INPUT[INDEX], INPUT[INDEX + STRIDE]); + +EIGEN_DEVICE_FUNC inline void ptranspose(PacketBlock& kernel) { + __m512 T0 = _mm512_unpacklo_ps(kernel.packet[0], kernel.packet[1]); + __m512 T1 = _mm512_unpackhi_ps(kernel.packet[0], kernel.packet[1]); + __m512 T2 = _mm512_unpacklo_ps(kernel.packet[2], kernel.packet[3]); + __m512 T3 = _mm512_unpackhi_ps(kernel.packet[2], kernel.packet[3]); + __m512 T4 = _mm512_unpacklo_ps(kernel.packet[4], kernel.packet[5]); + __m512 T5 = _mm512_unpackhi_ps(kernel.packet[4], kernel.packet[5]); + __m512 T6 = _mm512_unpacklo_ps(kernel.packet[6], kernel.packet[7]); + __m512 T7 = _mm512_unpackhi_ps(kernel.packet[6], kernel.packet[7]); + __m512 T8 = _mm512_unpacklo_ps(kernel.packet[8], kernel.packet[9]); + __m512 T9 = _mm512_unpackhi_ps(kernel.packet[8], kernel.packet[9]); + __m512 T10 = _mm512_unpacklo_ps(kernel.packet[10], kernel.packet[11]); + __m512 T11 = _mm512_unpackhi_ps(kernel.packet[10], kernel.packet[11]); + __m512 T12 = _mm512_unpacklo_ps(kernel.packet[12], kernel.packet[13]); + __m512 T13 = _mm512_unpackhi_ps(kernel.packet[12], kernel.packet[13]); + __m512 T14 = _mm512_unpacklo_ps(kernel.packet[14], kernel.packet[15]); + __m512 T15 = _mm512_unpackhi_ps(kernel.packet[14], kernel.packet[15]); + __m512 S0 = _mm512_shuffle_ps(T0, T2, _MM_SHUFFLE(1, 0, 1, 0)); + __m512 S1 = _mm512_shuffle_ps(T0, T2, _MM_SHUFFLE(3, 2, 3, 2)); + __m512 S2 = _mm512_shuffle_ps(T1, T3, _MM_SHUFFLE(1, 0, 1, 0)); + __m512 S3 = _mm512_shuffle_ps(T1, T3, _MM_SHUFFLE(3, 2, 3, 2)); + __m512 S4 = _mm512_shuffle_ps(T4, T6, _MM_SHUFFLE(1, 0, 1, 0)); + __m512 S5 = _mm512_shuffle_ps(T4, T6, _MM_SHUFFLE(3, 2, 3, 2)); + __m512 S6 = _mm512_shuffle_ps(T5, T7, _MM_SHUFFLE(1, 0, 1, 0)); + __m512 S7 = _mm512_shuffle_ps(T5, T7, _MM_SHUFFLE(3, 2, 3, 2)); + __m512 S8 = _mm512_shuffle_ps(T8, T10, _MM_SHUFFLE(1, 0, 1, 0)); + __m512 S9 = _mm512_shuffle_ps(T8, T10, _MM_SHUFFLE(3, 2, 3, 2)); + __m512 S10 = _mm512_shuffle_ps(T9, T11, _MM_SHUFFLE(1, 0, 1, 0)); + __m512 S11 = _mm512_shuffle_ps(T9, T11, _MM_SHUFFLE(3, 2, 3, 2)); + __m512 S12 = _mm512_shuffle_ps(T12, T14, _MM_SHUFFLE(1, 0, 1, 0)); + __m512 S13 = _mm512_shuffle_ps(T12, T14, _MM_SHUFFLE(3, 2, 3, 2)); + __m512 S14 = _mm512_shuffle_ps(T13, T15, _MM_SHUFFLE(1, 0, 1, 0)); + __m512 S15 = _mm512_shuffle_ps(T13, T15, _MM_SHUFFLE(3, 2, 3, 2)); + + EIGEN_EXTRACT_8f_FROM_16f(S0, S0); + EIGEN_EXTRACT_8f_FROM_16f(S1, S1); + EIGEN_EXTRACT_8f_FROM_16f(S2, S2); + EIGEN_EXTRACT_8f_FROM_16f(S3, S3); + EIGEN_EXTRACT_8f_FROM_16f(S4, S4); + EIGEN_EXTRACT_8f_FROM_16f(S5, S5); + EIGEN_EXTRACT_8f_FROM_16f(S6, S6); + EIGEN_EXTRACT_8f_FROM_16f(S7, S7); + EIGEN_EXTRACT_8f_FROM_16f(S8, S8); + EIGEN_EXTRACT_8f_FROM_16f(S9, S9); + EIGEN_EXTRACT_8f_FROM_16f(S10, S10); + EIGEN_EXTRACT_8f_FROM_16f(S11, S11); + EIGEN_EXTRACT_8f_FROM_16f(S12, S12); + EIGEN_EXTRACT_8f_FROM_16f(S13, S13); + EIGEN_EXTRACT_8f_FROM_16f(S14, S14); + EIGEN_EXTRACT_8f_FROM_16f(S15, S15); + + PacketBlock tmp; + + tmp.packet[0] = _mm256_permute2f128_ps(S0_0, S4_0, 0x20); + tmp.packet[1] = _mm256_permute2f128_ps(S1_0, S5_0, 0x20); + tmp.packet[2] = _mm256_permute2f128_ps(S2_0, S6_0, 0x20); + tmp.packet[3] = _mm256_permute2f128_ps(S3_0, S7_0, 0x20); + tmp.packet[4] = _mm256_permute2f128_ps(S0_0, S4_0, 0x31); + tmp.packet[5] = _mm256_permute2f128_ps(S1_0, S5_0, 0x31); + tmp.packet[6] = _mm256_permute2f128_ps(S2_0, S6_0, 0x31); + tmp.packet[7] = _mm256_permute2f128_ps(S3_0, S7_0, 0x31); + + tmp.packet[8] = _mm256_permute2f128_ps(S0_1, S4_1, 0x20); + tmp.packet[9] = _mm256_permute2f128_ps(S1_1, S5_1, 0x20); + tmp.packet[10] = _mm256_permute2f128_ps(S2_1, S6_1, 0x20); + tmp.packet[11] = _mm256_permute2f128_ps(S3_1, S7_1, 0x20); + tmp.packet[12] = _mm256_permute2f128_ps(S0_1, S4_1, 0x31); + tmp.packet[13] = _mm256_permute2f128_ps(S1_1, S5_1, 0x31); + tmp.packet[14] = _mm256_permute2f128_ps(S2_1, S6_1, 0x31); + tmp.packet[15] = _mm256_permute2f128_ps(S3_1, S7_1, 0x31); + + // Second set of _m256 outputs + tmp.packet[16] = _mm256_permute2f128_ps(S8_0, S12_0, 0x20); + tmp.packet[17] = _mm256_permute2f128_ps(S9_0, S13_0, 0x20); + tmp.packet[18] = _mm256_permute2f128_ps(S10_0, S14_0, 0x20); + tmp.packet[19] = _mm256_permute2f128_ps(S11_0, S15_0, 0x20); + tmp.packet[20] = _mm256_permute2f128_ps(S8_0, S12_0, 0x31); + tmp.packet[21] = _mm256_permute2f128_ps(S9_0, S13_0, 0x31); + tmp.packet[22] = _mm256_permute2f128_ps(S10_0, S14_0, 0x31); + tmp.packet[23] = _mm256_permute2f128_ps(S11_0, S15_0, 0x31); + + tmp.packet[24] = _mm256_permute2f128_ps(S8_1, S12_1, 0x20); + tmp.packet[25] = _mm256_permute2f128_ps(S9_1, S13_1, 0x20); + tmp.packet[26] = _mm256_permute2f128_ps(S10_1, S14_1, 0x20); + tmp.packet[27] = _mm256_permute2f128_ps(S11_1, S15_1, 0x20); + tmp.packet[28] = _mm256_permute2f128_ps(S8_1, S12_1, 0x31); + tmp.packet[29] = _mm256_permute2f128_ps(S9_1, S13_1, 0x31); + tmp.packet[30] = _mm256_permute2f128_ps(S10_1, S14_1, 0x31); + tmp.packet[31] = _mm256_permute2f128_ps(S11_1, S15_1, 0x31); + + // Pack them into the output + PACK_OUTPUT(kernel.packet, tmp.packet, 0, 16); + PACK_OUTPUT(kernel.packet, tmp.packet, 1, 16); + PACK_OUTPUT(kernel.packet, tmp.packet, 2, 16); + PACK_OUTPUT(kernel.packet, tmp.packet, 3, 16); + + PACK_OUTPUT(kernel.packet, tmp.packet, 4, 16); + PACK_OUTPUT(kernel.packet, tmp.packet, 5, 16); + PACK_OUTPUT(kernel.packet, tmp.packet, 6, 16); + PACK_OUTPUT(kernel.packet, tmp.packet, 7, 16); + + PACK_OUTPUT(kernel.packet, tmp.packet, 8, 16); + PACK_OUTPUT(kernel.packet, tmp.packet, 9, 16); + PACK_OUTPUT(kernel.packet, tmp.packet, 10, 16); + PACK_OUTPUT(kernel.packet, tmp.packet, 11, 16); + + PACK_OUTPUT(kernel.packet, tmp.packet, 12, 16); + PACK_OUTPUT(kernel.packet, tmp.packet, 13, 16); + PACK_OUTPUT(kernel.packet, tmp.packet, 14, 16); + PACK_OUTPUT(kernel.packet, tmp.packet, 15, 16); +} +#define PACK_OUTPUT_2(OUTPUT, INPUT, INDEX, STRIDE) \ + EIGEN_INSERT_8f_INTO_16f(OUTPUT[INDEX], INPUT[2 * INDEX], \ + INPUT[2 * INDEX + STRIDE]); + +EIGEN_DEVICE_FUNC inline void ptranspose(PacketBlock& kernel) { + __m512 T0 = _mm512_unpacklo_ps(kernel.packet[0], kernel.packet[1]); + __m512 T1 = _mm512_unpackhi_ps(kernel.packet[0], kernel.packet[1]); + __m512 T2 = _mm512_unpacklo_ps(kernel.packet[2], kernel.packet[3]); + __m512 T3 = _mm512_unpackhi_ps(kernel.packet[2], kernel.packet[3]); + + __m512 S0 = _mm512_shuffle_ps(T0, T2, _MM_SHUFFLE(1, 0, 1, 0)); + __m512 S1 = _mm512_shuffle_ps(T0, T2, _MM_SHUFFLE(3, 2, 3, 2)); + __m512 S2 = _mm512_shuffle_ps(T1, T3, _MM_SHUFFLE(1, 0, 1, 0)); + __m512 S3 = _mm512_shuffle_ps(T1, T3, _MM_SHUFFLE(3, 2, 3, 2)); + + EIGEN_EXTRACT_8f_FROM_16f(S0, S0); + EIGEN_EXTRACT_8f_FROM_16f(S1, S1); + EIGEN_EXTRACT_8f_FROM_16f(S2, S2); + EIGEN_EXTRACT_8f_FROM_16f(S3, S3); + + PacketBlock tmp; + + tmp.packet[0] = _mm256_permute2f128_ps(S0_0, S1_0, 0x20); + tmp.packet[1] = _mm256_permute2f128_ps(S2_0, S3_0, 0x20); + tmp.packet[2] = _mm256_permute2f128_ps(S0_0, S1_0, 0x31); + tmp.packet[3] = _mm256_permute2f128_ps(S2_0, S3_0, 0x31); + + tmp.packet[4] = _mm256_permute2f128_ps(S0_1, S1_1, 0x20); + tmp.packet[5] = _mm256_permute2f128_ps(S2_1, S3_1, 0x20); + tmp.packet[6] = _mm256_permute2f128_ps(S0_1, S1_1, 0x31); + tmp.packet[7] = _mm256_permute2f128_ps(S2_1, S3_1, 0x31); + + PACK_OUTPUT_2(kernel.packet, tmp.packet, 0, 1); + PACK_OUTPUT_2(kernel.packet, tmp.packet, 1, 1); + PACK_OUTPUT_2(kernel.packet, tmp.packet, 2, 1); + PACK_OUTPUT_2(kernel.packet, tmp.packet, 3, 1); +} + +#define PACK_OUTPUT_SQ_D(OUTPUT, INPUT, INDEX, STRIDE) \ + OUTPUT[INDEX] = _mm512_insertf64x4(OUTPUT[INDEX], INPUT[INDEX], 0); \ + OUTPUT[INDEX] = _mm512_insertf64x4(OUTPUT[INDEX], INPUT[INDEX + STRIDE], 1); + +#define PACK_OUTPUT_D(OUTPUT, INPUT, INDEX, STRIDE) \ + OUTPUT[INDEX] = _mm512_insertf64x4(OUTPUT[INDEX], INPUT[(2 * INDEX)], 0); \ + OUTPUT[INDEX] = \ + _mm512_insertf64x4(OUTPUT[INDEX], INPUT[(2 * INDEX) + STRIDE], 1); + +EIGEN_DEVICE_FUNC inline void ptranspose(PacketBlock& kernel) { + __m512d T0 = _mm512_shuffle_pd(kernel.packet[0], kernel.packet[1], 0); + __m512d T1 = _mm512_shuffle_pd(kernel.packet[0], kernel.packet[1], 0xff); + __m512d T2 = _mm512_shuffle_pd(kernel.packet[2], kernel.packet[3], 0); + __m512d T3 = _mm512_shuffle_pd(kernel.packet[2], kernel.packet[3], 0xff); + + PacketBlock tmp; + + tmp.packet[0] = _mm256_permute2f128_pd(_mm512_extractf64x4_pd(T0, 0), + _mm512_extractf64x4_pd(T2, 0), 0x20); + tmp.packet[1] = _mm256_permute2f128_pd(_mm512_extractf64x4_pd(T1, 0), + _mm512_extractf64x4_pd(T3, 0), 0x20); + tmp.packet[2] = _mm256_permute2f128_pd(_mm512_extractf64x4_pd(T0, 0), + _mm512_extractf64x4_pd(T2, 0), 0x31); + tmp.packet[3] = _mm256_permute2f128_pd(_mm512_extractf64x4_pd(T1, 0), + _mm512_extractf64x4_pd(T3, 0), 0x31); + + tmp.packet[4] = _mm256_permute2f128_pd(_mm512_extractf64x4_pd(T0, 1), + _mm512_extractf64x4_pd(T2, 1), 0x20); + tmp.packet[5] = _mm256_permute2f128_pd(_mm512_extractf64x4_pd(T1, 1), + _mm512_extractf64x4_pd(T3, 1), 0x20); + tmp.packet[6] = _mm256_permute2f128_pd(_mm512_extractf64x4_pd(T0, 1), + _mm512_extractf64x4_pd(T2, 1), 0x31); + tmp.packet[7] = _mm256_permute2f128_pd(_mm512_extractf64x4_pd(T1, 1), + _mm512_extractf64x4_pd(T3, 1), 0x31); + + PACK_OUTPUT_D(kernel.packet, tmp.packet, 0, 1); + PACK_OUTPUT_D(kernel.packet, tmp.packet, 1, 1); + PACK_OUTPUT_D(kernel.packet, tmp.packet, 2, 1); + PACK_OUTPUT_D(kernel.packet, tmp.packet, 3, 1); +} + +EIGEN_DEVICE_FUNC inline void ptranspose(PacketBlock& kernel) { + __m512d T0 = _mm512_unpacklo_pd(kernel.packet[0], kernel.packet[1]); + __m512d T1 = _mm512_unpackhi_pd(kernel.packet[0], kernel.packet[1]); + __m512d T2 = _mm512_unpacklo_pd(kernel.packet[2], kernel.packet[3]); + __m512d T3 = _mm512_unpackhi_pd(kernel.packet[2], kernel.packet[3]); + __m512d T4 = _mm512_unpacklo_pd(kernel.packet[4], kernel.packet[5]); + __m512d T5 = _mm512_unpackhi_pd(kernel.packet[4], kernel.packet[5]); + __m512d T6 = _mm512_unpacklo_pd(kernel.packet[6], kernel.packet[7]); + __m512d T7 = _mm512_unpackhi_pd(kernel.packet[6], kernel.packet[7]); + + PacketBlock tmp; + + tmp.packet[0] = _mm256_permute2f128_pd(_mm512_extractf64x4_pd(T0, 0), + _mm512_extractf64x4_pd(T2, 0), 0x20); + tmp.packet[1] = _mm256_permute2f128_pd(_mm512_extractf64x4_pd(T1, 0), + _mm512_extractf64x4_pd(T3, 0), 0x20); + tmp.packet[2] = _mm256_permute2f128_pd(_mm512_extractf64x4_pd(T0, 0), + _mm512_extractf64x4_pd(T2, 0), 0x31); + tmp.packet[3] = _mm256_permute2f128_pd(_mm512_extractf64x4_pd(T1, 0), + _mm512_extractf64x4_pd(T3, 0), 0x31); + + tmp.packet[4] = _mm256_permute2f128_pd(_mm512_extractf64x4_pd(T0, 1), + _mm512_extractf64x4_pd(T2, 1), 0x20); + tmp.packet[5] = _mm256_permute2f128_pd(_mm512_extractf64x4_pd(T1, 1), + _mm512_extractf64x4_pd(T3, 1), 0x20); + tmp.packet[6] = _mm256_permute2f128_pd(_mm512_extractf64x4_pd(T0, 1), + _mm512_extractf64x4_pd(T2, 1), 0x31); + tmp.packet[7] = _mm256_permute2f128_pd(_mm512_extractf64x4_pd(T1, 1), + _mm512_extractf64x4_pd(T3, 1), 0x31); + + tmp.packet[8] = _mm256_permute2f128_pd(_mm512_extractf64x4_pd(T4, 0), + _mm512_extractf64x4_pd(T6, 0), 0x20); + tmp.packet[9] = _mm256_permute2f128_pd(_mm512_extractf64x4_pd(T5, 0), + _mm512_extractf64x4_pd(T7, 0), 0x20); + tmp.packet[10] = _mm256_permute2f128_pd(_mm512_extractf64x4_pd(T4, 0), + _mm512_extractf64x4_pd(T6, 0), 0x31); + tmp.packet[11] = _mm256_permute2f128_pd(_mm512_extractf64x4_pd(T5, 0), + _mm512_extractf64x4_pd(T7, 0), 0x31); + + tmp.packet[12] = _mm256_permute2f128_pd(_mm512_extractf64x4_pd(T4, 1), + _mm512_extractf64x4_pd(T6, 1), 0x20); + tmp.packet[13] = _mm256_permute2f128_pd(_mm512_extractf64x4_pd(T5, 1), + _mm512_extractf64x4_pd(T7, 1), 0x20); + tmp.packet[14] = _mm256_permute2f128_pd(_mm512_extractf64x4_pd(T4, 1), + _mm512_extractf64x4_pd(T6, 1), 0x31); + tmp.packet[15] = _mm256_permute2f128_pd(_mm512_extractf64x4_pd(T5, 1), + _mm512_extractf64x4_pd(T7, 1), 0x31); + + PACK_OUTPUT_SQ_D(kernel.packet, tmp.packet, 0, 8); + PACK_OUTPUT_SQ_D(kernel.packet, tmp.packet, 1, 8); + PACK_OUTPUT_SQ_D(kernel.packet, tmp.packet, 2, 8); + PACK_OUTPUT_SQ_D(kernel.packet, tmp.packet, 3, 8); + + PACK_OUTPUT_SQ_D(kernel.packet, tmp.packet, 4, 8); + PACK_OUTPUT_SQ_D(kernel.packet, tmp.packet, 5, 8); + PACK_OUTPUT_SQ_D(kernel.packet, tmp.packet, 6, 8); + PACK_OUTPUT_SQ_D(kernel.packet, tmp.packet, 7, 8); +} +template <> +EIGEN_STRONG_INLINE Packet16f pblend(const Selector<16>& /*ifPacket*/, + const Packet16f& /*thenPacket*/, + const Packet16f& /*elsePacket*/) { + assert(false && "To be implemented"); + return Packet16f(); +} +template <> +EIGEN_STRONG_INLINE Packet8d pblend(const Selector<8>& ifPacket, + const Packet8d& thenPacket, + const Packet8d& elsePacket) { + __mmask8 m = (ifPacket.select[0] ) + | (ifPacket.select[1]<<1) + | (ifPacket.select[2]<<2) + | (ifPacket.select[3]<<3) + | (ifPacket.select[4]<<4) + | (ifPacket.select[5]<<5) + | (ifPacket.select[6]<<6) + | (ifPacket.select[7]<<7); + return _mm512_mask_blend_pd(m, elsePacket, thenPacket); +} + +// Packet math for Eigen::half +template<> EIGEN_STRONG_INLINE Packet16h pset1(const Eigen::half& from) { + return _mm256_set1_epi16(from.x); +} + +template<> EIGEN_STRONG_INLINE Eigen::half pfirst(const Packet16h& from) { + return half_impl::raw_uint16_to_half(static_cast(_mm256_extract_epi16(from, 0))); +} + +template<> EIGEN_STRONG_INLINE Packet16h pload(const Eigen::half* from) { + return _mm256_load_si256(reinterpret_cast(from)); +} + +template<> EIGEN_STRONG_INLINE Packet16h ploadu(const Eigen::half* from) { + return _mm256_loadu_si256(reinterpret_cast(from)); +} + +template<> EIGEN_STRONG_INLINE void pstore(Eigen::half* to, const Packet16h& from) { + // (void*) -> workaround clang warning: + // cast from 'Eigen::half *' to '__m256i *' increases required alignment from 2 to 32 + _mm256_store_si256((__m256i*)(void*)to, from); +} + +template<> EIGEN_STRONG_INLINE void pstoreu(Eigen::half* to, const Packet16h& from) { + // (void*) -> workaround clang warning: + // cast from 'Eigen::half *' to '__m256i *' increases required alignment from 2 to 32 + _mm256_storeu_si256((__m256i*)(void*)to, from); +} + +template<> EIGEN_STRONG_INLINE Packet16h +ploaddup(const Eigen::half* from) { + unsigned short a = from[0].x; + unsigned short b = from[1].x; + unsigned short c = from[2].x; + unsigned short d = from[3].x; + unsigned short e = from[4].x; + unsigned short f = from[5].x; + unsigned short g = from[6].x; + unsigned short h = from[7].x; + return _mm256_set_epi16(h, h, g, g, f, f, e, e, d, d, c, c, b, b, a, a); +} + +template<> EIGEN_STRONG_INLINE Packet16h +ploadquad(const Eigen::half* from) { + unsigned short a = from[0].x; + unsigned short b = from[1].x; + unsigned short c = from[2].x; + unsigned short d = from[3].x; + return _mm256_set_epi16(d, d, d, d, c, c, c, c, b, b, b, b, a, a, a, a); +} + +EIGEN_STRONG_INLINE Packet16f half2float(const Packet16h& a) { +#ifdef EIGEN_HAS_FP16_C + return _mm512_cvtph_ps(a); +#else + EIGEN_ALIGN64 half aux[16]; + pstore(aux, a); + float f0(aux[0]); + float f1(aux[1]); + float f2(aux[2]); + float f3(aux[3]); + float f4(aux[4]); + float f5(aux[5]); + float f6(aux[6]); + float f7(aux[7]); + float f8(aux[8]); + float f9(aux[9]); + float fa(aux[10]); + float fb(aux[11]); + float fc(aux[12]); + float fd(aux[13]); + float fe(aux[14]); + float ff(aux[15]); + + return _mm512_set_ps( + ff, fe, fd, fc, fb, fa, f9, f8, f7, f6, f5, f4, f3, f2, f1, f0); +#endif +} + +EIGEN_STRONG_INLINE Packet16h float2half(const Packet16f& a) { +#ifdef EIGEN_HAS_FP16_C + return _mm512_cvtps_ph(a, _MM_FROUND_TO_NEAREST_INT|_MM_FROUND_NO_EXC); +#else + EIGEN_ALIGN64 float aux[16]; + pstore(aux, a); + half h0(aux[0]); + half h1(aux[1]); + half h2(aux[2]); + half h3(aux[3]); + half h4(aux[4]); + half h5(aux[5]); + half h6(aux[6]); + half h7(aux[7]); + half h8(aux[8]); + half h9(aux[9]); + half ha(aux[10]); + half hb(aux[11]); + half hc(aux[12]); + half hd(aux[13]); + half he(aux[14]); + half hf(aux[15]); + + return _mm256_set_epi16( + hf.x, he.x, hd.x, hc.x, hb.x, ha.x, h9.x, h8.x, + h7.x, h6.x, h5.x, h4.x, h3.x, h2.x, h1.x, h0.x); +#endif +} + +template<> EIGEN_STRONG_INLINE Packet16h ptrue(const Packet16h& a) { + return ptrue(Packet8i(a)); +} + +template <> +EIGEN_STRONG_INLINE Packet16h pabs(const Packet16h& a) { + const __m256i sign_mask = _mm256_set1_epi16(static_cast(0x8000)); + return _mm256_andnot_si256(sign_mask, a); +} + +template <> +EIGEN_STRONG_INLINE Packet16h pmin(const Packet16h& a, + const Packet16h& b) { + return float2half(pmin(half2float(a), half2float(b))); +} + +template <> +EIGEN_STRONG_INLINE Packet16h pmax(const Packet16h& a, + const Packet16h& b) { + return float2half(pmax(half2float(a), half2float(b))); +} + +template <> +EIGEN_STRONG_INLINE Packet16h plset(const half& a) { + return float2half(plset(static_cast(a))); +} + +template<> EIGEN_STRONG_INLINE Packet16h por(const Packet16h& a,const Packet16h& b) { + // in some cases Packet8i is a wrapper around __m256i, so we need to + // cast to Packet8i to call the correct overload. + return por(Packet8i(a),Packet8i(b)); +} +template<> EIGEN_STRONG_INLINE Packet16h pxor(const Packet16h& a,const Packet16h& b) { + return pxor(Packet8i(a),Packet8i(b)); +} +template<> EIGEN_STRONG_INLINE Packet16h pand(const Packet16h& a,const Packet16h& b) { + return pand(Packet8i(a),Packet8i(b)); +} +template<> EIGEN_STRONG_INLINE Packet16h pandnot(const Packet16h& a,const Packet16h& b) { + return pandnot(Packet8i(a),Packet8i(b)); +} + +template<> EIGEN_STRONG_INLINE Packet16h pselect(const Packet16h& mask, const Packet16h& a, const Packet16h& b) { + return _mm256_blendv_epi8(b, a, mask); +} + +template<> EIGEN_STRONG_INLINE Packet16h pround(const Packet16h& a) { + return float2half(pround(half2float(a))); +} + +template<> EIGEN_STRONG_INLINE Packet16h print(const Packet16h& a) { + return float2half(print(half2float(a))); +} + +template<> EIGEN_STRONG_INLINE Packet16h pceil(const Packet16h& a) { + return float2half(pceil(half2float(a))); +} + +template<> EIGEN_STRONG_INLINE Packet16h pfloor(const Packet16h& a) { + return float2half(pfloor(half2float(a))); +} + +template<> EIGEN_STRONG_INLINE Packet16h pcmp_eq(const Packet16h& a,const Packet16h& b) { + Packet16f af = half2float(a); + Packet16f bf = half2float(b); + return Pack32To16(pcmp_eq(af, bf)); +} + +template<> EIGEN_STRONG_INLINE Packet16h pcmp_le(const Packet16h& a,const Packet16h& b) { + return Pack32To16(pcmp_le(half2float(a), half2float(b))); +} + +template<> EIGEN_STRONG_INLINE Packet16h pcmp_lt(const Packet16h& a,const Packet16h& b) { + return Pack32To16(pcmp_lt(half2float(a), half2float(b))); +} + +template<> EIGEN_STRONG_INLINE Packet16h pcmp_lt_or_nan(const Packet16h& a,const Packet16h& b) { + return Pack32To16(pcmp_lt_or_nan(half2float(a), half2float(b))); +} + +template<> EIGEN_STRONG_INLINE Packet16h pconj(const Packet16h& a) { return a; } + +template<> EIGEN_STRONG_INLINE Packet16h pnegate(const Packet16h& a) { + Packet16h sign_mask = _mm256_set1_epi16(static_cast(0x8000)); + return _mm256_xor_si256(a, sign_mask); +} + +template<> EIGEN_STRONG_INLINE Packet16h padd(const Packet16h& a, const Packet16h& b) { + Packet16f af = half2float(a); + Packet16f bf = half2float(b); + Packet16f rf = padd(af, bf); + return float2half(rf); +} + +template<> EIGEN_STRONG_INLINE Packet16h psub(const Packet16h& a, const Packet16h& b) { + Packet16f af = half2float(a); + Packet16f bf = half2float(b); + Packet16f rf = psub(af, bf); + return float2half(rf); +} + +template<> EIGEN_STRONG_INLINE Packet16h pmul(const Packet16h& a, const Packet16h& b) { + Packet16f af = half2float(a); + Packet16f bf = half2float(b); + Packet16f rf = pmul(af, bf); + return float2half(rf); +} + +template<> EIGEN_STRONG_INLINE Packet16h pdiv(const Packet16h& a, const Packet16h& b) { + Packet16f af = half2float(a); + Packet16f bf = half2float(b); + Packet16f rf = pdiv(af, bf); + return float2half(rf); +} + +template<> EIGEN_STRONG_INLINE half predux(const Packet16h& from) { + Packet16f from_float = half2float(from); + return half(predux(from_float)); +} + +template <> +EIGEN_STRONG_INLINE Packet8h predux_half_dowto4(const Packet16h& a) { + Packet8h lane0 = _mm256_extractf128_si256(a, 0); + Packet8h lane1 = _mm256_extractf128_si256(a, 1); + return padd(lane0, lane1); +} + +template<> EIGEN_STRONG_INLINE Eigen::half predux_max(const Packet16h& a) { + Packet16f af = half2float(a); + float reduced = predux_max(af); + return Eigen::half(reduced); +} + +template<> EIGEN_STRONG_INLINE Eigen::half predux_min(const Packet16h& a) { + Packet16f af = half2float(a); + float reduced = predux_min(af); + return Eigen::half(reduced); +} + +template<> EIGEN_STRONG_INLINE half predux_mul(const Packet16h& from) { + Packet16f from_float = half2float(from); + return half(predux_mul(from_float)); +} + +template<> EIGEN_STRONG_INLINE Packet16h preverse(const Packet16h& a) +{ + __m128i m = _mm_setr_epi8(14,15,12,13,10,11,8,9,6,7,4,5,2,3,0,1); + return _mm256_insertf128_si256( + _mm256_castsi128_si256(_mm_shuffle_epi8(_mm256_extractf128_si256(a,1),m)), + _mm_shuffle_epi8(_mm256_extractf128_si256(a,0),m), 1); +} + +template<> EIGEN_STRONG_INLINE Packet16h pgather(const Eigen::half* from, Index stride) +{ + return _mm256_set_epi16( + from[15*stride].x, from[14*stride].x, from[13*stride].x, from[12*stride].x, + from[11*stride].x, from[10*stride].x, from[9*stride].x, from[8*stride].x, + from[7*stride].x, from[6*stride].x, from[5*stride].x, from[4*stride].x, + from[3*stride].x, from[2*stride].x, from[1*stride].x, from[0*stride].x); +} + +template<> EIGEN_STRONG_INLINE void pscatter(half* to, const Packet16h& from, Index stride) +{ + EIGEN_ALIGN64 half aux[16]; + pstore(aux, from); + to[stride*0] = aux[0]; + to[stride*1] = aux[1]; + to[stride*2] = aux[2]; + to[stride*3] = aux[3]; + to[stride*4] = aux[4]; + to[stride*5] = aux[5]; + to[stride*6] = aux[6]; + to[stride*7] = aux[7]; + to[stride*8] = aux[8]; + to[stride*9] = aux[9]; + to[stride*10] = aux[10]; + to[stride*11] = aux[11]; + to[stride*12] = aux[12]; + to[stride*13] = aux[13]; + to[stride*14] = aux[14]; + to[stride*15] = aux[15]; +} + +EIGEN_STRONG_INLINE void +ptranspose(PacketBlock& kernel) { + __m256i a = kernel.packet[0]; + __m256i b = kernel.packet[1]; + __m256i c = kernel.packet[2]; + __m256i d = kernel.packet[3]; + __m256i e = kernel.packet[4]; + __m256i f = kernel.packet[5]; + __m256i g = kernel.packet[6]; + __m256i h = kernel.packet[7]; + __m256i i = kernel.packet[8]; + __m256i j = kernel.packet[9]; + __m256i k = kernel.packet[10]; + __m256i l = kernel.packet[11]; + __m256i m = kernel.packet[12]; + __m256i n = kernel.packet[13]; + __m256i o = kernel.packet[14]; + __m256i p = kernel.packet[15]; + + __m256i ab_07 = _mm256_unpacklo_epi16(a, b); + __m256i cd_07 = _mm256_unpacklo_epi16(c, d); + __m256i ef_07 = _mm256_unpacklo_epi16(e, f); + __m256i gh_07 = _mm256_unpacklo_epi16(g, h); + __m256i ij_07 = _mm256_unpacklo_epi16(i, j); + __m256i kl_07 = _mm256_unpacklo_epi16(k, l); + __m256i mn_07 = _mm256_unpacklo_epi16(m, n); + __m256i op_07 = _mm256_unpacklo_epi16(o, p); + + __m256i ab_8f = _mm256_unpackhi_epi16(a, b); + __m256i cd_8f = _mm256_unpackhi_epi16(c, d); + __m256i ef_8f = _mm256_unpackhi_epi16(e, f); + __m256i gh_8f = _mm256_unpackhi_epi16(g, h); + __m256i ij_8f = _mm256_unpackhi_epi16(i, j); + __m256i kl_8f = _mm256_unpackhi_epi16(k, l); + __m256i mn_8f = _mm256_unpackhi_epi16(m, n); + __m256i op_8f = _mm256_unpackhi_epi16(o, p); + + __m256i abcd_03 = _mm256_unpacklo_epi32(ab_07, cd_07); + __m256i abcd_47 = _mm256_unpackhi_epi32(ab_07, cd_07); + __m256i efgh_03 = _mm256_unpacklo_epi32(ef_07, gh_07); + __m256i efgh_47 = _mm256_unpackhi_epi32(ef_07, gh_07); + __m256i ijkl_03 = _mm256_unpacklo_epi32(ij_07, kl_07); + __m256i ijkl_47 = _mm256_unpackhi_epi32(ij_07, kl_07); + __m256i mnop_03 = _mm256_unpacklo_epi32(mn_07, op_07); + __m256i mnop_47 = _mm256_unpackhi_epi32(mn_07, op_07); + + __m256i abcd_8b = _mm256_unpacklo_epi32(ab_8f, cd_8f); + __m256i abcd_cf = _mm256_unpackhi_epi32(ab_8f, cd_8f); + __m256i efgh_8b = _mm256_unpacklo_epi32(ef_8f, gh_8f); + __m256i efgh_cf = _mm256_unpackhi_epi32(ef_8f, gh_8f); + __m256i ijkl_8b = _mm256_unpacklo_epi32(ij_8f, kl_8f); + __m256i ijkl_cf = _mm256_unpackhi_epi32(ij_8f, kl_8f); + __m256i mnop_8b = _mm256_unpacklo_epi32(mn_8f, op_8f); + __m256i mnop_cf = _mm256_unpackhi_epi32(mn_8f, op_8f); + + __m256i abcdefgh_01 = _mm256_unpacklo_epi64(abcd_03, efgh_03); + __m256i abcdefgh_23 = _mm256_unpackhi_epi64(abcd_03, efgh_03); + __m256i ijklmnop_01 = _mm256_unpacklo_epi64(ijkl_03, mnop_03); + __m256i ijklmnop_23 = _mm256_unpackhi_epi64(ijkl_03, mnop_03); + __m256i abcdefgh_45 = _mm256_unpacklo_epi64(abcd_47, efgh_47); + __m256i abcdefgh_67 = _mm256_unpackhi_epi64(abcd_47, efgh_47); + __m256i ijklmnop_45 = _mm256_unpacklo_epi64(ijkl_47, mnop_47); + __m256i ijklmnop_67 = _mm256_unpackhi_epi64(ijkl_47, mnop_47); + __m256i abcdefgh_89 = _mm256_unpacklo_epi64(abcd_8b, efgh_8b); + __m256i abcdefgh_ab = _mm256_unpackhi_epi64(abcd_8b, efgh_8b); + __m256i ijklmnop_89 = _mm256_unpacklo_epi64(ijkl_8b, mnop_8b); + __m256i ijklmnop_ab = _mm256_unpackhi_epi64(ijkl_8b, mnop_8b); + __m256i abcdefgh_cd = _mm256_unpacklo_epi64(abcd_cf, efgh_cf); + __m256i abcdefgh_ef = _mm256_unpackhi_epi64(abcd_cf, efgh_cf); + __m256i ijklmnop_cd = _mm256_unpacklo_epi64(ijkl_cf, mnop_cf); + __m256i ijklmnop_ef = _mm256_unpackhi_epi64(ijkl_cf, mnop_cf); + + // NOTE: no unpacklo/hi instr in this case, so using permute instr. + __m256i a_p_0 = _mm256_permute2x128_si256(abcdefgh_01, ijklmnop_01, 0x20); + __m256i a_p_1 = _mm256_permute2x128_si256(abcdefgh_23, ijklmnop_23, 0x20); + __m256i a_p_2 = _mm256_permute2x128_si256(abcdefgh_45, ijklmnop_45, 0x20); + __m256i a_p_3 = _mm256_permute2x128_si256(abcdefgh_67, ijklmnop_67, 0x20); + __m256i a_p_4 = _mm256_permute2x128_si256(abcdefgh_89, ijklmnop_89, 0x20); + __m256i a_p_5 = _mm256_permute2x128_si256(abcdefgh_ab, ijklmnop_ab, 0x20); + __m256i a_p_6 = _mm256_permute2x128_si256(abcdefgh_cd, ijklmnop_cd, 0x20); + __m256i a_p_7 = _mm256_permute2x128_si256(abcdefgh_ef, ijklmnop_ef, 0x20); + __m256i a_p_8 = _mm256_permute2x128_si256(abcdefgh_01, ijklmnop_01, 0x31); + __m256i a_p_9 = _mm256_permute2x128_si256(abcdefgh_23, ijklmnop_23, 0x31); + __m256i a_p_a = _mm256_permute2x128_si256(abcdefgh_45, ijklmnop_45, 0x31); + __m256i a_p_b = _mm256_permute2x128_si256(abcdefgh_67, ijklmnop_67, 0x31); + __m256i a_p_c = _mm256_permute2x128_si256(abcdefgh_89, ijklmnop_89, 0x31); + __m256i a_p_d = _mm256_permute2x128_si256(abcdefgh_ab, ijklmnop_ab, 0x31); + __m256i a_p_e = _mm256_permute2x128_si256(abcdefgh_cd, ijklmnop_cd, 0x31); + __m256i a_p_f = _mm256_permute2x128_si256(abcdefgh_ef, ijklmnop_ef, 0x31); + + kernel.packet[0] = a_p_0; + kernel.packet[1] = a_p_1; + kernel.packet[2] = a_p_2; + kernel.packet[3] = a_p_3; + kernel.packet[4] = a_p_4; + kernel.packet[5] = a_p_5; + kernel.packet[6] = a_p_6; + kernel.packet[7] = a_p_7; + kernel.packet[8] = a_p_8; + kernel.packet[9] = a_p_9; + kernel.packet[10] = a_p_a; + kernel.packet[11] = a_p_b; + kernel.packet[12] = a_p_c; + kernel.packet[13] = a_p_d; + kernel.packet[14] = a_p_e; + kernel.packet[15] = a_p_f; +} + +EIGEN_STRONG_INLINE void +ptranspose(PacketBlock& kernel) { + EIGEN_ALIGN64 half in[8][16]; + pstore(in[0], kernel.packet[0]); + pstore(in[1], kernel.packet[1]); + pstore(in[2], kernel.packet[2]); + pstore(in[3], kernel.packet[3]); + pstore(in[4], kernel.packet[4]); + pstore(in[5], kernel.packet[5]); + pstore(in[6], kernel.packet[6]); + pstore(in[7], kernel.packet[7]); + + EIGEN_ALIGN64 half out[8][16]; + + for (int i = 0; i < 8; ++i) { + for (int j = 0; j < 8; ++j) { + out[i][j] = in[j][2*i]; + } + for (int j = 0; j < 8; ++j) { + out[i][j+8] = in[j][2*i+1]; + } + } + + kernel.packet[0] = pload(out[0]); + kernel.packet[1] = pload(out[1]); + kernel.packet[2] = pload(out[2]); + kernel.packet[3] = pload(out[3]); + kernel.packet[4] = pload(out[4]); + kernel.packet[5] = pload(out[5]); + kernel.packet[6] = pload(out[6]); + kernel.packet[7] = pload(out[7]); +} + +EIGEN_STRONG_INLINE void +ptranspose(PacketBlock& kernel) { + EIGEN_ALIGN64 half in[4][16]; + pstore(in[0], kernel.packet[0]); + pstore(in[1], kernel.packet[1]); + pstore(in[2], kernel.packet[2]); + pstore(in[3], kernel.packet[3]); + + EIGEN_ALIGN64 half out[4][16]; + + for (int i = 0; i < 4; ++i) { + for (int j = 0; j < 4; ++j) { + out[i][j] = in[j][4*i]; + } + for (int j = 0; j < 4; ++j) { + out[i][j+4] = in[j][4*i+1]; + } + for (int j = 0; j < 4; ++j) { + out[i][j+8] = in[j][4*i+2]; + } + for (int j = 0; j < 4; ++j) { + out[i][j+12] = in[j][4*i+3]; + } + } + + kernel.packet[0] = pload(out[0]); + kernel.packet[1] = pload(out[1]); + kernel.packet[2] = pload(out[2]); + kernel.packet[3] = pload(out[3]); +} + +template <> struct is_arithmetic { enum { value = true }; }; + +template <> +struct packet_traits : default_packet_traits { + typedef Packet16bf type; + typedef Packet8bf half; + enum { + Vectorizable = 1, + AlignedOnScalar = 1, + size = 16, + HasHalfPacket = 1, + HasBlend = 0, + HasInsert = 1, + HasSin = EIGEN_FAST_MATH, + HasCos = EIGEN_FAST_MATH, +#if EIGEN_GNUC_AT_LEAST(5, 3) || (!EIGEN_COMP_GNUC_STRICT) +#ifdef EIGEN_VECTORIZE_AVX512DQ + HasLog = 1, // Currently fails test with bad accuracy. + HasLog1p = 1, + HasExpm1 = 1, + HasNdtri = 1, + HasBessel = 1, +#endif + HasExp = 1, + HasSqrt = EIGEN_FAST_MATH, + HasRsqrt = EIGEN_FAST_MATH, + HasTanh = EIGEN_FAST_MATH, + HasErf = EIGEN_FAST_MATH, +#endif + HasCmp = 1, + HasDiv = 1 + }; +}; + +template <> +struct unpacket_traits +{ + typedef bfloat16 type; + enum {size=16, alignment=Aligned32, vectorizable=true, masked_load_available=false, masked_store_available=false}; + typedef Packet8bf half; +}; + +template <> +EIGEN_STRONG_INLINE Packet16bf pset1(const bfloat16& from) { + return _mm256_set1_epi16(from.value); +} + +template <> +EIGEN_STRONG_INLINE bfloat16 pfirst(const Packet16bf& from) { + bfloat16 t; + t.value = static_cast(_mm256_extract_epi16(from, 0)); + return t; +} + +template <> +EIGEN_STRONG_INLINE Packet16bf pload(const bfloat16* from) { + return _mm256_load_si256(reinterpret_cast(from)); +} + +template <> +EIGEN_STRONG_INLINE Packet16bf ploadu(const bfloat16* from) { + return _mm256_loadu_si256(reinterpret_cast(from)); +} + +template <> +EIGEN_STRONG_INLINE void pstore(bfloat16* to, + const Packet16bf& from) { + _mm256_store_si256(reinterpret_cast<__m256i*>(to), from); +} + +template <> +EIGEN_STRONG_INLINE void pstoreu(bfloat16* to, + const Packet16bf& from) { + _mm256_storeu_si256(reinterpret_cast<__m256i*>(to), from); +} + +template<> EIGEN_STRONG_INLINE Packet16bf +ploaddup(const bfloat16* from) { + Packet16bf r; + unsigned short a = from[0].value; + unsigned short b = from[1].value; + unsigned short c = from[2].value; + unsigned short d = from[3].value; + unsigned short e = from[4].value; + unsigned short f = from[5].value; + unsigned short g = from[6].value; + unsigned short h = from[7].value; + return _mm256_set_epi16(h, h, g, g, f, f, e, e, d, d, c, c, b, b, a, a); +} + +template<> EIGEN_STRONG_INLINE Packet16bf +ploadquad(const bfloat16* from) { + Packet16bf r; + unsigned short a = from[0].value; + unsigned short b = from[1].value; + unsigned short c = from[2].value; + unsigned short d = from[3].value; + return _mm256_set_epi16(d, d, d, d, c, c, c, c, b, b, b, b, a, a, a, a); +} + +EIGEN_STRONG_INLINE Packet16f Bf16ToF32(const Packet16bf& a) { + return _mm512_castsi512_ps(_mm512_slli_epi32(_mm512_cvtepu16_epi32(a), 16)); +} + +// Convert float to bfloat16 according to round-to-nearest-even/denormals algorithm. +EIGEN_STRONG_INLINE Packet16bf F32ToBf16(const Packet16f& a) { + Packet16bf r; + +#if defined(EIGEN_VECTORIZE_AVX512BF16) && EIGEN_GNUC_AT_LEAST(10, 1) + // Since GCC 10.1 supports avx512bf16 and C style explicit cast + // (C++ static_cast is not supported yet), do converion via intrinsic + // and register path for performance. + r = (__m256i)(_mm512_cvtneps_pbh(a)); + +#else + __m512i t; + __m512i input = _mm512_castps_si512(a); + __m512i nan = _mm512_set1_epi32(0x7fc0); + + // uint32_t lsb = (input >> 16) & 1; + t = _mm512_and_si512(_mm512_srli_epi32(input, 16), _mm512_set1_epi32(1)); + // uint32_t rounding_bias = 0x7fff + lsb; + t = _mm512_add_epi32(t, _mm512_set1_epi32(0x7fff)); + // input += rounding_bias; + t = _mm512_add_epi32(t, input); + // input = input >> 16; + t = _mm512_srli_epi32(t, 16); + + // Check NaN before converting back to bf16 + __mmask16 mask = _mm512_cmp_ps_mask(a, a, _CMP_ORD_Q); + + t = _mm512_mask_blend_epi32(mask, nan, t); + // output.value = static_cast(input); + r = _mm512_cvtepi32_epi16(t); +#endif // EIGEN_VECTORIZE_AVX512BF16 + + return r; +} + +template <> +EIGEN_STRONG_INLINE Packet16bf ptrue(const Packet16bf& a) { + return ptrue(a); +} + +template <> +EIGEN_STRONG_INLINE Packet16bf por(const Packet16bf& a, const Packet16bf& b) { + return por(a, b); +} + +template <> +EIGEN_STRONG_INLINE Packet16bf pxor(const Packet16bf& a, const Packet16bf& b) { + return pxor(a, b); +} + +template <> +EIGEN_STRONG_INLINE Packet16bf pand(const Packet16bf& a, const Packet16bf& b) { + return pand(a, b); +} + +template <> +EIGEN_STRONG_INLINE Packet16bf pandnot(const Packet16bf& a, + const Packet16bf& b) { + return pandnot(a, b); +} + +template <> +EIGEN_STRONG_INLINE Packet16bf pselect(const Packet16bf& mask, + const Packet16bf& a, + const Packet16bf& b) { + // Input mask is expected to be all 0/1, handle it with 8-bit + // intrinsic for performance. + return _mm256_blendv_epi8(b, a, mask); +} + +template<> EIGEN_STRONG_INLINE Packet16bf pround(const Packet16bf& a) +{ + return F32ToBf16(pround(Bf16ToF32(a))); +} + +template<> EIGEN_STRONG_INLINE Packet16bf print(const Packet16bf& a) { + return F32ToBf16(print(Bf16ToF32(a))); +} + +template<> EIGEN_STRONG_INLINE Packet16bf pceil(const Packet16bf& a) { + return F32ToBf16(pceil(Bf16ToF32(a))); +} + +template<> EIGEN_STRONG_INLINE Packet16bf pfloor(const Packet16bf& a) { + return F32ToBf16(pfloor(Bf16ToF32(a))); +} + +template <> +EIGEN_STRONG_INLINE Packet16bf pcmp_eq(const Packet16bf& a, + const Packet16bf& b) { + return Pack32To16(pcmp_eq(Bf16ToF32(a), Bf16ToF32(b))); +} + +template <> +EIGEN_STRONG_INLINE Packet16bf pcmp_le(const Packet16bf& a, + const Packet16bf& b) { + return Pack32To16(pcmp_le(Bf16ToF32(a), Bf16ToF32(b))); +} + +template <> +EIGEN_STRONG_INLINE Packet16bf pcmp_lt(const Packet16bf& a, + const Packet16bf& b) { + return Pack32To16(pcmp_lt(Bf16ToF32(a), Bf16ToF32(b))); +} + +template <> +EIGEN_STRONG_INLINE Packet16bf pcmp_lt_or_nan(const Packet16bf& a, + const Packet16bf& b) { + return Pack32To16(pcmp_lt_or_nan(Bf16ToF32(a), Bf16ToF32(b))); +} + +template <> +EIGEN_STRONG_INLINE Packet16bf pnegate(const Packet16bf& a) { + Packet16bf sign_mask = _mm256_set1_epi16(static_cast(0x8000)); + return _mm256_xor_si256(a, sign_mask); +} + +template <> +EIGEN_STRONG_INLINE Packet16bf pconj(const Packet16bf& a) { + return a; +} + +template <> +EIGEN_STRONG_INLINE Packet16bf pabs(const Packet16bf& a) { + const __m256i sign_mask = _mm256_set1_epi16(static_cast(0x8000)); + return _mm256_andnot_si256(sign_mask, a); +} + +template <> +EIGEN_STRONG_INLINE Packet16bf padd(const Packet16bf& a, + const Packet16bf& b) { + return F32ToBf16(padd(Bf16ToF32(a), Bf16ToF32(b))); +} + +template <> +EIGEN_STRONG_INLINE Packet16bf psub(const Packet16bf& a, + const Packet16bf& b) { + return F32ToBf16(psub(Bf16ToF32(a), Bf16ToF32(b))); +} + +template <> +EIGEN_STRONG_INLINE Packet16bf pmul(const Packet16bf& a, + const Packet16bf& b) { + return F32ToBf16(pmul(Bf16ToF32(a), Bf16ToF32(b))); +} + +template <> +EIGEN_STRONG_INLINE Packet16bf pdiv(const Packet16bf& a, + const Packet16bf& b) { + return F32ToBf16(pdiv(Bf16ToF32(a), Bf16ToF32(b))); +} + +template <> +EIGEN_STRONG_INLINE Packet16bf pmin(const Packet16bf& a, + const Packet16bf& b) { + return F32ToBf16(pmin(Bf16ToF32(a), Bf16ToF32(b))); +} + +template <> +EIGEN_STRONG_INLINE Packet16bf pmax(const Packet16bf& a, + const Packet16bf& b) { + return F32ToBf16(pmax(Bf16ToF32(a), Bf16ToF32(b))); +} + +template <> +EIGEN_STRONG_INLINE Packet16bf plset(const bfloat16& a) { + return F32ToBf16(plset(static_cast(a))); +} + +template <> +EIGEN_STRONG_INLINE Packet8bf predux_half_dowto4(const Packet16bf& a) { + Packet8bf lane0 = _mm256_extractf128_si256(a, 0); + Packet8bf lane1 = _mm256_extractf128_si256(a, 1); + return padd(lane0, lane1); +} + +template <> +EIGEN_STRONG_INLINE bfloat16 predux(const Packet16bf& p) { + return static_cast(predux(Bf16ToF32(p))); +} + +template <> +EIGEN_STRONG_INLINE bfloat16 predux_mul(const Packet16bf& from) { + return static_cast(predux_mul(Bf16ToF32(from))); +} + +template <> +EIGEN_STRONG_INLINE bfloat16 predux_min(const Packet16bf& from) { + return static_cast(predux_min(Bf16ToF32(from))); +} + +template <> +EIGEN_STRONG_INLINE bfloat16 predux_max(const Packet16bf& from) { + return static_cast(predux_max(Bf16ToF32(from))); +} + +template <> +EIGEN_STRONG_INLINE Packet16bf preverse(const Packet16bf& a) { + __m256i m = _mm256_setr_epi8(14,15,12,13,10,11,8,9,6,7,4,5,2,3,0,1, + 14,15,12,13,10,11,8,9,6,7,4,5,2,3,0,1); + + Packet16bf res; + // Swap hi and lo first because shuffle is in 128-bit lanes. + res = _mm256_permute2x128_si256(a, a, 1); + // Shuffle 8-bit values in src within 2*128-bit lanes. + return _mm256_shuffle_epi8(res, m); +} + +template <> +EIGEN_STRONG_INLINE Packet16bf pgather(const bfloat16* from, + Index stride) { + return _mm256_set_epi16( + from[15*stride].value, from[14*stride].value, from[13*stride].value, from[12*stride].value, + from[11*stride].value, from[10*stride].value, from[9*stride].value, from[8*stride].value, + from[7*stride].value, from[6*stride].value, from[5*stride].value, from[4*stride].value, + from[3*stride].value, from[2*stride].value, from[1*stride].value, from[0*stride].value); +} + +template <> +EIGEN_STRONG_INLINE void pscatter(bfloat16* to, + const Packet16bf& from, + Index stride) { + EIGEN_ALIGN64 bfloat16 aux[16]; + pstore(aux, from); + to[stride*0] = aux[0]; + to[stride*1] = aux[1]; + to[stride*2] = aux[2]; + to[stride*3] = aux[3]; + to[stride*4] = aux[4]; + to[stride*5] = aux[5]; + to[stride*6] = aux[6]; + to[stride*7] = aux[7]; + to[stride*8] = aux[8]; + to[stride*9] = aux[9]; + to[stride*10] = aux[10]; + to[stride*11] = aux[11]; + to[stride*12] = aux[12]; + to[stride*13] = aux[13]; + to[stride*14] = aux[14]; + to[stride*15] = aux[15]; +} + +EIGEN_STRONG_INLINE void ptranspose(PacketBlock& kernel) { + __m256i a = kernel.packet[0]; + __m256i b = kernel.packet[1]; + __m256i c = kernel.packet[2]; + __m256i d = kernel.packet[3]; + __m256i e = kernel.packet[4]; + __m256i f = kernel.packet[5]; + __m256i g = kernel.packet[6]; + __m256i h = kernel.packet[7]; + __m256i i = kernel.packet[8]; + __m256i j = kernel.packet[9]; + __m256i k = kernel.packet[10]; + __m256i l = kernel.packet[11]; + __m256i m = kernel.packet[12]; + __m256i n = kernel.packet[13]; + __m256i o = kernel.packet[14]; + __m256i p = kernel.packet[15]; + + __m256i ab_07 = _mm256_unpacklo_epi16(a, b); + __m256i cd_07 = _mm256_unpacklo_epi16(c, d); + __m256i ef_07 = _mm256_unpacklo_epi16(e, f); + __m256i gh_07 = _mm256_unpacklo_epi16(g, h); + __m256i ij_07 = _mm256_unpacklo_epi16(i, j); + __m256i kl_07 = _mm256_unpacklo_epi16(k, l); + __m256i mn_07 = _mm256_unpacklo_epi16(m, n); + __m256i op_07 = _mm256_unpacklo_epi16(o, p); + + __m256i ab_8f = _mm256_unpackhi_epi16(a, b); + __m256i cd_8f = _mm256_unpackhi_epi16(c, d); + __m256i ef_8f = _mm256_unpackhi_epi16(e, f); + __m256i gh_8f = _mm256_unpackhi_epi16(g, h); + __m256i ij_8f = _mm256_unpackhi_epi16(i, j); + __m256i kl_8f = _mm256_unpackhi_epi16(k, l); + __m256i mn_8f = _mm256_unpackhi_epi16(m, n); + __m256i op_8f = _mm256_unpackhi_epi16(o, p); + + __m256i abcd_03 = _mm256_unpacklo_epi32(ab_07, cd_07); + __m256i abcd_47 = _mm256_unpackhi_epi32(ab_07, cd_07); + __m256i efgh_03 = _mm256_unpacklo_epi32(ef_07, gh_07); + __m256i efgh_47 = _mm256_unpackhi_epi32(ef_07, gh_07); + __m256i ijkl_03 = _mm256_unpacklo_epi32(ij_07, kl_07); + __m256i ijkl_47 = _mm256_unpackhi_epi32(ij_07, kl_07); + __m256i mnop_03 = _mm256_unpacklo_epi32(mn_07, op_07); + __m256i mnop_47 = _mm256_unpackhi_epi32(mn_07, op_07); + + __m256i abcd_8b = _mm256_unpacklo_epi32(ab_8f, cd_8f); + __m256i abcd_cf = _mm256_unpackhi_epi32(ab_8f, cd_8f); + __m256i efgh_8b = _mm256_unpacklo_epi32(ef_8f, gh_8f); + __m256i efgh_cf = _mm256_unpackhi_epi32(ef_8f, gh_8f); + __m256i ijkl_8b = _mm256_unpacklo_epi32(ij_8f, kl_8f); + __m256i ijkl_cf = _mm256_unpackhi_epi32(ij_8f, kl_8f); + __m256i mnop_8b = _mm256_unpacklo_epi32(mn_8f, op_8f); + __m256i mnop_cf = _mm256_unpackhi_epi32(mn_8f, op_8f); + + __m256i abcdefgh_01 = _mm256_unpacklo_epi64(abcd_03, efgh_03); + __m256i abcdefgh_23 = _mm256_unpackhi_epi64(abcd_03, efgh_03); + __m256i ijklmnop_01 = _mm256_unpacklo_epi64(ijkl_03, mnop_03); + __m256i ijklmnop_23 = _mm256_unpackhi_epi64(ijkl_03, mnop_03); + __m256i abcdefgh_45 = _mm256_unpacklo_epi64(abcd_47, efgh_47); + __m256i abcdefgh_67 = _mm256_unpackhi_epi64(abcd_47, efgh_47); + __m256i ijklmnop_45 = _mm256_unpacklo_epi64(ijkl_47, mnop_47); + __m256i ijklmnop_67 = _mm256_unpackhi_epi64(ijkl_47, mnop_47); + __m256i abcdefgh_89 = _mm256_unpacklo_epi64(abcd_8b, efgh_8b); + __m256i abcdefgh_ab = _mm256_unpackhi_epi64(abcd_8b, efgh_8b); + __m256i ijklmnop_89 = _mm256_unpacklo_epi64(ijkl_8b, mnop_8b); + __m256i ijklmnop_ab = _mm256_unpackhi_epi64(ijkl_8b, mnop_8b); + __m256i abcdefgh_cd = _mm256_unpacklo_epi64(abcd_cf, efgh_cf); + __m256i abcdefgh_ef = _mm256_unpackhi_epi64(abcd_cf, efgh_cf); + __m256i ijklmnop_cd = _mm256_unpacklo_epi64(ijkl_cf, mnop_cf); + __m256i ijklmnop_ef = _mm256_unpackhi_epi64(ijkl_cf, mnop_cf); + + // NOTE: no unpacklo/hi instr in this case, so using permute instr. + kernel.packet[0] = _mm256_permute2x128_si256(abcdefgh_01, ijklmnop_01, 0x20); + kernel.packet[1] = _mm256_permute2x128_si256(abcdefgh_23, ijklmnop_23, 0x20); + kernel.packet[2] = _mm256_permute2x128_si256(abcdefgh_45, ijklmnop_45, 0x20); + kernel.packet[3] = _mm256_permute2x128_si256(abcdefgh_67, ijklmnop_67, 0x20); + kernel.packet[4] = _mm256_permute2x128_si256(abcdefgh_89, ijklmnop_89, 0x20); + kernel.packet[5] = _mm256_permute2x128_si256(abcdefgh_ab, ijklmnop_ab, 0x20); + kernel.packet[6] = _mm256_permute2x128_si256(abcdefgh_cd, ijklmnop_cd, 0x20); + kernel.packet[7] = _mm256_permute2x128_si256(abcdefgh_ef, ijklmnop_ef, 0x20); + kernel.packet[8] = _mm256_permute2x128_si256(abcdefgh_01, ijklmnop_01, 0x31); + kernel.packet[9] = _mm256_permute2x128_si256(abcdefgh_23, ijklmnop_23, 0x31); + kernel.packet[10] = _mm256_permute2x128_si256(abcdefgh_45, ijklmnop_45, 0x31); + kernel.packet[11] = _mm256_permute2x128_si256(abcdefgh_67, ijklmnop_67, 0x31); + kernel.packet[12] = _mm256_permute2x128_si256(abcdefgh_89, ijklmnop_89, 0x31); + kernel.packet[13] = _mm256_permute2x128_si256(abcdefgh_ab, ijklmnop_ab, 0x31); + kernel.packet[14] = _mm256_permute2x128_si256(abcdefgh_cd, ijklmnop_cd, 0x31); + kernel.packet[15] = _mm256_permute2x128_si256(abcdefgh_ef, ijklmnop_ef, 0x31); +} + +EIGEN_STRONG_INLINE void ptranspose(PacketBlock& kernel) { + __m256i a = kernel.packet[0]; + __m256i b = kernel.packet[1]; + __m256i c = kernel.packet[2]; + __m256i d = kernel.packet[3]; + + __m256i ab_07 = _mm256_unpacklo_epi16(a, b); + __m256i cd_07 = _mm256_unpacklo_epi16(c, d); + __m256i ab_8f = _mm256_unpackhi_epi16(a, b); + __m256i cd_8f = _mm256_unpackhi_epi16(c, d); + + __m256i abcd_03 = _mm256_unpacklo_epi32(ab_07, cd_07); + __m256i abcd_47 = _mm256_unpackhi_epi32(ab_07, cd_07); + __m256i abcd_8b = _mm256_unpacklo_epi32(ab_8f, cd_8f); + __m256i abcd_cf = _mm256_unpackhi_epi32(ab_8f, cd_8f); + + // NOTE: no unpacklo/hi instr in this case, so using permute instr. + kernel.packet[0] = _mm256_permute2x128_si256(abcd_03, abcd_47, 0x20); + kernel.packet[1] = _mm256_permute2x128_si256(abcd_8b, abcd_cf, 0x20); + kernel.packet[2] = _mm256_permute2x128_si256(abcd_03, abcd_47, 0x31); + kernel.packet[3] = _mm256_permute2x128_si256(abcd_8b, abcd_cf, 0x31); +} + +} // end namespace internal + +} // end namespace Eigen + +#endif // EIGEN_PACKET_MATH_AVX512_H diff --git a/Eigen/src/Core/arch/AVX512/TypeCasting.h b/Eigen/src/Core/arch/AVX512/TypeCasting.h new file mode 100644 index 0000000..3304127 --- /dev/null +++ b/Eigen/src/Core/arch/AVX512/TypeCasting.h @@ -0,0 +1,89 @@ +// This file is part of Eigen, a lightweight C++ template library +// for linear algebra. +// +// Copyright (C) 2019 Rasmus Munk Larsen +// +// This Source Code Form is subject to the terms of the Mozilla +// Public License v. 2.0. If a copy of the MPL was not distributed +// with this file, You can obtain one at http://mozilla.org/MPL/2.0/. + +#ifndef EIGEN_TYPE_CASTING_AVX512_H +#define EIGEN_TYPE_CASTING_AVX512_H + +namespace Eigen { + +namespace internal { + +template<> EIGEN_STRONG_INLINE Packet16i pcast(const Packet16f& a) { + return _mm512_cvttps_epi32(a); +} + +template<> EIGEN_STRONG_INLINE Packet16f pcast(const Packet16i& a) { + return _mm512_cvtepi32_ps(a); +} + +template<> EIGEN_STRONG_INLINE Packet16i preinterpret(const Packet16f& a) { + return _mm512_castps_si512(a); +} + +template<> EIGEN_STRONG_INLINE Packet16f preinterpret(const Packet16i& a) { + return _mm512_castsi512_ps(a); +} + +template <> +struct type_casting_traits { + enum { + VectorizedCast = 1, + SrcCoeffRatio = 1, + TgtCoeffRatio = 1 + }; +}; + +template<> EIGEN_STRONG_INLINE Packet16f pcast(const Packet16h& a) { + return half2float(a); +} + +template <> +struct type_casting_traits { + enum { + VectorizedCast = 1, + SrcCoeffRatio = 1, + TgtCoeffRatio = 1 + }; +}; + +template<> EIGEN_STRONG_INLINE Packet16h pcast(const Packet16f& a) { + return float2half(a); +} + +template <> +struct type_casting_traits { + enum { + VectorizedCast = 1, + SrcCoeffRatio = 1, + TgtCoeffRatio = 1 + }; +}; + +template<> EIGEN_STRONG_INLINE Packet16f pcast(const Packet16bf& a) { + return Bf16ToF32(a); +} + +template <> +struct type_casting_traits { + enum { + VectorizedCast = 1, + SrcCoeffRatio = 1, + TgtCoeffRatio = 1 + }; +}; + +template<> EIGEN_STRONG_INLINE Packet16bf pcast(const Packet16f& a) { + return F32ToBf16(a); +} + +} // end namespace internal + +} // end namespace Eigen + +#endif // EIGEN_TYPE_CASTING_AVX512_H diff --git a/Eigen/src/Core/arch/AltiVec/Complex.h b/Eigen/src/Core/arch/AltiVec/Complex.h new file mode 100644 index 0000000..f424f11 --- /dev/null +++ b/Eigen/src/Core/arch/AltiVec/Complex.h @@ -0,0 +1,417 @@ +// This file is part of Eigen, a lightweight C++ template library +// for linear algebra. +// +// Copyright (C) 2010 Gael Guennebaud +// Copyright (C) 2010-2016 Konstantinos Margaritis +// +// This Source Code Form is subject to the terms of the Mozilla +// Public License v. 2.0. If a copy of the MPL was not distributed +// with this file, You can obtain one at http://mozilla.org/MPL/2.0/. + +#ifndef EIGEN_COMPLEX32_ALTIVEC_H +#define EIGEN_COMPLEX32_ALTIVEC_H + +namespace Eigen { + +namespace internal { + +static Packet4ui p4ui_CONJ_XOR = vec_mergeh((Packet4ui)p4i_ZERO, (Packet4ui)p4f_MZERO);//{ 0x00000000, 0x80000000, 0x00000000, 0x80000000 }; +#ifdef __VSX__ +#if defined(_BIG_ENDIAN) +static Packet2ul p2ul_CONJ_XOR1 = (Packet2ul) vec_sld((Packet4ui) p2d_MZERO, (Packet4ui) p2l_ZERO, 8);//{ 0x8000000000000000, 0x0000000000000000 }; +static Packet2ul p2ul_CONJ_XOR2 = (Packet2ul) vec_sld((Packet4ui) p2l_ZERO, (Packet4ui) p2d_MZERO, 8);//{ 0x8000000000000000, 0x0000000000000000 }; +#else +static Packet2ul p2ul_CONJ_XOR1 = (Packet2ul) vec_sld((Packet4ui) p2l_ZERO, (Packet4ui) p2d_MZERO, 8);//{ 0x8000000000000000, 0x0000000000000000 }; +static Packet2ul p2ul_CONJ_XOR2 = (Packet2ul) vec_sld((Packet4ui) p2d_MZERO, (Packet4ui) p2l_ZERO, 8);//{ 0x8000000000000000, 0x0000000000000000 }; +#endif +#endif + +//---------- float ---------- +struct Packet2cf +{ + EIGEN_STRONG_INLINE explicit Packet2cf() {} + EIGEN_STRONG_INLINE explicit Packet2cf(const Packet4f& a) : v(a) {} + + EIGEN_STRONG_INLINE Packet2cf pmul(const Packet2cf& a, const Packet2cf& b) + { + Packet4f v1, v2; + + // Permute and multiply the real parts of a and b + v1 = vec_perm(a.v, a.v, p16uc_PSET32_WODD); + // Get the imaginary parts of a + v2 = vec_perm(a.v, a.v, p16uc_PSET32_WEVEN); + // multiply a_re * b + v1 = vec_madd(v1, b.v, p4f_ZERO); + // multiply a_im * b and get the conjugate result + v2 = vec_madd(v2, b.v, p4f_ZERO); + v2 = reinterpret_cast(pxor(v2, reinterpret_cast(p4ui_CONJ_XOR))); + // permute back to a proper order + v2 = vec_perm(v2, v2, p16uc_COMPLEX32_REV); + + return Packet2cf(padd(v1, v2)); + } + + EIGEN_STRONG_INLINE Packet2cf& operator*=(const Packet2cf& b) { + v = pmul(Packet2cf(*this), b).v; + return *this; + } + EIGEN_STRONG_INLINE Packet2cf operator*(const Packet2cf& b) const { + return Packet2cf(*this) *= b; + } + + EIGEN_STRONG_INLINE Packet2cf& operator+=(const Packet2cf& b) { + v = padd(v, b.v); + return *this; + } + EIGEN_STRONG_INLINE Packet2cf operator+(const Packet2cf& b) const { + return Packet2cf(*this) += b; + } + EIGEN_STRONG_INLINE Packet2cf& operator-=(const Packet2cf& b) { + v = psub(v, b.v); + return *this; + } + EIGEN_STRONG_INLINE Packet2cf operator-(const Packet2cf& b) const { + return Packet2cf(*this) -= b; + } + EIGEN_STRONG_INLINE Packet2cf operator-(void) const { + return Packet2cf(-v); + } + + Packet4f v; +}; + +template<> struct packet_traits > : default_packet_traits +{ + typedef Packet2cf type; + typedef Packet2cf half; + typedef Packet4f as_real; + enum { + Vectorizable = 1, + AlignedOnScalar = 1, + size = 2, + HasHalfPacket = 0, + + HasAdd = 1, + HasSub = 1, + HasMul = 1, + HasDiv = 1, + HasNegate = 1, + HasAbs = 0, + HasAbs2 = 0, + HasMin = 0, + HasMax = 0, +#ifdef __VSX__ + HasBlend = 1, +#endif + HasSetLinear = 0 + }; +}; + +template<> struct unpacket_traits { typedef std::complex type; enum {size=2, alignment=Aligned16, vectorizable=true, masked_load_available=false, masked_store_available=false}; typedef Packet2cf half; typedef Packet4f as_real; }; + +template<> EIGEN_STRONG_INLINE Packet2cf pset1(const std::complex& from) +{ + Packet2cf res; + if((std::ptrdiff_t(&from) % 16) == 0) + res.v = pload((const float *)&from); + else + res.v = ploadu((const float *)&from); + res.v = vec_perm(res.v, res.v, p16uc_PSET64_HI); + return res; +} + +template<> EIGEN_STRONG_INLINE Packet2cf pload(const std::complex* from) { return Packet2cf(pload((const float *) from)); } +template<> EIGEN_STRONG_INLINE Packet2cf ploadu(const std::complex* from) { return Packet2cf(ploadu((const float*) from)); } +template<> EIGEN_STRONG_INLINE Packet2cf ploaddup(const std::complex* from) { return pset1(*from); } + +template<> EIGEN_STRONG_INLINE void pstore >(std::complex * to, const Packet2cf& from) { pstore((float*)to, from.v); } +template<> EIGEN_STRONG_INLINE void pstoreu >(std::complex * to, const Packet2cf& from) { pstoreu((float*)to, from.v); } + +EIGEN_STRONG_INLINE Packet2cf pload2(const std::complex* from0, const std::complex* from1) +{ + Packet4f res0, res1; +#ifdef __VSX__ + __asm__ ("lxsdx %x0,%y1" : "=wa" (res0) : "Z" (*from0)); + __asm__ ("lxsdx %x0,%y1" : "=wa" (res1) : "Z" (*from1)); +#ifdef _BIG_ENDIAN + __asm__ ("xxpermdi %x0, %x1, %x2, 0" : "=wa" (res0) : "wa" (res0), "wa" (res1)); +#else + __asm__ ("xxpermdi %x0, %x2, %x1, 0" : "=wa" (res0) : "wa" (res0), "wa" (res1)); +#endif +#else + *reinterpret_cast *>(&res0) = *from0; + *reinterpret_cast *>(&res1) = *from1; + res0 = vec_perm(res0, res1, p16uc_TRANSPOSE64_HI); +#endif + return Packet2cf(res0); +} + +template<> EIGEN_DEVICE_FUNC inline Packet2cf pgather, Packet2cf>(const std::complex* from, Index stride) +{ + EIGEN_ALIGN16 std::complex af[2]; + af[0] = from[0*stride]; + af[1] = from[1*stride]; + return pload(af); +} +template<> EIGEN_DEVICE_FUNC inline void pscatter, Packet2cf>(std::complex* to, const Packet2cf& from, Index stride) +{ + EIGEN_ALIGN16 std::complex af[2]; + pstore >((std::complex *) af, from); + to[0*stride] = af[0]; + to[1*stride] = af[1]; +} + +template<> EIGEN_STRONG_INLINE Packet2cf padd(const Packet2cf& a, const Packet2cf& b) { return Packet2cf(a.v + b.v); } +template<> EIGEN_STRONG_INLINE Packet2cf psub(const Packet2cf& a, const Packet2cf& b) { return Packet2cf(a.v - b.v); } +template<> EIGEN_STRONG_INLINE Packet2cf pnegate(const Packet2cf& a) { return Packet2cf(pnegate(a.v)); } +template<> EIGEN_STRONG_INLINE Packet2cf pconj(const Packet2cf& a) { return Packet2cf(pxor(a.v, reinterpret_cast(p4ui_CONJ_XOR))); } + +template<> EIGEN_STRONG_INLINE Packet2cf pand (const Packet2cf& a, const Packet2cf& b) { return Packet2cf(pand(a.v, b.v)); } +template<> EIGEN_STRONG_INLINE Packet2cf por (const Packet2cf& a, const Packet2cf& b) { return Packet2cf(por(a.v, b.v)); } +template<> EIGEN_STRONG_INLINE Packet2cf pxor (const Packet2cf& a, const Packet2cf& b) { return Packet2cf(pxor(a.v, b.v)); } +template<> EIGEN_STRONG_INLINE Packet2cf pandnot(const Packet2cf& a, const Packet2cf& b) { return Packet2cf(pandnot(a.v, b.v)); } + +template<> EIGEN_STRONG_INLINE void prefetch >(const std::complex * addr) { EIGEN_PPC_PREFETCH(addr); } + +template<> EIGEN_STRONG_INLINE std::complex pfirst(const Packet2cf& a) +{ + EIGEN_ALIGN16 std::complex res[2]; + pstore((float *)&res, a.v); + + return res[0]; +} + +template<> EIGEN_STRONG_INLINE Packet2cf preverse(const Packet2cf& a) +{ + Packet4f rev_a; + rev_a = vec_perm(a.v, a.v, p16uc_COMPLEX32_REV2); + return Packet2cf(rev_a); +} + +template<> EIGEN_STRONG_INLINE std::complex predux(const Packet2cf& a) +{ + Packet4f b; + b = vec_sld(a.v, a.v, 8); + b = padd(a.v, b); + return pfirst(Packet2cf(b)); +} + +template<> EIGEN_STRONG_INLINE std::complex predux_mul(const Packet2cf& a) +{ + Packet4f b; + Packet2cf prod; + b = vec_sld(a.v, a.v, 8); + prod = pmul(a, Packet2cf(b)); + + return pfirst(prod); +} + +EIGEN_MAKE_CONJ_HELPER_CPLX_REAL(Packet2cf,Packet4f) + +template<> EIGEN_STRONG_INLINE Packet2cf pdiv(const Packet2cf& a, const Packet2cf& b) +{ + // TODO optimize it for AltiVec + Packet2cf res = pmul(a, pconj(b)); + Packet4f s = pmul(b.v, b.v); + return Packet2cf(pdiv(res.v, padd(s, vec_perm(s, s, p16uc_COMPLEX32_REV)))); +} + +template<> EIGEN_STRONG_INLINE Packet2cf pcplxflip(const Packet2cf& x) +{ + return Packet2cf(vec_perm(x.v, x.v, p16uc_COMPLEX32_REV)); +} + +EIGEN_STRONG_INLINE void ptranspose(PacketBlock& kernel) +{ + Packet4f tmp = vec_perm(kernel.packet[0].v, kernel.packet[1].v, p16uc_TRANSPOSE64_HI); + kernel.packet[1].v = vec_perm(kernel.packet[0].v, kernel.packet[1].v, p16uc_TRANSPOSE64_LO); + kernel.packet[0].v = tmp; +} + +template<> EIGEN_STRONG_INLINE Packet2cf pcmp_eq(const Packet2cf& a, const Packet2cf& b) { + Packet4f eq = reinterpret_cast(vec_cmpeq(a.v,b.v)); + return Packet2cf(vec_and(eq, vec_perm(eq, eq, p16uc_COMPLEX32_REV))); +} + +#ifdef __VSX__ +template<> EIGEN_STRONG_INLINE Packet2cf pblend(const Selector<2>& ifPacket, const Packet2cf& thenPacket, const Packet2cf& elsePacket) { + Packet2cf result; + result.v = reinterpret_cast(pblend(ifPacket, reinterpret_cast(thenPacket.v), reinterpret_cast(elsePacket.v))); + return result; +} +#endif + +template<> EIGEN_STRONG_INLINE Packet2cf psqrt(const Packet2cf& a) +{ + return psqrt_complex(a); +} + +//---------- double ---------- +#ifdef __VSX__ +struct Packet1cd +{ + EIGEN_STRONG_INLINE Packet1cd() {} + EIGEN_STRONG_INLINE explicit Packet1cd(const Packet2d& a) : v(a) {} + + EIGEN_STRONG_INLINE Packet1cd pmul(const Packet1cd& a, const Packet1cd& b) + { + Packet2d a_re, a_im, v1, v2; + + // Permute and multiply the real parts of a and b + a_re = vec_perm(a.v, a.v, p16uc_PSET64_HI); + // Get the imaginary parts of a + a_im = vec_perm(a.v, a.v, p16uc_PSET64_LO); + // multiply a_re * b + v1 = vec_madd(a_re, b.v, p2d_ZERO); + // multiply a_im * b and get the conjugate result + v2 = vec_madd(a_im, b.v, p2d_ZERO); + v2 = reinterpret_cast(vec_sld(reinterpret_cast(v2), reinterpret_cast(v2), 8)); + v2 = pxor(v2, reinterpret_cast(p2ul_CONJ_XOR1)); + + return Packet1cd(padd(v1, v2)); + } + + EIGEN_STRONG_INLINE Packet1cd& operator*=(const Packet1cd& b) { + v = pmul(Packet1cd(*this), b).v; + return *this; + } + EIGEN_STRONG_INLINE Packet1cd operator*(const Packet1cd& b) const { + return Packet1cd(*this) *= b; + } + + EIGEN_STRONG_INLINE Packet1cd& operator+=(const Packet1cd& b) { + v = padd(v, b.v); + return *this; + } + EIGEN_STRONG_INLINE Packet1cd operator+(const Packet1cd& b) const { + return Packet1cd(*this) += b; + } + EIGEN_STRONG_INLINE Packet1cd& operator-=(const Packet1cd& b) { + v = psub(v, b.v); + return *this; + } + EIGEN_STRONG_INLINE Packet1cd operator-(const Packet1cd& b) const { + return Packet1cd(*this) -= b; + } + EIGEN_STRONG_INLINE Packet1cd operator-(void) const { + return Packet1cd(-v); + } + + Packet2d v; +}; + +template<> struct packet_traits > : default_packet_traits +{ + typedef Packet1cd type; + typedef Packet1cd half; + typedef Packet2d as_real; + enum { + Vectorizable = 1, + AlignedOnScalar = 0, + size = 1, + HasHalfPacket = 0, + + HasAdd = 1, + HasSub = 1, + HasMul = 1, + HasDiv = 1, + HasNegate = 1, + HasAbs = 0, + HasAbs2 = 0, + HasMin = 0, + HasMax = 0, + HasSetLinear = 0 + }; +}; + +template<> struct unpacket_traits { typedef std::complex type; enum {size=1, alignment=Aligned16, vectorizable=true, masked_load_available=false, masked_store_available=false}; typedef Packet1cd half; typedef Packet2d as_real; }; + +template<> EIGEN_STRONG_INLINE Packet1cd pload (const std::complex* from) { return Packet1cd(pload((const double*)from)); } +template<> EIGEN_STRONG_INLINE Packet1cd ploadu(const std::complex* from) { return Packet1cd(ploadu((const double*)from)); } +template<> EIGEN_STRONG_INLINE void pstore >(std::complex * to, const Packet1cd& from) { pstore((double*)to, from.v); } +template<> EIGEN_STRONG_INLINE void pstoreu >(std::complex * to, const Packet1cd& from) { pstoreu((double*)to, from.v); } + +template<> EIGEN_STRONG_INLINE Packet1cd pset1(const std::complex& from) +{ /* here we really have to use unaligned loads :( */ return ploadu(&from); } + +template<> EIGEN_DEVICE_FUNC inline Packet1cd pgather, Packet1cd>(const std::complex* from, Index) +{ + return pload(from); +} +template<> EIGEN_DEVICE_FUNC inline void pscatter, Packet1cd>(std::complex* to, const Packet1cd& from, Index) +{ + pstore >(to, from); +} + +template<> EIGEN_STRONG_INLINE Packet1cd padd(const Packet1cd& a, const Packet1cd& b) { return Packet1cd(a.v + b.v); } +template<> EIGEN_STRONG_INLINE Packet1cd psub(const Packet1cd& a, const Packet1cd& b) { return Packet1cd(a.v - b.v); } +template<> EIGEN_STRONG_INLINE Packet1cd pnegate(const Packet1cd& a) { return Packet1cd(pnegate(Packet2d(a.v))); } +template<> EIGEN_STRONG_INLINE Packet1cd pconj(const Packet1cd& a) { return Packet1cd(pxor(a.v, reinterpret_cast(p2ul_CONJ_XOR2))); } + +template<> EIGEN_STRONG_INLINE Packet1cd pand (const Packet1cd& a, const Packet1cd& b) { return Packet1cd(pand(a.v,b.v)); } +template<> EIGEN_STRONG_INLINE Packet1cd por (const Packet1cd& a, const Packet1cd& b) { return Packet1cd(por(a.v,b.v)); } +template<> EIGEN_STRONG_INLINE Packet1cd pxor (const Packet1cd& a, const Packet1cd& b) { return Packet1cd(pxor(a.v,b.v)); } +template<> EIGEN_STRONG_INLINE Packet1cd pandnot(const Packet1cd& a, const Packet1cd& b) { return Packet1cd(pandnot(a.v, b.v)); } + +template<> EIGEN_STRONG_INLINE Packet1cd ploaddup(const std::complex* from) { return pset1(*from); } + +template<> EIGEN_STRONG_INLINE void prefetch >(const std::complex * addr) { EIGEN_PPC_PREFETCH(addr); } + +template<> EIGEN_STRONG_INLINE std::complex pfirst(const Packet1cd& a) +{ + EIGEN_ALIGN16 std::complex res[2]; + pstore >(res, a); + + return res[0]; +} + +template<> EIGEN_STRONG_INLINE Packet1cd preverse(const Packet1cd& a) { return a; } + +template<> EIGEN_STRONG_INLINE std::complex predux(const Packet1cd& a) { return pfirst(a); } + +template<> EIGEN_STRONG_INLINE std::complex predux_mul(const Packet1cd& a) { return pfirst(a); } + +EIGEN_MAKE_CONJ_HELPER_CPLX_REAL(Packet1cd,Packet2d) + +template<> EIGEN_STRONG_INLINE Packet1cd pdiv(const Packet1cd& a, const Packet1cd& b) +{ + // TODO optimize it for AltiVec + Packet1cd res = pmul(a,pconj(b)); + Packet2d s = pmul(b.v, b.v); + return Packet1cd(pdiv(res.v, padd(s, vec_perm(s, s, p16uc_REVERSE64)))); +} + +EIGEN_STRONG_INLINE Packet1cd pcplxflip/**/(const Packet1cd& x) +{ + return Packet1cd(preverse(Packet2d(x.v))); +} + +EIGEN_STRONG_INLINE void ptranspose(PacketBlock& kernel) +{ + Packet2d tmp = vec_perm(kernel.packet[0].v, kernel.packet[1].v, p16uc_TRANSPOSE64_HI); + kernel.packet[1].v = vec_perm(kernel.packet[0].v, kernel.packet[1].v, p16uc_TRANSPOSE64_LO); + kernel.packet[0].v = tmp; +} + +template<> EIGEN_STRONG_INLINE Packet1cd pcmp_eq(const Packet1cd& a, const Packet1cd& b) { + // Compare real and imaginary parts of a and b to get the mask vector: + // [re(a)==re(b), im(a)==im(b)] + Packet2d eq = reinterpret_cast(vec_cmpeq(a.v,b.v)); + // Swap real/imag elements in the mask in to get: + // [im(a)==im(b), re(a)==re(b)] + Packet2d eq_swapped = reinterpret_cast(vec_sld(reinterpret_cast(eq), reinterpret_cast(eq), 8)); + // Return re(a)==re(b) & im(a)==im(b) by computing bitwise AND of eq and eq_swapped + return Packet1cd(vec_and(eq, eq_swapped)); +} + +template<> EIGEN_STRONG_INLINE Packet1cd psqrt(const Packet1cd& a) +{ + return psqrt_complex(a); +} + +#endif // __VSX__ +} // end namespace internal + +} // end namespace Eigen + +#endif // EIGEN_COMPLEX32_ALTIVEC_H diff --git a/Eigen/src/Core/arch/AltiVec/MathFunctions.h b/Eigen/src/Core/arch/AltiVec/MathFunctions.h new file mode 100644 index 0000000..3a7a329 --- /dev/null +++ b/Eigen/src/Core/arch/AltiVec/MathFunctions.h @@ -0,0 +1,90 @@ +// This file is part of Eigen, a lightweight C++ template library +// for linear algebra. +// +// Copyright (C) 2007 Julien Pommier +// Copyright (C) 2009 Gael Guennebaud +// Copyright (C) 2016 Konstantinos Margaritis +// +// This Source Code Form is subject to the terms of the Mozilla +// Public License v. 2.0. If a copy of the MPL was not distributed +// with this file, You can obtain one at http://mozilla.org/MPL/2.0/. + +#ifndef EIGEN_MATH_FUNCTIONS_ALTIVEC_H +#define EIGEN_MATH_FUNCTIONS_ALTIVEC_H + +namespace Eigen { + +namespace internal { + +template<> EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS EIGEN_UNUSED +Packet4f plog(const Packet4f& _x) +{ + return plog_float(_x); +} + +template<> EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS EIGEN_UNUSED +Packet4f pexp(const Packet4f& _x) +{ + return pexp_float(_x); +} + +template<> EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS EIGEN_UNUSED +Packet4f psin(const Packet4f& _x) +{ + return psin_float(_x); +} + +template<> EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS EIGEN_UNUSED +Packet4f pcos(const Packet4f& _x) +{ + return pcos_float(_x); +} + +#ifndef EIGEN_COMP_CLANG +template<> EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS EIGEN_UNUSED +Packet4f prsqrt(const Packet4f& x) +{ + return vec_rsqrt(x); +} +#endif + +#ifdef __VSX__ +#ifndef EIGEN_COMP_CLANG +template<> EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS EIGEN_UNUSED +Packet2d prsqrt(const Packet2d& x) +{ + return vec_rsqrt(x); +} +#endif + +template<> EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS EIGEN_UNUSED +Packet4f psqrt(const Packet4f& x) +{ + return vec_sqrt(x); +} + +template<> EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS EIGEN_UNUSED +Packet2d psqrt(const Packet2d& x) +{ + return vec_sqrt(x); +} + +template<> EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS EIGEN_UNUSED +Packet2d pexp(const Packet2d& _x) +{ + return pexp_double(_x); +} +#endif + +// Hyperbolic Tangent function. +template <> +EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS EIGEN_UNUSED Packet4f +ptanh(const Packet4f& x) { + return internal::generic_fast_tanh_float(x); +} + +} // end namespace internal + +} // end namespace Eigen + +#endif // EIGEN_MATH_FUNCTIONS_ALTIVEC_H diff --git a/Eigen/src/Core/arch/AltiVec/MatrixProduct.h b/Eigen/src/Core/arch/AltiVec/MatrixProduct.h new file mode 100644 index 0000000..3f79b97 --- /dev/null +++ b/Eigen/src/Core/arch/AltiVec/MatrixProduct.h @@ -0,0 +1,2937 @@ +// This file is part of Eigen, a lightweight C++ template library +// for linear algebra. +// +// Copyright (C) 2020 Everton Constantino (everton.constantino@ibm.com) +// Copyright (C) 2021 Chip Kerchner (chip.kerchner@ibm.com) +// +// This Source Code Form is subject to the terms of the Mozilla +// Public License v. 2.0. If a copy of the MPL was not distributed +// with this file, You can obtain one at http://mozilla.org/MPL/2.0/. + +#ifndef EIGEN_MATRIX_PRODUCT_ALTIVEC_H +#define EIGEN_MATRIX_PRODUCT_ALTIVEC_H + +#ifndef EIGEN_ALTIVEC_USE_CUSTOM_PACK +#define EIGEN_ALTIVEC_USE_CUSTOM_PACK 1 +#endif + +#include "MatrixProductCommon.h" + +// Since LLVM doesn't support dynamic dispatching, force either always MMA or VSX +#if EIGEN_COMP_LLVM +#if !defined(EIGEN_ALTIVEC_DISABLE_MMA) && !defined(EIGEN_ALTIVEC_MMA_ONLY) +#ifdef __MMA__ +#define EIGEN_ALTIVEC_MMA_ONLY +#else +#define EIGEN_ALTIVEC_DISABLE_MMA +#endif +#endif +#endif + +#ifdef __has_builtin +#if __has_builtin(__builtin_mma_assemble_acc) + #define ALTIVEC_MMA_SUPPORT +#endif +#endif + +#if defined(ALTIVEC_MMA_SUPPORT) && !defined(EIGEN_ALTIVEC_DISABLE_MMA) + #include "MatrixProductMMA.h" +#endif + +/************************************************************************************************** + * TODO * + * - Check StorageOrder on dhs_pack (the innermost second loop seems unvectorized when it could). * + * - Check the possibility of transposing as GETREAL and GETIMAG when needed. * + **************************************************************************************************/ +namespace Eigen { + +namespace internal { + +/************************** + * Constants and typedefs * + **************************/ +template +struct quad_traits +{ + typedef typename packet_traits::type vectortype; + typedef PacketBlock type; + typedef vectortype rhstype; + enum + { + vectorsize = packet_traits::size, + size = 4, + rows = 4 + }; +}; + +template<> +struct quad_traits +{ + typedef Packet2d vectortype; + typedef PacketBlock type; + typedef PacketBlock rhstype; + enum + { + vectorsize = packet_traits::size, + size = 2, + rows = 4 + }; +}; + +// MatrixProduct decomposes real/imaginary vectors into a real vector and an imaginary vector, this turned out +// to be faster than Eigen's usual approach of having real/imaginary pairs on a single vector. This constants then +// are responsible to extract from convert between Eigen's and MatrixProduct approach. + +const static Packet16uc p16uc_GETREAL32 = { 0, 1, 2, 3, + 8, 9, 10, 11, + 16, 17, 18, 19, + 24, 25, 26, 27}; + +const static Packet16uc p16uc_GETIMAG32 = { 4, 5, 6, 7, + 12, 13, 14, 15, + 20, 21, 22, 23, + 28, 29, 30, 31}; +const static Packet16uc p16uc_GETREAL64 = { 0, 1, 2, 3, 4, 5, 6, 7, + 16, 17, 18, 19, 20, 21, 22, 23}; + +//[a,ai],[b,bi] = [ai,bi] +const static Packet16uc p16uc_GETIMAG64 = { 8, 9, 10, 11, 12, 13, 14, 15, + 24, 25, 26, 27, 28, 29, 30, 31}; + +/********************************************* + * Single precision real and complex packing * + * *******************************************/ + +/** + * Symm packing is related to packing of symmetric adjoint blocks, as expected the packing leaves + * the diagonal real, whatever is below it is copied from the respective upper diagonal element and + * conjugated. There's no PanelMode available for symm packing. + * + * Packing in general is supposed to leave the lhs block and the rhs block easy to be read by gemm using + * its respective rank-update instructions. The float32/64 versions are different because at this moment + * the size of the accumulator is fixed at 512-bits so you can't have a 4x4 accumulator of 64-bit elements. + * + * As mentioned earlier MatrixProduct breaks complex numbers into a real vector and a complex vector so packing has + * to take that into account, at the moment, we run pack the real part and then the imaginary part, this is the main + * reason why packing for complex is broken down into several different parts, also the reason why we endup having a + * float32/64 and complex float32/64 version. + **/ +template +EIGEN_ALWAYS_INLINE std::complex getAdjointVal(Index i, Index j, const_blas_data_mapper, Index, StorageOrder>& dt) +{ + std::complex v; + if(i < j) + { + v.real( dt(j,i).real()); + v.imag(-dt(j,i).imag()); + } else if(i > j) + { + v.real( dt(i,j).real()); + v.imag( dt(i,j).imag()); + } else { + v.real( dt(i,j).real()); + v.imag((Scalar)0.0); + } + return v; +} + +template +EIGEN_STRONG_INLINE void symm_pack_complex_rhs_helper(std::complex* blockB, const std::complex* _rhs, Index rhsStride, Index rows, Index cols, Index k2) +{ + const Index depth = k2 + rows; + const_blas_data_mapper, Index, StorageOrder> rhs(_rhs, rhsStride); + const Index vectorSize = N*quad_traits::vectorsize; + const Index vectorDelta = vectorSize * rows; + Scalar* blockBf = reinterpret_cast(blockB); + + Index rir = 0, rii, j = 0; + for(; j + vectorSize <= cols; j+=vectorSize) + { + rii = rir + vectorDelta; + + for(Index i = k2; i < depth; i++) + { + for(Index k = 0; k < vectorSize; k++) + { + std::complex v = getAdjointVal(i, j + k, rhs); + + blockBf[rir + k] = v.real(); + blockBf[rii + k] = v.imag(); + } + rir += vectorSize; + rii += vectorSize; + } + + rir += vectorDelta; + } + if (j < cols) + { + rii = rir + ((cols - j) * rows); + + for(Index i = k2; i < depth; i++) + { + Index k = j; + for(; k < cols; k++) + { + std::complex v = getAdjointVal(i, k, rhs); + + blockBf[rir] = v.real(); + blockBf[rii] = v.imag(); + + rir += 1; + rii += 1; + } + } + } +} + +template +EIGEN_STRONG_INLINE void symm_pack_complex_lhs_helper(std::complex* blockA, const std::complex* _lhs, Index lhsStride, Index cols, Index rows) +{ + const Index depth = cols; + const_blas_data_mapper, Index, StorageOrder> lhs(_lhs, lhsStride); + const Index vectorSize = quad_traits::vectorsize; + const Index vectorDelta = vectorSize * depth; + Scalar* blockAf = (Scalar *)(blockA); + + Index rir = 0, rii, j = 0; + for(; j + vectorSize <= rows; j+=vectorSize) + { + rii = rir + vectorDelta; + + for(Index i = 0; i < depth; i++) + { + for(Index k = 0; k < vectorSize; k++) + { + std::complex v = getAdjointVal(j+k, i, lhs); + + blockAf[rir + k] = v.real(); + blockAf[rii + k] = v.imag(); + } + rir += vectorSize; + rii += vectorSize; + } + + rir += vectorDelta; + } + + if (j < rows) + { + rii = rir + ((rows - j) * depth); + + for(Index i = 0; i < depth; i++) + { + Index k = j; + for(; k < rows; k++) + { + std::complex v = getAdjointVal(k, i, lhs); + + blockAf[rir] = v.real(); + blockAf[rii] = v.imag(); + + rir += 1; + rii += 1; + } + } + } +} + +template +EIGEN_STRONG_INLINE void symm_pack_rhs_helper(Scalar* blockB, const Scalar* _rhs, Index rhsStride, Index rows, Index cols, Index k2) +{ + const Index depth = k2 + rows; + const_blas_data_mapper rhs(_rhs, rhsStride); + const Index vectorSize = quad_traits::vectorsize; + + Index ri = 0, j = 0; + for(; j + N*vectorSize <= cols; j+=N*vectorSize) + { + Index i = k2; + for(; i < depth; i++) + { + for(Index k = 0; k < N*vectorSize; k++) + { + if(i <= j+k) + blockB[ri + k] = rhs(j+k, i); + else + blockB[ri + k] = rhs(i, j+k); + } + ri += N*vectorSize; + } + } + + if (j < cols) + { + for(Index i = k2; i < depth; i++) + { + Index k = j; + for(; k < cols; k++) + { + if(k <= i) + blockB[ri] = rhs(i, k); + else + blockB[ri] = rhs(k, i); + ri += 1; + } + } + } +} + +template +EIGEN_STRONG_INLINE void symm_pack_lhs_helper(Scalar* blockA, const Scalar* _lhs, Index lhsStride, Index cols, Index rows) +{ + const Index depth = cols; + const_blas_data_mapper lhs(_lhs, lhsStride); + const Index vectorSize = quad_traits::vectorsize; + + Index ri = 0, j = 0; + for(; j + vectorSize <= rows; j+=vectorSize) + { + Index i = 0; + + for(; i < depth; i++) + { + for(Index k = 0; k < vectorSize; k++) + { + if(i <= j+k) + blockA[ri + k] = lhs(j+k, i); + else + blockA[ri + k] = lhs(i, j+k); + } + ri += vectorSize; + } + } + + if (j < rows) + { + for(Index i = 0; i < depth; i++) + { + Index k = j; + for(; k < rows; k++) + { + if(i <= k) + blockA[ri] = lhs(k, i); + else + blockA[ri] = lhs(i, k); + ri += 1; + } + } + } +} + +template +struct symm_pack_rhs, Index, nr, StorageOrder> +{ + void operator()(std::complex* blockB, const std::complex* _rhs, Index rhsStride, Index rows, Index cols, Index k2) + { + symm_pack_complex_rhs_helper(blockB, _rhs, rhsStride, rows, cols, k2); + } +}; + +template +struct symm_pack_lhs, Index, Pack1, Pack2_dummy, StorageOrder> +{ + void operator()(std::complex* blockA, const std::complex* _lhs, Index lhsStride, Index cols, Index rows) + { + symm_pack_complex_lhs_helper(blockA, _lhs, lhsStride, cols, rows); + } +}; + +// *********** symm_pack std::complex *********** + +template +struct symm_pack_rhs, Index, nr, StorageOrder> +{ + void operator()(std::complex* blockB, const std::complex* _rhs, Index rhsStride, Index rows, Index cols, Index k2) + { + symm_pack_complex_rhs_helper(blockB, _rhs, rhsStride, rows, cols, k2); + } +}; + +template +struct symm_pack_lhs, Index, Pack1, Pack2_dummy, StorageOrder> +{ + void operator()(std::complex* blockA, const std::complex* _lhs, Index lhsStride, Index cols, Index rows) + { + symm_pack_complex_lhs_helper(blockA, _lhs, lhsStride, cols, rows); + } +}; + +// *********** symm_pack float32 *********** +template +struct symm_pack_rhs +{ + void operator()(float* blockB, const float* _rhs, Index rhsStride, Index rows, Index cols, Index k2) + { + symm_pack_rhs_helper(blockB, _rhs, rhsStride, rows, cols, k2); + } +}; + +template +struct symm_pack_lhs +{ + void operator()(float* blockA, const float* _lhs, Index lhsStride, Index cols, Index rows) + { + symm_pack_lhs_helper(blockA, _lhs, lhsStride, cols, rows); + } +}; + +// *********** symm_pack float64 *********** +template +struct symm_pack_rhs +{ + void operator()(double* blockB, const double* _rhs, Index rhsStride, Index rows, Index cols, Index k2) + { + symm_pack_rhs_helper(blockB, _rhs, rhsStride, rows, cols, k2); + } +}; + +template +struct symm_pack_lhs +{ + void operator()(double* blockA, const double* _lhs, Index lhsStride, Index cols, Index rows) + { + symm_pack_lhs_helper(blockA, _lhs, lhsStride, cols, rows); + } +}; + +/** + * PanelMode + * Packing might be called several times before being multiplied by gebp_kernel, this happens because + * on special occasions it fills part of block with other parts of the matrix. Two variables control + * how PanelMode should behave: offset and stride. The idea is that those variables represent whatever + * is going to be the real offset and stride in the future and this is what you should obey. The process + * is to behave as you would with normal packing but leave the start of each part with the correct offset + * and the end as well respecting the real stride the block will have. Gebp is aware of both blocks stride + * and offset and behaves accordingly. + **/ + +template +EIGEN_ALWAYS_INLINE void storeBlock(Scalar* to, PacketBlock& block) +{ + const Index size = 16 / sizeof(Scalar); + pstore(to + (0 * size), block.packet[0]); + pstore(to + (1 * size), block.packet[1]); + pstore(to + (2 * size), block.packet[2]); + pstore(to + (3 * size), block.packet[3]); +} + +template +EIGEN_ALWAYS_INLINE void storeBlock(Scalar* to, PacketBlock& block) +{ + const Index size = 16 / sizeof(Scalar); + pstore(to + (0 * size), block.packet[0]); + pstore(to + (1 * size), block.packet[1]); +} + +// General template for lhs & rhs complex packing. +template +struct dhs_cpack { + EIGEN_STRONG_INLINE void operator()(std::complex* blockA, const DataMapper& lhs, Index depth, Index rows, Index stride, Index offset) + { + const Index vectorSize = quad_traits::vectorsize; + const Index vectorDelta = vectorSize * ((PanelMode) ? stride : depth); + Index rir = ((PanelMode) ? (vectorSize*offset) : 0), rii; + Scalar* blockAt = reinterpret_cast(blockA); + Index j = 0; + + for(; j + vectorSize <= rows; j+=vectorSize) + { + Index i = 0; + + rii = rir + vectorDelta; + + for(; i + vectorSize <= depth; i+=vectorSize) + { + PacketBlock blockr, blocki; + PacketBlock cblock; + + if (UseLhs) { + bload(cblock, lhs, j, i); + } else { + bload(cblock, lhs, i, j); + } + + blockr.packet[0] = vec_perm(cblock.packet[0].v, cblock.packet[4].v, p16uc_GETREAL32); + blockr.packet[1] = vec_perm(cblock.packet[1].v, cblock.packet[5].v, p16uc_GETREAL32); + blockr.packet[2] = vec_perm(cblock.packet[2].v, cblock.packet[6].v, p16uc_GETREAL32); + blockr.packet[3] = vec_perm(cblock.packet[3].v, cblock.packet[7].v, p16uc_GETREAL32); + + blocki.packet[0] = vec_perm(cblock.packet[0].v, cblock.packet[4].v, p16uc_GETIMAG32); + blocki.packet[1] = vec_perm(cblock.packet[1].v, cblock.packet[5].v, p16uc_GETIMAG32); + blocki.packet[2] = vec_perm(cblock.packet[2].v, cblock.packet[6].v, p16uc_GETIMAG32); + blocki.packet[3] = vec_perm(cblock.packet[3].v, cblock.packet[7].v, p16uc_GETIMAG32); + + if(Conjugate) + { + blocki.packet[0] = -blocki.packet[0]; + blocki.packet[1] = -blocki.packet[1]; + blocki.packet[2] = -blocki.packet[2]; + blocki.packet[3] = -blocki.packet[3]; + } + + if(((StorageOrder == RowMajor) && UseLhs) || (((StorageOrder == ColMajor) && !UseLhs))) + { + ptranspose(blockr); + ptranspose(blocki); + } + + storeBlock(blockAt + rir, blockr); + storeBlock(blockAt + rii, blocki); + + rir += 4*vectorSize; + rii += 4*vectorSize; + } + for(; i < depth; i++) + { + PacketBlock blockr, blocki; + PacketBlock cblock; + + if(((StorageOrder == ColMajor) && UseLhs) || (((StorageOrder == RowMajor) && !UseLhs))) + { + if (UseLhs) { + cblock.packet[0] = lhs.template loadPacket(j + 0, i); + cblock.packet[1] = lhs.template loadPacket(j + 2, i); + } else { + cblock.packet[0] = lhs.template loadPacket(i, j + 0); + cblock.packet[1] = lhs.template loadPacket(i, j + 2); + } + } else { + std::complex lhs0, lhs1; + if (UseLhs) { + lhs0 = lhs(j + 0, i); + lhs1 = lhs(j + 1, i); + cblock.packet[0] = pload2(&lhs0, &lhs1); + lhs0 = lhs(j + 2, i); + lhs1 = lhs(j + 3, i); + cblock.packet[1] = pload2(&lhs0, &lhs1); + } else { + lhs0 = lhs(i, j + 0); + lhs1 = lhs(i, j + 1); + cblock.packet[0] = pload2(&lhs0, &lhs1); + lhs0 = lhs(i, j + 2); + lhs1 = lhs(i, j + 3); + cblock.packet[1] = pload2(&lhs0, &lhs1); + } + } + + blockr.packet[0] = vec_perm(cblock.packet[0].v, cblock.packet[1].v, p16uc_GETREAL32); + blocki.packet[0] = vec_perm(cblock.packet[0].v, cblock.packet[1].v, p16uc_GETIMAG32); + + if(Conjugate) + { + blocki.packet[0] = -blocki.packet[0]; + } + + pstore(blockAt + rir, blockr.packet[0]); + pstore(blockAt + rii, blocki.packet[0]); + + rir += vectorSize; + rii += vectorSize; + } + + rir += ((PanelMode) ? (vectorSize*(2*stride - depth)) : vectorDelta); + } + + if (j < rows) + { + if(PanelMode) rir += (offset*(rows - j - vectorSize)); + rii = rir + (((PanelMode) ? stride : depth) * (rows - j)); + + for(Index i = 0; i < depth; i++) + { + Index k = j; + for(; k < rows; k++) + { + if (UseLhs) { + blockAt[rir] = lhs(k, i).real(); + + if(Conjugate) + blockAt[rii] = -lhs(k, i).imag(); + else + blockAt[rii] = lhs(k, i).imag(); + } else { + blockAt[rir] = lhs(i, k).real(); + + if(Conjugate) + blockAt[rii] = -lhs(i, k).imag(); + else + blockAt[rii] = lhs(i, k).imag(); + } + + rir += 1; + rii += 1; + } + } + } + } +}; + +// General template for lhs & rhs packing. +template +struct dhs_pack{ + EIGEN_STRONG_INLINE void operator()(Scalar* blockA, const DataMapper& lhs, Index depth, Index rows, Index stride, Index offset) + { + const Index vectorSize = quad_traits::vectorsize; + Index ri = 0, j = 0; + + for(; j + vectorSize <= rows; j+=vectorSize) + { + Index i = 0; + + if(PanelMode) ri += vectorSize*offset; + + for(; i + vectorSize <= depth; i+=vectorSize) + { + PacketBlock block; + + if (UseLhs) { + bload(block, lhs, j, i); + } else { + bload(block, lhs, i, j); + } + if(((StorageOrder == RowMajor) && UseLhs) || ((StorageOrder == ColMajor) && !UseLhs)) + { + ptranspose(block); + } + + storeBlock(blockA + ri, block); + + ri += 4*vectorSize; + } + for(; i < depth; i++) + { + if(((StorageOrder == RowMajor) && UseLhs) || ((StorageOrder == ColMajor) && !UseLhs)) + { + if (UseLhs) { + blockA[ri+0] = lhs(j+0, i); + blockA[ri+1] = lhs(j+1, i); + blockA[ri+2] = lhs(j+2, i); + blockA[ri+3] = lhs(j+3, i); + } else { + blockA[ri+0] = lhs(i, j+0); + blockA[ri+1] = lhs(i, j+1); + blockA[ri+2] = lhs(i, j+2); + blockA[ri+3] = lhs(i, j+3); + } + } else { + Packet lhsV; + if (UseLhs) { + lhsV = lhs.template loadPacket(j, i); + } else { + lhsV = lhs.template loadPacket(i, j); + } + pstore(blockA + ri, lhsV); + } + + ri += vectorSize; + } + + if(PanelMode) ri += vectorSize*(stride - offset - depth); + } + + if (j < rows) + { + if(PanelMode) ri += offset*(rows - j); + + for(Index i = 0; i < depth; i++) + { + Index k = j; + for(; k < rows; k++) + { + if (UseLhs) { + blockA[ri] = lhs(k, i); + } else { + blockA[ri] = lhs(i, k); + } + ri += 1; + } + } + } + } +}; + +// General template for lhs packing, float64 specialization. +template +struct dhs_pack +{ + EIGEN_STRONG_INLINE void operator()(double* blockA, const DataMapper& lhs, Index depth, Index rows, Index stride, Index offset) + { + const Index vectorSize = quad_traits::vectorsize; + Index ri = 0, j = 0; + + for(; j + vectorSize <= rows; j+=vectorSize) + { + Index i = 0; + + if(PanelMode) ri += vectorSize*offset; + + for(; i + vectorSize <= depth; i+=vectorSize) + { + PacketBlock block; + if(StorageOrder == RowMajor) + { + block.packet[0] = lhs.template loadPacket(j + 0, i); + block.packet[1] = lhs.template loadPacket(j + 1, i); + + ptranspose(block); + } else { + block.packet[0] = lhs.template loadPacket(j, i + 0); + block.packet[1] = lhs.template loadPacket(j, i + 1); + } + + storeBlock(blockA + ri, block); + + ri += 2*vectorSize; + } + for(; i < depth; i++) + { + if(StorageOrder == RowMajor) + { + blockA[ri+0] = lhs(j+0, i); + blockA[ri+1] = lhs(j+1, i); + } else { + Packet2d lhsV = lhs.template loadPacket(j, i); + pstore(blockA + ri, lhsV); + } + + ri += vectorSize; + } + + if(PanelMode) ri += vectorSize*(stride - offset - depth); + } + + if (j < rows) + { + if(PanelMode) ri += offset*(rows - j); + + for(Index i = 0; i < depth; i++) + { + Index k = j; + for(; k < rows; k++) + { + blockA[ri] = lhs(k, i); + ri += 1; + } + } + } + } +}; + +// General template for rhs packing, float64 specialization. +template +struct dhs_pack +{ + EIGEN_STRONG_INLINE void operator()(double* blockB, const DataMapper& rhs, Index depth, Index cols, Index stride, Index offset) + { + const Index vectorSize = quad_traits::vectorsize; + Index ri = 0, j = 0; + + for(; j + 2*vectorSize <= cols; j+=2*vectorSize) + { + Index i = 0; + + if(PanelMode) ri += offset*(2*vectorSize); + + for(; i + vectorSize <= depth; i+=vectorSize) + { + PacketBlock block; + if(StorageOrder == ColMajor) + { + PacketBlock block1, block2; + block1.packet[0] = rhs.template loadPacket(i, j + 0); + block1.packet[1] = rhs.template loadPacket(i, j + 1); + block2.packet[0] = rhs.template loadPacket(i, j + 2); + block2.packet[1] = rhs.template loadPacket(i, j + 3); + + ptranspose(block1); + ptranspose(block2); + + pstore(blockB + ri , block1.packet[0]); + pstore(blockB + ri + 2, block2.packet[0]); + pstore(blockB + ri + 4, block1.packet[1]); + pstore(blockB + ri + 6, block2.packet[1]); + } else { + block.packet[0] = rhs.template loadPacket(i + 0, j + 0); //[a1 a2] + block.packet[1] = rhs.template loadPacket(i + 0, j + 2); //[a3 a4] + block.packet[2] = rhs.template loadPacket(i + 1, j + 0); //[b1 b2] + block.packet[3] = rhs.template loadPacket(i + 1, j + 2); //[b3 b4] + + storeBlock(blockB + ri, block); + } + + ri += 4*vectorSize; + } + for(; i < depth; i++) + { + if(StorageOrder == ColMajor) + { + blockB[ri+0] = rhs(i, j+0); + blockB[ri+1] = rhs(i, j+1); + + ri += vectorSize; + + blockB[ri+0] = rhs(i, j+2); + blockB[ri+1] = rhs(i, j+3); + } else { + Packet2d rhsV = rhs.template loadPacket(i, j); + pstore(blockB + ri, rhsV); + + ri += vectorSize; + + rhsV = rhs.template loadPacket(i, j + 2); + pstore(blockB + ri, rhsV); + } + ri += vectorSize; + } + + if(PanelMode) ri += (2*vectorSize)*(stride - offset - depth); + } + + if (j < cols) + { + if(PanelMode) ri += offset*(cols - j); + + for(Index i = 0; i < depth; i++) + { + Index k = j; + for(; k < cols; k++) + { + blockB[ri] = rhs(i, k); + ri += 1; + } + } + } + } +}; + +// General template for lhs complex packing, float64 specialization. +template +struct dhs_cpack +{ + EIGEN_STRONG_INLINE void operator()(std::complex* blockA, const DataMapper& lhs, Index depth, Index rows, Index stride, Index offset) + { + const Index vectorSize = quad_traits::vectorsize; + const Index vectorDelta = vectorSize * ((PanelMode) ? stride : depth); + Index rir = ((PanelMode) ? (vectorSize*offset) : 0), rii; + double* blockAt = reinterpret_cast(blockA); + Index j = 0; + + for(; j + vectorSize <= rows; j+=vectorSize) + { + Index i = 0; + + rii = rir + vectorDelta; + + for(; i + vectorSize <= depth; i+=vectorSize) + { + PacketBlock blockr, blocki; + PacketBlock cblock; + + if(StorageOrder == ColMajor) + { + cblock.packet[0] = lhs.template loadPacket(j, i + 0); //[a1 a1i] + cblock.packet[1] = lhs.template loadPacket(j, i + 1); //[b1 b1i] + + cblock.packet[2] = lhs.template loadPacket(j + 1, i + 0); //[a2 a2i] + cblock.packet[3] = lhs.template loadPacket(j + 1, i + 1); //[b2 b2i] + + blockr.packet[0] = vec_perm(cblock.packet[0].v, cblock.packet[2].v, p16uc_GETREAL64); //[a1 a2] + blockr.packet[1] = vec_perm(cblock.packet[1].v, cblock.packet[3].v, p16uc_GETREAL64); //[b1 b2] + + blocki.packet[0] = vec_perm(cblock.packet[0].v, cblock.packet[2].v, p16uc_GETIMAG64); + blocki.packet[1] = vec_perm(cblock.packet[1].v, cblock.packet[3].v, p16uc_GETIMAG64); + } else { + cblock.packet[0] = lhs.template loadPacket(j + 0, i); //[a1 a1i] + cblock.packet[1] = lhs.template loadPacket(j + 1, i); //[a2 a2i] + + cblock.packet[2] = lhs.template loadPacket(j + 0, i + 1); //[b1 b1i] + cblock.packet[3] = lhs.template loadPacket(j + 1, i + 1); //[b2 b2i + + blockr.packet[0] = vec_perm(cblock.packet[0].v, cblock.packet[1].v, p16uc_GETREAL64); //[a1 a2] + blockr.packet[1] = vec_perm(cblock.packet[2].v, cblock.packet[3].v, p16uc_GETREAL64); //[b1 b2] + + blocki.packet[0] = vec_perm(cblock.packet[0].v, cblock.packet[1].v, p16uc_GETIMAG64); + blocki.packet[1] = vec_perm(cblock.packet[2].v, cblock.packet[3].v, p16uc_GETIMAG64); + } + + if(Conjugate) + { + blocki.packet[0] = -blocki.packet[0]; + blocki.packet[1] = -blocki.packet[1]; + } + + storeBlock(blockAt + rir, blockr); + storeBlock(blockAt + rii, blocki); + + rir += 2*vectorSize; + rii += 2*vectorSize; + } + for(; i < depth; i++) + { + PacketBlock blockr, blocki; + PacketBlock cblock; + + cblock.packet[0] = lhs.template loadPacket(j + 0, i); + cblock.packet[1] = lhs.template loadPacket(j + 1, i); + + blockr.packet[0] = vec_perm(cblock.packet[0].v, cblock.packet[1].v, p16uc_GETREAL64); + blocki.packet[0] = vec_perm(cblock.packet[0].v, cblock.packet[1].v, p16uc_GETIMAG64); + + if(Conjugate) + { + blocki.packet[0] = -blocki.packet[0]; + } + + pstore(blockAt + rir, blockr.packet[0]); + pstore(blockAt + rii, blocki.packet[0]); + + rir += vectorSize; + rii += vectorSize; + } + + rir += ((PanelMode) ? (vectorSize*(2*stride - depth)) : vectorDelta); + } + + if (j < rows) + { + if(PanelMode) rir += (offset*(rows - j - vectorSize)); + rii = rir + (((PanelMode) ? stride : depth) * (rows - j)); + + for(Index i = 0; i < depth; i++) + { + Index k = j; + for(; k < rows; k++) + { + blockAt[rir] = lhs(k, i).real(); + + if(Conjugate) + blockAt[rii] = -lhs(k, i).imag(); + else + blockAt[rii] = lhs(k, i).imag(); + + rir += 1; + rii += 1; + } + } + } + } +}; + +// General template for rhs complex packing, float64 specialization. +template +struct dhs_cpack +{ + EIGEN_STRONG_INLINE void operator()(std::complex* blockB, const DataMapper& rhs, Index depth, Index cols, Index stride, Index offset) + { + const Index vectorSize = quad_traits::vectorsize; + const Index vectorDelta = 2*vectorSize * ((PanelMode) ? stride : depth); + Index rir = ((PanelMode) ? (2*vectorSize*offset) : 0), rii; + double* blockBt = reinterpret_cast(blockB); + Index j = 0; + + for(; j + 2*vectorSize <= cols; j+=2*vectorSize) + { + Index i = 0; + + rii = rir + vectorDelta; + + for(; i < depth; i++) + { + PacketBlock cblock; + PacketBlock blockr, blocki; + + bload(cblock, rhs, i, j); + + blockr.packet[0] = vec_perm(cblock.packet[0].v, cblock.packet[1].v, p16uc_GETREAL64); + blockr.packet[1] = vec_perm(cblock.packet[2].v, cblock.packet[3].v, p16uc_GETREAL64); + + blocki.packet[0] = vec_perm(cblock.packet[0].v, cblock.packet[1].v, p16uc_GETIMAG64); + blocki.packet[1] = vec_perm(cblock.packet[2].v, cblock.packet[3].v, p16uc_GETIMAG64); + + if(Conjugate) + { + blocki.packet[0] = -blocki.packet[0]; + blocki.packet[1] = -blocki.packet[1]; + } + + storeBlock(blockBt + rir, blockr); + storeBlock(blockBt + rii, blocki); + + rir += 2*vectorSize; + rii += 2*vectorSize; + } + + rir += ((PanelMode) ? (2*vectorSize*(2*stride - depth)) : vectorDelta); + } + + if (j < cols) + { + if(PanelMode) rir += (offset*(cols - j - 2*vectorSize)); + rii = rir + (((PanelMode) ? stride : depth) * (cols - j)); + + for(Index i = 0; i < depth; i++) + { + Index k = j; + for(; k < cols; k++) + { + blockBt[rir] = rhs(i, k).real(); + + if(Conjugate) + blockBt[rii] = -rhs(i, k).imag(); + else + blockBt[rii] = rhs(i, k).imag(); + + rir += 1; + rii += 1; + } + } + } + } +}; + +/************** + * GEMM utils * + **************/ + +// 512-bits rank1-update of acc. It can either positive or negative accumulate (useful for complex gemm). +template +EIGEN_ALWAYS_INLINE void pger_common(PacketBlock* acc, const Packet& lhsV, const Packet* rhsV) +{ + if(NegativeAccumulate) + { + acc->packet[0] = vec_nmsub(lhsV, rhsV[0], acc->packet[0]); + acc->packet[1] = vec_nmsub(lhsV, rhsV[1], acc->packet[1]); + acc->packet[2] = vec_nmsub(lhsV, rhsV[2], acc->packet[2]); + acc->packet[3] = vec_nmsub(lhsV, rhsV[3], acc->packet[3]); + } else { + acc->packet[0] = vec_madd(lhsV, rhsV[0], acc->packet[0]); + acc->packet[1] = vec_madd(lhsV, rhsV[1], acc->packet[1]); + acc->packet[2] = vec_madd(lhsV, rhsV[2], acc->packet[2]); + acc->packet[3] = vec_madd(lhsV, rhsV[3], acc->packet[3]); + } +} + +template +EIGEN_ALWAYS_INLINE void pger_common(PacketBlock* acc, const Packet& lhsV, const Packet* rhsV) +{ + if(NegativeAccumulate) + { + acc->packet[0] = vec_nmsub(lhsV, rhsV[0], acc->packet[0]); + } else { + acc->packet[0] = vec_madd(lhsV, rhsV[0], acc->packet[0]); + } +} + +template +EIGEN_ALWAYS_INLINE void pger(PacketBlock* acc, const Scalar* lhs, const Packet* rhsV) +{ + Packet lhsV = pload(lhs); + + pger_common(acc, lhsV, rhsV); +} + +template +EIGEN_ALWAYS_INLINE void loadPacketRemaining(const Scalar* lhs, Packet &lhsV, Index remaining_rows) +{ +#ifdef _ARCH_PWR9 + lhsV = vec_xl_len((Scalar *)lhs, remaining_rows * sizeof(Scalar)); +#else + Index i = 0; + do { + lhsV[i] = lhs[i]; + } while (++i < remaining_rows); +#endif +} + +template +EIGEN_ALWAYS_INLINE void pger(PacketBlock* acc, const Scalar* lhs, const Packet* rhsV, Index remaining_rows) +{ + Packet lhsV; + loadPacketRemaining(lhs, lhsV, remaining_rows); + + pger_common(acc, lhsV, rhsV); +} + +// 512-bits rank1-update of complex acc. It takes decoupled accumulators as entries. It also takes cares of mixed types real * complex and complex * real. +template +EIGEN_ALWAYS_INLINE void pgerc_common(PacketBlock* accReal, PacketBlock* accImag, const Packet &lhsV, const Packet &lhsVi, const Packet* rhsV, const Packet* rhsVi) +{ + pger_common(accReal, lhsV, rhsV); + if(LhsIsReal) + { + pger_common(accImag, lhsV, rhsVi); + EIGEN_UNUSED_VARIABLE(lhsVi); + } else { + if (!RhsIsReal) { + pger_common(accReal, lhsVi, rhsVi); + pger_common(accImag, lhsV, rhsVi); + } else { + EIGEN_UNUSED_VARIABLE(rhsVi); + } + pger_common(accImag, lhsVi, rhsV); + } +} + +template +EIGEN_ALWAYS_INLINE void pgerc(PacketBlock* accReal, PacketBlock* accImag, const Scalar* lhs_ptr, const Scalar* lhs_ptr_imag, const Packet* rhsV, const Packet* rhsVi) +{ + Packet lhsV = ploadLhs(lhs_ptr); + Packet lhsVi; + if(!LhsIsReal) lhsVi = ploadLhs(lhs_ptr_imag); + else EIGEN_UNUSED_VARIABLE(lhs_ptr_imag); + + pgerc_common(accReal, accImag, lhsV, lhsVi, rhsV, rhsVi); +} + +template +EIGEN_ALWAYS_INLINE void loadPacketRemaining(const Scalar* lhs_ptr, const Scalar* lhs_ptr_imag, Packet &lhsV, Packet &lhsVi, Index remaining_rows) +{ +#ifdef _ARCH_PWR9 + lhsV = vec_xl_len((Scalar *)lhs_ptr, remaining_rows * sizeof(Scalar)); + if(!LhsIsReal) lhsVi = vec_xl_len((Scalar *)lhs_ptr_imag, remaining_rows * sizeof(Scalar)); + else EIGEN_UNUSED_VARIABLE(lhs_ptr_imag); +#else + Index i = 0; + do { + lhsV[i] = lhs_ptr[i]; + if(!LhsIsReal) lhsVi[i] = lhs_ptr_imag[i]; + } while (++i < remaining_rows); + if(LhsIsReal) EIGEN_UNUSED_VARIABLE(lhs_ptr_imag); +#endif +} + +template +EIGEN_ALWAYS_INLINE void pgerc(PacketBlock* accReal, PacketBlock* accImag, const Scalar* lhs_ptr, const Scalar* lhs_ptr_imag, const Packet* rhsV, const Packet* rhsVi, Index remaining_rows) +{ + Packet lhsV, lhsVi; + loadPacketRemaining(lhs_ptr, lhs_ptr_imag, lhsV, lhsVi, remaining_rows); + + pgerc_common(accReal, accImag, lhsV, lhsVi, rhsV, rhsVi); +} + +template +EIGEN_ALWAYS_INLINE Packet ploadLhs(const Scalar* lhs) +{ + return ploadu(lhs); +} + +// Zero the accumulator on PacketBlock. +template +EIGEN_ALWAYS_INLINE void bsetzero(PacketBlock& acc) +{ + acc.packet[0] = pset1((Scalar)0); + acc.packet[1] = pset1((Scalar)0); + acc.packet[2] = pset1((Scalar)0); + acc.packet[3] = pset1((Scalar)0); +} + +template +EIGEN_ALWAYS_INLINE void bsetzero(PacketBlock& acc) +{ + acc.packet[0] = pset1((Scalar)0); +} + +// Scale the PacketBlock vectors by alpha. +template +EIGEN_ALWAYS_INLINE void bscale(PacketBlock& acc, PacketBlock& accZ, const Packet& pAlpha) +{ + acc.packet[0] = pmadd(pAlpha, accZ.packet[0], acc.packet[0]); + acc.packet[1] = pmadd(pAlpha, accZ.packet[1], acc.packet[1]); + acc.packet[2] = pmadd(pAlpha, accZ.packet[2], acc.packet[2]); + acc.packet[3] = pmadd(pAlpha, accZ.packet[3], acc.packet[3]); +} + +template +EIGEN_ALWAYS_INLINE void bscale(PacketBlock& acc, PacketBlock& accZ, const Packet& pAlpha) +{ + acc.packet[0] = pmadd(pAlpha, accZ.packet[0], acc.packet[0]); +} + +template +EIGEN_ALWAYS_INLINE void bscalec_common(PacketBlock& acc, PacketBlock& accZ, const Packet& pAlpha) +{ + acc.packet[0] = pmul(accZ.packet[0], pAlpha); + acc.packet[1] = pmul(accZ.packet[1], pAlpha); + acc.packet[2] = pmul(accZ.packet[2], pAlpha); + acc.packet[3] = pmul(accZ.packet[3], pAlpha); +} + +template +EIGEN_ALWAYS_INLINE void bscalec_common(PacketBlock& acc, PacketBlock& accZ, const Packet& pAlpha) +{ + acc.packet[0] = pmul(accZ.packet[0], pAlpha); +} + +// Complex version of PacketBlock scaling. +template +EIGEN_ALWAYS_INLINE void bscalec(PacketBlock& aReal, PacketBlock& aImag, const Packet& bReal, const Packet& bImag, PacketBlock& cReal, PacketBlock& cImag) +{ + bscalec_common(cReal, aReal, bReal); + + bscalec_common(cImag, aImag, bReal); + + pger_common(&cReal, bImag, aImag.packet); + + pger_common(&cImag, bImag, aReal.packet); +} + +template +EIGEN_ALWAYS_INLINE void band(PacketBlock& acc, const Packet& pMask) +{ + acc.packet[0] = pand(acc.packet[0], pMask); + acc.packet[1] = pand(acc.packet[1], pMask); + acc.packet[2] = pand(acc.packet[2], pMask); + acc.packet[3] = pand(acc.packet[3], pMask); +} + +template +EIGEN_ALWAYS_INLINE void bscalec(PacketBlock& aReal, PacketBlock& aImag, const Packet& bReal, const Packet& bImag, PacketBlock& cReal, PacketBlock& cImag, const Packet& pMask) +{ + band(aReal, pMask); + band(aImag, pMask); + + bscalec(aReal, aImag, bReal, bImag, cReal, cImag); +} + +// Load a PacketBlock, the N parameters make tunning gemm easier so we can add more accumulators as needed. +template +EIGEN_ALWAYS_INLINE void bload(PacketBlock& acc, const DataMapper& res, Index row, Index col) +{ + if (StorageOrder == RowMajor) { + acc.packet[0] = res.template loadPacket(row + 0, col + N*accCols); + acc.packet[1] = res.template loadPacket(row + 1, col + N*accCols); + acc.packet[2] = res.template loadPacket(row + 2, col + N*accCols); + acc.packet[3] = res.template loadPacket(row + 3, col + N*accCols); + } else { + acc.packet[0] = res.template loadPacket(row + N*accCols, col + 0); + acc.packet[1] = res.template loadPacket(row + N*accCols, col + 1); + acc.packet[2] = res.template loadPacket(row + N*accCols, col + 2); + acc.packet[3] = res.template loadPacket(row + N*accCols, col + 3); + } +} + +// An overload of bload when you have a PacketBLock with 8 vectors. +template +EIGEN_ALWAYS_INLINE void bload(PacketBlock& acc, const DataMapper& res, Index row, Index col) +{ + if (StorageOrder == RowMajor) { + acc.packet[0] = res.template loadPacket(row + 0, col + N*accCols); + acc.packet[1] = res.template loadPacket(row + 1, col + N*accCols); + acc.packet[2] = res.template loadPacket(row + 2, col + N*accCols); + acc.packet[3] = res.template loadPacket(row + 3, col + N*accCols); + acc.packet[4] = res.template loadPacket(row + 0, col + (N+1)*accCols); + acc.packet[5] = res.template loadPacket(row + 1, col + (N+1)*accCols); + acc.packet[6] = res.template loadPacket(row + 2, col + (N+1)*accCols); + acc.packet[7] = res.template loadPacket(row + 3, col + (N+1)*accCols); + } else { + acc.packet[0] = res.template loadPacket(row + N*accCols, col + 0); + acc.packet[1] = res.template loadPacket(row + N*accCols, col + 1); + acc.packet[2] = res.template loadPacket(row + N*accCols, col + 2); + acc.packet[3] = res.template loadPacket(row + N*accCols, col + 3); + acc.packet[4] = res.template loadPacket(row + (N+1)*accCols, col + 0); + acc.packet[5] = res.template loadPacket(row + (N+1)*accCols, col + 1); + acc.packet[6] = res.template loadPacket(row + (N+1)*accCols, col + 2); + acc.packet[7] = res.template loadPacket(row + (N+1)*accCols, col + 3); + } +} + +template +EIGEN_ALWAYS_INLINE void bload(PacketBlock& acc, const DataMapper& res, Index row, Index col) +{ + acc.packet[0] = res.template loadPacket(row + N*accCols, col + 0); + acc.packet[1] = res.template loadPacket(row + (N+1)*accCols, col + 0); +} + +const static Packet4i mask41 = { -1, 0, 0, 0 }; +const static Packet4i mask42 = { -1, -1, 0, 0 }; +const static Packet4i mask43 = { -1, -1, -1, 0 }; + +const static Packet2l mask21 = { -1, 0 }; + +template +EIGEN_ALWAYS_INLINE Packet bmask(const int remaining_rows) +{ + if (remaining_rows == 0) { + return pset1(float(0.0)); // Not used + } else { + switch (remaining_rows) { + case 1: return Packet(mask41); + case 2: return Packet(mask42); + default: return Packet(mask43); + } + } +} + +template<> +EIGEN_ALWAYS_INLINE Packet2d bmask(const int remaining_rows) +{ + if (remaining_rows == 0) { + return pset1(double(0.0)); // Not used + } else { + return Packet2d(mask21); + } +} + +template +EIGEN_ALWAYS_INLINE void bscale(PacketBlock& acc, PacketBlock& accZ, const Packet& pAlpha, const Packet& pMask) +{ + band(accZ, pMask); + + bscale(acc, accZ, pAlpha); +} + +template +EIGEN_ALWAYS_INLINE void pbroadcast4_old(const __UNPACK_TYPE__(Packet)* a, Packet& a0, Packet& a1, Packet& a2, Packet& a3) +{ + pbroadcast4(a, a0, a1, a2, a3); +} + +template<> +EIGEN_ALWAYS_INLINE void pbroadcast4_old(const double* a, Packet2d& a0, Packet2d& a1, Packet2d& a2, Packet2d& a3) +{ + a1 = pload(a); + a3 = pload(a + 2); + a0 = vec_splat(a1, 0); + a1 = vec_splat(a1, 1); + a2 = vec_splat(a3, 0); + a3 = vec_splat(a3, 1); +} + +// PEEL loop factor. +#define PEEL 7 + +template +EIGEN_ALWAYS_INLINE void MICRO_EXTRA_COL( + const Scalar* &lhs_ptr, + const Scalar* &rhs_ptr, + PacketBlock &accZero, + Index remaining_rows, + Index remaining_cols) +{ + Packet rhsV[1]; + rhsV[0] = pset1(rhs_ptr[0]); + pger<1,Scalar, Packet, false>(&accZero, lhs_ptr, rhsV); + lhs_ptr += remaining_rows; + rhs_ptr += remaining_cols; +} + +template +EIGEN_STRONG_INLINE void gemm_extra_col( + const DataMapper& res, + const Scalar* lhs_base, + const Scalar* rhs_base, + Index depth, + Index strideA, + Index offsetA, + Index row, + Index col, + Index remaining_rows, + Index remaining_cols, + const Packet& pAlpha) +{ + const Scalar* rhs_ptr = rhs_base; + const Scalar* lhs_ptr = lhs_base + row*strideA + remaining_rows*offsetA; + PacketBlock accZero; + + bsetzero(accZero); + + Index remaining_depth = (depth & -accRows); + Index k = 0; + for(; k + PEEL <= remaining_depth; k+= PEEL) + { + EIGEN_POWER_PREFETCH(rhs_ptr); + EIGEN_POWER_PREFETCH(lhs_ptr); + for (int l = 0; l < PEEL; l++) { + MICRO_EXTRA_COL(lhs_ptr, rhs_ptr, accZero, remaining_rows, remaining_cols); + } + } + for(; k < remaining_depth; k++) + { + MICRO_EXTRA_COL(lhs_ptr, rhs_ptr, accZero, remaining_rows, remaining_cols); + } + for(; k < depth; k++) + { + Packet rhsV[1]; + rhsV[0] = pset1(rhs_ptr[0]); + pger<1, Scalar, Packet, Index, false>(&accZero, lhs_ptr, rhsV, remaining_rows); + lhs_ptr += remaining_rows; + rhs_ptr += remaining_cols; + } + + accZero.packet[0] = vec_mul(pAlpha, accZero.packet[0]); + for(Index i = 0; i < remaining_rows; i++) { + res(row + i, col) += accZero.packet[0][i]; + } +} + +template +EIGEN_ALWAYS_INLINE void MICRO_EXTRA_ROW( + const Scalar* &lhs_ptr, + const Scalar* &rhs_ptr, + PacketBlock &accZero, + Index remaining_rows) +{ + Packet rhsV[4]; + pbroadcast4(rhs_ptr, rhsV[0], rhsV[1], rhsV[2], rhsV[3]); + pger<4, Scalar, Packet, false>(&accZero, lhs_ptr, rhsV); + lhs_ptr += remaining_rows; + rhs_ptr += accRows; +} + +template +EIGEN_STRONG_INLINE void gemm_extra_row( + const DataMapper& res, + const Scalar* lhs_base, + const Scalar* rhs_base, + Index depth, + Index strideA, + Index offsetA, + Index row, + Index col, + Index rows, + Index cols, + Index remaining_rows, + const Packet& pAlpha, + const Packet& pMask) +{ + const Scalar* rhs_ptr = rhs_base; + const Scalar* lhs_ptr = lhs_base + row*strideA + remaining_rows*offsetA; + PacketBlock accZero, acc; + + bsetzero(accZero); + + Index remaining_depth = (col + accRows < cols) ? depth : (depth & -accRows); + Index k = 0; + for(; k + PEEL <= remaining_depth; k+= PEEL) + { + EIGEN_POWER_PREFETCH(rhs_ptr); + EIGEN_POWER_PREFETCH(lhs_ptr); + for (int l = 0; l < PEEL; l++) { + MICRO_EXTRA_ROW(lhs_ptr, rhs_ptr, accZero, remaining_rows); + } + } + for(; k < remaining_depth; k++) + { + MICRO_EXTRA_ROW(lhs_ptr, rhs_ptr, accZero, remaining_rows); + } + + if ((remaining_depth == depth) && (rows >= accCols)) + { + for(Index j = 0; j < 4; j++) { + acc.packet[j] = res.template loadPacket(row, col + j); + } + bscale(acc, accZero, pAlpha, pMask); + res.template storePacketBlock(row, col, acc); + } else { + for(; k < depth; k++) + { + Packet rhsV[4]; + pbroadcast4(rhs_ptr, rhsV[0], rhsV[1], rhsV[2], rhsV[3]); + pger<4, Scalar, Packet, Index, false>(&accZero, lhs_ptr, rhsV, remaining_rows); + lhs_ptr += remaining_rows; + rhs_ptr += accRows; + } + + for(Index j = 0; j < 4; j++) { + accZero.packet[j] = vec_mul(pAlpha, accZero.packet[j]); + } + for(Index j = 0; j < 4; j++) { + for(Index i = 0; i < remaining_rows; i++) { + res(row + i, col + j) += accZero.packet[j][i]; + } + } + } +} + +#define MICRO_UNROLL(func) \ + func(0) func(1) func(2) func(3) func(4) func(5) func(6) func(7) + +#define MICRO_UNROLL_WORK(func, func2, peel) \ + MICRO_UNROLL(func2); \ + func(0,peel) func(1,peel) func(2,peel) func(3,peel) \ + func(4,peel) func(5,peel) func(6,peel) func(7,peel) + +#define MICRO_LOAD_ONE(iter) \ + if (unroll_factor > iter) { \ + lhsV##iter = ploadLhs(lhs_ptr##iter); \ + lhs_ptr##iter += accCols; \ + } else { \ + EIGEN_UNUSED_VARIABLE(lhsV##iter); \ + } + +#define MICRO_WORK_ONE(iter, peel) \ + if (unroll_factor > iter) { \ + pger_common(&accZero##iter, lhsV##iter, rhsV##peel); \ + } + +#define MICRO_TYPE_PEEL4(func, func2, peel) \ + if (PEEL > peel) { \ + Packet lhsV0, lhsV1, lhsV2, lhsV3, lhsV4, lhsV5, lhsV6, lhsV7; \ + pbroadcast4(rhs_ptr + (accRows * peel), rhsV##peel[0], rhsV##peel[1], rhsV##peel[2], rhsV##peel[3]); \ + MICRO_UNROLL_WORK(func, func2, peel) \ + } else { \ + EIGEN_UNUSED_VARIABLE(rhsV##peel); \ + } + +#define MICRO_TYPE_PEEL1(func, func2, peel) \ + if (PEEL > peel) { \ + Packet lhsV0, lhsV1, lhsV2, lhsV3, lhsV4, lhsV5, lhsV6, lhsV7; \ + rhsV##peel[0] = pset1(rhs_ptr[remaining_cols * peel]); \ + MICRO_UNROLL_WORK(func, func2, peel) \ + } else { \ + EIGEN_UNUSED_VARIABLE(rhsV##peel); \ + } + +#define MICRO_UNROLL_TYPE_PEEL(M, func, func1, func2) \ + Packet rhsV0[M], rhsV1[M], rhsV2[M], rhsV3[M], rhsV4[M], rhsV5[M], rhsV6[M], rhsV7[M], rhsV8[M], rhsV9[M]; \ + func(func1,func2,0); func(func1,func2,1); \ + func(func1,func2,2); func(func1,func2,3); \ + func(func1,func2,4); func(func1,func2,5); \ + func(func1,func2,6); func(func1,func2,7); \ + func(func1,func2,8); func(func1,func2,9); + +#define MICRO_UNROLL_TYPE_ONE(M, func, func1, func2) \ + Packet rhsV0[M]; \ + func(func1,func2,0); + +#define MICRO_ONE_PEEL4 \ + MICRO_UNROLL_TYPE_PEEL(4, MICRO_TYPE_PEEL4, MICRO_WORK_ONE, MICRO_LOAD_ONE); \ + rhs_ptr += (accRows * PEEL); + +#define MICRO_ONE4 \ + MICRO_UNROLL_TYPE_ONE(4, MICRO_TYPE_PEEL4, MICRO_WORK_ONE, MICRO_LOAD_ONE); \ + rhs_ptr += accRows; + +#define MICRO_ONE_PEEL1 \ + MICRO_UNROLL_TYPE_PEEL(1, MICRO_TYPE_PEEL1, MICRO_WORK_ONE, MICRO_LOAD_ONE); \ + rhs_ptr += (remaining_cols * PEEL); + +#define MICRO_ONE1 \ + MICRO_UNROLL_TYPE_ONE(1, MICRO_TYPE_PEEL1, MICRO_WORK_ONE, MICRO_LOAD_ONE); \ + rhs_ptr += remaining_cols; + +#define MICRO_DST_PTR_ONE(iter) \ + if (unroll_factor > iter) { \ + bsetzero(accZero##iter); \ + } else { \ + EIGEN_UNUSED_VARIABLE(accZero##iter); \ + } + +#define MICRO_DST_PTR MICRO_UNROLL(MICRO_DST_PTR_ONE) + +#define MICRO_SRC_PTR_ONE(iter) \ + if (unroll_factor > iter) { \ + lhs_ptr##iter = lhs_base + ( (row/accCols) + iter )*strideA*accCols + accCols*offsetA; \ + } else { \ + EIGEN_UNUSED_VARIABLE(lhs_ptr##iter); \ + } + +#define MICRO_SRC_PTR MICRO_UNROLL(MICRO_SRC_PTR_ONE) + +#define MICRO_PREFETCH_ONE(iter) \ + if (unroll_factor > iter) { \ + EIGEN_POWER_PREFETCH(lhs_ptr##iter); \ + } + +#define MICRO_PREFETCH MICRO_UNROLL(MICRO_PREFETCH_ONE) + +#define MICRO_STORE_ONE(iter) \ + if (unroll_factor > iter) { \ + acc.packet[0] = res.template loadPacket(row + iter*accCols, col + 0); \ + acc.packet[1] = res.template loadPacket(row + iter*accCols, col + 1); \ + acc.packet[2] = res.template loadPacket(row + iter*accCols, col + 2); \ + acc.packet[3] = res.template loadPacket(row + iter*accCols, col + 3); \ + bscale(acc, accZero##iter, pAlpha); \ + res.template storePacketBlock(row + iter*accCols, col, acc); \ + } + +#define MICRO_STORE MICRO_UNROLL(MICRO_STORE_ONE) + +#define MICRO_COL_STORE_ONE(iter) \ + if (unroll_factor > iter) { \ + acc.packet[0] = res.template loadPacket(row + iter*accCols, col + 0); \ + bscale(acc, accZero##iter, pAlpha); \ + res.template storePacketBlock(row + iter*accCols, col, acc); \ + } + +#define MICRO_COL_STORE MICRO_UNROLL(MICRO_COL_STORE_ONE) + +template +EIGEN_STRONG_INLINE void gemm_unrolled_iteration( + const DataMapper& res, + const Scalar* lhs_base, + const Scalar* rhs_base, + Index depth, + Index strideA, + Index offsetA, + Index& row, + Index col, + const Packet& pAlpha) +{ + const Scalar* rhs_ptr = rhs_base; + const Scalar* lhs_ptr0 = NULL, * lhs_ptr1 = NULL, * lhs_ptr2 = NULL, * lhs_ptr3 = NULL, * lhs_ptr4 = NULL, * lhs_ptr5 = NULL, * lhs_ptr6 = NULL, * lhs_ptr7 = NULL; + PacketBlock accZero0, accZero1, accZero2, accZero3, accZero4, accZero5, accZero6, accZero7; + PacketBlock acc; + + MICRO_SRC_PTR + MICRO_DST_PTR + + Index k = 0; + for(; k + PEEL <= depth; k+= PEEL) + { + EIGEN_POWER_PREFETCH(rhs_ptr); + MICRO_PREFETCH + MICRO_ONE_PEEL4 + } + for(; k < depth; k++) + { + MICRO_ONE4 + } + MICRO_STORE + + row += unroll_factor*accCols; +} + +template +EIGEN_STRONG_INLINE void gemm_unrolled_col_iteration( + const DataMapper& res, + const Scalar* lhs_base, + const Scalar* rhs_base, + Index depth, + Index strideA, + Index offsetA, + Index& row, + Index col, + Index remaining_cols, + const Packet& pAlpha) +{ + const Scalar* rhs_ptr = rhs_base; + const Scalar* lhs_ptr0 = NULL, * lhs_ptr1 = NULL, * lhs_ptr2 = NULL, * lhs_ptr3 = NULL, * lhs_ptr4 = NULL, * lhs_ptr5 = NULL, * lhs_ptr6 = NULL, *lhs_ptr7 = NULL; + PacketBlock accZero0, accZero1, accZero2, accZero3, accZero4, accZero5, accZero6, accZero7; + PacketBlock acc; + + MICRO_SRC_PTR + MICRO_DST_PTR + + Index k = 0; + for(; k + PEEL <= depth; k+= PEEL) + { + EIGEN_POWER_PREFETCH(rhs_ptr); + MICRO_PREFETCH + MICRO_ONE_PEEL1 + } + for(; k < depth; k++) + { + MICRO_ONE1 + } + MICRO_COL_STORE + + row += unroll_factor*accCols; +} + +template +EIGEN_STRONG_INLINE void gemm_unrolled_col( + const DataMapper& res, + const Scalar* lhs_base, + const Scalar* rhs_base, + Index depth, + Index strideA, + Index offsetA, + Index& row, + Index rows, + Index col, + Index remaining_cols, + const Packet& pAlpha) +{ +#define MAX_UNROLL 6 + while(row + MAX_UNROLL*accCols <= rows) { + gemm_unrolled_col_iteration(res, lhs_base, rhs_base, depth, strideA, offsetA, row, col, remaining_cols, pAlpha); + } + switch( (rows-row)/accCols ) { +#if MAX_UNROLL > 7 + case 7: + gemm_unrolled_col_iteration<7, Scalar, Packet, DataMapper, Index, accCols>(res, lhs_base, rhs_base, depth, strideA, offsetA, row, col, remaining_cols, pAlpha); + break; +#endif +#if MAX_UNROLL > 6 + case 6: + gemm_unrolled_col_iteration<6, Scalar, Packet, DataMapper, Index, accCols>(res, lhs_base, rhs_base, depth, strideA, offsetA, row, col, remaining_cols, pAlpha); + break; +#endif +#if MAX_UNROLL > 5 + case 5: + gemm_unrolled_col_iteration<5, Scalar, Packet, DataMapper, Index, accCols>(res, lhs_base, rhs_base, depth, strideA, offsetA, row, col, remaining_cols, pAlpha); + break; +#endif +#if MAX_UNROLL > 4 + case 4: + gemm_unrolled_col_iteration<4, Scalar, Packet, DataMapper, Index, accCols>(res, lhs_base, rhs_base, depth, strideA, offsetA, row, col, remaining_cols, pAlpha); + break; +#endif +#if MAX_UNROLL > 3 + case 3: + gemm_unrolled_col_iteration<3, Scalar, Packet, DataMapper, Index, accCols>(res, lhs_base, rhs_base, depth, strideA, offsetA, row, col, remaining_cols, pAlpha); + break; +#endif +#if MAX_UNROLL > 2 + case 2: + gemm_unrolled_col_iteration<2, Scalar, Packet, DataMapper, Index, accCols>(res, lhs_base, rhs_base, depth, strideA, offsetA, row, col, remaining_cols, pAlpha); + break; +#endif +#if MAX_UNROLL > 1 + case 1: + gemm_unrolled_col_iteration<1, Scalar, Packet, DataMapper, Index, accCols>(res, lhs_base, rhs_base, depth, strideA, offsetA, row, col, remaining_cols, pAlpha); + break; +#endif + default: + break; + } +#undef MAX_UNROLL +} + +/**************** + * GEMM kernels * + * **************/ +template +EIGEN_STRONG_INLINE void gemm(const DataMapper& res, const Scalar* blockA, const Scalar* blockB, Index rows, Index depth, Index cols, Scalar alpha, Index strideA, Index strideB, Index offsetA, Index offsetB) +{ + const Index remaining_rows = rows % accCols; + const Index remaining_cols = cols % accRows; + + if( strideA == -1 ) strideA = depth; + if( strideB == -1 ) strideB = depth; + + const Packet pAlpha = pset1(alpha); + const Packet pMask = bmask((const int)(remaining_rows)); + + Index col = 0; + for(; col + accRows <= cols; col += accRows) + { + const Scalar* rhs_base = blockB + col*strideB + accRows*offsetB; + const Scalar* lhs_base = blockA; + Index row = 0; + +#define MAX_UNROLL 6 + while(row + MAX_UNROLL*accCols <= rows) { + gemm_unrolled_iteration(res, lhs_base, rhs_base, depth, strideA, offsetA, row, col, pAlpha); + } + switch( (rows-row)/accCols ) { +#if MAX_UNROLL > 7 + case 7: + gemm_unrolled_iteration<7, Scalar, Packet, DataMapper, Index, accRows, accCols>(res, lhs_base, rhs_base, depth, strideA, offsetA, row, col, pAlpha); + break; +#endif +#if MAX_UNROLL > 6 + case 6: + gemm_unrolled_iteration<6, Scalar, Packet, DataMapper, Index, accRows, accCols>(res, lhs_base, rhs_base, depth, strideA, offsetA, row, col, pAlpha); + break; +#endif +#if MAX_UNROLL > 5 + case 5: + gemm_unrolled_iteration<5, Scalar, Packet, DataMapper, Index, accRows, accCols>(res, lhs_base, rhs_base, depth, strideA, offsetA, row, col, pAlpha); + break; +#endif +#if MAX_UNROLL > 4 + case 4: + gemm_unrolled_iteration<4, Scalar, Packet, DataMapper, Index, accRows, accCols>(res, lhs_base, rhs_base, depth, strideA, offsetA, row, col, pAlpha); + break; +#endif +#if MAX_UNROLL > 3 + case 3: + gemm_unrolled_iteration<3, Scalar, Packet, DataMapper, Index, accRows, accCols>(res, lhs_base, rhs_base, depth, strideA, offsetA, row, col, pAlpha); + break; +#endif +#if MAX_UNROLL > 2 + case 2: + gemm_unrolled_iteration<2, Scalar, Packet, DataMapper, Index, accRows, accCols>(res, lhs_base, rhs_base, depth, strideA, offsetA, row, col, pAlpha); + break; +#endif +#if MAX_UNROLL > 1 + case 1: + gemm_unrolled_iteration<1, Scalar, Packet, DataMapper, Index, accRows, accCols>(res, lhs_base, rhs_base, depth, strideA, offsetA, row, col, pAlpha); + break; +#endif + default: + break; + } +#undef MAX_UNROLL + + if(remaining_rows > 0) + { + gemm_extra_row(res, lhs_base, rhs_base, depth, strideA, offsetA, row, col, rows, cols, remaining_rows, pAlpha, pMask); + } + } + + if(remaining_cols > 0) + { + const Scalar* rhs_base = blockB + col*strideB + remaining_cols*offsetB; + const Scalar* lhs_base = blockA; + + for(; col < cols; col++) + { + Index row = 0; + + gemm_unrolled_col(res, lhs_base, rhs_base, depth, strideA, offsetA, row, rows, col, remaining_cols, pAlpha); + + if (remaining_rows > 0) + { + gemm_extra_col(res, lhs_base, rhs_base, depth, strideA, offsetA, row, col, remaining_rows, remaining_cols, pAlpha); + } + rhs_base++; + } + } +} + +#define accColsC (accCols / 2) +#define advanceRows ((LhsIsReal) ? 1 : 2) +#define advanceCols ((RhsIsReal) ? 1 : 2) + +// PEEL_COMPLEX loop factor. +#define PEEL_COMPLEX 3 + +template +EIGEN_ALWAYS_INLINE void MICRO_COMPLEX_EXTRA_COL( + const Scalar* &lhs_ptr_real, const Scalar* &lhs_ptr_imag, + const Scalar* &rhs_ptr_real, const Scalar* &rhs_ptr_imag, + PacketBlock &accReal, PacketBlock &accImag, + Index remaining_rows, + Index remaining_cols) +{ + Packet rhsV[1], rhsVi[1]; + rhsV[0] = pset1(rhs_ptr_real[0]); + if(!RhsIsReal) rhsVi[0] = pset1(rhs_ptr_imag[0]); + pgerc<1, Scalar, Packet, ConjugateLhs, ConjugateRhs, LhsIsReal, RhsIsReal>(&accReal, &accImag, lhs_ptr_real, lhs_ptr_imag, rhsV, rhsVi); + lhs_ptr_real += remaining_rows; + if(!LhsIsReal) lhs_ptr_imag += remaining_rows; + else EIGEN_UNUSED_VARIABLE(lhs_ptr_imag); + rhs_ptr_real += remaining_cols; + if(!RhsIsReal) rhs_ptr_imag += remaining_cols; + else EIGEN_UNUSED_VARIABLE(rhs_ptr_imag); +} + +template +EIGEN_STRONG_INLINE void gemm_complex_extra_col( + const DataMapper& res, + const Scalar* lhs_base, + const Scalar* rhs_base, + Index depth, + Index strideA, + Index offsetA, + Index strideB, + Index row, + Index col, + Index remaining_rows, + Index remaining_cols, + const Packet& pAlphaReal, + const Packet& pAlphaImag) +{ + const Scalar* rhs_ptr_real = rhs_base; + const Scalar* rhs_ptr_imag; + if(!RhsIsReal) rhs_ptr_imag = rhs_base + remaining_cols*strideB; + else EIGEN_UNUSED_VARIABLE(rhs_ptr_imag); + const Scalar* lhs_ptr_real = lhs_base + advanceRows*row*strideA + remaining_rows*offsetA; + const Scalar* lhs_ptr_imag; + if(!LhsIsReal) lhs_ptr_imag = lhs_ptr_real + remaining_rows*strideA; + else EIGEN_UNUSED_VARIABLE(lhs_ptr_imag); + PacketBlock accReal, accImag; + PacketBlock taccReal, taccImag; + PacketBlock acc0, acc1; + + bsetzero(accReal); + bsetzero(accImag); + + Index remaining_depth = (depth & -accRows); + Index k = 0; + for(; k + PEEL_COMPLEX <= remaining_depth; k+= PEEL_COMPLEX) + { + EIGEN_POWER_PREFETCH(rhs_ptr_real); + if(!RhsIsReal) { + EIGEN_POWER_PREFETCH(rhs_ptr_imag); + } + EIGEN_POWER_PREFETCH(lhs_ptr_real); + if(!LhsIsReal) { + EIGEN_POWER_PREFETCH(lhs_ptr_imag); + } + for (int l = 0; l < PEEL_COMPLEX; l++) { + MICRO_COMPLEX_EXTRA_COL(lhs_ptr_real, lhs_ptr_imag, rhs_ptr_real, rhs_ptr_imag, accReal, accImag, remaining_rows, remaining_cols); + } + } + for(; k < remaining_depth; k++) + { + MICRO_COMPLEX_EXTRA_COL(lhs_ptr_real, lhs_ptr_imag, rhs_ptr_real, rhs_ptr_imag, accReal, accImag, remaining_rows, remaining_cols); + } + + for(; k < depth; k++) + { + Packet rhsV[1], rhsVi[1]; + rhsV[0] = pset1(rhs_ptr_real[0]); + if(!RhsIsReal) rhsVi[0] = pset1(rhs_ptr_imag[0]); + pgerc<1, Scalar, Packet, Index, ConjugateLhs, ConjugateRhs, LhsIsReal, RhsIsReal>(&accReal, &accImag, lhs_ptr_real, lhs_ptr_imag, rhsV, rhsVi, remaining_rows); + lhs_ptr_real += remaining_rows; + if(!LhsIsReal) lhs_ptr_imag += remaining_rows; + rhs_ptr_real += remaining_cols; + if(!RhsIsReal) rhs_ptr_imag += remaining_cols; + } + + bscalec(accReal, accImag, pAlphaReal, pAlphaImag, taccReal, taccImag); + bcouple_common(taccReal, taccImag, acc0, acc1); + + if ((sizeof(Scalar) == sizeof(float)) && (remaining_rows == 1)) + { + res(row + 0, col + 0) += pfirst(acc0.packet[0]); + } else { + acc0.packet[0] += res.template loadPacket(row + 0, col + 0); + res.template storePacketBlock(row + 0, col + 0, acc0); + if(remaining_rows > accColsC) { + res(row + accColsC, col + 0) += pfirst(acc1.packet[0]); + } + } +} + +template +EIGEN_ALWAYS_INLINE void MICRO_COMPLEX_EXTRA_ROW( + const Scalar* &lhs_ptr_real, const Scalar* &lhs_ptr_imag, + const Scalar* &rhs_ptr_real, const Scalar* &rhs_ptr_imag, + PacketBlock &accReal, PacketBlock &accImag, + Index remaining_rows) +{ + Packet rhsV[4], rhsVi[4]; + pbroadcast4_old(rhs_ptr_real, rhsV[0], rhsV[1], rhsV[2], rhsV[3]); + if(!RhsIsReal) pbroadcast4_old(rhs_ptr_imag, rhsVi[0], rhsVi[1], rhsVi[2], rhsVi[3]); + pgerc<4, Scalar, Packet, ConjugateLhs, ConjugateRhs, LhsIsReal, RhsIsReal>(&accReal, &accImag, lhs_ptr_real, lhs_ptr_imag, rhsV, rhsVi); + lhs_ptr_real += remaining_rows; + if(!LhsIsReal) lhs_ptr_imag += remaining_rows; + else EIGEN_UNUSED_VARIABLE(lhs_ptr_imag); + rhs_ptr_real += accRows; + if(!RhsIsReal) rhs_ptr_imag += accRows; + else EIGEN_UNUSED_VARIABLE(rhs_ptr_imag); +} + +template +EIGEN_STRONG_INLINE void gemm_complex_extra_row( + const DataMapper& res, + const Scalar* lhs_base, + const Scalar* rhs_base, + Index depth, + Index strideA, + Index offsetA, + Index strideB, + Index row, + Index col, + Index rows, + Index cols, + Index remaining_rows, + const Packet& pAlphaReal, + const Packet& pAlphaImag, + const Packet& pMask) +{ + const Scalar* rhs_ptr_real = rhs_base; + const Scalar* rhs_ptr_imag; + if(!RhsIsReal) rhs_ptr_imag = rhs_base + accRows*strideB; + else EIGEN_UNUSED_VARIABLE(rhs_ptr_imag); + const Scalar* lhs_ptr_real = lhs_base + advanceRows*row*strideA + remaining_rows*offsetA; + const Scalar* lhs_ptr_imag; + if(!LhsIsReal) lhs_ptr_imag = lhs_ptr_real + remaining_rows*strideA; + else EIGEN_UNUSED_VARIABLE(lhs_ptr_imag); + PacketBlock accReal, accImag; + PacketBlock taccReal, taccImag; + PacketBlock acc0, acc1; + PacketBlock tRes; + + bsetzero(accReal); + bsetzero(accImag); + + Index remaining_depth = (col + accRows < cols) ? depth : (depth & -accRows); + Index k = 0; + for(; k + PEEL_COMPLEX <= remaining_depth; k+= PEEL_COMPLEX) + { + EIGEN_POWER_PREFETCH(rhs_ptr_real); + if(!RhsIsReal) { + EIGEN_POWER_PREFETCH(rhs_ptr_imag); + } + EIGEN_POWER_PREFETCH(lhs_ptr_real); + if(!LhsIsReal) { + EIGEN_POWER_PREFETCH(lhs_ptr_imag); + } + for (int l = 0; l < PEEL_COMPLEX; l++) { + MICRO_COMPLEX_EXTRA_ROW(lhs_ptr_real, lhs_ptr_imag, rhs_ptr_real, rhs_ptr_imag, accReal, accImag, remaining_rows); + } + } + for(; k < remaining_depth; k++) + { + MICRO_COMPLEX_EXTRA_ROW(lhs_ptr_real, lhs_ptr_imag, rhs_ptr_real, rhs_ptr_imag, accReal, accImag, remaining_rows); + } + + if ((remaining_depth == depth) && (rows >= accCols)) + { + bload(tRes, res, row, col); + bscalec(accReal, accImag, pAlphaReal, pAlphaImag, taccReal, taccImag, pMask); + bcouple(taccReal, taccImag, tRes, acc0, acc1); + res.template storePacketBlock(row + 0, col, acc0); + res.template storePacketBlock(row + accColsC, col, acc1); + } else { + for(; k < depth; k++) + { + Packet rhsV[4], rhsVi[4]; + pbroadcast4_old(rhs_ptr_real, rhsV[0], rhsV[1], rhsV[2], rhsV[3]); + if(!RhsIsReal) pbroadcast4_old(rhs_ptr_imag, rhsVi[0], rhsVi[1], rhsVi[2], rhsVi[3]); + pgerc<4, Scalar, Packet, Index, ConjugateLhs, ConjugateRhs, LhsIsReal, RhsIsReal>(&accReal, &accImag, lhs_ptr_real, lhs_ptr_imag, rhsV, rhsVi, remaining_rows); + lhs_ptr_real += remaining_rows; + if(!LhsIsReal) lhs_ptr_imag += remaining_rows; + rhs_ptr_real += accRows; + if(!RhsIsReal) rhs_ptr_imag += accRows; + } + + bscalec(accReal, accImag, pAlphaReal, pAlphaImag, taccReal, taccImag); + bcouple_common(taccReal, taccImag, acc0, acc1); + + if ((sizeof(Scalar) == sizeof(float)) && (remaining_rows == 1)) + { + for(Index j = 0; j < 4; j++) { + res(row + 0, col + j) += pfirst(acc0.packet[j]); + } + } else { + for(Index j = 0; j < 4; j++) { + PacketBlock acc2; + acc2.packet[0] = res.template loadPacket(row + 0, col + j) + acc0.packet[j]; + res.template storePacketBlock(row + 0, col + j, acc2); + if(remaining_rows > accColsC) { + res(row + accColsC, col + j) += pfirst(acc1.packet[j]); + } + } + } + } +} + +#define MICRO_COMPLEX_UNROLL(func) \ + func(0) func(1) func(2) func(3) func(4) + +#define MICRO_COMPLEX_UNROLL_WORK(func, func2, peel) \ + MICRO_COMPLEX_UNROLL(func2); \ + func(0,peel) func(1,peel) func(2,peel) func(3,peel) func(4,peel) + +#define MICRO_COMPLEX_LOAD_ONE(iter) \ + if (unroll_factor > iter) { \ + lhsV##iter = ploadLhs(lhs_ptr_real##iter); \ + lhs_ptr_real##iter += accCols; \ + if(!LhsIsReal) { \ + lhsVi##iter = ploadLhs(lhs_ptr_imag##iter); \ + lhs_ptr_imag##iter += accCols; \ + } else { \ + EIGEN_UNUSED_VARIABLE(lhsVi##iter); \ + } \ + } else { \ + EIGEN_UNUSED_VARIABLE(lhsV##iter); \ + EIGEN_UNUSED_VARIABLE(lhsVi##iter); \ + } + +#define MICRO_COMPLEX_WORK_ONE4(iter, peel) \ + if (unroll_factor > iter) { \ + pgerc_common<4, Packet, ConjugateLhs, ConjugateRhs, LhsIsReal, RhsIsReal>(&accReal##iter, &accImag##iter, lhsV##iter, lhsVi##iter, rhsV##peel, rhsVi##peel); \ + } + +#define MICRO_COMPLEX_WORK_ONE1(iter, peel) \ + if (unroll_factor > iter) { \ + pgerc_common<1, Packet, ConjugateLhs, ConjugateRhs, LhsIsReal, RhsIsReal>(&accReal##iter, &accImag##iter, lhsV##iter, lhsVi##iter, rhsV##peel, rhsVi##peel); \ + } + +#define MICRO_COMPLEX_TYPE_PEEL4(func, func2, peel) \ + if (PEEL_COMPLEX > peel) { \ + Packet lhsV0, lhsV1, lhsV2, lhsV3, lhsV4; \ + Packet lhsVi0, lhsVi1, lhsVi2, lhsVi3, lhsVi4; \ + pbroadcast4_old(rhs_ptr_real + (accRows * peel), rhsV##peel[0], rhsV##peel[1], rhsV##peel[2], rhsV##peel[3]); \ + if(!RhsIsReal) { \ + pbroadcast4_old(rhs_ptr_imag + (accRows * peel), rhsVi##peel[0], rhsVi##peel[1], rhsVi##peel[2], rhsVi##peel[3]); \ + } else { \ + EIGEN_UNUSED_VARIABLE(rhsVi##peel); \ + } \ + MICRO_COMPLEX_UNROLL_WORK(func, func2, peel) \ + } else { \ + EIGEN_UNUSED_VARIABLE(rhsV##peel); \ + EIGEN_UNUSED_VARIABLE(rhsVi##peel); \ + } + +#define MICRO_COMPLEX_TYPE_PEEL1(func, func2, peel) \ + if (PEEL_COMPLEX > peel) { \ + Packet lhsV0, lhsV1, lhsV2, lhsV3, lhsV4; \ + Packet lhsVi0, lhsVi1, lhsVi2, lhsVi3, lhsVi4; \ + rhsV##peel[0] = pset1(rhs_ptr_real[remaining_cols * peel]); \ + if(!RhsIsReal) { \ + rhsVi##peel[0] = pset1(rhs_ptr_imag[remaining_cols * peel]); \ + } else { \ + EIGEN_UNUSED_VARIABLE(rhsVi##peel); \ + } \ + MICRO_COMPLEX_UNROLL_WORK(func, func2, peel) \ + } else { \ + EIGEN_UNUSED_VARIABLE(rhsV##peel); \ + EIGEN_UNUSED_VARIABLE(rhsVi##peel); \ + } + +#define MICRO_COMPLEX_UNROLL_TYPE_PEEL(M, func, func1, func2) \ + Packet rhsV0[M], rhsV1[M], rhsV2[M], rhsV3[M], rhsV4[M], rhsV5[M], rhsV6[M], rhsV7[M], rhsV8[M], rhsV9[M]; \ + Packet rhsVi0[M], rhsVi1[M], rhsVi2[M], rhsVi3[M], rhsVi4[M], rhsVi5[M], rhsVi6[M], rhsVi7[M], rhsVi8[M], rhsVi9[M]; \ + func(func1,func2,0); func(func1,func2,1); \ + func(func1,func2,2); func(func1,func2,3); \ + func(func1,func2,4); func(func1,func2,5); \ + func(func1,func2,6); func(func1,func2,7); \ + func(func1,func2,8); func(func1,func2,9); + +#define MICRO_COMPLEX_UNROLL_TYPE_ONE(M, func, func1, func2) \ + Packet rhsV0[M], rhsVi0[M];\ + func(func1,func2,0); + +#define MICRO_COMPLEX_ONE_PEEL4 \ + MICRO_COMPLEX_UNROLL_TYPE_PEEL(4, MICRO_COMPLEX_TYPE_PEEL4, MICRO_COMPLEX_WORK_ONE4, MICRO_COMPLEX_LOAD_ONE); \ + rhs_ptr_real += (accRows * PEEL_COMPLEX); \ + if(!RhsIsReal) rhs_ptr_imag += (accRows * PEEL_COMPLEX); + +#define MICRO_COMPLEX_ONE4 \ + MICRO_COMPLEX_UNROLL_TYPE_ONE(4, MICRO_COMPLEX_TYPE_PEEL4, MICRO_COMPLEX_WORK_ONE4, MICRO_COMPLEX_LOAD_ONE); \ + rhs_ptr_real += accRows; \ + if(!RhsIsReal) rhs_ptr_imag += accRows; + +#define MICRO_COMPLEX_ONE_PEEL1 \ + MICRO_COMPLEX_UNROLL_TYPE_PEEL(1, MICRO_COMPLEX_TYPE_PEEL1, MICRO_COMPLEX_WORK_ONE1, MICRO_COMPLEX_LOAD_ONE); \ + rhs_ptr_real += (remaining_cols * PEEL_COMPLEX); \ + if(!RhsIsReal) rhs_ptr_imag += (remaining_cols * PEEL_COMPLEX); + +#define MICRO_COMPLEX_ONE1 \ + MICRO_COMPLEX_UNROLL_TYPE_ONE(1, MICRO_COMPLEX_TYPE_PEEL1, MICRO_COMPLEX_WORK_ONE1, MICRO_COMPLEX_LOAD_ONE); \ + rhs_ptr_real += remaining_cols; \ + if(!RhsIsReal) rhs_ptr_imag += remaining_cols; + +#define MICRO_COMPLEX_DST_PTR_ONE(iter) \ + if (unroll_factor > iter) { \ + bsetzero(accReal##iter); \ + bsetzero(accImag##iter); \ + } else { \ + EIGEN_UNUSED_VARIABLE(accReal##iter); \ + EIGEN_UNUSED_VARIABLE(accImag##iter); \ + } + +#define MICRO_COMPLEX_DST_PTR MICRO_COMPLEX_UNROLL(MICRO_COMPLEX_DST_PTR_ONE) + +#define MICRO_COMPLEX_SRC_PTR_ONE(iter) \ + if (unroll_factor > iter) { \ + lhs_ptr_real##iter = lhs_base + ( ((advanceRows*row)/accCols) + iter*advanceRows )*strideA*accCols + accCols*offsetA; \ + if(!LhsIsReal) { \ + lhs_ptr_imag##iter = lhs_ptr_real##iter + accCols*strideA; \ + } else { \ + EIGEN_UNUSED_VARIABLE(lhs_ptr_imag##iter); \ + } \ + } else { \ + EIGEN_UNUSED_VARIABLE(lhs_ptr_real##iter); \ + EIGEN_UNUSED_VARIABLE(lhs_ptr_imag##iter); \ + } + +#define MICRO_COMPLEX_SRC_PTR MICRO_COMPLEX_UNROLL(MICRO_COMPLEX_SRC_PTR_ONE) + +#define MICRO_COMPLEX_PREFETCH_ONE(iter) \ + if (unroll_factor > iter) { \ + EIGEN_POWER_PREFETCH(lhs_ptr_real##iter); \ + if(!LhsIsReal) { \ + EIGEN_POWER_PREFETCH(lhs_ptr_imag##iter); \ + } \ + } + +#define MICRO_COMPLEX_PREFETCH MICRO_COMPLEX_UNROLL(MICRO_COMPLEX_PREFETCH_ONE) + +#define MICRO_COMPLEX_STORE_ONE(iter) \ + if (unroll_factor > iter) { \ + bload(tRes, res, row + iter*accCols, col); \ + bscalec(accReal##iter, accImag##iter, pAlphaReal, pAlphaImag, taccReal, taccImag); \ + bcouple(taccReal, taccImag, tRes, acc0, acc1); \ + res.template storePacketBlock(row + iter*accCols + 0, col, acc0); \ + res.template storePacketBlock(row + iter*accCols + accColsC, col, acc1); \ + } + +#define MICRO_COMPLEX_STORE MICRO_COMPLEX_UNROLL(MICRO_COMPLEX_STORE_ONE) + +#define MICRO_COMPLEX_COL_STORE_ONE(iter) \ + if (unroll_factor > iter) { \ + bload(tRes, res, row + iter*accCols, col); \ + bscalec(accReal##iter, accImag##iter, pAlphaReal, pAlphaImag, taccReal, taccImag); \ + bcouple(taccReal, taccImag, tRes, acc0, acc1); \ + res.template storePacketBlock(row + iter*accCols + 0, col, acc0); \ + res.template storePacketBlock(row + iter*accCols + accColsC, col, acc1); \ + } + +#define MICRO_COMPLEX_COL_STORE MICRO_COMPLEX_UNROLL(MICRO_COMPLEX_COL_STORE_ONE) + +template +EIGEN_STRONG_INLINE void gemm_complex_unrolled_iteration( + const DataMapper& res, + const Scalar* lhs_base, + const Scalar* rhs_base, + Index depth, + Index strideA, + Index offsetA, + Index strideB, + Index& row, + Index col, + const Packet& pAlphaReal, + const Packet& pAlphaImag) +{ + const Scalar* rhs_ptr_real = rhs_base; + const Scalar* rhs_ptr_imag; + if(!RhsIsReal) { + rhs_ptr_imag = rhs_base + accRows*strideB; + } else { + EIGEN_UNUSED_VARIABLE(rhs_ptr_imag); + } + const Scalar* lhs_ptr_real0 = NULL, * lhs_ptr_imag0 = NULL, * lhs_ptr_real1 = NULL, * lhs_ptr_imag1 = NULL; + const Scalar* lhs_ptr_real2 = NULL, * lhs_ptr_imag2 = NULL, * lhs_ptr_real3 = NULL, * lhs_ptr_imag3 = NULL; + const Scalar* lhs_ptr_real4 = NULL, * lhs_ptr_imag4 = NULL; + PacketBlock accReal0, accImag0, accReal1, accImag1; + PacketBlock accReal2, accImag2, accReal3, accImag3; + PacketBlock accReal4, accImag4; + PacketBlock taccReal, taccImag; + PacketBlock acc0, acc1; + PacketBlock tRes; + + MICRO_COMPLEX_SRC_PTR + MICRO_COMPLEX_DST_PTR + + Index k = 0; + for(; k + PEEL_COMPLEX <= depth; k+= PEEL_COMPLEX) + { + EIGEN_POWER_PREFETCH(rhs_ptr_real); + if(!RhsIsReal) { + EIGEN_POWER_PREFETCH(rhs_ptr_imag); + } + MICRO_COMPLEX_PREFETCH + MICRO_COMPLEX_ONE_PEEL4 + } + for(; k < depth; k++) + { + MICRO_COMPLEX_ONE4 + } + MICRO_COMPLEX_STORE + + row += unroll_factor*accCols; +} + +template +EIGEN_STRONG_INLINE void gemm_complex_unrolled_col_iteration( + const DataMapper& res, + const Scalar* lhs_base, + const Scalar* rhs_base, + Index depth, + Index strideA, + Index offsetA, + Index strideB, + Index& row, + Index col, + Index remaining_cols, + const Packet& pAlphaReal, + const Packet& pAlphaImag) +{ + const Scalar* rhs_ptr_real = rhs_base; + const Scalar* rhs_ptr_imag; + if(!RhsIsReal) { + rhs_ptr_imag = rhs_base + remaining_cols*strideB; + } else { + EIGEN_UNUSED_VARIABLE(rhs_ptr_imag); + } + const Scalar* lhs_ptr_real0 = NULL, * lhs_ptr_imag0 = NULL, * lhs_ptr_real1 = NULL, * lhs_ptr_imag1 = NULL; + const Scalar* lhs_ptr_real2 = NULL, * lhs_ptr_imag2 = NULL, * lhs_ptr_real3 = NULL, * lhs_ptr_imag3 = NULL; + const Scalar* lhs_ptr_real4 = NULL, * lhs_ptr_imag4 = NULL; + PacketBlock accReal0, accImag0, accReal1, accImag1; + PacketBlock accReal2, accImag2, accReal3, accImag3; + PacketBlock accReal4, accImag4; + PacketBlock taccReal, taccImag; + PacketBlock acc0, acc1; + PacketBlock tRes; + + MICRO_COMPLEX_SRC_PTR + MICRO_COMPLEX_DST_PTR + + Index k = 0; + for(; k + PEEL_COMPLEX <= depth; k+= PEEL_COMPLEX) + { + EIGEN_POWER_PREFETCH(rhs_ptr_real); + if(!RhsIsReal) { + EIGEN_POWER_PREFETCH(rhs_ptr_imag); + } + MICRO_COMPLEX_PREFETCH + MICRO_COMPLEX_ONE_PEEL1 + } + for(; k < depth; k++) + { + MICRO_COMPLEX_ONE1 + } + MICRO_COMPLEX_COL_STORE + + row += unroll_factor*accCols; +} + +template +EIGEN_STRONG_INLINE void gemm_complex_unrolled_col( + const DataMapper& res, + const Scalar* lhs_base, + const Scalar* rhs_base, + Index depth, + Index strideA, + Index offsetA, + Index strideB, + Index& row, + Index rows, + Index col, + Index remaining_cols, + const Packet& pAlphaReal, + const Packet& pAlphaImag) +{ +#define MAX_COMPLEX_UNROLL 3 + while(row + MAX_COMPLEX_UNROLL*accCols <= rows) { + gemm_complex_unrolled_col_iteration(res, lhs_base, rhs_base, depth, strideA, offsetA, strideB, row, col, remaining_cols, pAlphaReal, pAlphaImag); + } + switch( (rows-row)/accCols ) { +#if MAX_COMPLEX_UNROLL > 4 + case 4: + gemm_complex_unrolled_col_iteration<4, Scalar, Packet, Packetc, DataMapper, Index, accCols, ConjugateLhs, ConjugateRhs, LhsIsReal, RhsIsReal>(res, lhs_base, rhs_base, depth, strideA, offsetA, strideB, row, col, remaining_cols, pAlphaReal, pAlphaImag); + break; +#endif +#if MAX_COMPLEX_UNROLL > 3 + case 3: + gemm_complex_unrolled_col_iteration<3, Scalar, Packet, Packetc, DataMapper, Index, accCols, ConjugateLhs, ConjugateRhs, LhsIsReal, RhsIsReal>(res, lhs_base, rhs_base, depth, strideA, offsetA, strideB, row, col, remaining_cols, pAlphaReal, pAlphaImag); + break; +#endif +#if MAX_COMPLEX_UNROLL > 2 + case 2: + gemm_complex_unrolled_col_iteration<2, Scalar, Packet, Packetc, DataMapper, Index, accCols, ConjugateLhs, ConjugateRhs, LhsIsReal, RhsIsReal>(res, lhs_base, rhs_base, depth, strideA, offsetA, strideB, row, col, remaining_cols, pAlphaReal, pAlphaImag); + break; +#endif +#if MAX_COMPLEX_UNROLL > 1 + case 1: + gemm_complex_unrolled_col_iteration<1, Scalar, Packet, Packetc, DataMapper, Index, accCols, ConjugateLhs, ConjugateRhs, LhsIsReal, RhsIsReal>(res, lhs_base, rhs_base, depth, strideA, offsetA, strideB, row, col, remaining_cols, pAlphaReal, pAlphaImag); + break; +#endif + default: + break; + } +#undef MAX_COMPLEX_UNROLL +} + +template +EIGEN_STRONG_INLINE void gemm_complex(const DataMapper& res, const LhsScalar* blockAc, const RhsScalar* blockBc, Index rows, Index depth, Index cols, Scalarc alpha, Index strideA, Index strideB, Index offsetA, Index offsetB) +{ + const Index remaining_rows = rows % accCols; + const Index remaining_cols = cols % accRows; + + if( strideA == -1 ) strideA = depth; + if( strideB == -1 ) strideB = depth; + + const Packet pAlphaReal = pset1(alpha.real()); + const Packet pAlphaImag = pset1(alpha.imag()); + const Packet pMask = bmask((const int)(remaining_rows)); + + const Scalar* blockA = (Scalar *) blockAc; + const Scalar* blockB = (Scalar *) blockBc; + + Index col = 0; + for(; col + accRows <= cols; col += accRows) + { + const Scalar* rhs_base = blockB + advanceCols*col*strideB + accRows*offsetB; + const Scalar* lhs_base = blockA; + Index row = 0; + +#define MAX_COMPLEX_UNROLL 3 + while(row + MAX_COMPLEX_UNROLL*accCols <= rows) { + gemm_complex_unrolled_iteration(res, lhs_base, rhs_base, depth, strideA, offsetA, strideB, row, col, pAlphaReal, pAlphaImag); + } + switch( (rows-row)/accCols ) { +#if MAX_COMPLEX_UNROLL > 4 + case 4: + gemm_complex_unrolled_iteration<4, Scalar, Packet, Packetc, DataMapper, Index, accRows, accCols, ConjugateLhs, ConjugateRhs, LhsIsReal, RhsIsReal>(res, lhs_base, rhs_base, depth, strideA, offsetA, strideB, row, col, pAlphaReal, pAlphaImag); + break; +#endif +#if MAX_COMPLEX_UNROLL > 3 + case 3: + gemm_complex_unrolled_iteration<3, Scalar, Packet, Packetc, DataMapper, Index, accRows, accCols, ConjugateLhs, ConjugateRhs, LhsIsReal, RhsIsReal>(res, lhs_base, rhs_base, depth, strideA, offsetA, strideB, row, col, pAlphaReal, pAlphaImag); + break; +#endif +#if MAX_COMPLEX_UNROLL > 2 + case 2: + gemm_complex_unrolled_iteration<2, Scalar, Packet, Packetc, DataMapper, Index, accRows, accCols, ConjugateLhs, ConjugateRhs, LhsIsReal, RhsIsReal>(res, lhs_base, rhs_base, depth, strideA, offsetA, strideB, row, col, pAlphaReal, pAlphaImag); + break; +#endif +#if MAX_COMPLEX_UNROLL > 1 + case 1: + gemm_complex_unrolled_iteration<1, Scalar, Packet, Packetc, DataMapper, Index, accRows, accCols, ConjugateLhs, ConjugateRhs, LhsIsReal, RhsIsReal>(res, lhs_base, rhs_base, depth, strideA, offsetA, strideB, row, col, pAlphaReal, pAlphaImag); + break; +#endif + default: + break; + } +#undef MAX_COMPLEX_UNROLL + + if(remaining_rows > 0) + { + gemm_complex_extra_row(res, lhs_base, rhs_base, depth, strideA, offsetA, strideB, row, col, rows, cols, remaining_rows, pAlphaReal, pAlphaImag, pMask); + } + } + + if(remaining_cols > 0) + { + const Scalar* rhs_base = blockB + advanceCols*col*strideB + remaining_cols*offsetB; + const Scalar* lhs_base = blockA; + + for(; col < cols; col++) + { + Index row = 0; + + gemm_complex_unrolled_col(res, lhs_base, rhs_base, depth, strideA, offsetA, strideB, row, rows, col, remaining_cols, pAlphaReal, pAlphaImag); + + if (remaining_rows > 0) + { + gemm_complex_extra_col(res, lhs_base, rhs_base, depth, strideA, offsetA, strideB, row, col, remaining_rows, remaining_cols, pAlphaReal, pAlphaImag); + } + rhs_base++; + } + } +} + +#undef accColsC +#undef advanceCols +#undef advanceRows + +/************************************ + * ppc64le template specializations * + * **********************************/ +template +struct gemm_pack_lhs +{ + void operator()(double* blockA, const DataMapper& lhs, Index depth, Index rows, Index stride=0, Index offset=0); +}; + +template +void gemm_pack_lhs + ::operator()(double* blockA, const DataMapper& lhs, Index depth, Index rows, Index stride, Index offset) +{ + dhs_pack pack; + pack(blockA, lhs, depth, rows, stride, offset); +} + +template +struct gemm_pack_lhs +{ + void operator()(double* blockA, const DataMapper& lhs, Index depth, Index rows, Index stride=0, Index offset=0); +}; + +template +void gemm_pack_lhs + ::operator()(double* blockA, const DataMapper& lhs, Index depth, Index rows, Index stride, Index offset) +{ + dhs_pack pack; + pack(blockA, lhs, depth, rows, stride, offset); +} + +#if EIGEN_ALTIVEC_USE_CUSTOM_PACK +template +struct gemm_pack_rhs +{ + void operator()(double* blockB, const DataMapper& rhs, Index depth, Index cols, Index stride=0, Index offset=0); +}; + +template +void gemm_pack_rhs + ::operator()(double* blockB, const DataMapper& rhs, Index depth, Index cols, Index stride, Index offset) +{ + dhs_pack pack; + pack(blockB, rhs, depth, cols, stride, offset); +} + +template +struct gemm_pack_rhs +{ + void operator()(double* blockB, const DataMapper& rhs, Index depth, Index cols, Index stride=0, Index offset=0); +}; + +template +void gemm_pack_rhs + ::operator()(double* blockB, const DataMapper& rhs, Index depth, Index cols, Index stride, Index offset) +{ + dhs_pack pack; + pack(blockB, rhs, depth, cols, stride, offset); +} +#endif + +template +struct gemm_pack_lhs +{ + void operator()(float* blockA, const DataMapper& lhs, Index depth, Index rows, Index stride=0, Index offset=0); +}; + +template +void gemm_pack_lhs + ::operator()(float* blockA, const DataMapper& lhs, Index depth, Index rows, Index stride, Index offset) +{ + dhs_pack pack; + pack(blockA, lhs, depth, rows, stride, offset); +} + +template +struct gemm_pack_lhs +{ + void operator()(float* blockA, const DataMapper& lhs, Index depth, Index rows, Index stride=0, Index offset=0); +}; + +template +void gemm_pack_lhs + ::operator()(float* blockA, const DataMapper& lhs, Index depth, Index rows, Index stride, Index offset) +{ + dhs_pack pack; + pack(blockA, lhs, depth, rows, stride, offset); +} + +template +struct gemm_pack_lhs, Index, DataMapper, Pack1, Pack2, Packet, RowMajor, Conjugate, PanelMode> +{ + void operator()(std::complex* blockA, const DataMapper& lhs, Index depth, Index rows, Index stride=0, Index offset=0); +}; + +template +void gemm_pack_lhs, Index, DataMapper, Pack1, Pack2, Packet, RowMajor, Conjugate, PanelMode> + ::operator()(std::complex* blockA, const DataMapper& lhs, Index depth, Index rows, Index stride, Index offset) +{ + dhs_cpack pack; + pack(blockA, lhs, depth, rows, stride, offset); +} + +template +struct gemm_pack_lhs, Index, DataMapper, Pack1, Pack2, Packet, ColMajor, Conjugate, PanelMode> +{ + void operator()(std::complex* blockA, const DataMapper& lhs, Index depth, Index rows, Index stride=0, Index offset=0); +}; + +template +void gemm_pack_lhs, Index, DataMapper, Pack1, Pack2, Packet, ColMajor, Conjugate, PanelMode> + ::operator()(std::complex* blockA, const DataMapper& lhs, Index depth, Index rows, Index stride, Index offset) +{ + dhs_cpack pack; + pack(blockA, lhs, depth, rows, stride, offset); +} + +#if EIGEN_ALTIVEC_USE_CUSTOM_PACK +template +struct gemm_pack_rhs +{ + void operator()(float* blockB, const DataMapper& rhs, Index depth, Index cols, Index stride=0, Index offset=0); +}; + +template +void gemm_pack_rhs + ::operator()(float* blockB, const DataMapper& rhs, Index depth, Index cols, Index stride, Index offset) +{ + dhs_pack pack; + pack(blockB, rhs, depth, cols, stride, offset); +} + +template +struct gemm_pack_rhs +{ + void operator()(float* blockB, const DataMapper& rhs, Index depth, Index cols, Index stride=0, Index offset=0); +}; + +template +void gemm_pack_rhs + ::operator()(float* blockB, const DataMapper& rhs, Index depth, Index cols, Index stride, Index offset) +{ + dhs_pack pack; + pack(blockB, rhs, depth, cols, stride, offset); +} +#endif + +template +struct gemm_pack_rhs, Index, DataMapper, nr, ColMajor, Conjugate, PanelMode> +{ + void operator()(std::complex* blockB, const DataMapper& rhs, Index depth, Index cols, Index stride=0, Index offset=0); +}; + +template +void gemm_pack_rhs, Index, DataMapper, nr, ColMajor, Conjugate, PanelMode> + ::operator()(std::complex* blockB, const DataMapper& rhs, Index depth, Index cols, Index stride, Index offset) +{ + dhs_cpack pack; + pack(blockB, rhs, depth, cols, stride, offset); +} + +template +struct gemm_pack_rhs, Index, DataMapper, nr, RowMajor, Conjugate, PanelMode> +{ + void operator()(std::complex* blockB, const DataMapper& rhs, Index depth, Index cols, Index stride=0, Index offset=0); +}; + +template +void gemm_pack_rhs, Index, DataMapper, nr, RowMajor, Conjugate, PanelMode> + ::operator()(std::complex* blockB, const DataMapper& rhs, Index depth, Index cols, Index stride, Index offset) +{ + dhs_cpack pack; + pack(blockB, rhs, depth, cols, stride, offset); +} + +template +struct gemm_pack_lhs, Index, DataMapper, Pack1, Pack2, Packet, RowMajor, Conjugate, PanelMode> +{ + void operator()(std::complex* blockA, const DataMapper& lhs, Index depth, Index rows, Index stride=0, Index offset=0); +}; + +template +void gemm_pack_lhs, Index, DataMapper, Pack1, Pack2, Packet, RowMajor, Conjugate, PanelMode> + ::operator()(std::complex* blockA, const DataMapper& lhs, Index depth, Index rows, Index stride, Index offset) +{ + dhs_cpack pack; + pack(blockA, lhs, depth, rows, stride, offset); +} + +template +struct gemm_pack_lhs, Index, DataMapper, Pack1, Pack2, Packet, ColMajor, Conjugate, PanelMode> +{ + void operator()(std::complex* blockA, const DataMapper& lhs, Index depth, Index rows, Index stride=0, Index offset=0); +}; + +template +void gemm_pack_lhs, Index, DataMapper, Pack1, Pack2, Packet, ColMajor, Conjugate, PanelMode> + ::operator()(std::complex* blockA, const DataMapper& lhs, Index depth, Index rows, Index stride, Index offset) +{ + dhs_cpack pack; + pack(blockA, lhs, depth, rows, stride, offset); +} + +template +struct gemm_pack_rhs, Index, DataMapper, nr, ColMajor, Conjugate, PanelMode> +{ + void operator()(std::complex* blockB, const DataMapper& rhs, Index depth, Index cols, Index stride=0, Index offset=0); +}; + +template +void gemm_pack_rhs, Index, DataMapper, nr, ColMajor, Conjugate, PanelMode> + ::operator()(std::complex* blockB, const DataMapper& rhs, Index depth, Index cols, Index stride, Index offset) +{ + dhs_cpack pack; + pack(blockB, rhs, depth, cols, stride, offset); +} + +template +struct gemm_pack_rhs, Index, DataMapper, nr, RowMajor, Conjugate, PanelMode> +{ + void operator()(std::complex* blockB, const DataMapper& rhs, Index depth, Index cols, Index stride=0, Index offset=0); +}; + +template +void gemm_pack_rhs, Index, DataMapper, nr, RowMajor, Conjugate, PanelMode> + ::operator()(std::complex* blockB, const DataMapper& rhs, Index depth, Index cols, Index stride, Index offset) +{ + dhs_cpack pack; + pack(blockB, rhs, depth, cols, stride, offset); +} + +// ********* gebp specializations ********* +template +struct gebp_kernel +{ + typedef typename quad_traits::vectortype Packet; + typedef typename quad_traits::rhstype RhsPacket; + + void operator()(const DataMapper& res, const float* blockA, const float* blockB, + Index rows, Index depth, Index cols, float alpha, + Index strideA=-1, Index strideB=-1, Index offsetA=0, Index offsetB=0); +}; + +template +void gebp_kernel + ::operator()(const DataMapper& res, const float* blockA, const float* blockB, + Index rows, Index depth, Index cols, float alpha, + Index strideA, Index strideB, Index offsetA, Index offsetB) + { + const Index accRows = quad_traits::rows; + const Index accCols = quad_traits::size; + void (*gemm_function)(const DataMapper&, const float*, const float*, Index, Index, Index, float, Index, Index, Index, Index); + + #ifdef EIGEN_ALTIVEC_MMA_ONLY + //generate with MMA only + gemm_function = &Eigen::internal::gemmMMA; + #elif defined(ALTIVEC_MMA_SUPPORT) && !defined(EIGEN_ALTIVEC_DISABLE_MMA) + if (__builtin_cpu_supports ("arch_3_1") && __builtin_cpu_supports ("mma")){ + gemm_function = &Eigen::internal::gemmMMA; + } + else{ + gemm_function = &Eigen::internal::gemm; + } + #else + gemm_function = &Eigen::internal::gemm; + #endif + gemm_function(res, blockA, blockB, rows, depth, cols, alpha, strideA, strideB, offsetA, offsetB); + } + +template +struct gebp_kernel, std::complex, Index, DataMapper, mr, nr, ConjugateLhs, ConjugateRhs> +{ + typedef Packet4f Packet; + typedef Packet2cf Packetc; + typedef Packet4f RhsPacket; + + void operator()(const DataMapper& res, const std::complex* blockA, const std::complex* blockB, + Index rows, Index depth, Index cols, std::complex alpha, + Index strideA=-1, Index strideB=-1, Index offsetA=0, Index offsetB=0); +}; + +template +void gebp_kernel, std::complex, Index, DataMapper, mr, nr, ConjugateLhs, ConjugateRhs> + ::operator()(const DataMapper& res, const std::complex* blockA, const std::complex* blockB, + Index rows, Index depth, Index cols, std::complex alpha, + Index strideA, Index strideB, Index offsetA, Index offsetB) + { + const Index accRows = quad_traits::rows; + const Index accCols = quad_traits::size; + void (*gemm_function)(const DataMapper&, const std::complex*, const std::complex*, + Index, Index, Index, std::complex, Index, Index, Index, Index); + + #ifdef EIGEN_ALTIVEC_MMA_ONLY + //generate with MMA only + gemm_function = &Eigen::internal::gemm_complexMMA, std::complex, std::complex, float, Index, Packet, Packetc, RhsPacket, DataMapper, accRows, accCols, ConjugateLhs, ConjugateRhs, false, false>; + #elif defined(ALTIVEC_MMA_SUPPORT) && !defined(EIGEN_ALTIVEC_DISABLE_MMA) + if (__builtin_cpu_supports ("arch_3_1") && __builtin_cpu_supports ("mma")){ + gemm_function = &Eigen::internal::gemm_complexMMA, std::complex, std::complex, float, Index, Packet, Packetc, RhsPacket, DataMapper, accRows, accCols, ConjugateLhs, ConjugateRhs, false, false>; + } + else{ + gemm_function = &Eigen::internal::gemm_complex, std::complex, std::complex, float, Index, Packet, Packetc, RhsPacket, DataMapper, accRows, accCols, ConjugateLhs, ConjugateRhs, false, false>; + } + #else + gemm_function = &Eigen::internal::gemm_complex, std::complex, std::complex, float, Index, Packet, Packetc, RhsPacket, DataMapper, accRows, accCols, ConjugateLhs, ConjugateRhs, false, false>; + #endif + gemm_function(res, blockA, blockB, rows, depth, cols, alpha, strideA, strideB, offsetA, offsetB); + } + +template +struct gebp_kernel, Index, DataMapper, mr, nr, ConjugateLhs, ConjugateRhs> +{ + typedef Packet4f Packet; + typedef Packet2cf Packetc; + typedef Packet4f RhsPacket; + + void operator()(const DataMapper& res, const float* blockA, const std::complex* blockB, + Index rows, Index depth, Index cols, std::complex alpha, + Index strideA=-1, Index strideB=-1, Index offsetA=0, Index offsetB=0); +}; + +template +void gebp_kernel, Index, DataMapper, mr, nr, ConjugateLhs, ConjugateRhs> + ::operator()(const DataMapper& res, const float* blockA, const std::complex* blockB, + Index rows, Index depth, Index cols, std::complex alpha, + Index strideA, Index strideB, Index offsetA, Index offsetB) + { + const Index accRows = quad_traits::rows; + const Index accCols = quad_traits::size; + void (*gemm_function)(const DataMapper&, const float*, const std::complex*, + Index, Index, Index, std::complex, Index, Index, Index, Index); + #ifdef EIGEN_ALTIVEC_MMA_ONLY + //generate with MMA only + gemm_function = &Eigen::internal::gemm_complexMMA, std::complex, float, Index, Packet, Packetc, RhsPacket, DataMapper, accRows, accCols, ConjugateLhs, ConjugateRhs, true, false>; + #elif defined(ALTIVEC_MMA_SUPPORT) && !defined(EIGEN_ALTIVEC_DISABLE_MMA) + if (__builtin_cpu_supports ("arch_3_1") && __builtin_cpu_supports ("mma")){ + gemm_function = &Eigen::internal::gemm_complexMMA, std::complex, float, Index, Packet, Packetc, RhsPacket, DataMapper, accRows, accCols, ConjugateLhs, ConjugateRhs, true, false>; + } + else{ + gemm_function = &Eigen::internal::gemm_complex, std::complex, float, Index, Packet, Packetc, RhsPacket, DataMapper, accRows, accCols, ConjugateLhs, ConjugateRhs, true, false>; + } + #else + gemm_function = &Eigen::internal::gemm_complex, std::complex, float, Index, Packet, Packetc, RhsPacket, DataMapper, accRows, accCols, ConjugateLhs, ConjugateRhs, true, false>; + #endif + gemm_function(res, blockA, blockB, rows, depth, cols, alpha, strideA, strideB, offsetA, offsetB); + } + +template +struct gebp_kernel, float, Index, DataMapper, mr, nr, ConjugateLhs, ConjugateRhs> +{ + typedef Packet4f Packet; + typedef Packet2cf Packetc; + typedef Packet4f RhsPacket; + + void operator()(const DataMapper& res, const std::complex* blockA, const float* blockB, + Index rows, Index depth, Index cols, std::complex alpha, + Index strideA=-1, Index strideB=-1, Index offsetA=0, Index offsetB=0); +}; + +template +void gebp_kernel, float, Index, DataMapper, mr, nr, ConjugateLhs, ConjugateRhs> + ::operator()(const DataMapper& res, const std::complex* blockA, const float* blockB, + Index rows, Index depth, Index cols, std::complex alpha, + Index strideA, Index strideB, Index offsetA, Index offsetB) + { + const Index accRows = quad_traits::rows; + const Index accCols = quad_traits::size; + void (*gemm_function)(const DataMapper&, const std::complex*, const float*, + Index, Index, Index, std::complex, Index, Index, Index, Index); + #ifdef EIGEN_ALTIVEC_MMA_ONLY + //generate with MMA only + gemm_function = &Eigen::internal::gemm_complexMMA, float, std::complex, float, Index, Packet, Packetc, RhsPacket, DataMapper, accRows, accCols, ConjugateLhs, ConjugateRhs, false, true>; + #elif defined(ALTIVEC_MMA_SUPPORT) && !defined(EIGEN_ALTIVEC_DISABLE_MMA) + if (__builtin_cpu_supports ("arch_3_1") && __builtin_cpu_supports ("mma")){ + gemm_function = &Eigen::internal::gemm_complexMMA, float, std::complex, float, Index, Packet, Packetc, RhsPacket, DataMapper, accRows, accCols, ConjugateLhs, ConjugateRhs, false, true>; + } + else{ + gemm_function = &Eigen::internal::gemm_complex, float, std::complex, float, Index, Packet, Packetc, RhsPacket, DataMapper, accRows, accCols, ConjugateLhs, ConjugateRhs, false, true>; + } + #else + gemm_function = &Eigen::internal::gemm_complex, float, std::complex, float, Index, Packet, Packetc, RhsPacket, DataMapper, accRows, accCols, ConjugateLhs, ConjugateRhs, false, true>; + #endif + gemm_function(res, blockA, blockB, rows, depth, cols, alpha, strideA, strideB, offsetA, offsetB); + } + +template +struct gebp_kernel +{ + typedef typename quad_traits::vectortype Packet; + typedef typename quad_traits::rhstype RhsPacket; + + void operator()(const DataMapper& res, const double* blockA, const double* blockB, + Index rows, Index depth, Index cols, double alpha, + Index strideA=-1, Index strideB=-1, Index offsetA=0, Index offsetB=0); +}; + +template +void gebp_kernel + ::operator()(const DataMapper& res, const double* blockA, const double* blockB, + Index rows, Index depth, Index cols, double alpha, + Index strideA, Index strideB, Index offsetA, Index offsetB) + { + const Index accRows = quad_traits::rows; + const Index accCols = quad_traits::size; + void (*gemm_function)(const DataMapper&, const double*, const double*, Index, Index, Index, double, Index, Index, Index, Index); + + #ifdef EIGEN_ALTIVEC_MMA_ONLY + //generate with MMA only + gemm_function = &Eigen::internal::gemmMMA; + #elif defined(ALTIVEC_MMA_SUPPORT) && !defined(EIGEN_ALTIVEC_DISABLE_MMA) + if (__builtin_cpu_supports ("arch_3_1") && __builtin_cpu_supports ("mma")){ + gemm_function = &Eigen::internal::gemmMMA; + } + else{ + gemm_function = &Eigen::internal::gemm; + } + #else + gemm_function = &Eigen::internal::gemm; + #endif + gemm_function(res, blockA, blockB, rows, depth, cols, alpha, strideA, strideB, offsetA, offsetB); + } + +template +struct gebp_kernel, std::complex, Index, DataMapper, mr, nr, ConjugateLhs, ConjugateRhs> +{ + typedef quad_traits::vectortype Packet; + typedef Packet1cd Packetc; + typedef quad_traits::rhstype RhsPacket; + + void operator()(const DataMapper& res, const std::complex* blockA, const std::complex* blockB, + Index rows, Index depth, Index cols, std::complex alpha, + Index strideA=-1, Index strideB=-1, Index offsetA=0, Index offsetB=0); +}; + +template +void gebp_kernel, std::complex, Index, DataMapper, mr, nr, ConjugateLhs, ConjugateRhs> + ::operator()(const DataMapper& res, const std::complex* blockA, const std::complex* blockB, + Index rows, Index depth, Index cols, std::complex alpha, + Index strideA, Index strideB, Index offsetA, Index offsetB) + { + const Index accRows = quad_traits::rows; + const Index accCols = quad_traits::size; + void (*gemm_function)(const DataMapper&, const std::complex*, const std::complex*, + Index, Index, Index, std::complex, Index, Index, Index, Index); + #ifdef EIGEN_ALTIVEC_MMA_ONLY + //generate with MMA only + gemm_function = &Eigen::internal::gemm_complexMMA, std::complex, std::complex, double, Index, Packet, Packetc, RhsPacket, DataMapper, accRows, accCols, ConjugateLhs, ConjugateRhs, false, false>; + #elif defined(ALTIVEC_MMA_SUPPORT) && !defined(EIGEN_ALTIVEC_DISABLE_MMA) + if (__builtin_cpu_supports ("arch_3_1") && __builtin_cpu_supports ("mma")){ + gemm_function = &Eigen::internal::gemm_complexMMA, std::complex, std::complex, double, Index, Packet, Packetc, RhsPacket, DataMapper, accRows, accCols, ConjugateLhs, ConjugateRhs, false, false>; + } + else{ + gemm_function = &Eigen::internal::gemm_complex, std::complex, std::complex, double, Index, Packet, Packetc, RhsPacket, DataMapper, accRows, accCols, ConjugateLhs, ConjugateRhs, false, false>; + } + #else + gemm_function = &Eigen::internal::gemm_complex, std::complex, std::complex, double, Index, Packet, Packetc, RhsPacket, DataMapper, accRows, accCols, ConjugateLhs, ConjugateRhs, false, false>; + #endif + gemm_function(res, blockA, blockB, rows, depth, cols, alpha, strideA, strideB, offsetA, offsetB); + } + +template +struct gebp_kernel, double, Index, DataMapper, mr, nr, ConjugateLhs, ConjugateRhs> +{ + typedef quad_traits::vectortype Packet; + typedef Packet1cd Packetc; + typedef quad_traits::rhstype RhsPacket; + + void operator()(const DataMapper& res, const std::complex* blockA, const double* blockB, + Index rows, Index depth, Index cols, std::complex alpha, + Index strideA=-1, Index strideB=-1, Index offsetA=0, Index offsetB=0); +}; + +template +void gebp_kernel, double, Index, DataMapper, mr, nr, ConjugateLhs, ConjugateRhs> + ::operator()(const DataMapper& res, const std::complex* blockA, const double* blockB, + Index rows, Index depth, Index cols, std::complex alpha, + Index strideA, Index strideB, Index offsetA, Index offsetB) + { + const Index accRows = quad_traits::rows; + const Index accCols = quad_traits::size; + void (*gemm_function)(const DataMapper&, const std::complex*, const double*, + Index, Index, Index, std::complex, Index, Index, Index, Index); + #ifdef EIGEN_ALTIVEC_MMA_ONLY + //generate with MMA only + gemm_function = &Eigen::internal::gemm_complexMMA, double, std::complex, double, Index, Packet, Packetc, RhsPacket, DataMapper, accRows, accCols, ConjugateLhs, ConjugateRhs, false, true>; + #elif defined(ALTIVEC_MMA_SUPPORT) && !defined(EIGEN_ALTIVEC_DISABLE_MMA) + if (__builtin_cpu_supports ("arch_3_1") && __builtin_cpu_supports ("mma")){ + gemm_function = &Eigen::internal::gemm_complexMMA, double, std::complex, double, Index, Packet, Packetc, RhsPacket, DataMapper, accRows, accCols, ConjugateLhs, ConjugateRhs, false, true>; + } + else{ + gemm_function = &Eigen::internal::gemm_complex, double, std::complex, double, Index, Packet, Packetc, RhsPacket, DataMapper, accRows, accCols, ConjugateLhs, ConjugateRhs, false, true>; + } + #else + gemm_function = &Eigen::internal::gemm_complex, double, std::complex, double, Index, Packet, Packetc, RhsPacket, DataMapper, accRows, accCols, ConjugateLhs, ConjugateRhs, false, true>; + #endif + gemm_function(res, blockA, blockB, rows, depth, cols, alpha, strideA, strideB, offsetA, offsetB); + } + +template +struct gebp_kernel, Index, DataMapper, mr, nr, ConjugateLhs, ConjugateRhs> +{ + typedef quad_traits::vectortype Packet; + typedef Packet1cd Packetc; + typedef quad_traits::rhstype RhsPacket; + + void operator()(const DataMapper& res, const double* blockA, const std::complex* blockB, + Index rows, Index depth, Index cols, std::complex alpha, + Index strideA=-1, Index strideB=-1, Index offsetA=0, Index offsetB=0); +}; + +template +void gebp_kernel, Index, DataMapper, mr, nr, ConjugateLhs, ConjugateRhs> + ::operator()(const DataMapper& res, const double* blockA, const std::complex* blockB, + Index rows, Index depth, Index cols, std::complex alpha, + Index strideA, Index strideB, Index offsetA, Index offsetB) + { + const Index accRows = quad_traits::rows; + const Index accCols = quad_traits::size; + void (*gemm_function)(const DataMapper&, const double*, const std::complex*, + Index, Index, Index, std::complex, Index, Index, Index, Index); + #ifdef EIGEN_ALTIVEC_MMA_ONLY + //generate with MMA only + gemm_function = &Eigen::internal::gemm_complexMMA, std::complex, double, Index, Packet, Packetc, RhsPacket, DataMapper, accRows, accCols, ConjugateLhs, ConjugateRhs, true, false>; + #elif defined(ALTIVEC_MMA_SUPPORT) && !defined(EIGEN_ALTIVEC_DISABLE_MMA) + if (__builtin_cpu_supports ("arch_3_1") && __builtin_cpu_supports ("mma")){ + gemm_function = &Eigen::internal::gemm_complexMMA, std::complex, double, Index, Packet, Packetc, RhsPacket, DataMapper, accRows, accCols, ConjugateLhs, ConjugateRhs, true, false>; + } + else{ + gemm_function = &Eigen::internal::gemm_complex, std::complex, double, Index, Packet, Packetc, RhsPacket, DataMapper, accRows, accCols, ConjugateLhs, ConjugateRhs, true, false>; + } + #else + gemm_function = &Eigen::internal::gemm_complex, std::complex, double, Index, Packet, Packetc, RhsPacket, DataMapper, accRows, accCols, ConjugateLhs, ConjugateRhs, true, false>; + #endif + gemm_function(res, blockA, blockB, rows, depth, cols, alpha, strideA, strideB, offsetA, offsetB); + } +} // end namespace internal + +} // end namespace Eigen + +#endif // EIGEN_MATRIX_PRODUCT_ALTIVEC_H diff --git a/Eigen/src/Core/arch/AltiVec/MatrixProductCommon.h b/Eigen/src/Core/arch/AltiVec/MatrixProductCommon.h new file mode 100644 index 0000000..33d5434 --- /dev/null +++ b/Eigen/src/Core/arch/AltiVec/MatrixProductCommon.h @@ -0,0 +1,221 @@ +//#define EIGEN_POWER_USE_PREFETCH // Use prefetching in gemm routines +#ifdef EIGEN_POWER_USE_PREFETCH +#define EIGEN_POWER_PREFETCH(p) prefetch(p) +#else +#define EIGEN_POWER_PREFETCH(p) +#endif + +namespace Eigen { + +namespace internal { + +template +EIGEN_STRONG_INLINE void gemm_extra_col( + const DataMapper& res, + const Scalar* lhs_base, + const Scalar* rhs_base, + Index depth, + Index strideA, + Index offsetA, + Index row, + Index col, + Index remaining_rows, + Index remaining_cols, + const Packet& pAlpha); + +template +EIGEN_STRONG_INLINE void gemm_extra_row( + const DataMapper& res, + const Scalar* lhs_base, + const Scalar* rhs_base, + Index depth, + Index strideA, + Index offsetA, + Index row, + Index col, + Index rows, + Index cols, + Index remaining_rows, + const Packet& pAlpha, + const Packet& pMask); + +template +EIGEN_STRONG_INLINE void gemm_unrolled_col( + const DataMapper& res, + const Scalar* lhs_base, + const Scalar* rhs_base, + Index depth, + Index strideA, + Index offsetA, + Index& row, + Index rows, + Index col, + Index remaining_cols, + const Packet& pAlpha); + +template +EIGEN_ALWAYS_INLINE Packet bmask(const int remaining_rows); + +template +EIGEN_STRONG_INLINE void gemm_complex_extra_col( + const DataMapper& res, + const Scalar* lhs_base, + const Scalar* rhs_base, + Index depth, + Index strideA, + Index offsetA, + Index strideB, + Index row, + Index col, + Index remaining_rows, + Index remaining_cols, + const Packet& pAlphaReal, + const Packet& pAlphaImag); + +template +EIGEN_STRONG_INLINE void gemm_complex_extra_row( + const DataMapper& res, + const Scalar* lhs_base, + const Scalar* rhs_base, + Index depth, + Index strideA, + Index offsetA, + Index strideB, + Index row, + Index col, + Index rows, + Index cols, + Index remaining_rows, + const Packet& pAlphaReal, + const Packet& pAlphaImag, + const Packet& pMask); + +template +EIGEN_STRONG_INLINE void gemm_complex_unrolled_col( + const DataMapper& res, + const Scalar* lhs_base, + const Scalar* rhs_base, + Index depth, + Index strideA, + Index offsetA, + Index strideB, + Index& row, + Index rows, + Index col, + Index remaining_cols, + const Packet& pAlphaReal, + const Packet& pAlphaImag); + +template +EIGEN_ALWAYS_INLINE Packet ploadLhs(const Scalar* lhs); + +template +EIGEN_ALWAYS_INLINE void bload(PacketBlock& acc, const DataMapper& res, Index row, Index col); + +template +EIGEN_ALWAYS_INLINE void bload(PacketBlock& acc, const DataMapper& res, Index row, Index col); + +template +EIGEN_ALWAYS_INLINE void bscale(PacketBlock& acc, PacketBlock& accZ, const Packet& pAlpha); + +template +EIGEN_ALWAYS_INLINE void bscalec(PacketBlock& aReal, PacketBlock& aImag, const Packet& bReal, const Packet& bImag, PacketBlock& cReal, PacketBlock& cImag); + +const static Packet16uc p16uc_SETCOMPLEX32_FIRST = { 0, 1, 2, 3, + 16, 17, 18, 19, + 4, 5, 6, 7, + 20, 21, 22, 23}; + +const static Packet16uc p16uc_SETCOMPLEX32_SECOND = { 8, 9, 10, 11, + 24, 25, 26, 27, + 12, 13, 14, 15, + 28, 29, 30, 31}; +//[a,b],[ai,bi] = [a,ai] - This is equivalent to p16uc_GETREAL64 +const static Packet16uc p16uc_SETCOMPLEX64_FIRST = { 0, 1, 2, 3, 4, 5, 6, 7, + 16, 17, 18, 19, 20, 21, 22, 23}; + +//[a,b],[ai,bi] = [b,bi] - This is equivalent to p16uc_GETIMAG64 +const static Packet16uc p16uc_SETCOMPLEX64_SECOND = { 8, 9, 10, 11, 12, 13, 14, 15, + 24, 25, 26, 27, 28, 29, 30, 31}; + + +// Grab two decouples real/imaginary PacketBlocks and return two coupled (real/imaginary pairs) PacketBlocks. +template +EIGEN_ALWAYS_INLINE void bcouple_common(PacketBlock& taccReal, PacketBlock& taccImag, PacketBlock& acc1, PacketBlock& acc2) +{ + acc1.packet[0].v = vec_perm(taccReal.packet[0], taccImag.packet[0], p16uc_SETCOMPLEX32_FIRST); + acc1.packet[1].v = vec_perm(taccReal.packet[1], taccImag.packet[1], p16uc_SETCOMPLEX32_FIRST); + acc1.packet[2].v = vec_perm(taccReal.packet[2], taccImag.packet[2], p16uc_SETCOMPLEX32_FIRST); + acc1.packet[3].v = vec_perm(taccReal.packet[3], taccImag.packet[3], p16uc_SETCOMPLEX32_FIRST); + + acc2.packet[0].v = vec_perm(taccReal.packet[0], taccImag.packet[0], p16uc_SETCOMPLEX32_SECOND); + acc2.packet[1].v = vec_perm(taccReal.packet[1], taccImag.packet[1], p16uc_SETCOMPLEX32_SECOND); + acc2.packet[2].v = vec_perm(taccReal.packet[2], taccImag.packet[2], p16uc_SETCOMPLEX32_SECOND); + acc2.packet[3].v = vec_perm(taccReal.packet[3], taccImag.packet[3], p16uc_SETCOMPLEX32_SECOND); +} + +template +EIGEN_ALWAYS_INLINE void bcouple(PacketBlock& taccReal, PacketBlock& taccImag, PacketBlock& tRes, PacketBlock& acc1, PacketBlock& acc2) +{ + bcouple_common(taccReal, taccImag, acc1, acc2); + + acc1.packet[0] = padd(tRes.packet[0], acc1.packet[0]); + acc1.packet[1] = padd(tRes.packet[1], acc1.packet[1]); + acc1.packet[2] = padd(tRes.packet[2], acc1.packet[2]); + acc1.packet[3] = padd(tRes.packet[3], acc1.packet[3]); + + acc2.packet[0] = padd(tRes.packet[4], acc2.packet[0]); + acc2.packet[1] = padd(tRes.packet[5], acc2.packet[1]); + acc2.packet[2] = padd(tRes.packet[6], acc2.packet[2]); + acc2.packet[3] = padd(tRes.packet[7], acc2.packet[3]); +} + +template +EIGEN_ALWAYS_INLINE void bcouple_common(PacketBlock& taccReal, PacketBlock& taccImag, PacketBlock& acc1, PacketBlock& acc2) +{ + acc1.packet[0].v = vec_perm(taccReal.packet[0], taccImag.packet[0], p16uc_SETCOMPLEX32_FIRST); + + acc2.packet[0].v = vec_perm(taccReal.packet[0], taccImag.packet[0], p16uc_SETCOMPLEX32_SECOND); +} + +template +EIGEN_ALWAYS_INLINE void bcouple(PacketBlock& taccReal, PacketBlock& taccImag, PacketBlock& tRes, PacketBlock& acc1, PacketBlock& acc2) +{ + bcouple_common(taccReal, taccImag, acc1, acc2); + + acc1.packet[0] = padd(tRes.packet[0], acc1.packet[0]); + + acc2.packet[0] = padd(tRes.packet[1], acc2.packet[0]); +} + +template<> +EIGEN_ALWAYS_INLINE void bcouple_common(PacketBlock& taccReal, PacketBlock& taccImag, PacketBlock& acc1, PacketBlock& acc2) +{ + acc1.packet[0].v = vec_perm(taccReal.packet[0], taccImag.packet[0], p16uc_SETCOMPLEX64_FIRST); + acc1.packet[1].v = vec_perm(taccReal.packet[1], taccImag.packet[1], p16uc_SETCOMPLEX64_FIRST); + acc1.packet[2].v = vec_perm(taccReal.packet[2], taccImag.packet[2], p16uc_SETCOMPLEX64_FIRST); + acc1.packet[3].v = vec_perm(taccReal.packet[3], taccImag.packet[3], p16uc_SETCOMPLEX64_FIRST); + + acc2.packet[0].v = vec_perm(taccReal.packet[0], taccImag.packet[0], p16uc_SETCOMPLEX64_SECOND); + acc2.packet[1].v = vec_perm(taccReal.packet[1], taccImag.packet[1], p16uc_SETCOMPLEX64_SECOND); + acc2.packet[2].v = vec_perm(taccReal.packet[2], taccImag.packet[2], p16uc_SETCOMPLEX64_SECOND); + acc2.packet[3].v = vec_perm(taccReal.packet[3], taccImag.packet[3], p16uc_SETCOMPLEX64_SECOND); +} + +template<> +EIGEN_ALWAYS_INLINE void bcouple_common(PacketBlock& taccReal, PacketBlock& taccImag, PacketBlock& acc1, PacketBlock& acc2) +{ + acc1.packet[0].v = vec_perm(taccReal.packet[0], taccImag.packet[0], p16uc_SETCOMPLEX64_FIRST); + + acc2.packet[0].v = vec_perm(taccReal.packet[0], taccImag.packet[0], p16uc_SETCOMPLEX64_SECOND); +} + +// This is necessary because ploadRhs for double returns a pair of vectors when MMA is enabled. +template +EIGEN_ALWAYS_INLINE Packet ploadRhs(const Scalar* rhs) +{ + return ploadu(rhs); +} + +} // end namespace internal +} // end namespace Eigen diff --git a/Eigen/src/Core/arch/AltiVec/MatrixProductMMA.h b/Eigen/src/Core/arch/AltiVec/MatrixProductMMA.h new file mode 100644 index 0000000..6540c6f --- /dev/null +++ b/Eigen/src/Core/arch/AltiVec/MatrixProductMMA.h @@ -0,0 +1,629 @@ +// This file is part of Eigen, a lightweight C++ template library +// for linear algebra. +// +// Copyright (C) 2020 Everton Constantino (everton.constantino@ibm.com) +// Copyright (C) 2021 Chip Kerchner (chip.kerchner@ibm.com) +// +// This Source Code Form is subject to the terms of the Mozilla +// Public License v. 2.0. If a copy of the MPL was not distributed +// with this file, You can obtain one at http://mozilla.org/MPL/2.0/. + +#ifndef EIGEN_MATRIX_PRODUCT_MMA_ALTIVEC_H +#define EIGEN_MATRIX_PRODUCT_MMA_ALTIVEC_H + +#pragma GCC target("cpu=power10") + +#ifdef __has_builtin +#if !__has_builtin(__builtin_vsx_assemble_pair) +#define __builtin_vsx_assemble_pair __builtin_mma_assemble_pair +#endif +#endif + +namespace Eigen { + +namespace internal { + +template +EIGEN_ALWAYS_INLINE void bsetzeroMMA(__vector_quad* acc) +{ + __builtin_mma_xxsetaccz(acc); +} + +template +EIGEN_ALWAYS_INLINE void storeAccumulator(Index i, Index j, const DataMapper& data, const Packet& alpha, __vector_quad* acc) +{ + PacketBlock result; + __builtin_mma_disassemble_acc(&result.packet, acc); + + PacketBlock tRes; + bload(tRes, data, i, j); + + bscale(tRes, result, alpha); + + data.template storePacketBlock(i, j, tRes); +} + +template +EIGEN_ALWAYS_INLINE void storeComplexAccumulator(Index i, Index j, const DataMapper& data, const Packet& alphaReal, const Packet& alphaImag, __vector_quad* accReal, __vector_quad* accImag) +{ + PacketBlock resultReal, resultImag; + __builtin_mma_disassemble_acc(&resultReal.packet, accReal); + __builtin_mma_disassemble_acc(&resultImag.packet, accImag); + + PacketBlock tRes; + bload(tRes, data, i, j); + + PacketBlock taccReal, taccImag; + bscalec(resultReal, resultImag, alphaReal, alphaImag, taccReal, taccImag); + + PacketBlock acc1, acc2; + bcouple(taccReal, taccImag, tRes, acc1, acc2); + + data.template storePacketBlock(i + N*accColsC, j, acc1); + data.template storePacketBlock(i + (N+1)*accColsC, j, acc2); +} + +// Defaults to float32, since Eigen still supports C++03 we can't use default template arguments +template +EIGEN_ALWAYS_INLINE void pgerMMA(__vector_quad* acc, const RhsPacket& a, const LhsPacket& b) +{ + if(NegativeAccumulate) + { + __builtin_mma_xvf32gernp(acc, (__vector unsigned char)a, (__vector unsigned char)b); + } else { + __builtin_mma_xvf32gerpp(acc, (__vector unsigned char)a, (__vector unsigned char)b); + } +} + +template +EIGEN_ALWAYS_INLINE void pgerMMA(__vector_quad* acc, const PacketBlock& a, const Packet2d& b) +{ + __vector_pair* a0 = (__vector_pair *)(&a.packet[0]); + if(NegativeAccumulate) + { + __builtin_mma_xvf64gernp(acc, *a0, (__vector unsigned char)b); + } else { + __builtin_mma_xvf64gerpp(acc, *a0, (__vector unsigned char)b); + } +} + +template +EIGEN_ALWAYS_INLINE void pgerMMA(__vector_quad* acc, const __vector_pair& a, const Packet2d& b) +{ + if(NegativeAccumulate) + { + __builtin_mma_xvf64gernp(acc, (__vector_pair)a, (__vector unsigned char)b); + } else { + __builtin_mma_xvf64gerpp(acc, (__vector_pair)a, (__vector unsigned char)b); + } +} + +template +EIGEN_ALWAYS_INLINE void pgerMMA(__vector_quad*, const __vector_pair&, const Packet4f&) +{ + // Just for compilation +} + +template +EIGEN_ALWAYS_INLINE void pgercMMA(__vector_quad* accReal, __vector_quad* accImag, const Packet& lhsV, const Packet& lhsVi, const RhsPacket& rhsV, const RhsPacket& rhsVi) +{ + pgerMMA(accReal, rhsV, lhsV); + if(LhsIsReal) { + pgerMMA(accImag, rhsVi, lhsV); + } else { + if(!RhsIsReal) { + pgerMMA(accReal, rhsVi, lhsVi); + pgerMMA(accImag, rhsVi, lhsV); + } else { + EIGEN_UNUSED_VARIABLE(rhsVi); + } + pgerMMA(accImag, rhsV, lhsVi); + } +} + +// This is necessary because ploadRhs for double returns a pair of vectors when MMA is enabled. +template +EIGEN_ALWAYS_INLINE void ploadRhsMMA(const Scalar* rhs, Packet& rhsV) +{ + rhsV = ploadRhs((const Scalar*)(rhs)); +} + +template<> +EIGEN_ALWAYS_INLINE void ploadRhsMMA >(const double* rhs, PacketBlock& rhsV) +{ + rhsV.packet[0] = ploadRhs((const double *)((Packet2d *)rhs )); + rhsV.packet[1] = ploadRhs((const double *)(((Packet2d *)rhs) + 1)); +} + +template<> +EIGEN_ALWAYS_INLINE void ploadRhsMMA(const double* rhs, __vector_pair& rhsV) +{ +#if EIGEN_COMP_LLVM + __builtin_vsx_assemble_pair(&rhsV, + (__vector unsigned char)(ploadRhs((const double *)(((Packet2d *)rhs) + 1))), + (__vector unsigned char)(ploadRhs((const double *)((Packet2d *)rhs )))); +#else + __asm__ ("lxvp %x0,%1" : "=wa" (rhsV) : "Y" (*rhs)); +#endif +} + +template<> +EIGEN_ALWAYS_INLINE void ploadRhsMMA(const float*, __vector_pair&) +{ + // Just for compilation +} + +// PEEL_MMA loop factor. +#define PEEL_MMA 7 + +#define MICRO_MMA_UNROLL(func) \ + func(0) func(1) func(2) func(3) func(4) func(5) func(6) func(7) + +#define MICRO_MMA_LOAD_ONE(iter) \ + if (unroll_factor > iter) { \ + lhsV##iter = ploadLhs(lhs_ptr##iter); \ + lhs_ptr##iter += accCols; \ + } else { \ + EIGEN_UNUSED_VARIABLE(lhsV##iter); \ + } + +#define MICRO_MMA_WORK_ONE(iter, type, peel) \ + if (unroll_factor > iter) { \ + pgerMMA(&accZero##iter, rhsV##peel, lhsV##iter); \ + } + +#define MICRO_MMA_TYPE_PEEL(func, func2, type, peel) \ + if (PEEL_MMA > peel) { \ + Packet lhsV0, lhsV1, lhsV2, lhsV3, lhsV4, lhsV5, lhsV6, lhsV7; \ + ploadRhsMMA(rhs_ptr + (accRows * peel), rhsV##peel); \ + MICRO_MMA_UNROLL(func2); \ + func(0,type,peel) func(1,type,peel) func(2,type,peel) func(3,type,peel) \ + func(4,type,peel) func(5,type,peel) func(6,type,peel) func(7,type,peel) \ + } else { \ + EIGEN_UNUSED_VARIABLE(rhsV##peel); \ + } + +#define MICRO_MMA_UNROLL_TYPE_PEEL(func, func2, type) \ + type rhsV0, rhsV1, rhsV2, rhsV3, rhsV4, rhsV5, rhsV6, rhsV7, rhsV8, rhsV9; \ + MICRO_MMA_TYPE_PEEL(func,func2,type,0); MICRO_MMA_TYPE_PEEL(func,func2,type,1); \ + MICRO_MMA_TYPE_PEEL(func,func2,type,2); MICRO_MMA_TYPE_PEEL(func,func2,type,3); \ + MICRO_MMA_TYPE_PEEL(func,func2,type,4); MICRO_MMA_TYPE_PEEL(func,func2,type,5); \ + MICRO_MMA_TYPE_PEEL(func,func2,type,6); MICRO_MMA_TYPE_PEEL(func,func2,type,7); \ + MICRO_MMA_TYPE_PEEL(func,func2,type,8); MICRO_MMA_TYPE_PEEL(func,func2,type,9); + +#define MICRO_MMA_UNROLL_TYPE_ONE(func, func2, type) \ + type rhsV0; \ + MICRO_MMA_TYPE_PEEL(func,func2,type,0); + +#define MICRO_MMA_ONE_PEEL \ + if (sizeof(Scalar) == sizeof(float)) { \ + MICRO_MMA_UNROLL_TYPE_PEEL(MICRO_MMA_WORK_ONE, MICRO_MMA_LOAD_ONE, RhsPacket); \ + } else { \ + MICRO_MMA_UNROLL_TYPE_PEEL(MICRO_MMA_WORK_ONE, MICRO_MMA_LOAD_ONE, __vector_pair); \ + } \ + rhs_ptr += (accRows * PEEL_MMA); + +#define MICRO_MMA_ONE \ + if (sizeof(Scalar) == sizeof(float)) { \ + MICRO_MMA_UNROLL_TYPE_ONE(MICRO_MMA_WORK_ONE, MICRO_MMA_LOAD_ONE, RhsPacket); \ + } else { \ + MICRO_MMA_UNROLL_TYPE_ONE(MICRO_MMA_WORK_ONE, MICRO_MMA_LOAD_ONE, __vector_pair); \ + } \ + rhs_ptr += accRows; + +#define MICRO_MMA_DST_PTR_ONE(iter) \ + if (unroll_factor > iter) { \ + bsetzeroMMA(&accZero##iter); \ + } else { \ + EIGEN_UNUSED_VARIABLE(accZero##iter); \ + } + +#define MICRO_MMA_DST_PTR MICRO_MMA_UNROLL(MICRO_MMA_DST_PTR_ONE) + +#define MICRO_MMA_SRC_PTR_ONE(iter) \ + if (unroll_factor > iter) { \ + lhs_ptr##iter = lhs_base + ( (row/accCols) + iter )*strideA*accCols + accCols*offsetA; \ + } else { \ + EIGEN_UNUSED_VARIABLE(lhs_ptr##iter); \ + } + +#define MICRO_MMA_SRC_PTR MICRO_MMA_UNROLL(MICRO_MMA_SRC_PTR_ONE) + +#define MICRO_MMA_PREFETCH_ONE(iter) \ + if (unroll_factor > iter) { \ + EIGEN_POWER_PREFETCH(lhs_ptr##iter); \ + } + +#define MICRO_MMA_PREFETCH MICRO_MMA_UNROLL(MICRO_MMA_PREFETCH_ONE) + +#define MICRO_MMA_STORE_ONE(iter) \ + if (unroll_factor > iter) { \ + storeAccumulator(row + iter*accCols, col, res, pAlpha, &accZero##iter); \ + } + +#define MICRO_MMA_STORE MICRO_MMA_UNROLL(MICRO_MMA_STORE_ONE) + +template +EIGEN_STRONG_INLINE void gemm_unrolled_MMA_iteration( + const DataMapper& res, + const Scalar* lhs_base, + const Scalar* rhs_base, + Index depth, + Index strideA, + Index offsetA, + Index& row, + Index col, + const Packet& pAlpha) +{ + const Scalar* rhs_ptr = rhs_base; + const Scalar* lhs_ptr0 = NULL, * lhs_ptr1 = NULL, * lhs_ptr2 = NULL, * lhs_ptr3 = NULL, * lhs_ptr4 = NULL, * lhs_ptr5 = NULL, * lhs_ptr6 = NULL, * lhs_ptr7 = NULL; + __vector_quad accZero0, accZero1, accZero2, accZero3, accZero4, accZero5, accZero6, accZero7; + + MICRO_MMA_SRC_PTR + MICRO_MMA_DST_PTR + + Index k = 0; + for(; k + PEEL_MMA <= depth; k+= PEEL_MMA) + { + EIGEN_POWER_PREFETCH(rhs_ptr); + MICRO_MMA_PREFETCH + MICRO_MMA_ONE_PEEL + } + for(; k < depth; k++) + { + MICRO_MMA_ONE + } + MICRO_MMA_STORE + + row += unroll_factor*accCols; +} + +template +void gemmMMA(const DataMapper& res, const Scalar* blockA, const Scalar* blockB, Index rows, Index depth, Index cols, Scalar alpha, Index strideA, Index strideB, Index offsetA, Index offsetB) +{ + const Index remaining_rows = rows % accCols; + const Index remaining_cols = cols % accRows; + + if( strideA == -1 ) strideA = depth; + if( strideB == -1 ) strideB = depth; + + const Packet pAlpha = pset1(alpha); + const Packet pMask = bmask((const int)(remaining_rows)); + + Index col = 0; + for(; col + accRows <= cols; col += accRows) + { + const Scalar* rhs_base = blockB + col*strideB + accRows*offsetB; + const Scalar* lhs_base = blockA; + + Index row = 0; +#define MAX_MMA_UNROLL 7 + while(row + MAX_MMA_UNROLL*accCols <= rows) { + gemm_unrolled_MMA_iteration(res, lhs_base, rhs_base, depth, strideA, offsetA, row, col, pAlpha); + } + switch( (rows-row)/accCols ) { +#if MAX_MMA_UNROLL > 7 + case 7: + gemm_unrolled_MMA_iteration<7, Scalar, Packet, RhsPacket, DataMapper, Index, accRows, accCols>(res, lhs_base, rhs_base, depth, strideA, offsetA, row, col, pAlpha); + break; +#endif +#if MAX_MMA_UNROLL > 6 + case 6: + gemm_unrolled_MMA_iteration<6, Scalar, Packet, RhsPacket, DataMapper, Index, accRows, accCols>(res, lhs_base, rhs_base, depth, strideA, offsetA, row, col, pAlpha); + break; +#endif +#if MAX_MMA_UNROLL > 5 + case 5: + gemm_unrolled_MMA_iteration<5, Scalar, Packet, RhsPacket, DataMapper, Index, accRows, accCols>(res, lhs_base, rhs_base, depth, strideA, offsetA, row, col, pAlpha); + break; +#endif +#if MAX_MMA_UNROLL > 4 + case 4: + gemm_unrolled_MMA_iteration<4, Scalar, Packet, RhsPacket, DataMapper, Index, accRows, accCols>(res, lhs_base, rhs_base, depth, strideA, offsetA, row, col, pAlpha); + break; +#endif +#if MAX_MMA_UNROLL > 3 + case 3: + gemm_unrolled_MMA_iteration<3, Scalar, Packet, RhsPacket, DataMapper, Index, accRows, accCols>(res, lhs_base, rhs_base, depth, strideA, offsetA, row, col, pAlpha); + break; +#endif +#if MAX_MMA_UNROLL > 2 + case 2: + gemm_unrolled_MMA_iteration<2, Scalar, Packet, RhsPacket, DataMapper, Index, accRows, accCols>(res, lhs_base, rhs_base, depth, strideA, offsetA, row, col, pAlpha); + break; +#endif +#if MAX_MMA_UNROLL > 1 + case 1: + gemm_unrolled_MMA_iteration<1, Scalar, Packet, RhsPacket, DataMapper, Index, accRows, accCols>(res, lhs_base, rhs_base, depth, strideA, offsetA, row, col, pAlpha); + break; +#endif + default: + break; + } +#undef MAX_MMA_UNROLL + + if(remaining_rows > 0) + { + gemm_extra_row(res, lhs_base, rhs_base, depth, strideA, offsetA, row, col, rows, cols, remaining_rows, pAlpha, pMask); + } + } + + if(remaining_cols > 0) + { + const Scalar* rhs_base = blockB + col*strideB + remaining_cols*offsetB; + const Scalar* lhs_base = blockA; + + for(; col < cols; col++) + { + Index row = 0; + + gemm_unrolled_col(res, lhs_base, rhs_base, depth, strideA, offsetA, row, rows, col, remaining_cols, pAlpha); + + if (remaining_rows > 0) + { + gemm_extra_col(res, lhs_base, rhs_base, depth, strideA, offsetA, row, col, remaining_rows, remaining_cols, pAlpha); + } + rhs_base++; + } + } +} + +#define accColsC (accCols / 2) +#define advanceRows ((LhsIsReal) ? 1 : 2) +#define advanceCols ((RhsIsReal) ? 1 : 2) + +// PEEL_COMPLEX_MMA loop factor. +#define PEEL_COMPLEX_MMA 7 + +#define MICRO_COMPLEX_MMA_UNROLL(func) \ + func(0) func(1) func(2) func(3) func(4) + +#define MICRO_COMPLEX_MMA_LOAD_ONE(iter) \ + if (unroll_factor > iter) { \ + lhsV##iter = ploadLhs(lhs_ptr_real##iter); \ + lhs_ptr_real##iter += accCols; \ + if(!LhsIsReal) { \ + lhsVi##iter = ploadLhs(lhs_ptr_imag##iter); \ + lhs_ptr_imag##iter += accCols; \ + } else { \ + EIGEN_UNUSED_VARIABLE(lhsVi##iter); \ + } \ + } else { \ + EIGEN_UNUSED_VARIABLE(lhsV##iter); \ + EIGEN_UNUSED_VARIABLE(lhsVi##iter); \ + } + +#define MICRO_COMPLEX_MMA_WORK_ONE(iter, type, peel) \ + if (unroll_factor > iter) { \ + pgercMMA(&accReal##iter, &accImag##iter, lhsV##iter, lhsVi##iter, rhsV##peel, rhsVi##peel); \ + } + +#define MICRO_COMPLEX_MMA_TYPE_PEEL(func, func2, type, peel) \ + if (PEEL_COMPLEX_MMA > peel) { \ + Packet lhsV0, lhsV1, lhsV2, lhsV3, lhsV4; \ + Packet lhsVi0, lhsVi1, lhsVi2, lhsVi3, lhsVi4; \ + ploadRhsMMA(rhs_ptr_real + (accRows * peel), rhsV##peel); \ + if(!RhsIsReal) { \ + ploadRhsMMA(rhs_ptr_imag + (accRows * peel), rhsVi##peel); \ + } else { \ + EIGEN_UNUSED_VARIABLE(rhsVi##peel); \ + } \ + MICRO_COMPLEX_MMA_UNROLL(func2); \ + func(0,type,peel) func(1,type,peel) func(2,type,peel) func(3,type,peel) func(4,type,peel) \ + } else { \ + EIGEN_UNUSED_VARIABLE(rhsV##peel); \ + EIGEN_UNUSED_VARIABLE(rhsVi##peel); \ + } + +#define MICRO_COMPLEX_MMA_UNROLL_TYPE_PEEL(func, func2, type) \ + type rhsV0, rhsV1, rhsV2, rhsV3, rhsV4, rhsV5, rhsV6, rhsV7, rhsV8, rhsV9; \ + type rhsVi0, rhsVi1, rhsVi2, rhsVi3, rhsVi4, rhsVi5, rhsVi6, rhsVi7, rhsVi8, rhsVi9; \ + MICRO_COMPLEX_MMA_TYPE_PEEL(func,func2,type,0); MICRO_COMPLEX_MMA_TYPE_PEEL(func,func2,type,1); \ + MICRO_COMPLEX_MMA_TYPE_PEEL(func,func2,type,2); MICRO_COMPLEX_MMA_TYPE_PEEL(func,func2,type,3); \ + MICRO_COMPLEX_MMA_TYPE_PEEL(func,func2,type,4); MICRO_COMPLEX_MMA_TYPE_PEEL(func,func2,type,5); \ + MICRO_COMPLEX_MMA_TYPE_PEEL(func,func2,type,6); MICRO_COMPLEX_MMA_TYPE_PEEL(func,func2,type,7); \ + MICRO_COMPLEX_MMA_TYPE_PEEL(func,func2,type,8); MICRO_COMPLEX_MMA_TYPE_PEEL(func,func2,type,9); + +#define MICRO_COMPLEX_MMA_UNROLL_TYPE_ONE(func, func2, type) \ + type rhsV0, rhsVi0; \ + MICRO_COMPLEX_MMA_TYPE_PEEL(func,func2,type,0); + +#define MICRO_COMPLEX_MMA_ONE_PEEL \ + if (sizeof(Scalar) == sizeof(float)) { \ + MICRO_COMPLEX_MMA_UNROLL_TYPE_PEEL(MICRO_COMPLEX_MMA_WORK_ONE, MICRO_COMPLEX_MMA_LOAD_ONE, RhsPacket); \ + } else { \ + MICRO_COMPLEX_MMA_UNROLL_TYPE_PEEL(MICRO_COMPLEX_MMA_WORK_ONE, MICRO_COMPLEX_MMA_LOAD_ONE, __vector_pair); \ + } \ + rhs_ptr_real += (accRows * PEEL_COMPLEX_MMA); \ + if(!RhsIsReal) rhs_ptr_imag += (accRows * PEEL_COMPLEX_MMA); + +#define MICRO_COMPLEX_MMA_ONE \ + if (sizeof(Scalar) == sizeof(float)) { \ + MICRO_COMPLEX_MMA_UNROLL_TYPE_ONE(MICRO_COMPLEX_MMA_WORK_ONE, MICRO_COMPLEX_MMA_LOAD_ONE, RhsPacket); \ + } else { \ + MICRO_COMPLEX_MMA_UNROLL_TYPE_ONE(MICRO_COMPLEX_MMA_WORK_ONE, MICRO_COMPLEX_MMA_LOAD_ONE, __vector_pair); \ + } \ + rhs_ptr_real += accRows; \ + if(!RhsIsReal) rhs_ptr_imag += accRows; + +#define MICRO_COMPLEX_MMA_DST_PTR_ONE(iter) \ + if (unroll_factor > iter) { \ + bsetzeroMMA(&accReal##iter); \ + bsetzeroMMA(&accImag##iter); \ + } else { \ + EIGEN_UNUSED_VARIABLE(accReal##iter); \ + EIGEN_UNUSED_VARIABLE(accImag##iter); \ + } + +#define MICRO_COMPLEX_MMA_DST_PTR MICRO_COMPLEX_MMA_UNROLL(MICRO_COMPLEX_MMA_DST_PTR_ONE) + +#define MICRO_COMPLEX_MMA_SRC_PTR_ONE(iter) \ + if (unroll_factor > iter) { \ + lhs_ptr_real##iter = lhs_base + ( ((advanceRows*row)/accCols) + iter*advanceRows )*strideA*accCols + accCols*offsetA; \ + if(!LhsIsReal) { \ + lhs_ptr_imag##iter = lhs_ptr_real##iter + accCols*strideA; \ + } else { \ + EIGEN_UNUSED_VARIABLE(lhs_ptr_imag##iter); \ + } \ + } else { \ + EIGEN_UNUSED_VARIABLE(lhs_ptr_real##iter); \ + EIGEN_UNUSED_VARIABLE(lhs_ptr_imag##iter); \ + } + +#define MICRO_COMPLEX_MMA_SRC_PTR MICRO_COMPLEX_MMA_UNROLL(MICRO_COMPLEX_MMA_SRC_PTR_ONE) + +#define MICRO_COMPLEX_MMA_PREFETCH_ONE(iter) \ + if (unroll_factor > iter) { \ + EIGEN_POWER_PREFETCH(lhs_ptr_real##iter); \ + if(!LhsIsReal) { \ + EIGEN_POWER_PREFETCH(lhs_ptr_imag##iter); \ + } \ + } + +#define MICRO_COMPLEX_MMA_PREFETCH MICRO_COMPLEX_MMA_UNROLL(MICRO_COMPLEX_MMA_PREFETCH_ONE) + +#define MICRO_COMPLEX_MMA_STORE_ONE(iter) \ + if (unroll_factor > iter) { \ + storeComplexAccumulator(row + iter*accCols, col, res, pAlphaReal, pAlphaImag, &accReal##iter, &accImag##iter); \ + } + +#define MICRO_COMPLEX_MMA_STORE MICRO_COMPLEX_MMA_UNROLL(MICRO_COMPLEX_MMA_STORE_ONE) + +template +EIGEN_STRONG_INLINE void gemm_complex_unrolled_MMA_iteration( + const DataMapper& res, + const Scalar* lhs_base, + const Scalar* rhs_base, + Index depth, + Index strideA, + Index offsetA, + Index strideB, + Index& row, + Index col, + const Packet& pAlphaReal, + const Packet& pAlphaImag) +{ + const Scalar* rhs_ptr_real = rhs_base; + const Scalar* rhs_ptr_imag; + if(!RhsIsReal) { + rhs_ptr_imag = rhs_base + accRows*strideB; + } else { + EIGEN_UNUSED_VARIABLE(rhs_ptr_imag); + } + const Scalar* lhs_ptr_real0 = NULL, * lhs_ptr_imag0 = NULL, * lhs_ptr_real1 = NULL, * lhs_ptr_imag1 = NULL; + const Scalar* lhs_ptr_real2 = NULL, * lhs_ptr_imag2 = NULL, * lhs_ptr_real3 = NULL, * lhs_ptr_imag3 = NULL; + const Scalar* lhs_ptr_real4 = NULL, * lhs_ptr_imag4 = NULL; + __vector_quad accReal0, accImag0, accReal1, accImag1, accReal2, accImag2, accReal3, accImag3, accReal4, accImag4; + + MICRO_COMPLEX_MMA_SRC_PTR + MICRO_COMPLEX_MMA_DST_PTR + + Index k = 0; + for(; k + PEEL_COMPLEX_MMA <= depth; k+= PEEL_COMPLEX_MMA) + { + EIGEN_POWER_PREFETCH(rhs_ptr_real); + if(!RhsIsReal) { + EIGEN_POWER_PREFETCH(rhs_ptr_imag); + } + MICRO_COMPLEX_MMA_PREFETCH + MICRO_COMPLEX_MMA_ONE_PEEL + } + for(; k < depth; k++) + { + MICRO_COMPLEX_MMA_ONE + } + MICRO_COMPLEX_MMA_STORE + + row += unroll_factor*accCols; +} + +template +void gemm_complexMMA(const DataMapper& res, const LhsScalar* blockAc, const RhsScalar* blockBc, Index rows, Index depth, Index cols, Scalarc alpha, Index strideA, Index strideB, Index offsetA, Index offsetB) +{ + const Index remaining_rows = rows % accCols; + const Index remaining_cols = cols % accRows; + + if( strideA == -1 ) strideA = depth; + if( strideB == -1 ) strideB = depth; + + const Packet pAlphaReal = pset1(alpha.real()); + const Packet pAlphaImag = pset1(alpha.imag()); + const Packet pMask = bmask((const int)(remaining_rows)); + + const Scalar* blockA = (Scalar *) blockAc; + const Scalar* blockB = (Scalar *) blockBc; + + Index col = 0; + for(; col + accRows <= cols; col += accRows) + { + const Scalar* rhs_base = blockB + advanceCols*col*strideB + accRows*offsetB; + const Scalar* lhs_base = blockA; + Index row = 0; + +#define MAX_COMPLEX_MMA_UNROLL 4 + while(row + MAX_COMPLEX_MMA_UNROLL*accCols <= rows) { + gemm_complex_unrolled_MMA_iteration(res, lhs_base, rhs_base, depth, strideA, offsetA, strideB, row, col, pAlphaReal, pAlphaImag); + } + switch( (rows-row)/accCols ) { +#if MAX_COMPLEX_MMA_UNROLL > 4 + case 4: + gemm_complex_unrolled_MMA_iteration<4, Scalar, Packet, Packetc, RhsPacket, DataMapper, Index, accRows, accCols, ConjugateLhs, ConjugateRhs, LhsIsReal, RhsIsReal>(res, lhs_base, rhs_base, depth, strideA, offsetA, strideB, row, col, pAlphaReal, pAlphaImag); + break; +#endif +#if MAX_COMPLEX_MMA_UNROLL > 3 + case 3: + gemm_complex_unrolled_MMA_iteration<3, Scalar, Packet, Packetc, RhsPacket, DataMapper, Index, accRows, accCols, ConjugateLhs, ConjugateRhs, LhsIsReal, RhsIsReal>(res, lhs_base, rhs_base, depth, strideA, offsetA, strideB, row, col, pAlphaReal, pAlphaImag); + break; +#endif +#if MAX_COMPLEX_MMA_UNROLL > 2 + case 2: + gemm_complex_unrolled_MMA_iteration<2, Scalar, Packet, Packetc, RhsPacket, DataMapper, Index, accRows, accCols, ConjugateLhs, ConjugateRhs, LhsIsReal, RhsIsReal>(res, lhs_base, rhs_base, depth, strideA, offsetA, strideB, row, col, pAlphaReal, pAlphaImag); + break; +#endif +#if MAX_COMPLEX_MMA_UNROLL > 1 + case 1: + gemm_complex_unrolled_MMA_iteration<1, Scalar, Packet, Packetc, RhsPacket, DataMapper, Index, accRows, accCols, ConjugateLhs, ConjugateRhs, LhsIsReal, RhsIsReal>(res, lhs_base, rhs_base, depth, strideA, offsetA, strideB, row, col, pAlphaReal, pAlphaImag); + break; +#endif + default: + break; + } +#undef MAX_COMPLEX_MMA_UNROLL + + if(remaining_rows > 0) + { + gemm_complex_extra_row(res, lhs_base, rhs_base, depth, strideA, offsetA, strideB, row, col, rows, cols, remaining_rows, pAlphaReal, pAlphaImag, pMask); + } + } + + if(remaining_cols > 0) + { + const Scalar* rhs_base = blockB + advanceCols*col*strideB + remaining_cols*offsetB; + const Scalar* lhs_base = blockA; + + for(; col < cols; col++) + { + Index row = 0; + + gemm_complex_unrolled_col(res, lhs_base, rhs_base, depth, strideA, offsetA, strideB, row, rows, col, remaining_cols, pAlphaReal, pAlphaImag); + + if (remaining_rows > 0) + { + gemm_complex_extra_col(res, lhs_base, rhs_base, depth, strideA, offsetA, strideB, row, col, remaining_rows, remaining_cols, pAlphaReal, pAlphaImag); + } + rhs_base++; + } + } +} + +#undef accColsC +#undef advanceRows +#undef advanceCols + +#pragma GCC reset_options +} // end namespace internal + +} // end namespace Eigen + +#endif // EIGEN_MATRIX_PRODUCT_MMA_ALTIVEC_H + diff --git a/Eigen/src/Core/arch/AltiVec/PacketMath.h b/Eigen/src/Core/arch/AltiVec/PacketMath.h new file mode 100755 index 0000000..2a44054 --- /dev/null +++ b/Eigen/src/Core/arch/AltiVec/PacketMath.h @@ -0,0 +1,2711 @@ +// This file is part of Eigen, a lightweight C++ template library +// for linear algebra. +// +// Copyright (C) 2008-2016 Konstantinos Margaritis +// +// This Source Code Form is subject to the terms of the Mozilla +// Public License v. 2.0. If a copy of the MPL was not distributed +// with this file, You can obtain one at http://mozilla.org/MPL/2.0/. + +#ifndef EIGEN_PACKET_MATH_ALTIVEC_H +#define EIGEN_PACKET_MATH_ALTIVEC_H + +namespace Eigen { + +namespace internal { + +#ifndef EIGEN_CACHEFRIENDLY_PRODUCT_THRESHOLD +#define EIGEN_CACHEFRIENDLY_PRODUCT_THRESHOLD 4 +#endif + +#ifndef EIGEN_HAS_SINGLE_INSTRUCTION_MADD +#define EIGEN_HAS_SINGLE_INSTRUCTION_MADD +#endif + +// NOTE Altivec has 32 registers, but Eigen only accepts a value of 8 or 16 +#ifndef EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS +#define EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS 32 +#endif + +typedef __vector float Packet4f; +typedef __vector int Packet4i; +typedef __vector unsigned int Packet4ui; +typedef __vector __bool int Packet4bi; +typedef __vector short int Packet8s; +typedef __vector unsigned short int Packet8us; +typedef __vector signed char Packet16c; +typedef __vector unsigned char Packet16uc; +typedef eigen_packet_wrapper<__vector unsigned short int,0> Packet8bf; + +// We don't want to write the same code all the time, but we need to reuse the constants +// and it doesn't really work to declare them global, so we define macros instead +#define _EIGEN_DECLARE_CONST_FAST_Packet4f(NAME,X) \ + Packet4f p4f_##NAME = {X, X, X, X} + +#define _EIGEN_DECLARE_CONST_FAST_Packet4i(NAME,X) \ + Packet4i p4i_##NAME = vec_splat_s32(X) + +#define _EIGEN_DECLARE_CONST_FAST_Packet4ui(NAME,X) \ + Packet4ui p4ui_##NAME = {X, X, X, X} + +#define _EIGEN_DECLARE_CONST_FAST_Packet8us(NAME,X) \ + Packet8us p8us_##NAME = {X, X, X, X, X, X, X, X} + +#define _EIGEN_DECLARE_CONST_FAST_Packet16uc(NAME,X) \ + Packet16uc p16uc_##NAME = {X, X, X, X, X, X, X, X, X, X, X, X, X, X, X, X} + +#define _EIGEN_DECLARE_CONST_Packet4f(NAME,X) \ + Packet4f p4f_##NAME = pset1(X) + +#define _EIGEN_DECLARE_CONST_Packet4i(NAME,X) \ + Packet4i p4i_##NAME = pset1(X) + +#define _EIGEN_DECLARE_CONST_Packet2d(NAME,X) \ + Packet2d p2d_##NAME = pset1(X) + +#define _EIGEN_DECLARE_CONST_Packet2l(NAME,X) \ + Packet2l p2l_##NAME = pset1(X) + +#define _EIGEN_DECLARE_CONST_Packet4f_FROM_INT(NAME,X) \ + const Packet4f p4f_##NAME = reinterpret_cast(pset1(X)) + +#define DST_CHAN 1 +#define DST_CTRL(size, count, stride) (((size) << 24) | ((count) << 16) | (stride)) +#define __UNPACK_TYPE__(PACKETNAME) typename unpacket_traits::type + +// These constants are endian-agnostic +static _EIGEN_DECLARE_CONST_FAST_Packet4f(ZERO, 0); //{ 0.0, 0.0, 0.0, 0.0} +static _EIGEN_DECLARE_CONST_FAST_Packet4i(ZERO, 0); //{ 0, 0, 0, 0,} +static _EIGEN_DECLARE_CONST_FAST_Packet4i(ONE,1); //{ 1, 1, 1, 1} +static _EIGEN_DECLARE_CONST_FAST_Packet4i(MINUS16,-16); //{ -16, -16, -16, -16} +static _EIGEN_DECLARE_CONST_FAST_Packet4i(MINUS1,-1); //{ -1, -1, -1, -1} +static _EIGEN_DECLARE_CONST_FAST_Packet4ui(SIGN, 0x80000000u); +static _EIGEN_DECLARE_CONST_FAST_Packet4ui(PREV0DOT5, 0x3EFFFFFFu); +static _EIGEN_DECLARE_CONST_FAST_Packet8us(ONE,1); //{ 1, 1, 1, 1, 1, 1, 1, 1} +static _EIGEN_DECLARE_CONST_FAST_Packet16uc(ONE,1); +static Packet4f p4f_MZERO = (Packet4f) vec_sl((Packet4ui)p4i_MINUS1, (Packet4ui)p4i_MINUS1); //{ 0x80000000, 0x80000000, 0x80000000, 0x80000000} +#ifndef __VSX__ +static Packet4f p4f_ONE = vec_ctf(p4i_ONE, 0); //{ 1.0, 1.0, 1.0, 1.0} +#endif + +static Packet4f p4f_COUNTDOWN = { 0.0, 1.0, 2.0, 3.0 }; +static Packet4i p4i_COUNTDOWN = { 0, 1, 2, 3 }; +static Packet8s p8s_COUNTDOWN = { 0, 1, 2, 3, 4, 5, 6, 7 }; +static Packet8us p8us_COUNTDOWN = { 0, 1, 2, 3, 4, 5, 6, 7 }; + +static Packet16c p16c_COUNTDOWN = { 0, 1, 2, 3, 4, 5, 6, 7, + 8, 9, 10, 11, 12, 13, 14, 15}; +static Packet16uc p16uc_COUNTDOWN = { 0, 1, 2, 3, 4, 5, 6, 7, + 8, 9, 10, 11, 12, 13, 14, 15}; + +static Packet16uc p16uc_REVERSE32 = { 12,13,14,15, 8,9,10,11, 4,5,6,7, 0,1,2,3 }; +static Packet16uc p16uc_REVERSE16 = { 14,15, 12,13, 10,11, 8,9, 6,7, 4,5, 2,3, 0,1 }; +static Packet16uc p16uc_REVERSE8 = { 15,14,13,12,11,10,9,8,7,6,5,4,3,2,1,0 }; + +static Packet16uc p16uc_DUPLICATE32_HI = { 0,1,2,3, 0,1,2,3, 4,5,6,7, 4,5,6,7 }; +static Packet16uc p16uc_DUPLICATE16_HI = { 0,1,0,1, 2,3,2,3, 4,5,4,5, 6,7,6,7 }; +static Packet16uc p16uc_DUPLICATE8_HI = { 0,0, 1,1, 2,2, 3,3, 4,4, 5,5, 6,6, 7,7 }; +static const Packet16uc p16uc_DUPLICATE16_EVEN= { 0,1 ,0,1, 4,5, 4,5, 8,9, 8,9, 12,13, 12,13 }; +static const Packet16uc p16uc_DUPLICATE16_ODD = { 2,3 ,2,3, 6,7, 6,7, 10,11, 10,11, 14,15, 14,15 }; + +static Packet16uc p16uc_QUADRUPLICATE16_HI = { 0,1,0,1,0,1,0,1, 2,3,2,3,2,3,2,3 }; + +// Handle endianness properly while loading constants +// Define global static constants: +#ifdef _BIG_ENDIAN +static Packet16uc p16uc_FORWARD = vec_lvsl(0, (float*)0); +#ifdef __VSX__ +static Packet16uc p16uc_REVERSE64 = { 8,9,10,11, 12,13,14,15, 0,1,2,3, 4,5,6,7 }; +#endif +static Packet16uc p16uc_PSET32_WODD = vec_sld((Packet16uc) vec_splat((Packet4ui)p16uc_FORWARD, 0), (Packet16uc) vec_splat((Packet4ui)p16uc_FORWARD, 2), 8);//{ 0,1,2,3, 0,1,2,3, 8,9,10,11, 8,9,10,11 }; +static Packet16uc p16uc_PSET32_WEVEN = vec_sld(p16uc_DUPLICATE32_HI, (Packet16uc) vec_splat((Packet4ui)p16uc_FORWARD, 3), 8);//{ 4,5,6,7, 4,5,6,7, 12,13,14,15, 12,13,14,15 }; +static Packet16uc p16uc_HALF64_0_16 = vec_sld((Packet16uc)p4i_ZERO, vec_splat((Packet16uc) vec_abs(p4i_MINUS16), 3), 8); //{ 0,0,0,0, 0,0,0,0, 16,16,16,16, 16,16,16,16}; +#else +static Packet16uc p16uc_FORWARD = p16uc_REVERSE32; +static Packet16uc p16uc_REVERSE64 = { 8,9,10,11, 12,13,14,15, 0,1,2,3, 4,5,6,7 }; +static Packet16uc p16uc_PSET32_WODD = vec_sld((Packet16uc) vec_splat((Packet4ui)p16uc_FORWARD, 1), (Packet16uc) vec_splat((Packet4ui)p16uc_FORWARD, 3), 8);//{ 0,1,2,3, 0,1,2,3, 8,9,10,11, 8,9,10,11 }; +static Packet16uc p16uc_PSET32_WEVEN = vec_sld((Packet16uc) vec_splat((Packet4ui)p16uc_FORWARD, 0), (Packet16uc) vec_splat((Packet4ui)p16uc_FORWARD, 2), 8);//{ 4,5,6,7, 4,5,6,7, 12,13,14,15, 12,13,14,15 }; +static Packet16uc p16uc_HALF64_0_16 = vec_sld(vec_splat((Packet16uc) vec_abs(p4i_MINUS16), 0), (Packet16uc)p4i_ZERO, 8); //{ 0,0,0,0, 0,0,0,0, 16,16,16,16, 16,16,16,16}; +#endif // _BIG_ENDIAN + +static Packet16uc p16uc_PSET64_HI = (Packet16uc) vec_mergeh((Packet4ui)p16uc_PSET32_WODD, (Packet4ui)p16uc_PSET32_WEVEN); //{ 0,1,2,3, 4,5,6,7, 0,1,2,3, 4,5,6,7 }; +static Packet16uc p16uc_PSET64_LO = (Packet16uc) vec_mergel((Packet4ui)p16uc_PSET32_WODD, (Packet4ui)p16uc_PSET32_WEVEN); //{ 8,9,10,11, 12,13,14,15, 8,9,10,11, 12,13,14,15 }; +static Packet16uc p16uc_TRANSPOSE64_HI = p16uc_PSET64_HI + p16uc_HALF64_0_16; //{ 0,1,2,3, 4,5,6,7, 16,17,18,19, 20,21,22,23}; +static Packet16uc p16uc_TRANSPOSE64_LO = p16uc_PSET64_LO + p16uc_HALF64_0_16; //{ 8,9,10,11, 12,13,14,15, 24,25,26,27, 28,29,30,31}; + +static Packet16uc p16uc_COMPLEX32_REV = vec_sld(p16uc_REVERSE32, p16uc_REVERSE32, 8); //{ 4,5,6,7, 0,1,2,3, 12,13,14,15, 8,9,10,11 }; + +#ifdef _BIG_ENDIAN +static Packet16uc p16uc_COMPLEX32_REV2 = vec_sld(p16uc_FORWARD, p16uc_FORWARD, 8); //{ 8,9,10,11, 12,13,14,15, 0,1,2,3, 4,5,6,7 }; +#else +static Packet16uc p16uc_COMPLEX32_REV2 = vec_sld(p16uc_PSET64_HI, p16uc_PSET64_LO, 8); //{ 8,9,10,11, 12,13,14,15, 0,1,2,3, 4,5,6,7 }; +#endif // _BIG_ENDIAN + +#if EIGEN_HAS_BUILTIN(__builtin_prefetch) || EIGEN_COMP_GNUC + #define EIGEN_PPC_PREFETCH(ADDR) __builtin_prefetch(ADDR); +#else + #define EIGEN_PPC_PREFETCH(ADDR) asm( " dcbt [%[addr]]\n" :: [addr] "r" (ADDR) : "cc" ); +#endif + +template <> +struct packet_traits : default_packet_traits { + typedef Packet4f type; + typedef Packet4f half; + enum { + Vectorizable = 1, + AlignedOnScalar = 1, + size = 4, + HasHalfPacket = 1, + + HasAdd = 1, + HasSub = 1, + HasMul = 1, + HasDiv = 1, + HasMin = 1, + HasMax = 1, + HasAbs = 1, + HasSin = EIGEN_FAST_MATH, + HasCos = EIGEN_FAST_MATH, + HasLog = 1, + HasExp = 1, +#ifdef __VSX__ + HasSqrt = 1, +#if !EIGEN_COMP_CLANG + HasRsqrt = 1, +#else + HasRsqrt = 0, +#endif +#else + HasSqrt = 0, + HasRsqrt = 0, + HasTanh = EIGEN_FAST_MATH, + HasErf = EIGEN_FAST_MATH, +#endif + HasRound = 1, + HasFloor = 1, + HasCeil = 1, + HasRint = 1, + HasNegate = 1, + HasBlend = 1 + }; +}; +template <> +struct packet_traits : default_packet_traits { + typedef Packet8bf type; + typedef Packet8bf half; + enum { + Vectorizable = 1, + AlignedOnScalar = 1, + size = 8, + HasHalfPacket = 0, + + HasAdd = 1, + HasSub = 1, + HasMul = 1, + HasDiv = 1, + HasMin = 1, + HasMax = 1, + HasAbs = 1, + HasSin = EIGEN_FAST_MATH, + HasCos = EIGEN_FAST_MATH, + HasLog = 1, + HasExp = 1, +#ifdef __VSX__ + HasSqrt = 1, +#if !EIGEN_COMP_CLANG + HasRsqrt = 1, +#else + HasRsqrt = 0, +#endif +#else + HasSqrt = 0, + HasRsqrt = 0, + HasTanh = EIGEN_FAST_MATH, + HasErf = EIGEN_FAST_MATH, +#endif + HasRound = 1, + HasFloor = 1, + HasCeil = 1, + HasRint = 1, + HasNegate = 1, + HasBlend = 1 + }; +}; + +template <> +struct packet_traits : default_packet_traits { + typedef Packet4i type; + typedef Packet4i half; + enum { + Vectorizable = 1, + AlignedOnScalar = 1, + size = 4, + HasHalfPacket = 0, + + HasAdd = 1, + HasSub = 1, + HasShift = 1, + HasMul = 1, + HasDiv = 0, + HasBlend = 1 + }; +}; + +template <> +struct packet_traits : default_packet_traits { + typedef Packet8s type; + typedef Packet8s half; + enum { + Vectorizable = 1, + AlignedOnScalar = 1, + size = 8, + HasHalfPacket = 0, + + HasAdd = 1, + HasSub = 1, + HasMul = 1, + HasDiv = 0, + HasBlend = 1 + }; +}; + +template <> +struct packet_traits : default_packet_traits { + typedef Packet8us type; + typedef Packet8us half; + enum { + Vectorizable = 1, + AlignedOnScalar = 1, + size = 8, + HasHalfPacket = 0, + + HasAdd = 1, + HasSub = 1, + HasMul = 1, + HasDiv = 0, + HasBlend = 1 + }; +}; + +template <> +struct packet_traits : default_packet_traits { + typedef Packet16c type; + typedef Packet16c half; + enum { + Vectorizable = 1, + AlignedOnScalar = 1, + size = 16, + HasHalfPacket = 0, + + HasAdd = 1, + HasSub = 1, + HasMul = 1, + HasDiv = 0, + HasBlend = 1 + }; +}; + +template <> +struct packet_traits : default_packet_traits { + typedef Packet16uc type; + typedef Packet16uc half; + enum { + Vectorizable = 1, + AlignedOnScalar = 1, + size = 16, + HasHalfPacket = 0, + + HasAdd = 1, + HasSub = 1, + HasMul = 1, + HasDiv = 0, + HasBlend = 1 + }; +}; + +template<> struct unpacket_traits +{ + typedef float type; + typedef Packet4f half; + typedef Packet4i integer_packet; + enum {size=4, alignment=Aligned16, vectorizable=true, masked_load_available=false, masked_store_available=false}; +}; +template<> struct unpacket_traits +{ + typedef int type; + typedef Packet4i half; + enum {size=4, alignment=Aligned16, vectorizable=true, masked_load_available=false, masked_store_available=false}; +}; +template<> struct unpacket_traits +{ + typedef short int type; + typedef Packet8s half; + enum {size=8, alignment=Aligned16, vectorizable=true, masked_load_available=false, masked_store_available=false}; +}; +template<> struct unpacket_traits +{ + typedef unsigned short int type; + typedef Packet8us half; + enum {size=8, alignment=Aligned16, vectorizable=true, masked_load_available=false, masked_store_available=false}; +}; + +template<> struct unpacket_traits +{ + typedef signed char type; + typedef Packet16c half; + enum {size=16, alignment=Aligned16, vectorizable=true, masked_load_available=false, masked_store_available=false}; +}; +template<> struct unpacket_traits +{ + typedef unsigned char type; + typedef Packet16uc half; + enum {size=16, alignment=Aligned16, vectorizable=true, masked_load_available=false, masked_store_available=false}; +}; + +template<> struct unpacket_traits +{ + typedef bfloat16 type; + typedef Packet8bf half; + enum {size=8, alignment=Aligned16, vectorizable=true, masked_load_available=false, masked_store_available=false}; +}; +inline std::ostream & operator <<(std::ostream & s, const Packet16c & v) +{ + union { + Packet16c v; + signed char n[16]; + } vt; + vt.v = v; + for (int i=0; i< 16; i++) + s << vt.n[i] << ", "; + return s; +} + +inline std::ostream & operator <<(std::ostream & s, const Packet16uc & v) +{ + union { + Packet16uc v; + unsigned char n[16]; + } vt; + vt.v = v; + for (int i=0; i< 16; i++) + s << vt.n[i] << ", "; + return s; +} + +inline std::ostream & operator <<(std::ostream & s, const Packet4f & v) +{ + union { + Packet4f v; + float n[4]; + } vt; + vt.v = v; + s << vt.n[0] << ", " << vt.n[1] << ", " << vt.n[2] << ", " << vt.n[3]; + return s; +} + +inline std::ostream & operator <<(std::ostream & s, const Packet4i & v) +{ + union { + Packet4i v; + int n[4]; + } vt; + vt.v = v; + s << vt.n[0] << ", " << vt.n[1] << ", " << vt.n[2] << ", " << vt.n[3]; + return s; +} + +inline std::ostream & operator <<(std::ostream & s, const Packet4ui & v) +{ + union { + Packet4ui v; + unsigned int n[4]; + } vt; + vt.v = v; + s << vt.n[0] << ", " << vt.n[1] << ", " << vt.n[2] << ", " << vt.n[3]; + return s; +} + +template +EIGEN_STRONG_INLINE Packet pload_common(const __UNPACK_TYPE__(Packet)* from) +{ + // some versions of GCC throw "unused-but-set-parameter". + // ignoring these warnings for now. + EIGEN_UNUSED_VARIABLE(from); + EIGEN_DEBUG_ALIGNED_LOAD +#ifdef __VSX__ + return vec_xl(0, const_cast<__UNPACK_TYPE__(Packet)*>(from)); +#else + return vec_ld(0, from); +#endif +} + +// Need to define them first or we get specialization after instantiation errors +template<> EIGEN_STRONG_INLINE Packet4f pload(const float* from) +{ + return pload_common(from); +} + +template<> EIGEN_STRONG_INLINE Packet4i pload(const int* from) +{ + return pload_common(from); +} + +template<> EIGEN_STRONG_INLINE Packet8s pload(const short int* from) +{ + return pload_common(from); +} + +template<> EIGEN_STRONG_INLINE Packet8us pload(const unsigned short int* from) +{ + return pload_common(from); +} + +template<> EIGEN_STRONG_INLINE Packet16c pload(const signed char* from) +{ + return pload_common(from); +} + +template<> EIGEN_STRONG_INLINE Packet16uc pload(const unsigned char* from) +{ + return pload_common(from); +} + +template<> EIGEN_STRONG_INLINE Packet8bf pload(const bfloat16* from) +{ + return pload_common(reinterpret_cast(from)); +} + +template +EIGEN_STRONG_INLINE void pstore_common(__UNPACK_TYPE__(Packet)* to, const Packet& from){ + // some versions of GCC throw "unused-but-set-parameter" (float *to). + // ignoring these warnings for now. + EIGEN_UNUSED_VARIABLE(to); + EIGEN_DEBUG_ALIGNED_STORE +#ifdef __VSX__ + vec_xst(from, 0, to); +#else + vec_st(from, 0, to); +#endif +} + +template<> EIGEN_STRONG_INLINE void pstore(float* to, const Packet4f& from) +{ + pstore_common(to, from); +} + +template<> EIGEN_STRONG_INLINE void pstore(int* to, const Packet4i& from) +{ + pstore_common(to, from); +} + +template<> EIGEN_STRONG_INLINE void pstore(short int* to, const Packet8s& from) +{ + pstore_common(to, from); +} + +template<> EIGEN_STRONG_INLINE void pstore(unsigned short int* to, const Packet8us& from) +{ + pstore_common(to, from); +} + +template<> EIGEN_STRONG_INLINE void pstore(bfloat16* to, const Packet8bf& from) +{ + pstore_common(reinterpret_cast(to), from); +} + +template<> EIGEN_STRONG_INLINE void pstore(signed char* to, const Packet16c& from) +{ + pstore_common(to, from); +} + +template<> EIGEN_STRONG_INLINE void pstore(unsigned char* to, const Packet16uc& from) +{ + pstore_common(to, from); +} + +template +EIGEN_STRONG_INLINE Packet pset1_size4(const __UNPACK_TYPE__(Packet)& from) +{ + Packet v = {from, from, from, from}; + return v; +} + +template +EIGEN_STRONG_INLINE Packet pset1_size8(const __UNPACK_TYPE__(Packet)& from) +{ + Packet v = {from, from, from, from, from, from, from, from}; + return v; +} + +template +EIGEN_STRONG_INLINE Packet pset1_size16(const __UNPACK_TYPE__(Packet)& from) +{ + Packet v = {from, from, from, from, from, from, from, from, from, from, from, from, from, from, from, from}; + return v; +} + +template<> EIGEN_STRONG_INLINE Packet4f pset1(const float& from) { + return pset1_size4(from); +} + +template<> EIGEN_STRONG_INLINE Packet4i pset1(const int& from) { + return pset1_size4(from); +} + +template<> EIGEN_STRONG_INLINE Packet8s pset1(const short int& from) { + return pset1_size8(from); +} + +template<> EIGEN_STRONG_INLINE Packet8us pset1(const unsigned short int& from) { + return pset1_size8(from); +} + +template<> EIGEN_STRONG_INLINE Packet16c pset1(const signed char& from) { + return pset1_size16(from); +} + +template<> EIGEN_STRONG_INLINE Packet16uc pset1(const unsigned char& from) { + return pset1_size16(from); +} + +template<> EIGEN_STRONG_INLINE Packet4f pset1frombits(unsigned int from) { + return reinterpret_cast(pset1(from)); +} + +template<> EIGEN_STRONG_INLINE Packet8bf pset1(const bfloat16& from) { + return pset1_size8(reinterpret_cast(from)); +} + +template EIGEN_STRONG_INLINE void +pbroadcast4_common(const __UNPACK_TYPE__(Packet) *a, + Packet& a0, Packet& a1, Packet& a2, Packet& a3) +{ + a3 = pload(a); + a0 = vec_splat(a3, 0); + a1 = vec_splat(a3, 1); + a2 = vec_splat(a3, 2); + a3 = vec_splat(a3, 3); +} + +template<> EIGEN_STRONG_INLINE void +pbroadcast4(const float *a, + Packet4f& a0, Packet4f& a1, Packet4f& a2, Packet4f& a3) +{ + pbroadcast4_common(a, a0, a1, a2, a3); +} +template<> EIGEN_STRONG_INLINE void +pbroadcast4(const int *a, + Packet4i& a0, Packet4i& a1, Packet4i& a2, Packet4i& a3) +{ + pbroadcast4_common(a, a0, a1, a2, a3); +} + +template EIGEN_DEVICE_FUNC inline Packet pgather_common(const __UNPACK_TYPE__(Packet)* from, Index stride) +{ + EIGEN_ALIGN16 __UNPACK_TYPE__(Packet) a[4]; + a[0] = from[0*stride]; + a[1] = from[1*stride]; + a[2] = from[2*stride]; + a[3] = from[3*stride]; + return pload(a); +} + +template<> EIGEN_DEVICE_FUNC inline Packet4f pgather(const float* from, Index stride) +{ + return pgather_common(from, stride); +} + +template<> EIGEN_DEVICE_FUNC inline Packet4i pgather(const int* from, Index stride) +{ + return pgather_common(from, stride); +} + +template EIGEN_DEVICE_FUNC inline Packet pgather_size8(const __UNPACK_TYPE__(Packet)* from, Index stride) +{ + EIGEN_ALIGN16 __UNPACK_TYPE__(Packet) a[8]; + a[0] = from[0*stride]; + a[1] = from[1*stride]; + a[2] = from[2*stride]; + a[3] = from[3*stride]; + a[4] = from[4*stride]; + a[5] = from[5*stride]; + a[6] = from[6*stride]; + a[7] = from[7*stride]; + return pload(a); +} + +template<> EIGEN_DEVICE_FUNC inline Packet8s pgather(const short int* from, Index stride) +{ + return pgather_size8(from, stride); +} + +template<> EIGEN_DEVICE_FUNC inline Packet8us pgather(const unsigned short int* from, Index stride) +{ + return pgather_size8(from, stride); +} + +template<> EIGEN_DEVICE_FUNC inline Packet8bf pgather(const bfloat16* from, Index stride) +{ + return pgather_size8(from, stride); +} + +template EIGEN_DEVICE_FUNC inline Packet pgather_size16(const __UNPACK_TYPE__(Packet)* from, Index stride) +{ + EIGEN_ALIGN16 __UNPACK_TYPE__(Packet) a[16]; + a[0] = from[0*stride]; + a[1] = from[1*stride]; + a[2] = from[2*stride]; + a[3] = from[3*stride]; + a[4] = from[4*stride]; + a[5] = from[5*stride]; + a[6] = from[6*stride]; + a[7] = from[7*stride]; + a[8] = from[8*stride]; + a[9] = from[9*stride]; + a[10] = from[10*stride]; + a[11] = from[11*stride]; + a[12] = from[12*stride]; + a[13] = from[13*stride]; + a[14] = from[14*stride]; + a[15] = from[15*stride]; + return pload(a); +} + + +template<> EIGEN_DEVICE_FUNC inline Packet16c pgather(const signed char* from, Index stride) +{ + return pgather_size16(from, stride); +} + +template<> EIGEN_DEVICE_FUNC inline Packet16uc pgather(const unsigned char* from, Index stride) +{ + return pgather_size16(from, stride); +} + +template EIGEN_DEVICE_FUNC inline void pscatter_size4(__UNPACK_TYPE__(Packet)* to, const Packet& from, Index stride) +{ + EIGEN_ALIGN16 __UNPACK_TYPE__(Packet) a[4]; + pstore<__UNPACK_TYPE__(Packet)>(a, from); + to[0*stride] = a[0]; + to[1*stride] = a[1]; + to[2*stride] = a[2]; + to[3*stride] = a[3]; +} + +template<> EIGEN_DEVICE_FUNC inline void pscatter(float* to, const Packet4f& from, Index stride) +{ + pscatter_size4(to, from, stride); +} + +template<> EIGEN_DEVICE_FUNC inline void pscatter(int* to, const Packet4i& from, Index stride) +{ + pscatter_size4(to, from, stride); +} + +template EIGEN_DEVICE_FUNC inline void pscatter_size8(__UNPACK_TYPE__(Packet)* to, const Packet& from, Index stride) +{ + EIGEN_ALIGN16 __UNPACK_TYPE__(Packet) a[8]; + pstore<__UNPACK_TYPE__(Packet)>(a, from); + to[0*stride] = a[0]; + to[1*stride] = a[1]; + to[2*stride] = a[2]; + to[3*stride] = a[3]; + to[4*stride] = a[4]; + to[5*stride] = a[5]; + to[6*stride] = a[6]; + to[7*stride] = a[7]; +} + + +template<> EIGEN_DEVICE_FUNC inline void pscatter(short int* to, const Packet8s& from, Index stride) +{ + pscatter_size8(to, from, stride); +} + +template<> EIGEN_DEVICE_FUNC inline void pscatter(unsigned short int* to, const Packet8us& from, Index stride) +{ + pscatter_size8(to, from, stride); +} + +template<> EIGEN_DEVICE_FUNC inline void pscatter(bfloat16* to, const Packet8bf& from, Index stride) +{ + pscatter_size8(to, from, stride); +} + +template EIGEN_DEVICE_FUNC inline void pscatter_size16(__UNPACK_TYPE__(Packet)* to, const Packet& from, Index stride) +{ + EIGEN_ALIGN16 __UNPACK_TYPE__(Packet) a[16]; + pstore<__UNPACK_TYPE__(Packet)>(a, from); + to[0*stride] = a[0]; + to[1*stride] = a[1]; + to[2*stride] = a[2]; + to[3*stride] = a[3]; + to[4*stride] = a[4]; + to[5*stride] = a[5]; + to[6*stride] = a[6]; + to[7*stride] = a[7]; + to[8*stride] = a[8]; + to[9*stride] = a[9]; + to[10*stride] = a[10]; + to[11*stride] = a[11]; + to[12*stride] = a[12]; + to[13*stride] = a[13]; + to[14*stride] = a[14]; + to[15*stride] = a[15]; +} + +template<> EIGEN_DEVICE_FUNC inline void pscatter(signed char* to, const Packet16c& from, Index stride) +{ + pscatter_size16(to, from, stride); +} + +template<> EIGEN_DEVICE_FUNC inline void pscatter(unsigned char* to, const Packet16uc& from, Index stride) +{ + pscatter_size16(to, from, stride); +} + +template<> EIGEN_STRONG_INLINE Packet4f plset(const float& a) { return pset1(a) + p4f_COUNTDOWN; } +template<> EIGEN_STRONG_INLINE Packet4i plset(const int& a) { return pset1(a) + p4i_COUNTDOWN; } +template<> EIGEN_STRONG_INLINE Packet8s plset(const short int& a) { return pset1(a) + p8s_COUNTDOWN; } +template<> EIGEN_STRONG_INLINE Packet8us plset(const unsigned short int& a) { return pset1(a) + p8us_COUNTDOWN; } +template<> EIGEN_STRONG_INLINE Packet16c plset(const signed char& a) { return pset1(a) + p16c_COUNTDOWN; } +template<> EIGEN_STRONG_INLINE Packet16uc plset(const unsigned char& a) { return pset1(a) + p16uc_COUNTDOWN; } + +template<> EIGEN_STRONG_INLINE Packet4f padd (const Packet4f& a, const Packet4f& b) { return a + b; } +template<> EIGEN_STRONG_INLINE Packet4i padd (const Packet4i& a, const Packet4i& b) { return a + b; } +template<> EIGEN_STRONG_INLINE Packet4ui padd (const Packet4ui& a, const Packet4ui& b) { return a + b; } +template<> EIGEN_STRONG_INLINE Packet8s padd (const Packet8s& a, const Packet8s& b) { return a + b; } +template<> EIGEN_STRONG_INLINE Packet8us padd (const Packet8us& a, const Packet8us& b) { return a + b; } +template<> EIGEN_STRONG_INLINE Packet16c padd (const Packet16c& a, const Packet16c& b) { return a + b; } +template<> EIGEN_STRONG_INLINE Packet16uc padd(const Packet16uc& a, const Packet16uc& b) { return a + b; } + +template<> EIGEN_STRONG_INLINE Packet4f psub (const Packet4f& a, const Packet4f& b) { return a - b; } +template<> EIGEN_STRONG_INLINE Packet4i psub (const Packet4i& a, const Packet4i& b) { return a - b; } +template<> EIGEN_STRONG_INLINE Packet8s psub (const Packet8s& a, const Packet8s& b) { return a - b; } +template<> EIGEN_STRONG_INLINE Packet8us psub (const Packet8us& a, const Packet8us& b) { return a - b; } +template<> EIGEN_STRONG_INLINE Packet16c psub (const Packet16c& a, const Packet16c& b) { return a - b; } +template<> EIGEN_STRONG_INLINE Packet16uc psub(const Packet16uc& a, const Packet16uc& b) { return a - b; } + +template<> EIGEN_STRONG_INLINE Packet4f pnegate(const Packet4f& a) { return p4f_ZERO - a; } +template<> EIGEN_STRONG_INLINE Packet4i pnegate(const Packet4i& a) { return p4i_ZERO - a; } + +template<> EIGEN_STRONG_INLINE Packet4f pconj(const Packet4f& a) { return a; } +template<> EIGEN_STRONG_INLINE Packet4i pconj(const Packet4i& a) { return a; } + +template<> EIGEN_STRONG_INLINE Packet4f pmul (const Packet4f& a, const Packet4f& b) { return vec_madd(a,b, p4f_MZERO); } +template<> EIGEN_STRONG_INLINE Packet4i pmul (const Packet4i& a, const Packet4i& b) { return a * b; } +template<> EIGEN_STRONG_INLINE Packet8s pmul (const Packet8s& a, const Packet8s& b) { return vec_mul(a,b); } +template<> EIGEN_STRONG_INLINE Packet8us pmul (const Packet8us& a, const Packet8us& b) { return vec_mul(a,b); } +template<> EIGEN_STRONG_INLINE Packet16c pmul (const Packet16c& a, const Packet16c& b) { return vec_mul(a,b); } +template<> EIGEN_STRONG_INLINE Packet16uc pmul(const Packet16uc& a, const Packet16uc& b) { return vec_mul(a,b); } + + +template<> EIGEN_STRONG_INLINE Packet4f pdiv(const Packet4f& a, const Packet4f& b) +{ +#ifndef __VSX__ // VSX actually provides a div instruction + Packet4f t, y_0, y_1; + + // Altivec does not offer a divide instruction, we have to do a reciprocal approximation + y_0 = vec_re(b); + + // Do one Newton-Raphson iteration to get the needed accuracy + t = vec_nmsub(y_0, b, p4f_ONE); + y_1 = vec_madd(y_0, t, y_0); + + return vec_madd(a, y_1, p4f_MZERO); +#else + return vec_div(a, b); +#endif +} + +template<> EIGEN_STRONG_INLINE Packet4i pdiv(const Packet4i& /*a*/, const Packet4i& /*b*/) +{ eigen_assert(false && "packet integer division are not supported by AltiVec"); + return pset1(0); +} + +// for some weird raisons, it has to be overloaded for packet of integers +template<> EIGEN_STRONG_INLINE Packet4f pmadd(const Packet4f& a, const Packet4f& b, const Packet4f& c) { return vec_madd(a,b,c); } +template<> EIGEN_STRONG_INLINE Packet4i pmadd(const Packet4i& a, const Packet4i& b, const Packet4i& c) { return a*b + c; } +template<> EIGEN_STRONG_INLINE Packet8s pmadd(const Packet8s& a, const Packet8s& b, const Packet8s& c) { return vec_madd(a,b,c); } +template<> EIGEN_STRONG_INLINE Packet8us pmadd(const Packet8us& a, const Packet8us& b, const Packet8us& c) { return vec_madd(a,b,c); } + +template<> EIGEN_STRONG_INLINE Packet4f pmin(const Packet4f& a, const Packet4f& b) +{ + #ifdef __VSX__ + // NOTE: about 10% slower than vec_min, but consistent with std::min and SSE regarding NaN + Packet4f ret; + __asm__ ("xvcmpgesp %x0,%x1,%x2\n\txxsel %x0,%x1,%x2,%x0" : "=&wa" (ret) : "wa" (a), "wa" (b)); + return ret; + #else + return vec_min(a, b); + #endif +} +template<> EIGEN_STRONG_INLINE Packet4i pmin(const Packet4i& a, const Packet4i& b) { return vec_min(a, b); } +template<> EIGEN_STRONG_INLINE Packet8s pmin(const Packet8s& a, const Packet8s& b) { return vec_min(a, b); } +template<> EIGEN_STRONG_INLINE Packet8us pmin(const Packet8us& a, const Packet8us& b) { return vec_min(a, b); } +template<> EIGEN_STRONG_INLINE Packet16c pmin(const Packet16c& a, const Packet16c& b) { return vec_min(a, b); } +template<> EIGEN_STRONG_INLINE Packet16uc pmin(const Packet16uc& a, const Packet16uc& b) { return vec_min(a, b); } + + +template<> EIGEN_STRONG_INLINE Packet4f pmax(const Packet4f& a, const Packet4f& b) +{ + #ifdef __VSX__ + // NOTE: about 10% slower than vec_max, but consistent with std::max and SSE regarding NaN + Packet4f ret; + __asm__ ("xvcmpgtsp %x0,%x2,%x1\n\txxsel %x0,%x1,%x2,%x0" : "=&wa" (ret) : "wa" (a), "wa" (b)); + return ret; + #else + return vec_max(a, b); + #endif +} +template<> EIGEN_STRONG_INLINE Packet4i pmax(const Packet4i& a, const Packet4i& b) { return vec_max(a, b); } +template<> EIGEN_STRONG_INLINE Packet8s pmax(const Packet8s& a, const Packet8s& b) { return vec_max(a, b); } +template<> EIGEN_STRONG_INLINE Packet8us pmax(const Packet8us& a, const Packet8us& b) { return vec_max(a, b); } +template<> EIGEN_STRONG_INLINE Packet16c pmax(const Packet16c& a, const Packet16c& b) { return vec_max(a, b); } +template<> EIGEN_STRONG_INLINE Packet16uc pmax(const Packet16uc& a, const Packet16uc& b) { return vec_max(a, b); } + +template<> EIGEN_STRONG_INLINE Packet4f pcmp_le(const Packet4f& a, const Packet4f& b) { return reinterpret_cast(vec_cmple(a,b)); } +template<> EIGEN_STRONG_INLINE Packet4f pcmp_lt(const Packet4f& a, const Packet4f& b) { return reinterpret_cast(vec_cmplt(a,b)); } +template<> EIGEN_STRONG_INLINE Packet4f pcmp_eq(const Packet4f& a, const Packet4f& b) { return reinterpret_cast(vec_cmpeq(a,b)); } +template<> EIGEN_STRONG_INLINE Packet4f pcmp_lt_or_nan(const Packet4f& a, const Packet4f& b) { + Packet4f c = reinterpret_cast(vec_cmpge(a,b)); + return vec_nor(c,c); +} + +template<> EIGEN_STRONG_INLINE Packet4i pcmp_le(const Packet4i& a, const Packet4i& b) { return reinterpret_cast(vec_cmple(a,b)); } +template<> EIGEN_STRONG_INLINE Packet4i pcmp_lt(const Packet4i& a, const Packet4i& b) { return reinterpret_cast(vec_cmplt(a,b)); } +template<> EIGEN_STRONG_INLINE Packet4i pcmp_eq(const Packet4i& a, const Packet4i& b) { return reinterpret_cast(vec_cmpeq(a,b)); } +template<> EIGEN_STRONG_INLINE Packet8s pcmp_le(const Packet8s& a, const Packet8s& b) { return reinterpret_cast(vec_cmple(a,b)); } +template<> EIGEN_STRONG_INLINE Packet8s pcmp_lt(const Packet8s& a, const Packet8s& b) { return reinterpret_cast(vec_cmplt(a,b)); } +template<> EIGEN_STRONG_INLINE Packet8s pcmp_eq(const Packet8s& a, const Packet8s& b) { return reinterpret_cast(vec_cmpeq(a,b)); } +template<> EIGEN_STRONG_INLINE Packet8us pcmp_le(const Packet8us& a, const Packet8us& b) { return reinterpret_cast(vec_cmple(a,b)); } +template<> EIGEN_STRONG_INLINE Packet8us pcmp_lt(const Packet8us& a, const Packet8us& b) { return reinterpret_cast(vec_cmplt(a,b)); } +template<> EIGEN_STRONG_INLINE Packet8us pcmp_eq(const Packet8us& a, const Packet8us& b) { return reinterpret_cast(vec_cmpeq(a,b)); } +template<> EIGEN_STRONG_INLINE Packet16c pcmp_le(const Packet16c& a, const Packet16c& b) { return reinterpret_cast(vec_cmple(a,b)); } +template<> EIGEN_STRONG_INLINE Packet16c pcmp_lt(const Packet16c& a, const Packet16c& b) { return reinterpret_cast(vec_cmplt(a,b)); } +template<> EIGEN_STRONG_INLINE Packet16c pcmp_eq(const Packet16c& a, const Packet16c& b) { return reinterpret_cast(vec_cmpeq(a,b)); } +template<> EIGEN_STRONG_INLINE Packet16uc pcmp_le(const Packet16uc& a, const Packet16uc& b) { return reinterpret_cast(vec_cmple(a,b)); } +template<> EIGEN_STRONG_INLINE Packet16uc pcmp_lt(const Packet16uc& a, const Packet16uc& b) { return reinterpret_cast(vec_cmplt(a,b)); } +template<> EIGEN_STRONG_INLINE Packet16uc pcmp_eq(const Packet16uc& a, const Packet16uc& b) { return reinterpret_cast(vec_cmpeq(a,b)); } + +template<> EIGEN_STRONG_INLINE Packet4f pand(const Packet4f& a, const Packet4f& b) { return vec_and(a, b); } +template<> EIGEN_STRONG_INLINE Packet4i pand(const Packet4i& a, const Packet4i& b) { return vec_and(a, b); } +template<> EIGEN_STRONG_INLINE Packet4ui pand(const Packet4ui& a, const Packet4ui& b) { return vec_and(a, b); } +template<> EIGEN_STRONG_INLINE Packet8us pand(const Packet8us& a, const Packet8us& b) { return vec_and(a, b); } +template<> EIGEN_STRONG_INLINE Packet8bf pand(const Packet8bf& a, const Packet8bf& b) { + return pand(a, b); +} + + +template<> EIGEN_STRONG_INLINE Packet4f por(const Packet4f& a, const Packet4f& b) { return vec_or(a, b); } +template<> EIGEN_STRONG_INLINE Packet4i por(const Packet4i& a, const Packet4i& b) { return vec_or(a, b); } +template<> EIGEN_STRONG_INLINE Packet8s por(const Packet8s& a, const Packet8s& b) { return vec_or(a, b); } +template<> EIGEN_STRONG_INLINE Packet8us por(const Packet8us& a, const Packet8us& b) { return vec_or(a, b); } +template<> EIGEN_STRONG_INLINE Packet8bf por(const Packet8bf& a, const Packet8bf& b) { + return por(a, b); +} + +template<> EIGEN_STRONG_INLINE Packet4f pxor(const Packet4f& a, const Packet4f& b) { return vec_xor(a, b); } +template<> EIGEN_STRONG_INLINE Packet4i pxor(const Packet4i& a, const Packet4i& b) { return vec_xor(a, b); } +template<> EIGEN_STRONG_INLINE Packet8bf pxor(const Packet8bf& a, const Packet8bf& b) { + return pxor(a, b); +} + +template<> EIGEN_STRONG_INLINE Packet4f pandnot(const Packet4f& a, const Packet4f& b) { return vec_andc(a, b); } +template<> EIGEN_STRONG_INLINE Packet4i pandnot(const Packet4i& a, const Packet4i& b) { return vec_andc(a, b); } + +template<> EIGEN_STRONG_INLINE Packet4f pselect(const Packet4f& mask, const Packet4f& a, const Packet4f& b) { + return vec_sel(b, a, reinterpret_cast(mask)); +} + +template<> EIGEN_STRONG_INLINE Packet4f pround(const Packet4f& a) +{ + Packet4f t = vec_add(reinterpret_cast(vec_or(vec_and(reinterpret_cast(a), p4ui_SIGN), p4ui_PREV0DOT5)), a); + Packet4f res; + +#ifdef __VSX__ + __asm__("xvrspiz %x0, %x1\n\t" + : "=&wa" (res) + : "wa" (t)); +#else + __asm__("vrfiz %0, %1\n\t" + : "=v" (res) + : "v" (t)); +#endif + + return res; +} +template<> EIGEN_STRONG_INLINE Packet4f pceil(const Packet4f& a) { return vec_ceil(a); } +template<> EIGEN_STRONG_INLINE Packet4f pfloor(const Packet4f& a) { return vec_floor(a); } +template<> EIGEN_STRONG_INLINE Packet4f print(const Packet4f& a) +{ + Packet4f res; + + __asm__("xvrspic %x0, %x1\n\t" + : "=&wa" (res) + : "wa" (a)); + + return res; +} + +template EIGEN_STRONG_INLINE Packet ploadu_common(const __UNPACK_TYPE__(Packet)* from) +{ + EIGEN_DEBUG_ALIGNED_LOAD +#ifdef _BIG_ENDIAN + Packet16uc MSQ, LSQ; + Packet16uc mask; + MSQ = vec_ld(0, (unsigned char *)from); // most significant quadword + LSQ = vec_ld(15, (unsigned char *)from); // least significant quadword + mask = vec_lvsl(0, from); // create the permute mask + //TODO: Add static_cast here + return (Packet) vec_perm(MSQ, LSQ, mask); // align the data +#else + EIGEN_DEBUG_UNALIGNED_LOAD + return vec_xl(0, const_cast<__UNPACK_TYPE__(Packet)*>(from)); +#endif +} + +template<> EIGEN_STRONG_INLINE Packet4f ploadu(const float* from) +{ + return ploadu_common(from); +} +template<> EIGEN_STRONG_INLINE Packet4i ploadu(const int* from) +{ + return ploadu_common(from); +} +template<> EIGEN_STRONG_INLINE Packet8s ploadu(const short int* from) +{ + return ploadu_common(from); +} +template<> EIGEN_STRONG_INLINE Packet8us ploadu(const unsigned short int* from) +{ + return ploadu_common(from); +} +template<> EIGEN_STRONG_INLINE Packet8bf ploadu(const bfloat16* from) +{ + return ploadu_common(reinterpret_cast(from)); +} +template<> EIGEN_STRONG_INLINE Packet16c ploadu(const signed char* from) +{ + return ploadu_common(from); +} +template<> EIGEN_STRONG_INLINE Packet16uc ploadu(const unsigned char* from) +{ + return ploadu_common(from); +} + +template EIGEN_STRONG_INLINE Packet ploaddup_common(const __UNPACK_TYPE__(Packet)* from) +{ + Packet p; + if((std::ptrdiff_t(from) % 16) == 0) p = pload(from); + else p = ploadu(from); + return vec_perm(p, p, p16uc_DUPLICATE32_HI); +} +template<> EIGEN_STRONG_INLINE Packet4f ploaddup(const float* from) +{ + return ploaddup_common(from); +} +template<> EIGEN_STRONG_INLINE Packet4i ploaddup(const int* from) +{ + return ploaddup_common(from); +} + +template<> EIGEN_STRONG_INLINE Packet8s ploaddup(const short int* from) +{ + Packet8s p; + if((std::ptrdiff_t(from) % 16) == 0) p = pload(from); + else p = ploadu(from); + return vec_perm(p, p, p16uc_DUPLICATE16_HI); +} + +template<> EIGEN_STRONG_INLINE Packet8us ploaddup(const unsigned short int* from) +{ + Packet8us p; + if((std::ptrdiff_t(from) % 16) == 0) p = pload(from); + else p = ploadu(from); + return vec_perm(p, p, p16uc_DUPLICATE16_HI); +} + +template<> EIGEN_STRONG_INLINE Packet8s ploadquad(const short int* from) +{ + Packet8s p; + if((std::ptrdiff_t(from) % 16) == 0) p = pload(from); + else p = ploadu(from); + return vec_perm(p, p, p16uc_QUADRUPLICATE16_HI); +} + +template<> EIGEN_STRONG_INLINE Packet8us ploadquad(const unsigned short int* from) +{ + Packet8us p; + if((std::ptrdiff_t(from) % 16) == 0) p = pload(from); + else p = ploadu(from); + return vec_perm(p, p, p16uc_QUADRUPLICATE16_HI); +} + +template<> EIGEN_STRONG_INLINE Packet8bf ploadquad(const bfloat16* from) +{ + return ploadquad(reinterpret_cast(from)); +} + +template<> EIGEN_STRONG_INLINE Packet16c ploaddup(const signed char* from) +{ + Packet16c p; + if((std::ptrdiff_t(from) % 16) == 0) p = pload(from); + else p = ploadu(from); + return vec_perm(p, p, p16uc_DUPLICATE8_HI); +} + +template<> EIGEN_STRONG_INLINE Packet16uc ploaddup(const unsigned char* from) +{ + Packet16uc p; + if((std::ptrdiff_t(from) % 16) == 0) p = pload(from); + else p = ploadu(from); + return vec_perm(p, p, p16uc_DUPLICATE8_HI); +} + +template EIGEN_STRONG_INLINE void pstoreu_common(__UNPACK_TYPE__(Packet)* to, const Packet& from) +{ + EIGEN_DEBUG_UNALIGNED_STORE +#ifdef _BIG_ENDIAN + // Taken from http://developer.apple.com/hardwaredrivers/ve/alignment.html + // Warning: not thread safe! + Packet16uc MSQ, LSQ, edges; + Packet16uc edgeAlign, align; + + MSQ = vec_ld(0, (unsigned char *)to); // most significant quadword + LSQ = vec_ld(15, (unsigned char *)to); // least significant quadword + edgeAlign = vec_lvsl(0, to); // permute map to extract edges + edges=vec_perm(LSQ,MSQ,edgeAlign); // extract the edges + align = vec_lvsr( 0, to ); // permute map to misalign data + MSQ = vec_perm(edges,(Packet16uc)from,align); // misalign the data (MSQ) + LSQ = vec_perm((Packet16uc)from,edges,align); // misalign the data (LSQ) + vec_st( LSQ, 15, (unsigned char *)to ); // Store the LSQ part first + vec_st( MSQ, 0, (unsigned char *)to ); // Store the MSQ part second +#else + vec_xst(from, 0, to); +#endif +} +template<> EIGEN_STRONG_INLINE void pstoreu(float* to, const Packet4f& from) +{ + pstoreu_common(to, from); +} +template<> EIGEN_STRONG_INLINE void pstoreu(int* to, const Packet4i& from) +{ + pstoreu_common(to, from); +} +template<> EIGEN_STRONG_INLINE void pstoreu(short int* to, const Packet8s& from) +{ + pstoreu_common(to, from); +} +template<> EIGEN_STRONG_INLINE void pstoreu(unsigned short int* to, const Packet8us& from) +{ + pstoreu_common(to, from); +} +template<> EIGEN_STRONG_INLINE void pstoreu(bfloat16* to, const Packet8bf& from) +{ + pstoreu_common(reinterpret_cast(to), from); +} +template<> EIGEN_STRONG_INLINE void pstoreu(signed char* to, const Packet16c& from) +{ + pstoreu_common(to, from); +} +template<> EIGEN_STRONG_INLINE void pstoreu(unsigned char* to, const Packet16uc& from) +{ + pstoreu_common(to, from); +} + +template<> EIGEN_STRONG_INLINE void prefetch(const float* addr) { EIGEN_PPC_PREFETCH(addr); } +template<> EIGEN_STRONG_INLINE void prefetch(const int* addr) { EIGEN_PPC_PREFETCH(addr); } + +template<> EIGEN_STRONG_INLINE float pfirst(const Packet4f& a) { EIGEN_ALIGN16 float x; vec_ste(a, 0, &x); return x; } +template<> EIGEN_STRONG_INLINE int pfirst(const Packet4i& a) { EIGEN_ALIGN16 int x; vec_ste(a, 0, &x); return x; } + +template EIGEN_STRONG_INLINE __UNPACK_TYPE__(Packet) pfirst_common(const Packet& a) { + EIGEN_ALIGN16 __UNPACK_TYPE__(Packet) x; + vec_ste(a, 0, &x); + return x; +} + +template<> EIGEN_STRONG_INLINE short int pfirst(const Packet8s& a) { + return pfirst_common(a); +} + +template<> EIGEN_STRONG_INLINE unsigned short int pfirst(const Packet8us& a) { + return pfirst_common(a); +} + +template<> EIGEN_STRONG_INLINE signed char pfirst(const Packet16c& a) +{ + return pfirst_common(a); +} + +template<> EIGEN_STRONG_INLINE unsigned char pfirst(const Packet16uc& a) +{ + return pfirst_common(a); +} + +template<> EIGEN_STRONG_INLINE Packet4f preverse(const Packet4f& a) +{ + return reinterpret_cast(vec_perm(reinterpret_cast(a), reinterpret_cast(a), p16uc_REVERSE32)); +} +template<> EIGEN_STRONG_INLINE Packet4i preverse(const Packet4i& a) +{ + return reinterpret_cast(vec_perm(reinterpret_cast(a), reinterpret_cast(a), p16uc_REVERSE32)); +} +template<> EIGEN_STRONG_INLINE Packet8s preverse(const Packet8s& a) +{ + return reinterpret_cast(vec_perm(reinterpret_cast(a), reinterpret_cast(a), p16uc_REVERSE16)); +} +template<> EIGEN_STRONG_INLINE Packet8us preverse(const Packet8us& a) +{ + return reinterpret_cast(vec_perm(reinterpret_cast(a), reinterpret_cast(a), p16uc_REVERSE16)); +} +template<> EIGEN_STRONG_INLINE Packet16c preverse(const Packet16c& a) +{ + return vec_perm(a, a, p16uc_REVERSE8); +} +template<> EIGEN_STRONG_INLINE Packet16uc preverse(const Packet16uc& a) +{ + return vec_perm(a, a, p16uc_REVERSE8); +} +template<> EIGEN_STRONG_INLINE Packet8bf preverse(const Packet8bf& a) +{ + return preverse(a); +} + +template<> EIGEN_STRONG_INLINE Packet4f pabs(const Packet4f& a) { return vec_abs(a); } +template<> EIGEN_STRONG_INLINE Packet4i pabs(const Packet4i& a) { return vec_abs(a); } +template<> EIGEN_STRONG_INLINE Packet8s pabs(const Packet8s& a) { return vec_abs(a); } +template<> EIGEN_STRONG_INLINE Packet8us pabs(const Packet8us& a) { return a; } +template<> EIGEN_STRONG_INLINE Packet16c pabs(const Packet16c& a) { return vec_abs(a); } +template<> EIGEN_STRONG_INLINE Packet16uc pabs(const Packet16uc& a) { return a; } +template<> EIGEN_STRONG_INLINE Packet8bf pabs(const Packet8bf& a) { + _EIGEN_DECLARE_CONST_FAST_Packet8us(abs_mask,0x7FFF); + return pand(p8us_abs_mask, a); +} + +template EIGEN_STRONG_INLINE Packet4i parithmetic_shift_right(const Packet4i& a) +{ return vec_sra(a,reinterpret_cast(pset1(N))); } +template EIGEN_STRONG_INLINE Packet4i plogical_shift_right(const Packet4i& a) +{ return vec_sr(a,reinterpret_cast(pset1(N))); } +template EIGEN_STRONG_INLINE Packet4i plogical_shift_left(const Packet4i& a) +{ return vec_sl(a,reinterpret_cast(pset1(N))); } +template EIGEN_STRONG_INLINE Packet4f plogical_shift_left(const Packet4f& a) +{ + const _EIGEN_DECLARE_CONST_FAST_Packet4ui(mask, N); + Packet4ui r = vec_sl(reinterpret_cast(a), p4ui_mask); + return reinterpret_cast(r); +} + +template EIGEN_STRONG_INLINE Packet4f plogical_shift_right(const Packet4f& a) +{ + const _EIGEN_DECLARE_CONST_FAST_Packet4ui(mask, N); + Packet4ui r = vec_sr(reinterpret_cast(a), p4ui_mask); + return reinterpret_cast(r); +} + +template EIGEN_STRONG_INLINE Packet4ui plogical_shift_right(const Packet4ui& a) +{ + const _EIGEN_DECLARE_CONST_FAST_Packet4ui(mask, N); + return vec_sr(a, p4ui_mask); +} + +template EIGEN_STRONG_INLINE Packet4ui plogical_shift_left(const Packet4ui& a) +{ + const _EIGEN_DECLARE_CONST_FAST_Packet4ui(mask, N); + return vec_sl(a, p4ui_mask); +} + +template EIGEN_STRONG_INLINE Packet8us plogical_shift_left(const Packet8us& a) +{ + const _EIGEN_DECLARE_CONST_FAST_Packet8us(mask, N); + return vec_sl(a, p8us_mask); +} +template EIGEN_STRONG_INLINE Packet8us plogical_shift_right(const Packet8us& a) +{ + const _EIGEN_DECLARE_CONST_FAST_Packet8us(mask, N); + return vec_sr(a, p8us_mask); +} + +EIGEN_STRONG_INLINE Packet4f Bf16ToF32Even(const Packet8bf& bf){ + return plogical_shift_left<16>(reinterpret_cast(bf.m_val)); +} + +EIGEN_STRONG_INLINE Packet4f Bf16ToF32Odd(const Packet8bf& bf){ + const _EIGEN_DECLARE_CONST_FAST_Packet4ui(high_mask, 0xFFFF0000); + return pand( + reinterpret_cast(bf.m_val), + reinterpret_cast(p4ui_high_mask) + ); +} + +// Simple interleaving of bool masks, prevents true values from being +// converted to NaNs. +EIGEN_STRONG_INLINE Packet8bf F32ToBf16Bool(Packet4f even, Packet4f odd) { + const _EIGEN_DECLARE_CONST_FAST_Packet4ui(high_mask, 0xFFFF0000); + Packet4f bf_odd, bf_even; + bf_odd = pand(reinterpret_cast(p4ui_high_mask), odd); + bf_even = plogical_shift_right<16>(even); + return reinterpret_cast(por(bf_even, bf_odd)); +} + +EIGEN_STRONG_INLINE Packet8bf F32ToBf16(Packet4f p4f){ + Packet4ui input = reinterpret_cast(p4f); + Packet4ui lsb = plogical_shift_right<16>(input); + lsb = pand(lsb, reinterpret_cast(p4i_ONE)); + + _EIGEN_DECLARE_CONST_FAST_Packet4ui(BIAS,0x7FFFu); + Packet4ui rounding_bias = padd(lsb, p4ui_BIAS); + input = padd(input, rounding_bias); + + //Test NaN and Subnormal - Begin + const _EIGEN_DECLARE_CONST_FAST_Packet4ui(exp_mask, 0x7F800000); + Packet4ui exp = pand(p4ui_exp_mask, reinterpret_cast(p4f)); + + const _EIGEN_DECLARE_CONST_FAST_Packet4ui(mantissa_mask, 0x7FFFFF); + Packet4ui mantissa = pand(p4ui_mantissa_mask, reinterpret_cast(p4f)); + + const _EIGEN_DECLARE_CONST_FAST_Packet4ui(max_exp, 0x7F800000); + Packet4bi is_max_exp = vec_cmpeq(exp, p4ui_max_exp); + Packet4bi is_zero_exp = vec_cmpeq(exp, reinterpret_cast(p4i_ZERO)); + + Packet4bi is_mant_zero = vec_cmpeq(mantissa, reinterpret_cast(p4i_ZERO)); + Packet4ui nan_selector = pandnot( + reinterpret_cast(is_max_exp), + reinterpret_cast(is_mant_zero) + ); + + Packet4ui subnormal_selector = pandnot( + reinterpret_cast(is_zero_exp), + reinterpret_cast(is_mant_zero) + ); + + const _EIGEN_DECLARE_CONST_FAST_Packet4ui(nan, 0x7FC00000); + input = vec_sel(input, p4ui_nan, nan_selector); + input = vec_sel(input, reinterpret_cast(p4f), subnormal_selector); + //Test NaN and Subnormal - End + + input = plogical_shift_right<16>(input); + return reinterpret_cast(input); +} + +EIGEN_STRONG_INLINE Packet8bf F32ToBf16(Packet4f even, Packet4f odd){ + Packet4f bf_odd, bf_even; + bf_odd = reinterpret_cast(F32ToBf16(odd).m_val); + bf_odd = plogical_shift_left<16>(bf_odd); + bf_even = reinterpret_cast(F32ToBf16(even).m_val); + return reinterpret_cast(por(bf_even, bf_odd)); +} +#define BF16_TO_F32_UNARY_OP_WRAPPER(OP, A) \ + Packet4f a_even = Bf16ToF32Even(A);\ + Packet4f a_odd = Bf16ToF32Odd(A);\ + Packet4f op_even = OP(a_even);\ + Packet4f op_odd = OP(a_odd);\ + return F32ToBf16(op_even, op_odd);\ + +#define BF16_TO_F32_BINARY_OP_WRAPPER(OP, A, B) \ + Packet4f a_even = Bf16ToF32Even(A);\ + Packet4f a_odd = Bf16ToF32Odd(A);\ + Packet4f b_even = Bf16ToF32Even(B);\ + Packet4f b_odd = Bf16ToF32Odd(B);\ + Packet4f op_even = OP(a_even, b_even);\ + Packet4f op_odd = OP(a_odd, b_odd);\ + return F32ToBf16(op_even, op_odd);\ + +#define BF16_TO_F32_BINARY_OP_WRAPPER_BOOL(OP, A, B) \ + Packet4f a_even = Bf16ToF32Even(A);\ + Packet4f a_odd = Bf16ToF32Odd(A);\ + Packet4f b_even = Bf16ToF32Even(B);\ + Packet4f b_odd = Bf16ToF32Odd(B);\ + Packet4f op_even = OP(a_even, b_even);\ + Packet4f op_odd = OP(a_odd, b_odd);\ + return F32ToBf16Bool(op_even, op_odd);\ + +template<> EIGEN_STRONG_INLINE Packet8bf padd(const Packet8bf& a, const Packet8bf& b) { + BF16_TO_F32_BINARY_OP_WRAPPER(padd, a, b); +} + +template<> EIGEN_STRONG_INLINE Packet8bf pmul(const Packet8bf& a, const Packet8bf& b) { + BF16_TO_F32_BINARY_OP_WRAPPER(pmul, a, b); +} + +template<> EIGEN_STRONG_INLINE Packet8bf pdiv(const Packet8bf& a, const Packet8bf& b) { + BF16_TO_F32_BINARY_OP_WRAPPER(pdiv, a, b); +} + +template<> EIGEN_STRONG_INLINE Packet8bf pnegate(const Packet8bf& a) { + BF16_TO_F32_UNARY_OP_WRAPPER(pnegate, a); +} + +template<> EIGEN_STRONG_INLINE Packet8bf psub(const Packet8bf& a, const Packet8bf& b) { + BF16_TO_F32_BINARY_OP_WRAPPER(psub, a, b); +} + +template<> EIGEN_STRONG_INLINE Packet8bf psqrt (const Packet8bf& a){ + BF16_TO_F32_UNARY_OP_WRAPPER(vec_sqrt, a); +} +template<> EIGEN_STRONG_INLINE Packet8bf prsqrt (const Packet8bf& a){ + BF16_TO_F32_UNARY_OP_WRAPPER(prsqrt, a); +} +template<> EIGEN_STRONG_INLINE Packet8bf pexp (const Packet8bf& a){ + BF16_TO_F32_UNARY_OP_WRAPPER(pexp_float, a); +} + +template<> EIGEN_STRONG_INLINE Packet4f pldexp(const Packet4f& a, const Packet4f& exponent) { + return pldexp_generic(a,exponent); +} +template<> EIGEN_STRONG_INLINE Packet8bf pldexp (const Packet8bf& a, const Packet8bf& exponent){ + BF16_TO_F32_BINARY_OP_WRAPPER(pldexp, a, exponent); +} + +template<> EIGEN_STRONG_INLINE Packet4f pfrexp(const Packet4f& a, Packet4f& exponent) { + return pfrexp_generic(a,exponent); +} +template<> EIGEN_STRONG_INLINE Packet8bf pfrexp (const Packet8bf& a, Packet8bf& e){ + Packet4f a_even = Bf16ToF32Even(a); + Packet4f a_odd = Bf16ToF32Odd(a); + Packet4f e_even; + Packet4f e_odd; + Packet4f op_even = pfrexp(a_even, e_even); + Packet4f op_odd = pfrexp(a_odd, e_odd); + e = F32ToBf16(e_even, e_odd); + return F32ToBf16(op_even, op_odd); +} + +template<> EIGEN_STRONG_INLINE Packet8bf psin (const Packet8bf& a){ + BF16_TO_F32_UNARY_OP_WRAPPER(psin_float, a); +} +template<> EIGEN_STRONG_INLINE Packet8bf pcos (const Packet8bf& a){ + BF16_TO_F32_UNARY_OP_WRAPPER(pcos_float, a); +} +template<> EIGEN_STRONG_INLINE Packet8bf plog (const Packet8bf& a){ + BF16_TO_F32_UNARY_OP_WRAPPER(plog_float, a); +} +template<> EIGEN_STRONG_INLINE Packet8bf pfloor (const Packet8bf& a){ + BF16_TO_F32_UNARY_OP_WRAPPER(pfloor, a); +} +template<> EIGEN_STRONG_INLINE Packet8bf pceil (const Packet8bf& a){ + BF16_TO_F32_UNARY_OP_WRAPPER(pceil, a); +} +template<> EIGEN_STRONG_INLINE Packet8bf pround (const Packet8bf& a){ + BF16_TO_F32_UNARY_OP_WRAPPER(pround, a); +} +template<> EIGEN_STRONG_INLINE Packet8bf print (const Packet8bf& a){ + BF16_TO_F32_UNARY_OP_WRAPPER(print, a); +} +template<> EIGEN_STRONG_INLINE Packet8bf pmadd(const Packet8bf& a, const Packet8bf& b, const Packet8bf& c) { + Packet4f a_even = Bf16ToF32Even(a); + Packet4f a_odd = Bf16ToF32Odd(a); + Packet4f b_even = Bf16ToF32Even(b); + Packet4f b_odd = Bf16ToF32Odd(b); + Packet4f c_even = Bf16ToF32Even(c); + Packet4f c_odd = Bf16ToF32Odd(c); + Packet4f pmadd_even = pmadd(a_even, b_even, c_even); + Packet4f pmadd_odd = pmadd(a_odd, b_odd, c_odd); + return F32ToBf16(pmadd_even, pmadd_odd); +} + +template<> EIGEN_STRONG_INLINE Packet8bf pmin(const Packet8bf& a, const Packet8bf& b) { + BF16_TO_F32_BINARY_OP_WRAPPER(pmin, a, b); +} + +template<> EIGEN_STRONG_INLINE Packet8bf pmax(const Packet8bf& a, const Packet8bf& b) { + BF16_TO_F32_BINARY_OP_WRAPPER(pmax, a, b); +} + +template<> EIGEN_STRONG_INLINE Packet8bf pcmp_lt(const Packet8bf& a, const Packet8bf& b) { + BF16_TO_F32_BINARY_OP_WRAPPER_BOOL(pcmp_lt, a, b); +} +template<> EIGEN_STRONG_INLINE Packet8bf pcmp_lt_or_nan(const Packet8bf& a, const Packet8bf& b) { + BF16_TO_F32_BINARY_OP_WRAPPER_BOOL(pcmp_lt_or_nan, a, b); +} +template<> EIGEN_STRONG_INLINE Packet8bf pcmp_le(const Packet8bf& a, const Packet8bf& b) { + BF16_TO_F32_BINARY_OP_WRAPPER_BOOL(pcmp_le, a, b); +} +template<> EIGEN_STRONG_INLINE Packet8bf pcmp_eq(const Packet8bf& a, const Packet8bf& b) { + BF16_TO_F32_BINARY_OP_WRAPPER_BOOL(pcmp_eq, a, b); +} + +template<> EIGEN_STRONG_INLINE bfloat16 pfirst(const Packet8bf& a) { + return Eigen::bfloat16_impl::raw_uint16_to_bfloat16((pfirst(a))); +} + +template<> EIGEN_STRONG_INLINE Packet8bf ploaddup(const bfloat16* from) +{ + return ploaddup(reinterpret_cast(from)); +} + +template<> EIGEN_STRONG_INLINE Packet8bf plset(const bfloat16& a) { + bfloat16 countdown[8] = { bfloat16(0), bfloat16(1), bfloat16(2), bfloat16(3), + bfloat16(4), bfloat16(5), bfloat16(6), bfloat16(7) }; + return padd(pset1(a), pload(countdown)); +} + +template<> EIGEN_STRONG_INLINE float predux(const Packet4f& a) +{ + Packet4f b, sum; + b = vec_sld(a, a, 8); + sum = a + b; + b = vec_sld(sum, sum, 4); + sum += b; + return pfirst(sum); +} + +template<> EIGEN_STRONG_INLINE int predux(const Packet4i& a) +{ + Packet4i sum; + sum = vec_sums(a, p4i_ZERO); +#ifdef _BIG_ENDIAN + sum = vec_sld(sum, p4i_ZERO, 12); +#else + sum = vec_sld(p4i_ZERO, sum, 4); +#endif + return pfirst(sum); +} + +template<> EIGEN_STRONG_INLINE bfloat16 predux(const Packet8bf& a) +{ + float redux_even = predux(Bf16ToF32Even(a)); + float redux_odd = predux(Bf16ToF32Odd(a)); + float f32_result = redux_even + redux_odd; + return bfloat16(f32_result); +} +template EIGEN_STRONG_INLINE __UNPACK_TYPE__(Packet) predux_size8(const Packet& a) +{ + union{ + Packet v; + __UNPACK_TYPE__(Packet) n[8]; + } vt; + vt.v = a; + + EIGEN_ALIGN16 int first_loader[4] = { vt.n[0], vt.n[1], vt.n[2], vt.n[3] }; + EIGEN_ALIGN16 int second_loader[4] = { vt.n[4], vt.n[5], vt.n[6], vt.n[7] }; + Packet4i first_half = pload(first_loader); + Packet4i second_half = pload(second_loader); + + return static_cast<__UNPACK_TYPE__(Packet)>(predux(first_half) + predux(second_half)); +} + +template<> EIGEN_STRONG_INLINE short int predux(const Packet8s& a) +{ + return predux_size8(a); +} + +template<> EIGEN_STRONG_INLINE unsigned short int predux(const Packet8us& a) +{ + return predux_size8(a); +} + +template EIGEN_STRONG_INLINE __UNPACK_TYPE__(Packet) predux_size16(const Packet& a) +{ + union{ + Packet v; + __UNPACK_TYPE__(Packet) n[16]; + } vt; + vt.v = a; + + EIGEN_ALIGN16 int first_loader[4] = { vt.n[0], vt.n[1], vt.n[2], vt.n[3] }; + EIGEN_ALIGN16 int second_loader[4] = { vt.n[4], vt.n[5], vt.n[6], vt.n[7] }; + EIGEN_ALIGN16 int third_loader[4] = { vt.n[8], vt.n[9], vt.n[10], vt.n[11] }; + EIGEN_ALIGN16 int fourth_loader[4] = { vt.n[12], vt.n[13], vt.n[14], vt.n[15] }; + + Packet4i first_quarter = pload(first_loader); + Packet4i second_quarter = pload(second_loader); + Packet4i third_quarter = pload(third_loader); + Packet4i fourth_quarter = pload(fourth_loader); + + return static_cast<__UNPACK_TYPE__(Packet)>(predux(first_quarter) + predux(second_quarter) + + predux(third_quarter) + predux(fourth_quarter)); +} + +template<> EIGEN_STRONG_INLINE signed char predux(const Packet16c& a) +{ + return predux_size16(a); +} + +template<> EIGEN_STRONG_INLINE unsigned char predux(const Packet16uc& a) +{ + return predux_size16(a); +} + +// Other reduction functions: +// mul +template<> EIGEN_STRONG_INLINE float predux_mul(const Packet4f& a) +{ + Packet4f prod; + prod = pmul(a, vec_sld(a, a, 8)); + return pfirst(pmul(prod, vec_sld(prod, prod, 4))); +} + +template<> EIGEN_STRONG_INLINE int predux_mul(const Packet4i& a) +{ + EIGEN_ALIGN16 int aux[4]; + pstore(aux, a); + return aux[0] * aux[1] * aux[2] * aux[3]; +} + +template<> EIGEN_STRONG_INLINE short int predux_mul(const Packet8s& a) +{ + Packet8s pair, quad, octo; + + pair = vec_mul(a, vec_sld(a, a, 8)); + quad = vec_mul(pair, vec_sld(pair, pair, 4)); + octo = vec_mul(quad, vec_sld(quad, quad, 2)); + + return pfirst(octo); +} + +template<> EIGEN_STRONG_INLINE unsigned short int predux_mul(const Packet8us& a) +{ + Packet8us pair, quad, octo; + + pair = vec_mul(a, vec_sld(a, a, 8)); + quad = vec_mul(pair, vec_sld(pair, pair, 4)); + octo = vec_mul(quad, vec_sld(quad, quad, 2)); + + return pfirst(octo); +} + +template<> EIGEN_STRONG_INLINE bfloat16 predux_mul(const Packet8bf& a) +{ + float redux_even = predux_mul(Bf16ToF32Even(a)); + float redux_odd = predux_mul(Bf16ToF32Odd(a)); + float f32_result = redux_even * redux_odd; + return bfloat16(f32_result); +} + + +template<> EIGEN_STRONG_INLINE signed char predux_mul(const Packet16c& a) +{ + Packet16c pair, quad, octo, result; + + pair = vec_mul(a, vec_sld(a, a, 8)); + quad = vec_mul(pair, vec_sld(pair, pair, 4)); + octo = vec_mul(quad, vec_sld(quad, quad, 2)); + result = vec_mul(octo, vec_sld(octo, octo, 1)); + + return pfirst(result); +} + +template<> EIGEN_STRONG_INLINE unsigned char predux_mul(const Packet16uc& a) +{ + Packet16uc pair, quad, octo, result; + + pair = vec_mul(a, vec_sld(a, a, 8)); + quad = vec_mul(pair, vec_sld(pair, pair, 4)); + octo = vec_mul(quad, vec_sld(quad, quad, 2)); + result = vec_mul(octo, vec_sld(octo, octo, 1)); + + return pfirst(result); +} + +// min +template EIGEN_STRONG_INLINE +__UNPACK_TYPE__(Packet) predux_min4(const Packet& a) +{ + Packet b, res; + b = vec_min(a, vec_sld(a, a, 8)); + res = vec_min(b, vec_sld(b, b, 4)); + return pfirst(res); +} + + +template<> EIGEN_STRONG_INLINE float predux_min(const Packet4f& a) +{ + return predux_min4(a); +} + +template<> EIGEN_STRONG_INLINE int predux_min(const Packet4i& a) +{ + return predux_min4(a); +} + +template<> EIGEN_STRONG_INLINE bfloat16 predux_min(const Packet8bf& a) +{ + float redux_even = predux_min(Bf16ToF32Even(a)); + float redux_odd = predux_min(Bf16ToF32Odd(a)); + float f32_result = (std::min)(redux_even, redux_odd); + return bfloat16(f32_result); +} + +template<> EIGEN_STRONG_INLINE short int predux_min(const Packet8s& a) +{ + Packet8s pair, quad, octo; + + //pair = { Min(a0,a4), Min(a1,a5), Min(a2,a6), Min(a3,a7) } + pair = vec_min(a, vec_sld(a, a, 8)); + + //quad = { Min(a0, a4, a2, a6), Min(a1, a5, a3, a7) } + quad = vec_min(pair, vec_sld(pair, pair, 4)); + + //octo = { Min(a0, a4, a2, a6, a1, a5, a3, a7) } + octo = vec_min(quad, vec_sld(quad, quad, 2)); + return pfirst(octo); +} + +template<> EIGEN_STRONG_INLINE unsigned short int predux_min(const Packet8us& a) +{ + Packet8us pair, quad, octo; + + //pair = { Min(a0,a4), Min(a1,a5), Min(a2,a6), Min(a3,a7) } + pair = vec_min(a, vec_sld(a, a, 8)); + + //quad = { Min(a0, a4, a2, a6), Min(a1, a5, a3, a7) } + quad = vec_min(pair, vec_sld(pair, pair, 4)); + + //octo = { Min(a0, a4, a2, a6, a1, a5, a3, a7) } + octo = vec_min(quad, vec_sld(quad, quad, 2)); + return pfirst(octo); +} + +template<> EIGEN_STRONG_INLINE signed char predux_min(const Packet16c& a) +{ + Packet16c pair, quad, octo, result; + + pair = vec_min(a, vec_sld(a, a, 8)); + quad = vec_min(pair, vec_sld(pair, pair, 4)); + octo = vec_min(quad, vec_sld(quad, quad, 2)); + result = vec_min(octo, vec_sld(octo, octo, 1)); + + return pfirst(result); +} + +template<> EIGEN_STRONG_INLINE unsigned char predux_min(const Packet16uc& a) +{ + Packet16uc pair, quad, octo, result; + + pair = vec_min(a, vec_sld(a, a, 8)); + quad = vec_min(pair, vec_sld(pair, pair, 4)); + octo = vec_min(quad, vec_sld(quad, quad, 2)); + result = vec_min(octo, vec_sld(octo, octo, 1)); + + return pfirst(result); +} +// max +template EIGEN_STRONG_INLINE __UNPACK_TYPE__(Packet) predux_max4(const Packet& a) +{ + Packet b, res; + b = vec_max(a, vec_sld(a, a, 8)); + res = vec_max(b, vec_sld(b, b, 4)); + return pfirst(res); +} + +template<> EIGEN_STRONG_INLINE float predux_max(const Packet4f& a) +{ + return predux_max4(a); +} + +template<> EIGEN_STRONG_INLINE int predux_max(const Packet4i& a) +{ + return predux_max4(a); +} + +template<> EIGEN_STRONG_INLINE bfloat16 predux_max(const Packet8bf& a) +{ + float redux_even = predux_max(Bf16ToF32Even(a)); + float redux_odd = predux_max(Bf16ToF32Odd(a)); + float f32_result = (std::max)(redux_even, redux_odd); + return bfloat16(f32_result); +} + +template<> EIGEN_STRONG_INLINE short int predux_max(const Packet8s& a) +{ + Packet8s pair, quad, octo; + + //pair = { Max(a0,a4), Max(a1,a5), Max(a2,a6), Max(a3,a7) } + pair = vec_max(a, vec_sld(a, a, 8)); + + //quad = { Max(a0, a4, a2, a6), Max(a1, a5, a3, a7) } + quad = vec_max(pair, vec_sld(pair, pair, 4)); + + //octo = { Max(a0, a4, a2, a6, a1, a5, a3, a7) } + octo = vec_max(quad, vec_sld(quad, quad, 2)); + return pfirst(octo); +} + +template<> EIGEN_STRONG_INLINE unsigned short int predux_max(const Packet8us& a) +{ + Packet8us pair, quad, octo; + + //pair = { Max(a0,a4), Max(a1,a5), Max(a2,a6), Max(a3,a7) } + pair = vec_max(a, vec_sld(a, a, 8)); + + //quad = { Max(a0, a4, a2, a6), Max(a1, a5, a3, a7) } + quad = vec_max(pair, vec_sld(pair, pair, 4)); + + //octo = { Max(a0, a4, a2, a6, a1, a5, a3, a7) } + octo = vec_max(quad, vec_sld(quad, quad, 2)); + return pfirst(octo); +} + +template<> EIGEN_STRONG_INLINE signed char predux_max(const Packet16c& a) +{ + Packet16c pair, quad, octo, result; + + pair = vec_max(a, vec_sld(a, a, 8)); + quad = vec_max(pair, vec_sld(pair, pair, 4)); + octo = vec_max(quad, vec_sld(quad, quad, 2)); + result = vec_max(octo, vec_sld(octo, octo, 1)); + + return pfirst(result); +} + +template<> EIGEN_STRONG_INLINE unsigned char predux_max(const Packet16uc& a) +{ + Packet16uc pair, quad, octo, result; + + pair = vec_max(a, vec_sld(a, a, 8)); + quad = vec_max(pair, vec_sld(pair, pair, 4)); + octo = vec_max(quad, vec_sld(quad, quad, 2)); + result = vec_max(octo, vec_sld(octo, octo, 1)); + + return pfirst(result); +} + +template<> EIGEN_STRONG_INLINE bool predux_any(const Packet4f& x) +{ + return vec_any_ne(x, pzero(x)); +} + +template EIGEN_DEVICE_FUNC inline void +ptranpose_common(PacketBlock& kernel){ + T t0, t1, t2, t3; + t0 = vec_mergeh(kernel.packet[0], kernel.packet[2]); + t1 = vec_mergel(kernel.packet[0], kernel.packet[2]); + t2 = vec_mergeh(kernel.packet[1], kernel.packet[3]); + t3 = vec_mergel(kernel.packet[1], kernel.packet[3]); + kernel.packet[0] = vec_mergeh(t0, t2); + kernel.packet[1] = vec_mergel(t0, t2); + kernel.packet[2] = vec_mergeh(t1, t3); + kernel.packet[3] = vec_mergel(t1, t3); +} + +EIGEN_DEVICE_FUNC inline void +ptranspose(PacketBlock& kernel) { + ptranpose_common(kernel); +} + +EIGEN_DEVICE_FUNC inline void +ptranspose(PacketBlock& kernel) { + ptranpose_common(kernel); +} + +EIGEN_DEVICE_FUNC inline void +ptranspose(PacketBlock& kernel) { + Packet8s t0, t1, t2, t3; + t0 = vec_mergeh(kernel.packet[0], kernel.packet[2]); + t1 = vec_mergel(kernel.packet[0], kernel.packet[2]); + t2 = vec_mergeh(kernel.packet[1], kernel.packet[3]); + t3 = vec_mergel(kernel.packet[1], kernel.packet[3]); + kernel.packet[0] = vec_mergeh(t0, t2); + kernel.packet[1] = vec_mergel(t0, t2); + kernel.packet[2] = vec_mergeh(t1, t3); + kernel.packet[3] = vec_mergel(t1, t3); +} + +EIGEN_DEVICE_FUNC inline void +ptranspose(PacketBlock& kernel) { + Packet8us t0, t1, t2, t3; + t0 = vec_mergeh(kernel.packet[0], kernel.packet[2]); + t1 = vec_mergel(kernel.packet[0], kernel.packet[2]); + t2 = vec_mergeh(kernel.packet[1], kernel.packet[3]); + t3 = vec_mergel(kernel.packet[1], kernel.packet[3]); + kernel.packet[0] = vec_mergeh(t0, t2); + kernel.packet[1] = vec_mergel(t0, t2); + kernel.packet[2] = vec_mergeh(t1, t3); + kernel.packet[3] = vec_mergel(t1, t3); +} + + +EIGEN_DEVICE_FUNC inline void +ptranspose(PacketBlock& kernel) { + Packet8us t0, t1, t2, t3; + + t0 = vec_mergeh(kernel.packet[0].m_val, kernel.packet[2].m_val); + t1 = vec_mergel(kernel.packet[0].m_val, kernel.packet[2].m_val); + t2 = vec_mergeh(kernel.packet[1].m_val, kernel.packet[3].m_val); + t3 = vec_mergel(kernel.packet[1].m_val, kernel.packet[3].m_val); + kernel.packet[0] = vec_mergeh(t0, t2); + kernel.packet[1] = vec_mergel(t0, t2); + kernel.packet[2] = vec_mergeh(t1, t3); + kernel.packet[3] = vec_mergel(t1, t3); +} + +EIGEN_DEVICE_FUNC inline void +ptranspose(PacketBlock& kernel) { + Packet16c t0, t1, t2, t3; + t0 = vec_mergeh(kernel.packet[0], kernel.packet[2]); + t1 = vec_mergel(kernel.packet[0], kernel.packet[2]); + t2 = vec_mergeh(kernel.packet[1], kernel.packet[3]); + t3 = vec_mergel(kernel.packet[1], kernel.packet[3]); + kernel.packet[0] = vec_mergeh(t0, t2); + kernel.packet[1] = vec_mergel(t0, t2); + kernel.packet[2] = vec_mergeh(t1, t3); + kernel.packet[3] = vec_mergel(t1, t3); +} + + +EIGEN_DEVICE_FUNC inline void +ptranspose(PacketBlock& kernel) { + Packet16uc t0, t1, t2, t3; + t0 = vec_mergeh(kernel.packet[0], kernel.packet[2]); + t1 = vec_mergel(kernel.packet[0], kernel.packet[2]); + t2 = vec_mergeh(kernel.packet[1], kernel.packet[3]); + t3 = vec_mergel(kernel.packet[1], kernel.packet[3]); + kernel.packet[0] = vec_mergeh(t0, t2); + kernel.packet[1] = vec_mergel(t0, t2); + kernel.packet[2] = vec_mergeh(t1, t3); + kernel.packet[3] = vec_mergel(t1, t3); +} + +EIGEN_DEVICE_FUNC inline void +ptranspose(PacketBlock& kernel) { + Packet8s v[8], sum[8]; + + v[0] = vec_mergeh(kernel.packet[0], kernel.packet[4]); + v[1] = vec_mergel(kernel.packet[0], kernel.packet[4]); + v[2] = vec_mergeh(kernel.packet[1], kernel.packet[5]); + v[3] = vec_mergel(kernel.packet[1], kernel.packet[5]); + v[4] = vec_mergeh(kernel.packet[2], kernel.packet[6]); + v[5] = vec_mergel(kernel.packet[2], kernel.packet[6]); + v[6] = vec_mergeh(kernel.packet[3], kernel.packet[7]); + v[7] = vec_mergel(kernel.packet[3], kernel.packet[7]); + sum[0] = vec_mergeh(v[0], v[4]); + sum[1] = vec_mergel(v[0], v[4]); + sum[2] = vec_mergeh(v[1], v[5]); + sum[3] = vec_mergel(v[1], v[5]); + sum[4] = vec_mergeh(v[2], v[6]); + sum[5] = vec_mergel(v[2], v[6]); + sum[6] = vec_mergeh(v[3], v[7]); + sum[7] = vec_mergel(v[3], v[7]); + + kernel.packet[0] = vec_mergeh(sum[0], sum[4]); + kernel.packet[1] = vec_mergel(sum[0], sum[4]); + kernel.packet[2] = vec_mergeh(sum[1], sum[5]); + kernel.packet[3] = vec_mergel(sum[1], sum[5]); + kernel.packet[4] = vec_mergeh(sum[2], sum[6]); + kernel.packet[5] = vec_mergel(sum[2], sum[6]); + kernel.packet[6] = vec_mergeh(sum[3], sum[7]); + kernel.packet[7] = vec_mergel(sum[3], sum[7]); +} + +EIGEN_DEVICE_FUNC inline void +ptranspose(PacketBlock& kernel) { + Packet8us v[8], sum[8]; + + v[0] = vec_mergeh(kernel.packet[0], kernel.packet[4]); + v[1] = vec_mergel(kernel.packet[0], kernel.packet[4]); + v[2] = vec_mergeh(kernel.packet[1], kernel.packet[5]); + v[3] = vec_mergel(kernel.packet[1], kernel.packet[5]); + v[4] = vec_mergeh(kernel.packet[2], kernel.packet[6]); + v[5] = vec_mergel(kernel.packet[2], kernel.packet[6]); + v[6] = vec_mergeh(kernel.packet[3], kernel.packet[7]); + v[7] = vec_mergel(kernel.packet[3], kernel.packet[7]); + sum[0] = vec_mergeh(v[0], v[4]); + sum[1] = vec_mergel(v[0], v[4]); + sum[2] = vec_mergeh(v[1], v[5]); + sum[3] = vec_mergel(v[1], v[5]); + sum[4] = vec_mergeh(v[2], v[6]); + sum[5] = vec_mergel(v[2], v[6]); + sum[6] = vec_mergeh(v[3], v[7]); + sum[7] = vec_mergel(v[3], v[7]); + + kernel.packet[0] = vec_mergeh(sum[0], sum[4]); + kernel.packet[1] = vec_mergel(sum[0], sum[4]); + kernel.packet[2] = vec_mergeh(sum[1], sum[5]); + kernel.packet[3] = vec_mergel(sum[1], sum[5]); + kernel.packet[4] = vec_mergeh(sum[2], sum[6]); + kernel.packet[5] = vec_mergel(sum[2], sum[6]); + kernel.packet[6] = vec_mergeh(sum[3], sum[7]); + kernel.packet[7] = vec_mergel(sum[3], sum[7]); +} + +EIGEN_DEVICE_FUNC inline void +ptranspose(PacketBlock& kernel) { + Packet8bf v[8], sum[8]; + + v[0] = vec_mergeh(kernel.packet[0].m_val, kernel.packet[4].m_val); + v[1] = vec_mergel(kernel.packet[0].m_val, kernel.packet[4].m_val); + v[2] = vec_mergeh(kernel.packet[1].m_val, kernel.packet[5].m_val); + v[3] = vec_mergel(kernel.packet[1].m_val, kernel.packet[5].m_val); + v[4] = vec_mergeh(kernel.packet[2].m_val, kernel.packet[6].m_val); + v[5] = vec_mergel(kernel.packet[2].m_val, kernel.packet[6].m_val); + v[6] = vec_mergeh(kernel.packet[3].m_val, kernel.packet[7].m_val); + v[7] = vec_mergel(kernel.packet[3].m_val, kernel.packet[7].m_val); + sum[0] = vec_mergeh(v[0].m_val, v[4].m_val); + sum[1] = vec_mergel(v[0].m_val, v[4].m_val); + sum[2] = vec_mergeh(v[1].m_val, v[5].m_val); + sum[3] = vec_mergel(v[1].m_val, v[5].m_val); + sum[4] = vec_mergeh(v[2].m_val, v[6].m_val); + sum[5] = vec_mergel(v[2].m_val, v[6].m_val); + sum[6] = vec_mergeh(v[3].m_val, v[7].m_val); + sum[7] = vec_mergel(v[3].m_val, v[7].m_val); + + kernel.packet[0] = vec_mergeh(sum[0].m_val, sum[4].m_val); + kernel.packet[1] = vec_mergel(sum[0].m_val, sum[4].m_val); + kernel.packet[2] = vec_mergeh(sum[1].m_val, sum[5].m_val); + kernel.packet[3] = vec_mergel(sum[1].m_val, sum[5].m_val); + kernel.packet[4] = vec_mergeh(sum[2].m_val, sum[6].m_val); + kernel.packet[5] = vec_mergel(sum[2].m_val, sum[6].m_val); + kernel.packet[6] = vec_mergeh(sum[3].m_val, sum[7].m_val); + kernel.packet[7] = vec_mergel(sum[3].m_val, sum[7].m_val); +} + +EIGEN_DEVICE_FUNC inline void +ptranspose(PacketBlock& kernel) { + Packet16c step1[16], step2[16], step3[16]; + + step1[0] = vec_mergeh(kernel.packet[0], kernel.packet[8]); + step1[1] = vec_mergel(kernel.packet[0], kernel.packet[8]); + step1[2] = vec_mergeh(kernel.packet[1], kernel.packet[9]); + step1[3] = vec_mergel(kernel.packet[1], kernel.packet[9]); + step1[4] = vec_mergeh(kernel.packet[2], kernel.packet[10]); + step1[5] = vec_mergel(kernel.packet[2], kernel.packet[10]); + step1[6] = vec_mergeh(kernel.packet[3], kernel.packet[11]); + step1[7] = vec_mergel(kernel.packet[3], kernel.packet[11]); + step1[8] = vec_mergeh(kernel.packet[4], kernel.packet[12]); + step1[9] = vec_mergel(kernel.packet[4], kernel.packet[12]); + step1[10] = vec_mergeh(kernel.packet[5], kernel.packet[13]); + step1[11] = vec_mergel(kernel.packet[5], kernel.packet[13]); + step1[12] = vec_mergeh(kernel.packet[6], kernel.packet[14]); + step1[13] = vec_mergel(kernel.packet[6], kernel.packet[14]); + step1[14] = vec_mergeh(kernel.packet[7], kernel.packet[15]); + step1[15] = vec_mergel(kernel.packet[7], kernel.packet[15]); + + step2[0] = vec_mergeh(step1[0], step1[8]); + step2[1] = vec_mergel(step1[0], step1[8]); + step2[2] = vec_mergeh(step1[1], step1[9]); + step2[3] = vec_mergel(step1[1], step1[9]); + step2[4] = vec_mergeh(step1[2], step1[10]); + step2[5] = vec_mergel(step1[2], step1[10]); + step2[6] = vec_mergeh(step1[3], step1[11]); + step2[7] = vec_mergel(step1[3], step1[11]); + step2[8] = vec_mergeh(step1[4], step1[12]); + step2[9] = vec_mergel(step1[4], step1[12]); + step2[10] = vec_mergeh(step1[5], step1[13]); + step2[11] = vec_mergel(step1[5], step1[13]); + step2[12] = vec_mergeh(step1[6], step1[14]); + step2[13] = vec_mergel(step1[6], step1[14]); + step2[14] = vec_mergeh(step1[7], step1[15]); + step2[15] = vec_mergel(step1[7], step1[15]); + + step3[0] = vec_mergeh(step2[0], step2[8]); + step3[1] = vec_mergel(step2[0], step2[8]); + step3[2] = vec_mergeh(step2[1], step2[9]); + step3[3] = vec_mergel(step2[1], step2[9]); + step3[4] = vec_mergeh(step2[2], step2[10]); + step3[5] = vec_mergel(step2[2], step2[10]); + step3[6] = vec_mergeh(step2[3], step2[11]); + step3[7] = vec_mergel(step2[3], step2[11]); + step3[8] = vec_mergeh(step2[4], step2[12]); + step3[9] = vec_mergel(step2[4], step2[12]); + step3[10] = vec_mergeh(step2[5], step2[13]); + step3[11] = vec_mergel(step2[5], step2[13]); + step3[12] = vec_mergeh(step2[6], step2[14]); + step3[13] = vec_mergel(step2[6], step2[14]); + step3[14] = vec_mergeh(step2[7], step2[15]); + step3[15] = vec_mergel(step2[7], step2[15]); + + kernel.packet[0] = vec_mergeh(step3[0], step3[8]); + kernel.packet[1] = vec_mergel(step3[0], step3[8]); + kernel.packet[2] = vec_mergeh(step3[1], step3[9]); + kernel.packet[3] = vec_mergel(step3[1], step3[9]); + kernel.packet[4] = vec_mergeh(step3[2], step3[10]); + kernel.packet[5] = vec_mergel(step3[2], step3[10]); + kernel.packet[6] = vec_mergeh(step3[3], step3[11]); + kernel.packet[7] = vec_mergel(step3[3], step3[11]); + kernel.packet[8] = vec_mergeh(step3[4], step3[12]); + kernel.packet[9] = vec_mergel(step3[4], step3[12]); + kernel.packet[10] = vec_mergeh(step3[5], step3[13]); + kernel.packet[11] = vec_mergel(step3[5], step3[13]); + kernel.packet[12] = vec_mergeh(step3[6], step3[14]); + kernel.packet[13] = vec_mergel(step3[6], step3[14]); + kernel.packet[14] = vec_mergeh(step3[7], step3[15]); + kernel.packet[15] = vec_mergel(step3[7], step3[15]); +} + +EIGEN_DEVICE_FUNC inline void +ptranspose(PacketBlock& kernel) { + Packet16uc step1[16], step2[16], step3[16]; + + step1[0] = vec_mergeh(kernel.packet[0], kernel.packet[8]); + step1[1] = vec_mergel(kernel.packet[0], kernel.packet[8]); + step1[2] = vec_mergeh(kernel.packet[1], kernel.packet[9]); + step1[3] = vec_mergel(kernel.packet[1], kernel.packet[9]); + step1[4] = vec_mergeh(kernel.packet[2], kernel.packet[10]); + step1[5] = vec_mergel(kernel.packet[2], kernel.packet[10]); + step1[6] = vec_mergeh(kernel.packet[3], kernel.packet[11]); + step1[7] = vec_mergel(kernel.packet[3], kernel.packet[11]); + step1[8] = vec_mergeh(kernel.packet[4], kernel.packet[12]); + step1[9] = vec_mergel(kernel.packet[4], kernel.packet[12]); + step1[10] = vec_mergeh(kernel.packet[5], kernel.packet[13]); + step1[11] = vec_mergel(kernel.packet[5], kernel.packet[13]); + step1[12] = vec_mergeh(kernel.packet[6], kernel.packet[14]); + step1[13] = vec_mergel(kernel.packet[6], kernel.packet[14]); + step1[14] = vec_mergeh(kernel.packet[7], kernel.packet[15]); + step1[15] = vec_mergel(kernel.packet[7], kernel.packet[15]); + + step2[0] = vec_mergeh(step1[0], step1[8]); + step2[1] = vec_mergel(step1[0], step1[8]); + step2[2] = vec_mergeh(step1[1], step1[9]); + step2[3] = vec_mergel(step1[1], step1[9]); + step2[4] = vec_mergeh(step1[2], step1[10]); + step2[5] = vec_mergel(step1[2], step1[10]); + step2[6] = vec_mergeh(step1[3], step1[11]); + step2[7] = vec_mergel(step1[3], step1[11]); + step2[8] = vec_mergeh(step1[4], step1[12]); + step2[9] = vec_mergel(step1[4], step1[12]); + step2[10] = vec_mergeh(step1[5], step1[13]); + step2[11] = vec_mergel(step1[5], step1[13]); + step2[12] = vec_mergeh(step1[6], step1[14]); + step2[13] = vec_mergel(step1[6], step1[14]); + step2[14] = vec_mergeh(step1[7], step1[15]); + step2[15] = vec_mergel(step1[7], step1[15]); + + step3[0] = vec_mergeh(step2[0], step2[8]); + step3[1] = vec_mergel(step2[0], step2[8]); + step3[2] = vec_mergeh(step2[1], step2[9]); + step3[3] = vec_mergel(step2[1], step2[9]); + step3[4] = vec_mergeh(step2[2], step2[10]); + step3[5] = vec_mergel(step2[2], step2[10]); + step3[6] = vec_mergeh(step2[3], step2[11]); + step3[7] = vec_mergel(step2[3], step2[11]); + step3[8] = vec_mergeh(step2[4], step2[12]); + step3[9] = vec_mergel(step2[4], step2[12]); + step3[10] = vec_mergeh(step2[5], step2[13]); + step3[11] = vec_mergel(step2[5], step2[13]); + step3[12] = vec_mergeh(step2[6], step2[14]); + step3[13] = vec_mergel(step2[6], step2[14]); + step3[14] = vec_mergeh(step2[7], step2[15]); + step3[15] = vec_mergel(step2[7], step2[15]); + + kernel.packet[0] = vec_mergeh(step3[0], step3[8]); + kernel.packet[1] = vec_mergel(step3[0], step3[8]); + kernel.packet[2] = vec_mergeh(step3[1], step3[9]); + kernel.packet[3] = vec_mergel(step3[1], step3[9]); + kernel.packet[4] = vec_mergeh(step3[2], step3[10]); + kernel.packet[5] = vec_mergel(step3[2], step3[10]); + kernel.packet[6] = vec_mergeh(step3[3], step3[11]); + kernel.packet[7] = vec_mergel(step3[3], step3[11]); + kernel.packet[8] = vec_mergeh(step3[4], step3[12]); + kernel.packet[9] = vec_mergel(step3[4], step3[12]); + kernel.packet[10] = vec_mergeh(step3[5], step3[13]); + kernel.packet[11] = vec_mergel(step3[5], step3[13]); + kernel.packet[12] = vec_mergeh(step3[6], step3[14]); + kernel.packet[13] = vec_mergel(step3[6], step3[14]); + kernel.packet[14] = vec_mergeh(step3[7], step3[15]); + kernel.packet[15] = vec_mergel(step3[7], step3[15]); +} + +template EIGEN_STRONG_INLINE +Packet pblend4(const Selector<4>& ifPacket, const Packet& thenPacket, const Packet& elsePacket) { + Packet4ui select = { ifPacket.select[0], ifPacket.select[1], ifPacket.select[2], ifPacket.select[3] }; + Packet4ui mask = reinterpret_cast(vec_cmpeq(reinterpret_cast(select), reinterpret_cast(p4i_ONE))); + return vec_sel(elsePacket, thenPacket, mask); +} + +template<> EIGEN_STRONG_INLINE Packet4i pblend(const Selector<4>& ifPacket, const Packet4i& thenPacket, const Packet4i& elsePacket) { + return pblend4(ifPacket, thenPacket, elsePacket); +} + +template<> EIGEN_STRONG_INLINE Packet4f pblend(const Selector<4>& ifPacket, const Packet4f& thenPacket, const Packet4f& elsePacket) { + return pblend4(ifPacket, thenPacket, elsePacket); +} + +template<> EIGEN_STRONG_INLINE Packet8s pblend(const Selector<8>& ifPacket, const Packet8s& thenPacket, const Packet8s& elsePacket) { + Packet8us select = { ifPacket.select[0], ifPacket.select[1], ifPacket.select[2], ifPacket.select[3], + ifPacket.select[4], ifPacket.select[5], ifPacket.select[6], ifPacket.select[7] }; + Packet8us mask = reinterpret_cast(vec_cmpeq(select, p8us_ONE)); + Packet8s result = vec_sel(elsePacket, thenPacket, mask); + return result; +} + +template<> EIGEN_STRONG_INLINE Packet8us pblend(const Selector<8>& ifPacket, const Packet8us& thenPacket, const Packet8us& elsePacket) { + Packet8us select = { ifPacket.select[0], ifPacket.select[1], ifPacket.select[2], ifPacket.select[3], + ifPacket.select[4], ifPacket.select[5], ifPacket.select[6], ifPacket.select[7] }; + Packet8us mask = reinterpret_cast(vec_cmpeq(reinterpret_cast(select), p8us_ONE)); + return vec_sel(elsePacket, thenPacket, mask); +} + +template<> EIGEN_STRONG_INLINE Packet8bf pblend(const Selector<8>& ifPacket, const Packet8bf& thenPacket, const Packet8bf& elsePacket) { + return pblend(ifPacket, thenPacket, elsePacket); +} + +template<> EIGEN_STRONG_INLINE Packet16c pblend(const Selector<16>& ifPacket, const Packet16c& thenPacket, const Packet16c& elsePacket) { + Packet16uc select = { ifPacket.select[0], ifPacket.select[1], ifPacket.select[2], ifPacket.select[3], + ifPacket.select[4], ifPacket.select[5], ifPacket.select[6], ifPacket.select[7], + ifPacket.select[8], ifPacket.select[9], ifPacket.select[10], ifPacket.select[11], + ifPacket.select[12], ifPacket.select[13], ifPacket.select[14], ifPacket.select[15] }; + + Packet16uc mask = reinterpret_cast(vec_cmpeq(reinterpret_cast(select), p16uc_ONE)); + return vec_sel(elsePacket, thenPacket, mask); +} + +template<> EIGEN_STRONG_INLINE Packet16uc pblend(const Selector<16>& ifPacket, const Packet16uc& thenPacket, const Packet16uc& elsePacket) { + Packet16uc select = { ifPacket.select[0], ifPacket.select[1], ifPacket.select[2], ifPacket.select[3], + ifPacket.select[4], ifPacket.select[5], ifPacket.select[6], ifPacket.select[7], + ifPacket.select[8], ifPacket.select[9], ifPacket.select[10], ifPacket.select[11], + ifPacket.select[12], ifPacket.select[13], ifPacket.select[14], ifPacket.select[15] }; + + Packet16uc mask = reinterpret_cast(vec_cmpeq(reinterpret_cast(select), p16uc_ONE)); + return vec_sel(elsePacket, thenPacket, mask); +} + +template <> +struct type_casting_traits { + enum { + VectorizedCast = 1, + SrcCoeffRatio = 1, + TgtCoeffRatio = 1 + }; +}; + +template <> +struct type_casting_traits { + enum { + VectorizedCast = 1, + SrcCoeffRatio = 1, + TgtCoeffRatio = 1 + }; +}; + +template <> +struct type_casting_traits { + enum { + VectorizedCast = 1, + SrcCoeffRatio = 1, + TgtCoeffRatio = 1 + }; +}; + +template <> +struct type_casting_traits { + enum { + VectorizedCast = 1, + SrcCoeffRatio = 1, + TgtCoeffRatio = 1 + }; +}; + +template<> EIGEN_STRONG_INLINE Packet4i pcast(const Packet4f& a) { + return vec_cts(a,0); +} + +template<> EIGEN_STRONG_INLINE Packet4ui pcast(const Packet4f& a) { + return vec_ctu(a,0); +} + +template<> EIGEN_STRONG_INLINE Packet4f pcast(const Packet4i& a) { + return vec_ctf(a,0); +} + +template<> EIGEN_STRONG_INLINE Packet4f pcast(const Packet4ui& a) { + return vec_ctf(a,0); +} + +template<> EIGEN_STRONG_INLINE Packet8us pcast(const Packet8bf& a) { + Packet4f float_even = Bf16ToF32Even(a); + Packet4f float_odd = Bf16ToF32Odd(a); + Packet4ui int_even = pcast(float_even); + Packet4ui int_odd = pcast(float_odd); + const _EIGEN_DECLARE_CONST_FAST_Packet4ui(low_mask, 0x0000FFFF); + Packet4ui low_even = pand(int_even, p4ui_low_mask); + Packet4ui low_odd = pand(int_odd, p4ui_low_mask); + + //Check values that are bigger than USHRT_MAX (0xFFFF) + Packet4bi overflow_selector; + if(vec_any_gt(int_even, p4ui_low_mask)){ + overflow_selector = vec_cmpgt(int_even, p4ui_low_mask); + low_even = vec_sel(low_even, p4ui_low_mask, overflow_selector); + } + if(vec_any_gt(int_odd, p4ui_low_mask)){ + overflow_selector = vec_cmpgt(int_odd, p4ui_low_mask); + low_odd = vec_sel(low_even, p4ui_low_mask, overflow_selector); + } + + low_odd = plogical_shift_left<16>(low_odd); + + Packet4ui int_final = por(low_even, low_odd); + return reinterpret_cast(int_final); +} + +template<> EIGEN_STRONG_INLINE Packet8bf pcast(const Packet8us& a) { + //short -> int -> float -> bfloat16 + const _EIGEN_DECLARE_CONST_FAST_Packet4ui(low_mask, 0x0000FFFF); + Packet4ui int_cast = reinterpret_cast(a); + Packet4ui int_even = pand(int_cast, p4ui_low_mask); + Packet4ui int_odd = plogical_shift_right<16>(int_cast); + Packet4f float_even = pcast(int_even); + Packet4f float_odd = pcast(int_odd); + return F32ToBf16(float_even, float_odd); +} + + +template<> EIGEN_STRONG_INLINE Packet4i preinterpret(const Packet4f& a) { + return reinterpret_cast(a); +} + +template<> EIGEN_STRONG_INLINE Packet4f preinterpret(const Packet4i& a) { + return reinterpret_cast(a); +} + + + +//---------- double ---------- +#ifdef __VSX__ +typedef __vector double Packet2d; +typedef __vector unsigned long long Packet2ul; +typedef __vector long long Packet2l; +#if EIGEN_COMP_CLANG +typedef Packet2ul Packet2bl; +#else +typedef __vector __bool long Packet2bl; +#endif + +static Packet2l p2l_ONE = { 1, 1 }; +static Packet2l p2l_ZERO = reinterpret_cast(p4i_ZERO); +static Packet2ul p2ul_SIGN = { 0x8000000000000000ull, 0x8000000000000000ull }; +static Packet2ul p2ul_PREV0DOT5 = { 0x3FDFFFFFFFFFFFFFull, 0x3FDFFFFFFFFFFFFFull }; +static Packet2d p2d_ONE = { 1.0, 1.0 }; +static Packet2d p2d_ZERO = reinterpret_cast(p4f_ZERO); +static Packet2d p2d_MZERO = { numext::bit_cast(0x8000000000000000ull), + numext::bit_cast(0x8000000000000000ull) }; + +#ifdef _BIG_ENDIAN +static Packet2d p2d_COUNTDOWN = reinterpret_cast(vec_sld(reinterpret_cast(p2d_ZERO), reinterpret_cast(p2d_ONE), 8)); +#else +static Packet2d p2d_COUNTDOWN = reinterpret_cast(vec_sld(reinterpret_cast(p2d_ONE), reinterpret_cast(p2d_ZERO), 8)); +#endif + +template Packet2d vec_splat_dbl(Packet2d& a) +{ + return vec_splat(a, index); +} + +template<> struct packet_traits : default_packet_traits +{ + typedef Packet2d type; + typedef Packet2d half; + enum { + Vectorizable = 1, + AlignedOnScalar = 1, + size=2, + HasHalfPacket = 1, + + HasAdd = 1, + HasSub = 1, + HasMul = 1, + HasDiv = 1, + HasMin = 1, + HasMax = 1, + HasAbs = 1, + HasSin = 0, + HasCos = 0, + HasLog = 0, + HasExp = 1, + HasSqrt = 1, + HasRsqrt = 1, + HasRound = 1, + HasFloor = 1, + HasCeil = 1, + HasRint = 1, + HasNegate = 1, + HasBlend = 1 + }; +}; + +template<> struct unpacket_traits { typedef double type; enum {size=2, alignment=Aligned16, vectorizable=true, masked_load_available=false, masked_store_available=false}; typedef Packet2d half; }; + +inline std::ostream & operator <<(std::ostream & s, const Packet2l & v) +{ + union { + Packet2l v; + int64_t n[2]; + } vt; + vt.v = v; + s << vt.n[0] << ", " << vt.n[1]; + return s; +} + +inline std::ostream & operator <<(std::ostream & s, const Packet2d & v) +{ + union { + Packet2d v; + double n[2]; + } vt; + vt.v = v; + s << vt.n[0] << ", " << vt.n[1]; + return s; +} + +// Need to define them first or we get specialization after instantiation errors +template<> EIGEN_STRONG_INLINE Packet2d pload(const double* from) +{ + EIGEN_DEBUG_ALIGNED_LOAD + return vec_xl(0, const_cast(from)); // cast needed by Clang +} + +template<> EIGEN_STRONG_INLINE void pstore(double* to, const Packet2d& from) +{ + EIGEN_DEBUG_ALIGNED_STORE + vec_xst(from, 0, to); +} + +template<> EIGEN_STRONG_INLINE Packet2d pset1(const double& from) { + Packet2d v = {from, from}; + return v; +} + +template<> EIGEN_STRONG_INLINE Packet2d pset1frombits(unsigned long from) { + Packet2l v = {static_cast(from), static_cast(from)}; + return reinterpret_cast(v); +} + +template<> EIGEN_STRONG_INLINE void +pbroadcast4(const double *a, + Packet2d& a0, Packet2d& a1, Packet2d& a2, Packet2d& a3) +{ + //This way is faster than vec_splat (at least for doubles in Power 9) + a0 = pset1(a[0]); + a1 = pset1(a[1]); + a2 = pset1(a[2]); + a3 = pset1(a[3]); +} + +template<> EIGEN_DEVICE_FUNC inline Packet2d pgather(const double* from, Index stride) +{ + EIGEN_ALIGN16 double af[2]; + af[0] = from[0*stride]; + af[1] = from[1*stride]; + return pload(af); +} +template<> EIGEN_DEVICE_FUNC inline void pscatter(double* to, const Packet2d& from, Index stride) +{ + EIGEN_ALIGN16 double af[2]; + pstore(af, from); + to[0*stride] = af[0]; + to[1*stride] = af[1]; +} + +template<> EIGEN_STRONG_INLINE Packet2d plset(const double& a) { return pset1(a) + p2d_COUNTDOWN; } + +template<> EIGEN_STRONG_INLINE Packet2d padd(const Packet2d& a, const Packet2d& b) { return a + b; } + +template<> EIGEN_STRONG_INLINE Packet2d psub(const Packet2d& a, const Packet2d& b) { return a - b; } + +template<> EIGEN_STRONG_INLINE Packet2d pnegate(const Packet2d& a) { return p2d_ZERO - a; } + +template<> EIGEN_STRONG_INLINE Packet2d pconj(const Packet2d& a) { return a; } + +template<> EIGEN_STRONG_INLINE Packet2d pmul(const Packet2d& a, const Packet2d& b) { return vec_madd(a,b,p2d_MZERO); } +template<> EIGEN_STRONG_INLINE Packet2d pdiv(const Packet2d& a, const Packet2d& b) { return vec_div(a,b); } + +// for some weird raisons, it has to be overloaded for packet of integers +template<> EIGEN_STRONG_INLINE Packet2d pmadd(const Packet2d& a, const Packet2d& b, const Packet2d& c) { return vec_madd(a, b, c); } + +template<> EIGEN_STRONG_INLINE Packet2d pmin(const Packet2d& a, const Packet2d& b) +{ + // NOTE: about 10% slower than vec_min, but consistent with std::min and SSE regarding NaN + Packet2d ret; + __asm__ ("xvcmpgedp %x0,%x1,%x2\n\txxsel %x0,%x1,%x2,%x0" : "=&wa" (ret) : "wa" (a), "wa" (b)); + return ret; + } + +template<> EIGEN_STRONG_INLINE Packet2d pmax(const Packet2d& a, const Packet2d& b) +{ + // NOTE: about 10% slower than vec_max, but consistent with std::max and SSE regarding NaN + Packet2d ret; + __asm__ ("xvcmpgtdp %x0,%x2,%x1\n\txxsel %x0,%x1,%x2,%x0" : "=&wa" (ret) : "wa" (a), "wa" (b)); + return ret; +} + +template<> EIGEN_STRONG_INLINE Packet2d pcmp_le(const Packet2d& a, const Packet2d& b) { return reinterpret_cast(vec_cmple(a,b)); } +template<> EIGEN_STRONG_INLINE Packet2d pcmp_lt(const Packet2d& a, const Packet2d& b) { return reinterpret_cast(vec_cmplt(a,b)); } +template<> EIGEN_STRONG_INLINE Packet2d pcmp_eq(const Packet2d& a, const Packet2d& b) { return reinterpret_cast(vec_cmpeq(a,b)); } +template<> EIGEN_STRONG_INLINE Packet2d pcmp_lt_or_nan(const Packet2d& a, const Packet2d& b) { + Packet2d c = reinterpret_cast(vec_cmpge(a,b)); + return vec_nor(c,c); +} + +template<> EIGEN_STRONG_INLINE Packet2d pand(const Packet2d& a, const Packet2d& b) { return vec_and(a, b); } + +template<> EIGEN_STRONG_INLINE Packet2d por(const Packet2d& a, const Packet2d& b) { return vec_or(a, b); } + +template<> EIGEN_STRONG_INLINE Packet2d pxor(const Packet2d& a, const Packet2d& b) { return vec_xor(a, b); } + +template<> EIGEN_STRONG_INLINE Packet2d pandnot(const Packet2d& a, const Packet2d& b) { return vec_and(a, vec_nor(b, b)); } + +template<> EIGEN_STRONG_INLINE Packet2d pround(const Packet2d& a) +{ + Packet2d t = vec_add(reinterpret_cast(vec_or(vec_and(reinterpret_cast(a), p2ul_SIGN), p2ul_PREV0DOT5)), a); + Packet2d res; + + __asm__("xvrdpiz %x0, %x1\n\t" + : "=&wa" (res) + : "wa" (t)); + + return res; +} +template<> EIGEN_STRONG_INLINE Packet2d pceil(const Packet2d& a) { return vec_ceil(a); } +template<> EIGEN_STRONG_INLINE Packet2d pfloor(const Packet2d& a) { return vec_floor(a); } +template<> EIGEN_STRONG_INLINE Packet2d print(const Packet2d& a) +{ + Packet2d res; + + __asm__("xvrdpic %x0, %x1\n\t" + : "=&wa" (res) + : "wa" (a)); + + return res; +} + +template<> EIGEN_STRONG_INLINE Packet2d ploadu(const double* from) +{ + EIGEN_DEBUG_UNALIGNED_LOAD + return vec_xl(0, const_cast(from)); +} + +template<> EIGEN_STRONG_INLINE Packet2d ploaddup(const double* from) +{ + Packet2d p; + if((std::ptrdiff_t(from) % 16) == 0) p = pload(from); + else p = ploadu(from); + return vec_splat_dbl<0>(p); +} + +template<> EIGEN_STRONG_INLINE void pstoreu(double* to, const Packet2d& from) +{ + EIGEN_DEBUG_UNALIGNED_STORE + vec_xst(from, 0, to); +} + +template<> EIGEN_STRONG_INLINE void prefetch(const double* addr) { EIGEN_PPC_PREFETCH(addr); } + +template<> EIGEN_STRONG_INLINE double pfirst(const Packet2d& a) { EIGEN_ALIGN16 double x[2]; pstore(x, a); return x[0]; } + +template<> EIGEN_STRONG_INLINE Packet2d preverse(const Packet2d& a) +{ + return reinterpret_cast(vec_perm(reinterpret_cast(a), reinterpret_cast(a), p16uc_REVERSE64)); +} +template<> EIGEN_STRONG_INLINE Packet2d pabs(const Packet2d& a) { return vec_abs(a); } + +// VSX support varies between different compilers and even different +// versions of the same compiler. For gcc version >= 4.9.3, we can use +// vec_cts to efficiently convert Packet2d to Packet2l. Otherwise, use +// a slow version that works with older compilers. +// Update: apparently vec_cts/vec_ctf intrinsics for 64-bit doubles +// are buggy, https://gcc.gnu.org/bugzilla/show_bug.cgi?id=70963 +template<> +inline Packet2l pcast(const Packet2d& x) { +#if EIGEN_GNUC_AT_LEAST(5, 4) || \ + (EIGEN_GNUC_AT(6, 1) && __GNUC_PATCHLEVEL__ >= 1) + return vec_cts(x, 0); // TODO: check clang version. +#else + double tmp[2]; + memcpy(tmp, &x, sizeof(tmp)); + Packet2l l = { static_cast(tmp[0]), + static_cast(tmp[1]) }; + return l; +#endif +} + +template<> +inline Packet2d pcast(const Packet2l& x) { + unsigned long long tmp[2]; + memcpy(tmp, &x, sizeof(tmp)); + Packet2d d = { static_cast(tmp[0]), + static_cast(tmp[1]) }; + return d; +} + + +// Packet2l shifts. +// For POWER8 we simply use vec_sr/l. +// +// Things are more complicated for POWER7. There is actually a +// vec_xxsxdi intrinsic but it is not supported by some gcc versions. +// So we need to shift by N % 32 and rearrage bytes. +#ifdef __POWER8_VECTOR__ + +template +EIGEN_STRONG_INLINE Packet2l plogical_shift_left(const Packet2l& a) { + const Packet2ul shift = { N, N }; + return vec_sl(a, shift); +} + +template +EIGEN_STRONG_INLINE Packet2l plogical_shift_right(const Packet2l& a) { + const Packet2ul shift = { N, N }; + return vec_sr(a, shift); +} + +#else + +// Shifts [A, B, C, D] to [B, 0, D, 0]. +// Used to implement left shifts for Packet2l. +EIGEN_ALWAYS_INLINE Packet4i shift_even_left(const Packet4i& a) { + static const Packet16uc perm = { + 0x14, 0x15, 0x16, 0x17, 0x00, 0x01, 0x02, 0x03, + 0x1c, 0x1d, 0x1e, 0x1f, 0x08, 0x09, 0x0a, 0x0b }; + #ifdef _BIG_ENDIAN + return vec_perm(p4i_ZERO, a, perm); + #else + return vec_perm(a, p4i_ZERO, perm); + #endif +} + +// Shifts [A, B, C, D] to [0, A, 0, C]. +// Used to implement right shifts for Packet2l. +EIGEN_ALWAYS_INLINE Packet4i shift_odd_right(const Packet4i& a) { + static const Packet16uc perm = { + 0x04, 0x05, 0x06, 0x07, 0x10, 0x11, 0x12, 0x13, + 0x0c, 0x0d, 0x0e, 0x0f, 0x18, 0x19, 0x1a, 0x1b }; + #ifdef _BIG_ENDIAN + return vec_perm(p4i_ZERO, a, perm); + #else + return vec_perm(a, p4i_ZERO, perm); + #endif +} + +template +struct plogical_shift_left_impl; + +template +struct plogical_shift_left_impl= 0)>::type> { + static EIGEN_STRONG_INLINE Packet2l run(const Packet2l& a) { + static const unsigned n = static_cast(N); + const Packet4ui shift = {n, n, n, n}; + const Packet4i ai = reinterpret_cast(a); + static const unsigned m = static_cast(32 - N); + const Packet4ui shift_right = {m, m, m, m}; + const Packet4i out_hi = vec_sl(ai, shift); + const Packet4i out_lo = shift_even_left(vec_sr(ai, shift_right)); + return reinterpret_cast(por(out_hi, out_lo)); + } +}; + +template +struct plogical_shift_left_impl= 32)>::type> { + static EIGEN_STRONG_INLINE Packet2l run(const Packet2l& a) { + static const unsigned m = static_cast(N - 32); + const Packet4ui shift = {m, m, m, m}; + const Packet4i ai = reinterpret_cast(a); + return reinterpret_cast(shift_even_left(vec_sl(ai, shift))); + } +}; + +template +EIGEN_STRONG_INLINE Packet2l plogical_shift_left(const Packet2l& a) { + return plogical_shift_left_impl::run(a); +} + +template +struct plogical_shift_right_impl; + +template +struct plogical_shift_right_impl= 0)>::type> { + static EIGEN_STRONG_INLINE Packet2l run(const Packet2l& a) { + static const unsigned n = static_cast(N); + const Packet4ui shift = {n, n, n, n}; + const Packet4i ai = reinterpret_cast(a); + static const unsigned m = static_cast(32 - N); + const Packet4ui shift_left = {m, m, m, m}; + const Packet4i out_lo = vec_sr(ai, shift); + const Packet4i out_hi = shift_odd_right(vec_sl(ai, shift_left)); + return reinterpret_cast(por(out_hi, out_lo)); + } +}; + +template +struct plogical_shift_right_impl= 32)>::type> { + static EIGEN_STRONG_INLINE Packet2l run(const Packet2l& a) { + static const unsigned m = static_cast(N - 32); + const Packet4ui shift = {m, m, m, m}; + const Packet4i ai = reinterpret_cast(a); + return reinterpret_cast(shift_odd_right(vec_sr(ai, shift))); + } +}; + +template +EIGEN_STRONG_INLINE Packet2l plogical_shift_right(const Packet2l& a) { + return plogical_shift_right_impl::run(a); +} +#endif + +template<> EIGEN_STRONG_INLINE Packet2d pldexp(const Packet2d& a, const Packet2d& exponent) { + // Clamp exponent to [-2099, 2099] + const Packet2d max_exponent = pset1(2099.0); + const Packet2l e = pcast(pmin(pmax(exponent, pnegate(max_exponent)), max_exponent)); + + // Split 2^e into four factors and multiply: + const Packet2l bias = { 1023, 1023 }; + Packet2l b = plogical_shift_right<2>(e); // floor(e/4) + Packet2d c = reinterpret_cast(plogical_shift_left<52>(b + bias)); + Packet2d out = pmul(pmul(pmul(a, c), c), c); // a * 2^(3b) + b = psub(psub(psub(e, b), b), b); // e - 3b + c = reinterpret_cast(plogical_shift_left<52>(b + bias)); // 2^(e - 3b) + out = pmul(out, c); // a * 2^e + return out; +} + + +// Extract exponent without existence of Packet2l. +template<> +EIGEN_STRONG_INLINE +Packet2d pfrexp_generic_get_biased_exponent(const Packet2d& a) { + return pcast(plogical_shift_right<52>(reinterpret_cast(pabs(a)))); +} + +template<> EIGEN_STRONG_INLINE Packet2d pfrexp (const Packet2d& a, Packet2d& exponent) { + return pfrexp_generic(a, exponent); +} + +template<> EIGEN_STRONG_INLINE double predux(const Packet2d& a) +{ + Packet2d b, sum; + b = reinterpret_cast(vec_sld(reinterpret_cast(a), reinterpret_cast(a), 8)); + sum = a + b; + return pfirst(sum); +} + +// Other reduction functions: +// mul +template<> EIGEN_STRONG_INLINE double predux_mul(const Packet2d& a) +{ + return pfirst(pmul(a, reinterpret_cast(vec_sld(reinterpret_cast(a), reinterpret_cast(a), 8)))); +} + +// min +template<> EIGEN_STRONG_INLINE double predux_min(const Packet2d& a) +{ + return pfirst(pmin(a, reinterpret_cast(vec_sld(reinterpret_cast(a), reinterpret_cast(a), 8)))); +} + +// max +template<> EIGEN_STRONG_INLINE double predux_max(const Packet2d& a) +{ + return pfirst(pmax(a, reinterpret_cast(vec_sld(reinterpret_cast(a), reinterpret_cast(a), 8)))); +} + +EIGEN_DEVICE_FUNC inline void +ptranspose(PacketBlock& kernel) { + Packet2d t0, t1; + t0 = vec_perm(kernel.packet[0], kernel.packet[1], p16uc_TRANSPOSE64_HI); + t1 = vec_perm(kernel.packet[0], kernel.packet[1], p16uc_TRANSPOSE64_LO); + kernel.packet[0] = t0; + kernel.packet[1] = t1; +} + +template<> EIGEN_STRONG_INLINE Packet2d pblend(const Selector<2>& ifPacket, const Packet2d& thenPacket, const Packet2d& elsePacket) { + Packet2l select = { ifPacket.select[0], ifPacket.select[1] }; + Packet2bl mask = reinterpret_cast( vec_cmpeq(reinterpret_cast(select), reinterpret_cast(p2l_ONE)) ); + return vec_sel(elsePacket, thenPacket, mask); +} + + +#endif // __VSX__ +} // end namespace internal + +} // end namespace Eigen + +#endif // EIGEN_PACKET_MATH_ALTIVEC_H diff --git a/Eigen/src/Core/arch/CUDA/Complex.h b/Eigen/src/Core/arch/CUDA/Complex.h new file mode 100644 index 0000000..deb4c86 --- /dev/null +++ b/Eigen/src/Core/arch/CUDA/Complex.h @@ -0,0 +1,258 @@ +// This file is part of Eigen, a lightweight C++ template library +// for linear algebra. +// +// Copyright (C) 2014 Benoit Steiner +// Copyright (C) 2021 C. Antonio Sanchez +// +// This Source Code Form is subject to the terms of the Mozilla +// Public License v. 2.0. If a copy of the MPL was not distributed +// with this file, You can obtain one at http://mozilla.org/MPL/2.0/. + +#ifndef EIGEN_COMPLEX_CUDA_H +#define EIGEN_COMPLEX_CUDA_H + +// clang-format off +// Many std::complex methods such as operator+, operator-, operator* and +// operator/ are not constexpr. Due to this, GCC and older versions of clang do +// not treat them as device functions and thus Eigen functors making use of +// these operators fail to compile. Here, we manually specialize these +// operators and functors for complex types when building for CUDA to enable +// their use on-device. + +#if defined(EIGEN_CUDACC) && defined(EIGEN_GPU_COMPILE_PHASE) + +// ICC already specializes std::complex and std::complex +// operators, preventing us from making them device functions here. +// This will lead to silent runtime errors if the operators are used on device. +// +// To allow std::complex operator use on device, define _OVERRIDE_COMPLEX_SPECIALIZATION_ +// prior to first inclusion of . This prevents ICC from adding +// its own specializations, so our custom ones below can be used instead. +#if !(defined(EIGEN_COMP_ICC) && defined(_USE_COMPLEX_SPECIALIZATION_)) + +// Import Eigen's internal operator specializations. +#define EIGEN_USING_STD_COMPLEX_OPERATORS \ + using Eigen::complex_operator_detail::operator+; \ + using Eigen::complex_operator_detail::operator-; \ + using Eigen::complex_operator_detail::operator*; \ + using Eigen::complex_operator_detail::operator/; \ + using Eigen::complex_operator_detail::operator+=; \ + using Eigen::complex_operator_detail::operator-=; \ + using Eigen::complex_operator_detail::operator*=; \ + using Eigen::complex_operator_detail::operator/=; \ + using Eigen::complex_operator_detail::operator==; \ + using Eigen::complex_operator_detail::operator!=; + +namespace Eigen { + +// Specialized std::complex overloads. +namespace complex_operator_detail { + +template +EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE +std::complex complex_multiply(const std::complex& a, const std::complex& b) { + const T a_real = numext::real(a); + const T a_imag = numext::imag(a); + const T b_real = numext::real(b); + const T b_imag = numext::imag(b); + return std::complex( + a_real * b_real - a_imag * b_imag, + a_imag * b_real + a_real * b_imag); +} + +template +EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE +std::complex complex_divide_fast(const std::complex& a, const std::complex& b) { + const T a_real = numext::real(a); + const T a_imag = numext::imag(a); + const T b_real = numext::real(b); + const T b_imag = numext::imag(b); + const T norm = (b_real * b_real + b_imag * b_imag); + return std::complex((a_real * b_real + a_imag * b_imag) / norm, + (a_imag * b_real - a_real * b_imag) / norm); +} + +template +EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE +std::complex complex_divide_stable(const std::complex& a, const std::complex& b) { + const T a_real = numext::real(a); + const T a_imag = numext::imag(a); + const T b_real = numext::real(b); + const T b_imag = numext::imag(b); + // Smith's complex division (https://arxiv.org/pdf/1210.4539.pdf), + // guards against over/under-flow. + const bool scale_imag = numext::abs(b_imag) <= numext::abs(b_real); + const T rscale = scale_imag ? T(1) : b_real / b_imag; + const T iscale = scale_imag ? b_imag / b_real : T(1); + const T denominator = b_real * rscale + b_imag * iscale; + return std::complex((a_real * rscale + a_imag * iscale) / denominator, + (a_imag * rscale - a_real * iscale) / denominator); +} + +template +EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE +std::complex complex_divide(const std::complex& a, const std::complex& b) { +#if EIGEN_FAST_MATH + return complex_divide_fast(a, b); +#else + return complex_divide_stable(a, b); +#endif +} + +// NOTE: We cannot specialize compound assignment operators with Scalar T, +// (i.e. operator@=(const T&), for @=+,-,*,/) +// since they are already specialized for float/double/long double within +// the standard header. We also do not specialize the stream +// operators. +#define EIGEN_CREATE_STD_COMPLEX_OPERATOR_SPECIALIZATIONS(T) \ + \ +EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE \ +std::complex operator+(const std::complex& a) { return a; } \ + \ +EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE \ +std::complex operator-(const std::complex& a) { \ + return std::complex(-numext::real(a), -numext::imag(a)); \ +} \ + \ +EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE \ +std::complex operator+(const std::complex& a, const std::complex& b) { \ + return std::complex(numext::real(a) + numext::real(b), numext::imag(a) + numext::imag(b)); \ +} \ + \ +EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE \ +std::complex operator+(const std::complex& a, const T& b) { \ + return std::complex(numext::real(a) + b, numext::imag(a)); \ +} \ + \ +EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE \ +std::complex operator+(const T& a, const std::complex& b) { \ + return std::complex(a + numext::real(b), numext::imag(b)); \ +} \ + \ +EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE \ +std::complex operator-(const std::complex& a, const std::complex& b) { \ + return std::complex(numext::real(a) - numext::real(b), numext::imag(a) - numext::imag(b)); \ +} \ + \ +EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE \ +std::complex operator-(const std::complex& a, const T& b) { \ + return std::complex(numext::real(a) - b, numext::imag(a)); \ +} \ + \ +EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE \ +std::complex operator-(const T& a, const std::complex& b) { \ + return std::complex(a - numext::real(b), -numext::imag(b)); \ +} \ + \ +EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE \ +std::complex operator*(const std::complex& a, const std::complex& b) { \ + return complex_multiply(a, b); \ +} \ + \ +EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE \ +std::complex operator*(const std::complex& a, const T& b) { \ + return std::complex(numext::real(a) * b, numext::imag(a) * b); \ +} \ + \ +EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE \ +std::complex operator*(const T& a, const std::complex& b) { \ + return std::complex(a * numext::real(b), a * numext::imag(b)); \ +} \ + \ +EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE \ +std::complex operator/(const std::complex& a, const std::complex& b) { \ + return complex_divide(a, b); \ +} \ + \ +EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE \ +std::complex operator/(const std::complex& a, const T& b) { \ + return std::complex(numext::real(a) / b, numext::imag(a) / b); \ +} \ + \ +EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE \ +std::complex operator/(const T& a, const std::complex& b) { \ + return complex_divide(std::complex(a, 0), b); \ +} \ + \ +EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE \ +std::complex& operator+=(std::complex& a, const std::complex& b) { \ + numext::real_ref(a) += numext::real(b); \ + numext::imag_ref(a) += numext::imag(b); \ + return a; \ +} \ + \ +EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE \ +std::complex& operator-=(std::complex& a, const std::complex& b) { \ + numext::real_ref(a) -= numext::real(b); \ + numext::imag_ref(a) -= numext::imag(b); \ + return a; \ +} \ + \ +EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE \ +std::complex& operator*=(std::complex& a, const std::complex& b) { \ + a = complex_multiply(a, b); \ + return a; \ +} \ + \ +EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE \ +std::complex& operator/=(std::complex& a, const std::complex& b) { \ + a = complex_divide(a, b); \ + return a; \ +} \ + \ +EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE \ +bool operator==(const std::complex& a, const std::complex& b) { \ + return numext::real(a) == numext::real(b) && numext::imag(a) == numext::imag(b); \ +} \ + \ +EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE \ +bool operator==(const std::complex& a, const T& b) { \ + return numext::real(a) == b && numext::imag(a) == 0; \ +} \ + \ +EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE \ +bool operator==(const T& a, const std::complex& b) { \ + return a == numext::real(b) && 0 == numext::imag(b); \ +} \ + \ +EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE \ +bool operator!=(const std::complex& a, const std::complex& b) { \ + return !(a == b); \ +} \ + \ +EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE \ +bool operator!=(const std::complex& a, const T& b) { \ + return !(a == b); \ +} \ + \ +EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE \ +bool operator!=(const T& a, const std::complex& b) { \ + return !(a == b); \ +} + +// Do not specialize for long double, since that reduces to double on device. +EIGEN_CREATE_STD_COMPLEX_OPERATOR_SPECIALIZATIONS(float) +EIGEN_CREATE_STD_COMPLEX_OPERATOR_SPECIALIZATIONS(double) + +#undef EIGEN_CREATE_STD_COMPLEX_OPERATOR_SPECIALIZATIONS + + +} // namespace complex_operator_detail + +EIGEN_USING_STD_COMPLEX_OPERATORS + +namespace numext { +EIGEN_USING_STD_COMPLEX_OPERATORS +} // namespace numext + +namespace internal { +EIGEN_USING_STD_COMPLEX_OPERATORS + +} // namespace internal +} // namespace Eigen + +#endif // !(EIGEN_COMP_ICC && _USE_COMPLEX_SPECIALIZATION_) + +#endif // EIGEN_CUDACC && EIGEN_GPU_COMPILE_PHASE + +#endif // EIGEN_COMPLEX_CUDA_H diff --git a/Eigen/src/Core/arch/Default/BFloat16.h b/Eigen/src/Core/arch/Default/BFloat16.h new file mode 100644 index 0000000..1c28f4f --- /dev/null +++ b/Eigen/src/Core/arch/Default/BFloat16.h @@ -0,0 +1,700 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef EIGEN_BFLOAT16_H +#define EIGEN_BFLOAT16_H + +#define BF16_PACKET_FUNCTION(PACKET_F, PACKET_BF16, METHOD) \ + template <> \ + EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS EIGEN_UNUSED \ + PACKET_BF16 METHOD(const PACKET_BF16& _x) { \ + return F32ToBf16(METHOD(Bf16ToF32(_x))); \ + } + +namespace Eigen { + +struct bfloat16; + +namespace bfloat16_impl { + +// Make our own __bfloat16_raw definition. +struct __bfloat16_raw { + EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR __bfloat16_raw() : value(0) {} + explicit EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR __bfloat16_raw(unsigned short raw) : value(raw) {} + unsigned short value; +}; + +EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR __bfloat16_raw raw_uint16_to_bfloat16(unsigned short value); +template +EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC __bfloat16_raw float_to_bfloat16_rtne(float ff); +// Forward declarations of template specializations, to avoid Visual C++ 2019 errors, saying: +// > error C2908: explicit specialization; 'float_to_bfloat16_rtne' has already been instantiated +template <> +EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC __bfloat16_raw float_to_bfloat16_rtne(float ff); +template <> +EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC __bfloat16_raw float_to_bfloat16_rtne(float ff); +EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC float bfloat16_to_float(__bfloat16_raw h); + +struct bfloat16_base : public __bfloat16_raw { + EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR bfloat16_base() {} + EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR bfloat16_base(const __bfloat16_raw& h) : __bfloat16_raw(h) {} +}; + +} // namespace bfloat16_impl + +// Class definition. +struct bfloat16 : public bfloat16_impl::bfloat16_base { + + typedef bfloat16_impl::__bfloat16_raw __bfloat16_raw; + + EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR bfloat16() {} + + EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR bfloat16(const __bfloat16_raw& h) : bfloat16_impl::bfloat16_base(h) {} + + explicit EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR bfloat16(bool b) + : bfloat16_impl::bfloat16_base(bfloat16_impl::raw_uint16_to_bfloat16(b ? 0x3f80 : 0)) {} + + template + explicit EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR bfloat16(T val) + : bfloat16_impl::bfloat16_base(bfloat16_impl::float_to_bfloat16_rtne::value>(static_cast(val))) {} + + explicit EIGEN_DEVICE_FUNC bfloat16(float f) + : bfloat16_impl::bfloat16_base(bfloat16_impl::float_to_bfloat16_rtne(f)) {} + + // Following the convention of numpy, converting between complex and + // float will lead to loss of imag value. + template + explicit EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR bfloat16(const std::complex& val) + : bfloat16_impl::bfloat16_base(bfloat16_impl::float_to_bfloat16_rtne(static_cast(val.real()))) {} + + EIGEN_DEVICE_FUNC operator float() const { // NOLINT: Allow implicit conversion to float, because it is lossless. + return bfloat16_impl::bfloat16_to_float(*this); + } +}; +} // namespace Eigen + +namespace std { +template<> +struct numeric_limits { + static const bool is_specialized = true; + static const bool is_signed = true; + static const bool is_integer = false; + static const bool is_exact = false; + static const bool has_infinity = true; + static const bool has_quiet_NaN = true; + static const bool has_signaling_NaN = true; + static const float_denorm_style has_denorm = std::denorm_absent; + static const bool has_denorm_loss = false; + static const std::float_round_style round_style = numeric_limits::round_style; + static const bool is_iec559 = false; + static const bool is_bounded = true; + static const bool is_modulo = false; + static const int digits = 8; + static const int digits10 = 2; + static const int max_digits10 = 4; + static const int radix = 2; + static const int min_exponent = numeric_limits::min_exponent; + static const int min_exponent10 = numeric_limits::min_exponent10; + static const int max_exponent = numeric_limits::max_exponent; + static const int max_exponent10 = numeric_limits::max_exponent10; + static const bool traps = numeric_limits::traps; + static const bool tinyness_before = numeric_limits::tinyness_before; + + static Eigen::bfloat16 (min)() { return Eigen::bfloat16_impl::raw_uint16_to_bfloat16(0x0080); } + static Eigen::bfloat16 lowest() { return Eigen::bfloat16_impl::raw_uint16_to_bfloat16(0xff7f); } + static Eigen::bfloat16 (max)() { return Eigen::bfloat16_impl::raw_uint16_to_bfloat16(0x7f7f); } + static Eigen::bfloat16 epsilon() { return Eigen::bfloat16_impl::raw_uint16_to_bfloat16(0x3c00); } + static Eigen::bfloat16 round_error() { return Eigen::bfloat16(0x3f00); } + static Eigen::bfloat16 infinity() { return Eigen::bfloat16_impl::raw_uint16_to_bfloat16(0x7f80); } + static Eigen::bfloat16 quiet_NaN() { return Eigen::bfloat16_impl::raw_uint16_to_bfloat16(0x7fc0); } + static Eigen::bfloat16 signaling_NaN() { return Eigen::bfloat16_impl::raw_uint16_to_bfloat16(0x7f81); } + static Eigen::bfloat16 denorm_min() { return Eigen::bfloat16_impl::raw_uint16_to_bfloat16(0x0001); } +}; + +// If std::numeric_limits is specialized, should also specialize +// std::numeric_limits, std::numeric_limits, and +// std::numeric_limits +// https://stackoverflow.com/a/16519653/ +template<> +struct numeric_limits : numeric_limits {}; +template<> +struct numeric_limits : numeric_limits {}; +template<> +struct numeric_limits : numeric_limits {}; +} // namespace std + +namespace Eigen { + +namespace bfloat16_impl { + +// We need to distinguish ‘clang as the CUDA compiler’ from ‘clang as the host compiler, +// invoked by NVCC’ (e.g. on MacOS). The former needs to see both host and device implementation +// of the functions, while the latter can only deal with one of them. +#if !defined(EIGEN_HAS_NATIVE_BF16) || (EIGEN_COMP_CLANG && !EIGEN_COMP_NVCC) // Emulate support for bfloat16 floats + +#if EIGEN_COMP_CLANG && defined(EIGEN_CUDACC) +// We need to provide emulated *host-side* BF16 operators for clang. +#pragma push_macro("EIGEN_DEVICE_FUNC") +#undef EIGEN_DEVICE_FUNC +#if defined(EIGEN_HAS_CUDA_BF16) && defined(EIGEN_HAS_NATIVE_BF16) +#define EIGEN_DEVICE_FUNC __host__ +#else // both host and device need emulated ops. +#define EIGEN_DEVICE_FUNC __host__ __device__ +#endif +#endif + +// Definitions for CPUs, mostly working through conversion +// to/from fp32. + +EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 operator + (const bfloat16& a, const bfloat16& b) { + return bfloat16(float(a) + float(b)); +} +EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 operator + (const bfloat16& a, const int& b) { + return bfloat16(float(a) + static_cast(b)); +} +EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 operator + (const int& a, const bfloat16& b) { + return bfloat16(static_cast(a) + float(b)); +} +EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 operator * (const bfloat16& a, const bfloat16& b) { + return bfloat16(float(a) * float(b)); +} +EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 operator - (const bfloat16& a, const bfloat16& b) { + return bfloat16(float(a) - float(b)); +} +EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 operator / (const bfloat16& a, const bfloat16& b) { + return bfloat16(float(a) / float(b)); +} +EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 operator - (const bfloat16& a) { + bfloat16 result; + result.value = a.value ^ 0x8000; + return result; +} +EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16& operator += (bfloat16& a, const bfloat16& b) { + a = bfloat16(float(a) + float(b)); + return a; +} +EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16& operator *= (bfloat16& a, const bfloat16& b) { + a = bfloat16(float(a) * float(b)); + return a; +} +EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16& operator -= (bfloat16& a, const bfloat16& b) { + a = bfloat16(float(a) - float(b)); + return a; +} +EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16& operator /= (bfloat16& a, const bfloat16& b) { + a = bfloat16(float(a) / float(b)); + return a; +} +EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 operator++(bfloat16& a) { + a += bfloat16(1); + return a; +} +EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 operator--(bfloat16& a) { + a -= bfloat16(1); + return a; +} +EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 operator++(bfloat16& a, int) { + bfloat16 original_value = a; + ++a; + return original_value; +} +EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 operator--(bfloat16& a, int) { + bfloat16 original_value = a; + --a; + return original_value; +} +EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bool operator == (const bfloat16& a, const bfloat16& b) { + return numext::equal_strict(float(a),float(b)); +} +EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bool operator != (const bfloat16& a, const bfloat16& b) { + return numext::not_equal_strict(float(a), float(b)); +} +EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bool operator < (const bfloat16& a, const bfloat16& b) { + return float(a) < float(b); +} +EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bool operator <= (const bfloat16& a, const bfloat16& b) { + return float(a) <= float(b); +} +EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bool operator > (const bfloat16& a, const bfloat16& b) { + return float(a) > float(b); +} +EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bool operator >= (const bfloat16& a, const bfloat16& b) { + return float(a) >= float(b); +} + +#if EIGEN_COMP_CLANG && defined(EIGEN_CUDACC) +#pragma pop_macro("EIGEN_DEVICE_FUNC") +#endif +#endif // Emulate support for bfloat16 floats + +// Division by an index. Do it in full float precision to avoid accuracy +// issues in converting the denominator to bfloat16. +EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 operator / (const bfloat16& a, Index b) { + return bfloat16(static_cast(a) / static_cast(b)); +} + +EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC __bfloat16_raw truncate_to_bfloat16(const float v) { + __bfloat16_raw output; + if (Eigen::numext::isnan EIGEN_NOT_A_MACRO(v)) { + output.value = std::signbit(v) ? 0xFFC0: 0x7FC0; + return output; + } + const uint16_t* p = reinterpret_cast(&v); +#if defined(__BYTE_ORDER__) && __BYTE_ORDER__ == __ORDER_BIG_ENDIAN__ + output.value = p[0]; +#else + output.value = p[1]; +#endif + return output; +} + +EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR __bfloat16_raw raw_uint16_to_bfloat16(numext::uint16_t value) { + return __bfloat16_raw(value); +} + +EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR numext::uint16_t raw_bfloat16_as_uint16(const __bfloat16_raw& bf) { + return bf.value; +} + +// float_to_bfloat16_rtne template specialization that does not make any +// assumption about the value of its function argument (ff). +template <> +EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC __bfloat16_raw float_to_bfloat16_rtne(float ff) { +#if (defined(EIGEN_HAS_CUDA_BF16) && defined(EIGEN_HAS_HIP_BF16)) + // Nothing to do here +#else + __bfloat16_raw output; + + if (Eigen::numext::isnan EIGEN_NOT_A_MACRO(ff)) { + // If the value is a NaN, squash it to a qNaN with msb of fraction set, + // this makes sure after truncation we don't end up with an inf. + // + // qNaN magic: All exponent bits set + most significant bit of fraction + // set. + output.value = std::signbit(ff) ? 0xFFC0: 0x7FC0; + } else { + // Fast rounding algorithm that rounds a half value to nearest even. This + // reduces expected error when we convert a large number of floats. Here + // is how it works: + // + // Definitions: + // To convert a float 32 to bfloat16, a float 32 can be viewed as 32 bits + // with the following tags: + // + // Sign | Exp (8 bits) | Frac (23 bits) + // S EEEEEEEE FFFFFFLRTTTTTTTTTTTTTTT + // + // S: Sign bit. + // E: Exponent bits. + // F: First 6 bits of fraction. + // L: Least significant bit of resulting bfloat16 if we truncate away the + // rest of the float32. This is also the 7th bit of fraction + // R: Rounding bit, 8th bit of fraction. + // T: Sticky bits, rest of fraction, 15 bits. + // + // To round half to nearest even, there are 3 cases where we want to round + // down (simply truncate the result of the bits away, which consists of + // rounding bit and sticky bits) and two cases where we want to round up + // (truncate then add one to the result). + // + // The fast converting algorithm simply adds lsb (L) to 0x7fff (15 bits of + // 1s) as the rounding bias, adds the rounding bias to the input, then + // truncates the last 16 bits away. + // + // To understand how it works, we can analyze this algorithm case by case: + // + // 1. L = 0, R = 0: + // Expect: round down, this is less than half value. + // + // Algorithm: + // - Rounding bias: 0x7fff + 0 = 0x7fff + // - Adding rounding bias to input may create any carry, depending on + // whether there is any value set to 1 in T bits. + // - R may be set to 1 if there is a carry. + // - L remains 0. + // - Note that this case also handles Inf and -Inf, where all fraction + // bits, including L, R and Ts are all 0. The output remains Inf after + // this algorithm. + // + // 2. L = 1, R = 0: + // Expect: round down, this is less than half value. + // + // Algorithm: + // - Rounding bias: 0x7fff + 1 = 0x8000 + // - Adding rounding bias to input doesn't change sticky bits but + // adds 1 to rounding bit. + // - L remains 1. + // + // 3. L = 0, R = 1, all of T are 0: + // Expect: round down, this is exactly at half, the result is already + // even (L=0). + // + // Algorithm: + // - Rounding bias: 0x7fff + 0 = 0x7fff + // - Adding rounding bias to input sets all sticky bits to 1, but + // doesn't create a carry. + // - R remains 1. + // - L remains 0. + // + // 4. L = 1, R = 1: + // Expect: round up, this is exactly at half, the result needs to be + // round to the next even number. + // + // Algorithm: + // - Rounding bias: 0x7fff + 1 = 0x8000 + // - Adding rounding bias to input doesn't change sticky bits, but + // creates a carry from rounding bit. + // - The carry sets L to 0, creates another carry bit and propagate + // forward to F bits. + // - If all the F bits are 1, a carry then propagates to the exponent + // bits, which then creates the minimum value with the next exponent + // value. Note that we won't have the case where exponents are all 1, + // since that's either a NaN (handled in the other if condition) or inf + // (handled in case 1). + // + // 5. L = 0, R = 1, any of T is 1: + // Expect: round up, this is greater than half. + // + // Algorithm: + // - Rounding bias: 0x7fff + 0 = 0x7fff + // - Adding rounding bias to input creates a carry from sticky bits, + // sets rounding bit to 0, then create another carry. + // - The second carry sets L to 1. + // + // Examples: + // + // Exact half value that is already even: + // Input: + // Sign | Exp (8 bit) | Frac (first 7 bit) | Frac (last 16 bit) + // S E E E E E E E E F F F F F F L RTTTTTTTTTTTTTTT + // 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 0 1000000000000000 + // + // This falls into case 3. We truncate the rest of 16 bits and no + // carry is created into F and L: + // + // Output: + // Sign | Exp (8 bit) | Frac (first 7 bit) + // S E E E E E E E E F F F F F F L + // 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 0 + // + // Exact half value, round to next even number: + // Input: + // Sign | Exp (8 bit) | Frac (first 7 bit) | Frac (last 16 bit) + // S E E E E E E E E F F F F F F L RTTTTTTTTTTTTTTT + // 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 1000000000000000 + // + // This falls into case 4. We create a carry from R and T, + // which then propagates into L and F: + // + // Output: + // Sign | Exp (8 bit) | Frac (first 7 bit) + // S E E E E E E E E F F F F F F L + // 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 0 + // + // + // Max denormal value round to min normal value: + // Input: + // Sign | Exp (8 bit) | Frac (first 7 bit) | Frac (last 16 bit) + // S E E E E E E E E F F F F F F L RTTTTTTTTTTTTTTT + // 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1111111111111111 + // + // This falls into case 4. We create a carry from R and T, + // propagate into L and F, which then propagates into exponent + // bits: + // + // Output: + // Sign | Exp (8 bit) | Frac (first 7 bit) + // S E E E E E E E E F F F F F F L + // 0 0 0 0 0 0 0 0 1 0 0 0 0 0 0 0 + // + // Max normal value round to Inf: + // Input: + // Sign | Exp (8 bit) | Frac (first 7 bit) | Frac (last 16 bit) + // S E E E E E E E E F F F F F F L RTTTTTTTTTTTTTTT + // 0 1 1 1 1 1 1 1 0 1 1 1 1 1 1 1 1111111111111111 + // + // This falls into case 4. We create a carry from R and T, + // propagate into L and F, which then propagates into exponent + // bits: + // + // Sign | Exp (8 bit) | Frac (first 7 bit) + // S E E E E E E E E F F F F F F L + // 0 1 1 1 1 1 1 1 1 0 0 0 0 0 0 0 + + // At this point, ff must be either a normal float, or +/-infinity. + output = float_to_bfloat16_rtne(ff); + } + return output; +#endif +} + +// float_to_bfloat16_rtne template specialization that assumes that its function +// argument (ff) is either a normal floating point number, or +/-infinity, or +// zero. Used to improve the runtime performance of conversion from an integer +// type to bfloat16. +template <> +EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC __bfloat16_raw float_to_bfloat16_rtne(float ff) { +#if (defined(EIGEN_HAS_CUDA_BF16) && defined(EIGEN_HAS_HIP_BF16)) + // Nothing to do here +#else + numext::uint32_t input = numext::bit_cast(ff); + __bfloat16_raw output; + + // Least significant bit of resulting bfloat. + numext::uint32_t lsb = (input >> 16) & 1; + numext::uint32_t rounding_bias = 0x7fff + lsb; + input += rounding_bias; + output.value = static_cast(input >> 16); + return output; +#endif +} + +EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC float bfloat16_to_float(__bfloat16_raw h) { + float result = 0; + unsigned short* q = reinterpret_cast(&result); +#if defined(__BYTE_ORDER__) && __BYTE_ORDER__ == __ORDER_BIG_ENDIAN__ + q[0] = h.value; +#else + q[1] = h.value; +#endif + return result; +} +// --- standard functions --- + +EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bool (isinf)(const bfloat16& a) { + EIGEN_USING_STD(isinf); + return (isinf)(float(a)); +} +EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bool (isnan)(const bfloat16& a) { + EIGEN_USING_STD(isnan); + return (isnan)(float(a)); +} +EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bool (isfinite)(const bfloat16& a) { + return !(isinf EIGEN_NOT_A_MACRO (a)) && !(isnan EIGEN_NOT_A_MACRO (a)); +} + +EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 abs(const bfloat16& a) { + bfloat16 result; + result.value = a.value & 0x7FFF; + return result; +} +EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 exp(const bfloat16& a) { + return bfloat16(::expf(float(a))); +} +EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 expm1(const bfloat16& a) { + return bfloat16(numext::expm1(float(a))); +} +EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 log(const bfloat16& a) { + return bfloat16(::logf(float(a))); +} +EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 log1p(const bfloat16& a) { + return bfloat16(numext::log1p(float(a))); +} +EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 log10(const bfloat16& a) { + return bfloat16(::log10f(float(a))); +} +EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 log2(const bfloat16& a) { + return bfloat16(static_cast(EIGEN_LOG2E) * ::logf(float(a))); +} +EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 sqrt(const bfloat16& a) { + return bfloat16(::sqrtf(float(a))); +} +EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 pow(const bfloat16& a, const bfloat16& b) { + return bfloat16(::powf(float(a), float(b))); +} +EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 sin(const bfloat16& a) { + return bfloat16(::sinf(float(a))); +} +EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 cos(const bfloat16& a) { + return bfloat16(::cosf(float(a))); +} +EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 tan(const bfloat16& a) { + return bfloat16(::tanf(float(a))); +} +EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 asin(const bfloat16& a) { + return bfloat16(::asinf(float(a))); +} +EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 acos(const bfloat16& a) { + return bfloat16(::acosf(float(a))); +} +EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 atan(const bfloat16& a) { + return bfloat16(::atanf(float(a))); +} +EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 sinh(const bfloat16& a) { + return bfloat16(::sinhf(float(a))); +} +EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 cosh(const bfloat16& a) { + return bfloat16(::coshf(float(a))); +} +EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 tanh(const bfloat16& a) { + return bfloat16(::tanhf(float(a))); +} +#if EIGEN_HAS_CXX11_MATH +EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 asinh(const bfloat16& a) { + return bfloat16(::asinhf(float(a))); +} +EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 acosh(const bfloat16& a) { + return bfloat16(::acoshf(float(a))); +} +EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 atanh(const bfloat16& a) { + return bfloat16(::atanhf(float(a))); +} +#endif +EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 floor(const bfloat16& a) { + return bfloat16(::floorf(float(a))); +} +EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 ceil(const bfloat16& a) { + return bfloat16(::ceilf(float(a))); +} +EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 rint(const bfloat16& a) { + return bfloat16(::rintf(float(a))); +} +EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 round(const bfloat16& a) { + return bfloat16(::roundf(float(a))); +} +EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 fmod(const bfloat16& a, const bfloat16& b) { + return bfloat16(::fmodf(float(a), float(b))); +} + +EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 (min)(const bfloat16& a, const bfloat16& b) { + const float f1 = static_cast(a); + const float f2 = static_cast(b); + return f2 < f1 ? b : a; +} +EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 (max)(const bfloat16& a, const bfloat16& b) { + const float f1 = static_cast(a); + const float f2 = static_cast(b); + return f1 < f2 ? b : a; +} + +EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 fmin(const bfloat16& a, const bfloat16& b) { + const float f1 = static_cast(a); + const float f2 = static_cast(b); + return bfloat16(::fminf(f1, f2)); +} +EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 fmax(const bfloat16& a, const bfloat16& b) { + const float f1 = static_cast(a); + const float f2 = static_cast(b); + return bfloat16(::fmaxf(f1, f2)); +} + +#ifndef EIGEN_NO_IO +EIGEN_ALWAYS_INLINE std::ostream& operator << (std::ostream& os, const bfloat16& v) { + os << static_cast(v); + return os; +} +#endif + +} // namespace bfloat16_impl + +namespace internal { + +template<> +struct random_default_impl +{ + static inline bfloat16 run(const bfloat16& x, const bfloat16& y) + { + return x + (y-x) * bfloat16(float(std::rand()) / float(RAND_MAX)); + } + static inline bfloat16 run() + { + return run(bfloat16(-1.f), bfloat16(1.f)); + } +}; + +template<> struct is_arithmetic { enum { value = true }; }; + +} // namespace internal + +template<> struct NumTraits + : GenericNumTraits +{ + enum { + IsSigned = true, + IsInteger = false, + IsComplex = false, + RequireInitialization = false + }; + + EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR static EIGEN_STRONG_INLINE Eigen::bfloat16 epsilon() { + return bfloat16_impl::raw_uint16_to_bfloat16(0x3c00); + } + EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR static EIGEN_STRONG_INLINE Eigen::bfloat16 dummy_precision() { + return bfloat16_impl::raw_uint16_to_bfloat16(0x3D4D); // bfloat16(5e-2f); + + } + EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR static EIGEN_STRONG_INLINE Eigen::bfloat16 highest() { + return bfloat16_impl::raw_uint16_to_bfloat16(0x7F7F); + } + EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR static EIGEN_STRONG_INLINE Eigen::bfloat16 lowest() { + return bfloat16_impl::raw_uint16_to_bfloat16(0xFF7F); + } + EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR static EIGEN_STRONG_INLINE Eigen::bfloat16 infinity() { + return bfloat16_impl::raw_uint16_to_bfloat16(0x7f80); + } + EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR static EIGEN_STRONG_INLINE Eigen::bfloat16 quiet_NaN() { + return bfloat16_impl::raw_uint16_to_bfloat16(0x7fc0); + } +}; + +} // namespace Eigen + +namespace Eigen { +namespace numext { + +template<> +EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE +bool (isnan)(const Eigen::bfloat16& h) { + return (bfloat16_impl::isnan)(h); +} + +template<> +EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE +bool (isinf)(const Eigen::bfloat16& h) { + return (bfloat16_impl::isinf)(h); +} + +template<> +EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE +bool (isfinite)(const Eigen::bfloat16& h) { + return (bfloat16_impl::isfinite)(h); +} + +template <> +EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC Eigen::bfloat16 bit_cast(const uint16_t& src) { + return Eigen::bfloat16(Eigen::bfloat16_impl::raw_uint16_to_bfloat16(src)); +} + +template <> +EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC uint16_t bit_cast(const Eigen::bfloat16& src) { + return Eigen::bfloat16_impl::raw_bfloat16_as_uint16(src); +} + +} // namespace numext +} // namespace Eigen + +#if EIGEN_HAS_STD_HASH +namespace std { +template <> +struct hash { + EIGEN_STRONG_INLINE std::size_t operator()(const Eigen::bfloat16& a) const { + return static_cast(Eigen::numext::bit_cast(a)); + } +}; +} // namespace std +#endif + + +#endif // EIGEN_BFLOAT16_H diff --git a/Eigen/src/Core/arch/Default/ConjHelper.h b/Eigen/src/Core/arch/Default/ConjHelper.h new file mode 100644 index 0000000..53830b5 --- /dev/null +++ b/Eigen/src/Core/arch/Default/ConjHelper.h @@ -0,0 +1,117 @@ + +// This file is part of Eigen, a lightweight C++ template library +// for linear algebra. +// +// Copyright (C) 2017 Gael Guennebaud +// +// This Source Code Form is subject to the terms of the Mozilla +// Public License v. 2.0. If a copy of the MPL was not distributed +// with this file, You can obtain one at http://mozilla.org/MPL/2.0/. + +#ifndef EIGEN_ARCH_CONJ_HELPER_H +#define EIGEN_ARCH_CONJ_HELPER_H + +#define EIGEN_MAKE_CONJ_HELPER_CPLX_REAL(PACKET_CPLX, PACKET_REAL) \ + template <> \ + struct conj_helper { \ + EIGEN_STRONG_INLINE PACKET_CPLX pmadd(const PACKET_REAL& x, \ + const PACKET_CPLX& y, \ + const PACKET_CPLX& c) const { \ + return padd(c, this->pmul(x, y)); \ + } \ + EIGEN_STRONG_INLINE PACKET_CPLX pmul(const PACKET_REAL& x, \ + const PACKET_CPLX& y) const { \ + return PACKET_CPLX(Eigen::internal::pmul(x, y.v)); \ + } \ + }; \ + \ + template <> \ + struct conj_helper { \ + EIGEN_STRONG_INLINE PACKET_CPLX pmadd(const PACKET_CPLX& x, \ + const PACKET_REAL& y, \ + const PACKET_CPLX& c) const { \ + return padd(c, this->pmul(x, y)); \ + } \ + EIGEN_STRONG_INLINE PACKET_CPLX pmul(const PACKET_CPLX& x, \ + const PACKET_REAL& y) const { \ + return PACKET_CPLX(Eigen::internal::pmul(x.v, y)); \ + } \ + }; + +namespace Eigen { +namespace internal { + +template struct conj_if; + +template<> struct conj_if { + template + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE T operator()(const T& x) const { return numext::conj(x); } + template + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE T pconj(const T& x) const { return internal::pconj(x); } +}; + +template<> struct conj_if { + template + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const T& operator()(const T& x) const { return x; } + template + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const T& pconj(const T& x) const { return x; } +}; + +// Generic Implementation, assume scalars since the packet-version is +// specialized below. +template +struct conj_helper { + typedef typename ScalarBinaryOpTraits::ReturnType ResultType; + + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE ResultType + pmadd(const LhsType& x, const RhsType& y, const ResultType& c) const + { return this->pmul(x, y) + c; } + + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE ResultType + pmul(const LhsType& x, const RhsType& y) const + { return conj_if()(x) * conj_if()(y); } +}; + +template +struct conj_helper { + typedef typename ScalarBinaryOpTraits::ReturnType ResultType; + + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE ResultType + pmadd(const LhsScalar& x, const RhsScalar& y, const ResultType& c) const + { return this->pmul(x, y) + c; } + + // We save a conjuation by using the identity conj(a)*conj(b) = conj(a*b). + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE ResultType + pmul(const LhsScalar& x, const RhsScalar& y) const + { return numext::conj(x * y); } +}; + +// Implementation with equal type, use packet operations. +template +struct conj_helper +{ + typedef Packet ResultType; + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet pmadd(const Packet& x, const Packet& y, const Packet& c) const + { return Eigen::internal::pmadd(conj_if().pconj(x), conj_if().pconj(y), c); } + + + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet pmul(const Packet& x, const Packet& y) const + { return Eigen::internal::pmul(conj_if().pconj(x), conj_if().pconj(y)); } +}; + +template +struct conj_helper +{ + typedef Packet ResultType; + + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet pmadd(const Packet& x, const Packet& y, const Packet& c) const + { return Eigen::internal::pmadd(pconj(x), pconj(y), c); } + // We save a conjuation by using the identity conj(a)*conj(b) = conj(a*b). + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet pmul(const Packet& x, const Packet& y) const + { return pconj(Eigen::internal::pmul(x, y)); } +}; + +} // namespace internal +} // namespace Eigen + +#endif // EIGEN_ARCH_CONJ_HELPER_H diff --git a/Eigen/src/Core/arch/Default/GenericPacketMathFunctions.h b/Eigen/src/Core/arch/Default/GenericPacketMathFunctions.h new file mode 100644 index 0000000..c9fbaf6 --- /dev/null +++ b/Eigen/src/Core/arch/Default/GenericPacketMathFunctions.h @@ -0,0 +1,1649 @@ +// This file is part of Eigen, a lightweight C++ template library +// for linear algebra. +// +// Copyright (C) 2007 Julien Pommier +// Copyright (C) 2014 Pedro Gonnet (pedro.gonnet@gmail.com) +// Copyright (C) 2009-2019 Gael Guennebaud +// +// This Source Code Form is subject to the terms of the Mozilla +// Public License v. 2.0. If a copy of the MPL was not distributed +// with this file, You can obtain one at http://mozilla.org/MPL/2.0/. + +/* The exp and log functions of this file initially come from + * Julien Pommier's sse math library: http://gruntthepeon.free.fr/ssemath/ + */ + +#ifndef EIGEN_ARCH_GENERIC_PACKET_MATH_FUNCTIONS_H +#define EIGEN_ARCH_GENERIC_PACKET_MATH_FUNCTIONS_H + +namespace Eigen { +namespace internal { + +// Creates a Scalar integer type with same bit-width. +template struct make_integer; +template<> struct make_integer { typedef numext::int32_t type; }; +template<> struct make_integer { typedef numext::int64_t type; }; +template<> struct make_integer { typedef numext::int16_t type; }; +template<> struct make_integer { typedef numext::int16_t type; }; + +template EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC +Packet pfrexp_generic_get_biased_exponent(const Packet& a) { + typedef typename unpacket_traits::type Scalar; + typedef typename unpacket_traits::integer_packet PacketI; + enum { mantissa_bits = numext::numeric_limits::digits - 1}; + return pcast(plogical_shift_right(preinterpret(pabs(a)))); +} + +// Safely applies frexp, correctly handles denormals. +// Assumes IEEE floating point format. +template EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC +Packet pfrexp_generic(const Packet& a, Packet& exponent) { + typedef typename unpacket_traits::type Scalar; + typedef typename make_unsigned::type>::type ScalarUI; + enum { + TotalBits = sizeof(Scalar) * CHAR_BIT, + MantissaBits = numext::numeric_limits::digits - 1, + ExponentBits = int(TotalBits) - int(MantissaBits) - 1 + }; + + EIGEN_CONSTEXPR ScalarUI scalar_sign_mantissa_mask = + ~(((ScalarUI(1) << int(ExponentBits)) - ScalarUI(1)) << int(MantissaBits)); // ~0x7f800000 + const Packet sign_mantissa_mask = pset1frombits(static_cast(scalar_sign_mantissa_mask)); + const Packet half = pset1(Scalar(0.5)); + const Packet zero = pzero(a); + const Packet normal_min = pset1((numext::numeric_limits::min)()); // Minimum normal value, 2^-126 + + // To handle denormals, normalize by multiplying by 2^(int(MantissaBits)+1). + const Packet is_denormal = pcmp_lt(pabs(a), normal_min); + EIGEN_CONSTEXPR ScalarUI scalar_normalization_offset = ScalarUI(int(MantissaBits) + 1); // 24 + // The following cannot be constexpr because bfloat16(uint16_t) is not constexpr. + const Scalar scalar_normalization_factor = Scalar(ScalarUI(1) << int(scalar_normalization_offset)); // 2^24 + const Packet normalization_factor = pset1(scalar_normalization_factor); + const Packet normalized_a = pselect(is_denormal, pmul(a, normalization_factor), a); + + // Determine exponent offset: -126 if normal, -126-24 if denormal + const Scalar scalar_exponent_offset = -Scalar((ScalarUI(1)<<(int(ExponentBits)-1)) - ScalarUI(2)); // -126 + Packet exponent_offset = pset1(scalar_exponent_offset); + const Packet normalization_offset = pset1(-Scalar(scalar_normalization_offset)); // -24 + exponent_offset = pselect(is_denormal, padd(exponent_offset, normalization_offset), exponent_offset); + + // Determine exponent and mantissa from normalized_a. + exponent = pfrexp_generic_get_biased_exponent(normalized_a); + // Zero, Inf and NaN return 'a' unmodified, exponent is zero + // (technically the exponent is unspecified for inf/NaN, but GCC/Clang set it to zero) + const Scalar scalar_non_finite_exponent = Scalar((ScalarUI(1) << int(ExponentBits)) - ScalarUI(1)); // 255 + const Packet non_finite_exponent = pset1(scalar_non_finite_exponent); + const Packet is_zero_or_not_finite = por(pcmp_eq(a, zero), pcmp_eq(exponent, non_finite_exponent)); + const Packet m = pselect(is_zero_or_not_finite, a, por(pand(normalized_a, sign_mantissa_mask), half)); + exponent = pselect(is_zero_or_not_finite, zero, padd(exponent, exponent_offset)); + return m; +} + +// Safely applies ldexp, correctly handles overflows, underflows and denormals. +// Assumes IEEE floating point format. +template EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC +Packet pldexp_generic(const Packet& a, const Packet& exponent) { + // We want to return a * 2^exponent, allowing for all possible integer + // exponents without overflowing or underflowing in intermediate + // computations. + // + // Since 'a' and the output can be denormal, the maximum range of 'exponent' + // to consider for a float is: + // -255-23 -> 255+23 + // Below -278 any finite float 'a' will become zero, and above +278 any + // finite float will become inf, including when 'a' is the smallest possible + // denormal. + // + // Unfortunately, 2^(278) cannot be represented using either one or two + // finite normal floats, so we must split the scale factor into at least + // three parts. It turns out to be faster to split 'exponent' into four + // factors, since [exponent>>2] is much faster to compute that [exponent/3]. + // + // Set e = min(max(exponent, -278), 278); + // b = floor(e/4); + // out = ((((a * 2^(b)) * 2^(b)) * 2^(b)) * 2^(e-3*b)) + // + // This will avoid any intermediate overflows and correctly handle 0, inf, + // NaN cases. + typedef typename unpacket_traits::integer_packet PacketI; + typedef typename unpacket_traits::type Scalar; + typedef typename unpacket_traits::type ScalarI; + enum { + TotalBits = sizeof(Scalar) * CHAR_BIT, + MantissaBits = numext::numeric_limits::digits - 1, + ExponentBits = int(TotalBits) - int(MantissaBits) - 1 + }; + + const Packet max_exponent = pset1(Scalar((ScalarI(1)<((ScalarI(1)<<(int(ExponentBits)-1)) - ScalarI(1)); // 127 + const PacketI e = pcast(pmin(pmax(exponent, pnegate(max_exponent)), max_exponent)); + PacketI b = parithmetic_shift_right<2>(e); // floor(e/4); + Packet c = preinterpret(plogical_shift_left(padd(b, bias))); // 2^b + Packet out = pmul(pmul(pmul(a, c), c), c); // a * 2^(3b) + b = psub(psub(psub(e, b), b), b); // e - 3b + c = preinterpret(plogical_shift_left(padd(b, bias))); // 2^(e-3*b) + out = pmul(out, c); + return out; +} + +// Explicitly multiplies +// a * (2^e) +// clamping e to the range +// [NumTraits::min_exponent()-2, NumTraits::max_exponent()] +// +// This is approx 7x faster than pldexp_impl, but will prematurely over/underflow +// if 2^e doesn't fit into a normal floating-point Scalar. +// +// Assumes IEEE floating point format +template +struct pldexp_fast_impl { + typedef typename unpacket_traits::integer_packet PacketI; + typedef typename unpacket_traits::type Scalar; + typedef typename unpacket_traits::type ScalarI; + enum { + TotalBits = sizeof(Scalar) * CHAR_BIT, + MantissaBits = numext::numeric_limits::digits - 1, + ExponentBits = int(TotalBits) - int(MantissaBits) - 1 + }; + + static EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC + Packet run(const Packet& a, const Packet& exponent) { + const Packet bias = pset1(Scalar((ScalarI(1)<<(int(ExponentBits)-1)) - ScalarI(1))); // 127 + const Packet limit = pset1(Scalar((ScalarI(1)<(pmin(pmax(padd(exponent, bias), pzero(limit)), limit)); // exponent + 127 + // return a * (2^e) + return pmul(a, preinterpret(plogical_shift_left(e))); + } +}; + +// Natural or base 2 logarithm. +// Computes log(x) as log(2^e * m) = C*e + log(m), where the constant C =log(2) +// and m is in the range [sqrt(1/2),sqrt(2)). In this range, the logarithm can +// be easily approximated by a polynomial centered on m=1 for stability. +// TODO(gonnet): Further reduce the interval allowing for lower-degree +// polynomial interpolants -> ... -> profit! +template +EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS +EIGEN_UNUSED +Packet plog_impl_float(const Packet _x) +{ + Packet x = _x; + + const Packet cst_1 = pset1(1.0f); + const Packet cst_neg_half = pset1(-0.5f); + // The smallest non denormalized float number. + const Packet cst_min_norm_pos = pset1frombits( 0x00800000u); + const Packet cst_minus_inf = pset1frombits( 0xff800000u); + const Packet cst_pos_inf = pset1frombits( 0x7f800000u); + + // Polynomial coefficients. + const Packet cst_cephes_SQRTHF = pset1(0.707106781186547524f); + const Packet cst_cephes_log_p0 = pset1(7.0376836292E-2f); + const Packet cst_cephes_log_p1 = pset1(-1.1514610310E-1f); + const Packet cst_cephes_log_p2 = pset1(1.1676998740E-1f); + const Packet cst_cephes_log_p3 = pset1(-1.2420140846E-1f); + const Packet cst_cephes_log_p4 = pset1(+1.4249322787E-1f); + const Packet cst_cephes_log_p5 = pset1(-1.6668057665E-1f); + const Packet cst_cephes_log_p6 = pset1(+2.0000714765E-1f); + const Packet cst_cephes_log_p7 = pset1(-2.4999993993E-1f); + const Packet cst_cephes_log_p8 = pset1(+3.3333331174E-1f); + + // Truncate input values to the minimum positive normal. + x = pmax(x, cst_min_norm_pos); + + Packet e; + // extract significant in the range [0.5,1) and exponent + x = pfrexp(x,e); + + // part2: Shift the inputs from the range [0.5,1) to [sqrt(1/2),sqrt(2)) + // and shift by -1. The values are then centered around 0, which improves + // the stability of the polynomial evaluation. + // if( x < SQRTHF ) { + // e -= 1; + // x = x + x - 1.0; + // } else { x = x - 1.0; } + Packet mask = pcmp_lt(x, cst_cephes_SQRTHF); + Packet tmp = pand(x, mask); + x = psub(x, cst_1); + e = psub(e, pand(cst_1, mask)); + x = padd(x, tmp); + + Packet x2 = pmul(x, x); + Packet x3 = pmul(x2, x); + + // Evaluate the polynomial approximant of degree 8 in three parts, probably + // to improve instruction-level parallelism. + Packet y, y1, y2; + y = pmadd(cst_cephes_log_p0, x, cst_cephes_log_p1); + y1 = pmadd(cst_cephes_log_p3, x, cst_cephes_log_p4); + y2 = pmadd(cst_cephes_log_p6, x, cst_cephes_log_p7); + y = pmadd(y, x, cst_cephes_log_p2); + y1 = pmadd(y1, x, cst_cephes_log_p5); + y2 = pmadd(y2, x, cst_cephes_log_p8); + y = pmadd(y, x3, y1); + y = pmadd(y, x3, y2); + y = pmul(y, x3); + + y = pmadd(cst_neg_half, x2, y); + x = padd(x, y); + + // Add the logarithm of the exponent back to the result of the interpolation. + if (base2) { + const Packet cst_log2e = pset1(static_cast(EIGEN_LOG2E)); + x = pmadd(x, cst_log2e, e); + } else { + const Packet cst_ln2 = pset1(static_cast(EIGEN_LN2)); + x = pmadd(e, cst_ln2, x); + } + + Packet invalid_mask = pcmp_lt_or_nan(_x, pzero(_x)); + Packet iszero_mask = pcmp_eq(_x,pzero(_x)); + Packet pos_inf_mask = pcmp_eq(_x,cst_pos_inf); + // Filter out invalid inputs, i.e.: + // - negative arg will be NAN + // - 0 will be -INF + // - +INF will be +INF + return pselect(iszero_mask, cst_minus_inf, + por(pselect(pos_inf_mask,cst_pos_inf,x), invalid_mask)); +} + +template +EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS +EIGEN_UNUSED +Packet plog_float(const Packet _x) +{ + return plog_impl_float(_x); +} + +template +EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS +EIGEN_UNUSED +Packet plog2_float(const Packet _x) +{ + return plog_impl_float(_x); +} + +/* Returns the base e (2.718...) or base 2 logarithm of x. + * The argument is separated into its exponent and fractional parts. + * The logarithm of the fraction in the interval [sqrt(1/2), sqrt(2)], + * is approximated by + * + * log(1+x) = x - 0.5 x**2 + x**3 P(x)/Q(x). + * + * for more detail see: http://www.netlib.org/cephes/ + */ +template +EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS +EIGEN_UNUSED +Packet plog_impl_double(const Packet _x) +{ + Packet x = _x; + + const Packet cst_1 = pset1(1.0); + const Packet cst_neg_half = pset1(-0.5); + // The smallest non denormalized double. + const Packet cst_min_norm_pos = pset1frombits( static_cast(0x0010000000000000ull)); + const Packet cst_minus_inf = pset1frombits( static_cast(0xfff0000000000000ull)); + const Packet cst_pos_inf = pset1frombits( static_cast(0x7ff0000000000000ull)); + + + // Polynomial Coefficients for log(1+x) = x - x**2/2 + x**3 P(x)/Q(x) + // 1/sqrt(2) <= x < sqrt(2) + const Packet cst_cephes_SQRTHF = pset1(0.70710678118654752440E0); + const Packet cst_cephes_log_p0 = pset1(1.01875663804580931796E-4); + const Packet cst_cephes_log_p1 = pset1(4.97494994976747001425E-1); + const Packet cst_cephes_log_p2 = pset1(4.70579119878881725854E0); + const Packet cst_cephes_log_p3 = pset1(1.44989225341610930846E1); + const Packet cst_cephes_log_p4 = pset1(1.79368678507819816313E1); + const Packet cst_cephes_log_p5 = pset1(7.70838733755885391666E0); + + const Packet cst_cephes_log_q0 = pset1(1.0); + const Packet cst_cephes_log_q1 = pset1(1.12873587189167450590E1); + const Packet cst_cephes_log_q2 = pset1(4.52279145837532221105E1); + const Packet cst_cephes_log_q3 = pset1(8.29875266912776603211E1); + const Packet cst_cephes_log_q4 = pset1(7.11544750618563894466E1); + const Packet cst_cephes_log_q5 = pset1(2.31251620126765340583E1); + + // Truncate input values to the minimum positive normal. + x = pmax(x, cst_min_norm_pos); + + Packet e; + // extract significant in the range [0.5,1) and exponent + x = pfrexp(x,e); + + // Shift the inputs from the range [0.5,1) to [sqrt(1/2),sqrt(2)) + // and shift by -1. The values are then centered around 0, which improves + // the stability of the polynomial evaluation. + // if( x < SQRTHF ) { + // e -= 1; + // x = x + x - 1.0; + // } else { x = x - 1.0; } + Packet mask = pcmp_lt(x, cst_cephes_SQRTHF); + Packet tmp = pand(x, mask); + x = psub(x, cst_1); + e = psub(e, pand(cst_1, mask)); + x = padd(x, tmp); + + Packet x2 = pmul(x, x); + Packet x3 = pmul(x2, x); + + // Evaluate the polynomial approximant , probably to improve instruction-level parallelism. + // y = x - 0.5*x^2 + x^3 * polevl( x, P, 5 ) / p1evl( x, Q, 5 ) ); + Packet y, y1, y_; + y = pmadd(cst_cephes_log_p0, x, cst_cephes_log_p1); + y1 = pmadd(cst_cephes_log_p3, x, cst_cephes_log_p4); + y = pmadd(y, x, cst_cephes_log_p2); + y1 = pmadd(y1, x, cst_cephes_log_p5); + y_ = pmadd(y, x3, y1); + + y = pmadd(cst_cephes_log_q0, x, cst_cephes_log_q1); + y1 = pmadd(cst_cephes_log_q3, x, cst_cephes_log_q4); + y = pmadd(y, x, cst_cephes_log_q2); + y1 = pmadd(y1, x, cst_cephes_log_q5); + y = pmadd(y, x3, y1); + + y_ = pmul(y_, x3); + y = pdiv(y_, y); + + y = pmadd(cst_neg_half, x2, y); + x = padd(x, y); + + // Add the logarithm of the exponent back to the result of the interpolation. + if (base2) { + const Packet cst_log2e = pset1(static_cast(EIGEN_LOG2E)); + x = pmadd(x, cst_log2e, e); + } else { + const Packet cst_ln2 = pset1(static_cast(EIGEN_LN2)); + x = pmadd(e, cst_ln2, x); + } + + Packet invalid_mask = pcmp_lt_or_nan(_x, pzero(_x)); + Packet iszero_mask = pcmp_eq(_x,pzero(_x)); + Packet pos_inf_mask = pcmp_eq(_x,cst_pos_inf); + // Filter out invalid inputs, i.e.: + // - negative arg will be NAN + // - 0 will be -INF + // - +INF will be +INF + return pselect(iszero_mask, cst_minus_inf, + por(pselect(pos_inf_mask,cst_pos_inf,x), invalid_mask)); +} + +template +EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS +EIGEN_UNUSED +Packet plog_double(const Packet _x) +{ + return plog_impl_double(_x); +} + +template +EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS +EIGEN_UNUSED +Packet plog2_double(const Packet _x) +{ + return plog_impl_double(_x); +} + +/** \internal \returns log(1 + x) computed using W. Kahan's formula. + See: http://www.plunk.org/~hatch/rightway.php + */ +template +Packet generic_plog1p(const Packet& x) +{ + typedef typename unpacket_traits::type ScalarType; + const Packet one = pset1(ScalarType(1)); + Packet xp1 = padd(x, one); + Packet small_mask = pcmp_eq(xp1, one); + Packet log1 = plog(xp1); + Packet inf_mask = pcmp_eq(xp1, log1); + Packet log_large = pmul(x, pdiv(log1, psub(xp1, one))); + return pselect(por(small_mask, inf_mask), x, log_large); +} + +/** \internal \returns exp(x)-1 computed using W. Kahan's formula. + See: http://www.plunk.org/~hatch/rightway.php + */ +template +Packet generic_expm1(const Packet& x) +{ + typedef typename unpacket_traits::type ScalarType; + const Packet one = pset1(ScalarType(1)); + const Packet neg_one = pset1(ScalarType(-1)); + Packet u = pexp(x); + Packet one_mask = pcmp_eq(u, one); + Packet u_minus_one = psub(u, one); + Packet neg_one_mask = pcmp_eq(u_minus_one, neg_one); + Packet logu = plog(u); + // The following comparison is to catch the case where + // exp(x) = +inf. It is written in this way to avoid having + // to form the constant +inf, which depends on the packet + // type. + Packet pos_inf_mask = pcmp_eq(logu, u); + Packet expm1 = pmul(u_minus_one, pdiv(x, logu)); + expm1 = pselect(pos_inf_mask, u, expm1); + return pselect(one_mask, + x, + pselect(neg_one_mask, + neg_one, + expm1)); +} + + +// Exponential function. Works by writing "x = m*log(2) + r" where +// "m = floor(x/log(2)+1/2)" and "r" is the remainder. The result is then +// "exp(x) = 2^m*exp(r)" where exp(r) is in the range [-1,1). +template +EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS +EIGEN_UNUSED +Packet pexp_float(const Packet _x) +{ + const Packet cst_1 = pset1(1.0f); + const Packet cst_half = pset1(0.5f); + const Packet cst_exp_hi = pset1( 88.723f); + const Packet cst_exp_lo = pset1(-88.723f); + + const Packet cst_cephes_LOG2EF = pset1(1.44269504088896341f); + const Packet cst_cephes_exp_p0 = pset1(1.9875691500E-4f); + const Packet cst_cephes_exp_p1 = pset1(1.3981999507E-3f); + const Packet cst_cephes_exp_p2 = pset1(8.3334519073E-3f); + const Packet cst_cephes_exp_p3 = pset1(4.1665795894E-2f); + const Packet cst_cephes_exp_p4 = pset1(1.6666665459E-1f); + const Packet cst_cephes_exp_p5 = pset1(5.0000001201E-1f); + + // Clamp x. + Packet x = pmax(pmin(_x, cst_exp_hi), cst_exp_lo); + + // Express exp(x) as exp(m*ln(2) + r), start by extracting + // m = floor(x/ln(2) + 0.5). + Packet m = pfloor(pmadd(x, cst_cephes_LOG2EF, cst_half)); + + // Get r = x - m*ln(2). If no FMA instructions are available, m*ln(2) is + // subtracted out in two parts, m*C1+m*C2 = m*ln(2), to avoid accumulating + // truncation errors. + const Packet cst_cephes_exp_C1 = pset1(-0.693359375f); + const Packet cst_cephes_exp_C2 = pset1(2.12194440e-4f); + Packet r = pmadd(m, cst_cephes_exp_C1, x); + r = pmadd(m, cst_cephes_exp_C2, r); + + Packet r2 = pmul(r, r); + Packet r3 = pmul(r2, r); + + // Evaluate the polynomial approximant,improved by instruction-level parallelism. + Packet y, y1, y2; + y = pmadd(cst_cephes_exp_p0, r, cst_cephes_exp_p1); + y1 = pmadd(cst_cephes_exp_p3, r, cst_cephes_exp_p4); + y2 = padd(r, cst_1); + y = pmadd(y, r, cst_cephes_exp_p2); + y1 = pmadd(y1, r, cst_cephes_exp_p5); + y = pmadd(y, r3, y1); + y = pmadd(y, r2, y2); + + // Return 2^m * exp(r). + // TODO: replace pldexp with faster implementation since y in [-1, 1). + return pmax(pldexp(y,m), _x); +} + +template +EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS +EIGEN_UNUSED +Packet pexp_double(const Packet _x) +{ + Packet x = _x; + + const Packet cst_1 = pset1(1.0); + const Packet cst_2 = pset1(2.0); + const Packet cst_half = pset1(0.5); + + const Packet cst_exp_hi = pset1(709.784); + const Packet cst_exp_lo = pset1(-709.784); + + const Packet cst_cephes_LOG2EF = pset1(1.4426950408889634073599); + const Packet cst_cephes_exp_p0 = pset1(1.26177193074810590878e-4); + const Packet cst_cephes_exp_p1 = pset1(3.02994407707441961300e-2); + const Packet cst_cephes_exp_p2 = pset1(9.99999999999999999910e-1); + const Packet cst_cephes_exp_q0 = pset1(3.00198505138664455042e-6); + const Packet cst_cephes_exp_q1 = pset1(2.52448340349684104192e-3); + const Packet cst_cephes_exp_q2 = pset1(2.27265548208155028766e-1); + const Packet cst_cephes_exp_q3 = pset1(2.00000000000000000009e0); + const Packet cst_cephes_exp_C1 = pset1(0.693145751953125); + const Packet cst_cephes_exp_C2 = pset1(1.42860682030941723212e-6); + + Packet tmp, fx; + + // clamp x + x = pmax(pmin(x, cst_exp_hi), cst_exp_lo); + // Express exp(x) as exp(g + n*log(2)). + fx = pmadd(cst_cephes_LOG2EF, x, cst_half); + + // Get the integer modulus of log(2), i.e. the "n" described above. + fx = pfloor(fx); + + // Get the remainder modulo log(2), i.e. the "g" described above. Subtract + // n*log(2) out in two steps, i.e. n*C1 + n*C2, C1+C2=log2 to get the last + // digits right. + tmp = pmul(fx, cst_cephes_exp_C1); + Packet z = pmul(fx, cst_cephes_exp_C2); + x = psub(x, tmp); + x = psub(x, z); + + Packet x2 = pmul(x, x); + + // Evaluate the numerator polynomial of the rational interpolant. + Packet px = cst_cephes_exp_p0; + px = pmadd(px, x2, cst_cephes_exp_p1); + px = pmadd(px, x2, cst_cephes_exp_p2); + px = pmul(px, x); + + // Evaluate the denominator polynomial of the rational interpolant. + Packet qx = cst_cephes_exp_q0; + qx = pmadd(qx, x2, cst_cephes_exp_q1); + qx = pmadd(qx, x2, cst_cephes_exp_q2); + qx = pmadd(qx, x2, cst_cephes_exp_q3); + + // I don't really get this bit, copied from the SSE2 routines, so... + // TODO(gonnet): Figure out what is going on here, perhaps find a better + // rational interpolant? + x = pdiv(px, psub(qx, px)); + x = pmadd(cst_2, x, cst_1); + + // Construct the result 2^n * exp(g) = e * x. The max is used to catch + // non-finite values in the input. + // TODO: replace pldexp with faster implementation since x in [-1, 1). + return pmax(pldexp(x,fx), _x); +} + +// The following code is inspired by the following stack-overflow answer: +// https://stackoverflow.com/questions/30463616/payne-hanek-algorithm-implementation-in-c/30465751#30465751 +// It has been largely optimized: +// - By-pass calls to frexp. +// - Aligned loads of required 96 bits of 2/pi. This is accomplished by +// (1) balancing the mantissa and exponent to the required bits of 2/pi are +// aligned on 8-bits, and (2) replicating the storage of the bits of 2/pi. +// - Avoid a branch in rounding and extraction of the remaining fractional part. +// Overall, I measured a speed up higher than x2 on x86-64. +inline float trig_reduce_huge (float xf, int *quadrant) +{ + using Eigen::numext::int32_t; + using Eigen::numext::uint32_t; + using Eigen::numext::int64_t; + using Eigen::numext::uint64_t; + + const double pio2_62 = 3.4061215800865545e-19; // pi/2 * 2^-62 + const uint64_t zero_dot_five = uint64_t(1) << 61; // 0.5 in 2.62-bit fixed-point foramt + + // 192 bits of 2/pi for Payne-Hanek reduction + // Bits are introduced by packet of 8 to enable aligned reads. + static const uint32_t two_over_pi [] = + { + 0x00000028, 0x000028be, 0x0028be60, 0x28be60db, + 0xbe60db93, 0x60db9391, 0xdb939105, 0x9391054a, + 0x91054a7f, 0x054a7f09, 0x4a7f09d5, 0x7f09d5f4, + 0x09d5f47d, 0xd5f47d4d, 0xf47d4d37, 0x7d4d3770, + 0x4d377036, 0x377036d8, 0x7036d8a5, 0x36d8a566, + 0xd8a5664f, 0xa5664f10, 0x664f10e4, 0x4f10e410, + 0x10e41000, 0xe4100000 + }; + + uint32_t xi = numext::bit_cast(xf); + // Below, -118 = -126 + 8. + // -126 is to get the exponent, + // +8 is to enable alignment of 2/pi's bits on 8 bits. + // This is possible because the fractional part of x as only 24 meaningful bits. + uint32_t e = (xi >> 23) - 118; + // Extract the mantissa and shift it to align it wrt the exponent + xi = ((xi & 0x007fffffu)| 0x00800000u) << (e & 0x7); + + uint32_t i = e >> 3; + uint32_t twoopi_1 = two_over_pi[i-1]; + uint32_t twoopi_2 = two_over_pi[i+3]; + uint32_t twoopi_3 = two_over_pi[i+7]; + + // Compute x * 2/pi in 2.62-bit fixed-point format. + uint64_t p; + p = uint64_t(xi) * twoopi_3; + p = uint64_t(xi) * twoopi_2 + (p >> 32); + p = (uint64_t(xi * twoopi_1) << 32) + p; + + // Round to nearest: add 0.5 and extract integral part. + uint64_t q = (p + zero_dot_five) >> 62; + *quadrant = int(q); + // Now it remains to compute "r = x - q*pi/2" with high accuracy, + // since we have p=x/(pi/2) with high accuracy, we can more efficiently compute r as: + // r = (p-q)*pi/2, + // where the product can be be carried out with sufficient accuracy using double precision. + p -= q<<62; + return float(double(int64_t(p)) * pio2_62); +} + +template +EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS +EIGEN_UNUSED +#if EIGEN_GNUC_AT_LEAST(4,4) && EIGEN_COMP_GNUC_STRICT +__attribute__((optimize("-fno-unsafe-math-optimizations"))) +#endif +Packet psincos_float(const Packet& _x) +{ + typedef typename unpacket_traits::integer_packet PacketI; + + const Packet cst_2oPI = pset1(0.636619746685028076171875f); // 2/PI + const Packet cst_rounding_magic = pset1(12582912); // 2^23 for rounding + const PacketI csti_1 = pset1(1); + const Packet cst_sign_mask = pset1frombits(0x80000000u); + + Packet x = pabs(_x); + + // Scale x by 2/Pi to find x's octant. + Packet y = pmul(x, cst_2oPI); + + // Rounding trick: + Packet y_round = padd(y, cst_rounding_magic); + EIGEN_OPTIMIZATION_BARRIER(y_round) + PacketI y_int = preinterpret(y_round); // last 23 digits represent integer (if abs(x)<2^24) + y = psub(y_round, cst_rounding_magic); // nearest integer to x*4/pi + + // Reduce x by y octants to get: -Pi/4 <= x <= +Pi/4 + // using "Extended precision modular arithmetic" + #if defined(EIGEN_HAS_SINGLE_INSTRUCTION_MADD) + // This version requires true FMA for high accuracy + // It provides a max error of 1ULP up to (with absolute_error < 5.9605e-08): + const float huge_th = ComputeSine ? 117435.992f : 71476.0625f; + x = pmadd(y, pset1(-1.57079601287841796875f), x); + x = pmadd(y, pset1(-3.1391647326017846353352069854736328125e-07f), x); + x = pmadd(y, pset1(-5.390302529957764765544681040410068817436695098876953125e-15f), x); + #else + // Without true FMA, the previous set of coefficients maintain 1ULP accuracy + // up to x<15.7 (for sin), but accuracy is immediately lost for x>15.7. + // We thus use one more iteration to maintain 2ULPs up to reasonably large inputs. + + // The following set of coefficients maintain 1ULP up to 9.43 and 14.16 for sin and cos respectively. + // and 2 ULP up to: + const float huge_th = ComputeSine ? 25966.f : 18838.f; + x = pmadd(y, pset1(-1.5703125), x); // = 0xbfc90000 + EIGEN_OPTIMIZATION_BARRIER(x) + x = pmadd(y, pset1(-0.000483989715576171875), x); // = 0xb9fdc000 + EIGEN_OPTIMIZATION_BARRIER(x) + x = pmadd(y, pset1(1.62865035235881805419921875e-07), x); // = 0x342ee000 + x = pmadd(y, pset1(5.5644315544167710640977020375430583953857421875e-11), x); // = 0x2e74b9ee + + // For the record, the following set of coefficients maintain 2ULP up + // to a slightly larger range: + // const float huge_th = ComputeSine ? 51981.f : 39086.125f; + // but it slightly fails to maintain 1ULP for two values of sin below pi. + // x = pmadd(y, pset1(-3.140625/2.), x); + // x = pmadd(y, pset1(-0.00048351287841796875), x); + // x = pmadd(y, pset1(-3.13855707645416259765625e-07), x); + // x = pmadd(y, pset1(-6.0771006282767103812147979624569416046142578125e-11), x); + + // For the record, with only 3 iterations it is possible to maintain + // 1 ULP up to 3PI (maybe more) and 2ULP up to 255. + // The coefficients are: 0xbfc90f80, 0xb7354480, 0x2e74b9ee + #endif + + if(predux_any(pcmp_le(pset1(huge_th),pabs(_x)))) + { + const int PacketSize = unpacket_traits::size; + EIGEN_ALIGN_TO_BOUNDARY(sizeof(Packet)) float vals[PacketSize]; + EIGEN_ALIGN_TO_BOUNDARY(sizeof(Packet)) float x_cpy[PacketSize]; + EIGEN_ALIGN_TO_BOUNDARY(sizeof(Packet)) int y_int2[PacketSize]; + pstoreu(vals, pabs(_x)); + pstoreu(x_cpy, x); + pstoreu(y_int2, y_int); + for(int k=0; k=huge_th && (numext::isfinite)(val)) + x_cpy[k] = trig_reduce_huge(val,&y_int2[k]); + } + x = ploadu(x_cpy); + y_int = ploadu(y_int2); + } + + // Compute the sign to apply to the polynomial. + // sin: sign = second_bit(y_int) xor signbit(_x) + // cos: sign = second_bit(y_int+1) + Packet sign_bit = ComputeSine ? pxor(_x, preinterpret(plogical_shift_left<30>(y_int))) + : preinterpret(plogical_shift_left<30>(padd(y_int,csti_1))); + sign_bit = pand(sign_bit, cst_sign_mask); // clear all but left most bit + + // Get the polynomial selection mask from the second bit of y_int + // We'll calculate both (sin and cos) polynomials and then select from the two. + Packet poly_mask = preinterpret(pcmp_eq(pand(y_int, csti_1), pzero(y_int))); + + Packet x2 = pmul(x,x); + + // Evaluate the cos(x) polynomial. (-Pi/4 <= x <= Pi/4) + Packet y1 = pset1(2.4372266125283204019069671630859375e-05f); + y1 = pmadd(y1, x2, pset1(-0.00138865201734006404876708984375f )); + y1 = pmadd(y1, x2, pset1(0.041666619479656219482421875f )); + y1 = pmadd(y1, x2, pset1(-0.5f)); + y1 = pmadd(y1, x2, pset1(1.f)); + + // Evaluate the sin(x) polynomial. (Pi/4 <= x <= Pi/4) + // octave/matlab code to compute those coefficients: + // x = (0:0.0001:pi/4)'; + // A = [x.^3 x.^5 x.^7]; + // w = ((1.-(x/(pi/4)).^2).^5)*2000+1; # weights trading relative accuracy + // c = (A'*diag(w)*A)\(A'*diag(w)*(sin(x)-x)); # weighted LS, linear coeff forced to 1 + // printf('%.64f\n %.64f\n%.64f\n', c(3), c(2), c(1)) + // + Packet y2 = pset1(-0.0001959234114083702898469196984621021329076029360294342041015625f); + y2 = pmadd(y2, x2, pset1( 0.0083326873655616851693794799871284340042620897293090820312500000f)); + y2 = pmadd(y2, x2, pset1(-0.1666666203982298255503735617821803316473960876464843750000000000f)); + y2 = pmul(y2, x2); + y2 = pmadd(y2, x, x); + + // Select the correct result from the two polynomials. + y = ComputeSine ? pselect(poly_mask,y2,y1) + : pselect(poly_mask,y1,y2); + + // Update the sign and filter huge inputs + return pxor(y, sign_bit); +} + +template +EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS +EIGEN_UNUSED +Packet psin_float(const Packet& x) +{ + return psincos_float(x); +} + +template +EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS +EIGEN_UNUSED +Packet pcos_float(const Packet& x) +{ + return psincos_float(x); +} + + +template +EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS +EIGEN_UNUSED +Packet psqrt_complex(const Packet& a) { + typedef typename unpacket_traits::type Scalar; + typedef typename Scalar::value_type RealScalar; + typedef typename unpacket_traits::as_real RealPacket; + + // Computes the principal sqrt of the complex numbers in the input. + // + // For example, for packets containing 2 complex numbers stored in interleaved format + // a = [a0, a1] = [x0, y0, x1, y1], + // where x0 = real(a0), y0 = imag(a0) etc., this function returns + // b = [b0, b1] = [u0, v0, u1, v1], + // such that b0^2 = a0, b1^2 = a1. + // + // To derive the formula for the complex square roots, let's consider the equation for + // a single complex square root of the number x + i*y. We want to find real numbers + // u and v such that + // (u + i*v)^2 = x + i*y <=> + // u^2 - v^2 + i*2*u*v = x + i*v. + // By equating the real and imaginary parts we get: + // u^2 - v^2 = x + // 2*u*v = y. + // + // For x >= 0, this has the numerically stable solution + // u = sqrt(0.5 * (x + sqrt(x^2 + y^2))) + // v = 0.5 * (y / u) + // and for x < 0, + // v = sign(y) * sqrt(0.5 * (-x + sqrt(x^2 + y^2))) + // u = 0.5 * (y / v) + // + // To avoid unnecessary over- and underflow, we compute sqrt(x^2 + y^2) as + // l = max(|x|, |y|) * sqrt(1 + (min(|x|, |y|) / max(|x|, |y|))^2) , + + // In the following, without lack of generality, we have annotated the code, assuming + // that the input is a packet of 2 complex numbers. + // + // Step 1. Compute l = [l0, l0, l1, l1], where + // l0 = sqrt(x0^2 + y0^2), l1 = sqrt(x1^2 + y1^2) + // To avoid over- and underflow, we use the stable formula for each hypotenuse + // l0 = (min0 == 0 ? max0 : max0 * sqrt(1 + (min0/max0)**2)), + // where max0 = max(|x0|, |y0|), min0 = min(|x0|, |y0|), and similarly for l1. + + RealPacket a_abs = pabs(a.v); // [|x0|, |y0|, |x1|, |y1|] + RealPacket a_abs_flip = pcplxflip(Packet(a_abs)).v; // [|y0|, |x0|, |y1|, |x1|] + RealPacket a_max = pmax(a_abs, a_abs_flip); + RealPacket a_min = pmin(a_abs, a_abs_flip); + RealPacket a_min_zero_mask = pcmp_eq(a_min, pzero(a_min)); + RealPacket a_max_zero_mask = pcmp_eq(a_max, pzero(a_max)); + RealPacket r = pdiv(a_min, a_max); + const RealPacket cst_one = pset1(RealScalar(1)); + RealPacket l = pmul(a_max, psqrt(padd(cst_one, pmul(r, r)))); // [l0, l0, l1, l1] + // Set l to a_max if a_min is zero. + l = pselect(a_min_zero_mask, a_max, l); + + // Step 2. Compute [rho0, *, rho1, *], where + // rho0 = sqrt(0.5 * (l0 + |x0|)), rho1 = sqrt(0.5 * (l1 + |x1|)) + // We don't care about the imaginary parts computed here. They will be overwritten later. + const RealPacket cst_half = pset1(RealScalar(0.5)); + Packet rho; + rho.v = psqrt(pmul(cst_half, padd(a_abs, l))); + + // Step 3. Compute [rho0, eta0, rho1, eta1], where + // eta0 = (y0 / l0) / 2, and eta1 = (y1 / l1) / 2. + // set eta = 0 of input is 0 + i0. + RealPacket eta = pandnot(pmul(cst_half, pdiv(a.v, pcplxflip(rho).v)), a_max_zero_mask); + RealPacket real_mask = peven_mask(a.v); + Packet positive_real_result; + // Compute result for inputs with positive real part. + positive_real_result.v = pselect(real_mask, rho.v, eta); + + // Step 4. Compute solution for inputs with negative real part: + // [|eta0|, sign(y0)*rho0, |eta1|, sign(y1)*rho1] + const RealScalar neg_zero = RealScalar(numext::bit_cast(0x80000000u)); + const RealPacket cst_imag_sign_mask = pset1(Scalar(RealScalar(0.0), neg_zero)).v; + RealPacket imag_signs = pand(a.v, cst_imag_sign_mask); + Packet negative_real_result; + // Notice that rho is positive, so taking it's absolute value is a noop. + negative_real_result.v = por(pabs(pcplxflip(positive_real_result).v), imag_signs); + + // Step 5. Select solution branch based on the sign of the real parts. + Packet negative_real_mask; + negative_real_mask.v = pcmp_lt(pand(real_mask, a.v), pzero(a.v)); + negative_real_mask.v = por(negative_real_mask.v, pcplxflip(negative_real_mask).v); + Packet result = pselect(negative_real_mask, negative_real_result, positive_real_result); + + // Step 6. Handle special cases for infinities: + // * If z is (x,+∞), the result is (+∞,+∞) even if x is NaN + // * If z is (x,-∞), the result is (+∞,-∞) even if x is NaN + // * If z is (-∞,y), the result is (0*|y|,+∞) for finite or NaN y + // * If z is (+∞,y), the result is (+∞,0*|y|) for finite or NaN y + const RealPacket cst_pos_inf = pset1(NumTraits::infinity()); + Packet is_inf; + is_inf.v = pcmp_eq(a_abs, cst_pos_inf); + Packet is_real_inf; + is_real_inf.v = pand(is_inf.v, real_mask); + is_real_inf = por(is_real_inf, pcplxflip(is_real_inf)); + // prepare packet of (+∞,0*|y|) or (0*|y|,+∞), depending on the sign of the infinite real part. + Packet real_inf_result; + real_inf_result.v = pmul(a_abs, pset1(Scalar(RealScalar(1.0), RealScalar(0.0))).v); + real_inf_result.v = pselect(negative_real_mask.v, pcplxflip(real_inf_result).v, real_inf_result.v); + // prepare packet of (+∞,+∞) or (+∞,-∞), depending on the sign of the infinite imaginary part. + Packet is_imag_inf; + is_imag_inf.v = pandnot(is_inf.v, real_mask); + is_imag_inf = por(is_imag_inf, pcplxflip(is_imag_inf)); + Packet imag_inf_result; + imag_inf_result.v = por(pand(cst_pos_inf, real_mask), pandnot(a.v, real_mask)); + + return pselect(is_imag_inf, imag_inf_result, + pselect(is_real_inf, real_inf_result,result)); +} + +// TODO(rmlarsen): The following set of utilities for double word arithmetic +// should perhaps be refactored as a separate file, since it would be generally +// useful for special function implementation etc. Writing the algorithms in +// terms if a double word type would also make the code more readable. + +// This function splits x into the nearest integer n and fractional part r, +// such that x = n + r holds exactly. +template +EIGEN_STRONG_INLINE +void absolute_split(const Packet& x, Packet& n, Packet& r) { + n = pround(x); + r = psub(x, n); +} + +// This function computes the sum {s, r}, such that x + y = s_hi + s_lo +// holds exactly, and s_hi = fl(x+y), if |x| >= |y|. +template +EIGEN_STRONG_INLINE +void fast_twosum(const Packet& x, const Packet& y, Packet& s_hi, Packet& s_lo) { + s_hi = padd(x, y); + const Packet t = psub(s_hi, x); + s_lo = psub(y, t); +} + +#ifdef EIGEN_HAS_SINGLE_INSTRUCTION_MADD +// This function implements the extended precision product of +// a pair of floating point numbers. Given {x, y}, it computes the pair +// {p_hi, p_lo} such that x * y = p_hi + p_lo holds exactly and +// p_hi = fl(x * y). +template +EIGEN_STRONG_INLINE +void twoprod(const Packet& x, const Packet& y, + Packet& p_hi, Packet& p_lo) { + p_hi = pmul(x, y); + p_lo = pmadd(x, y, pnegate(p_hi)); +} + +#else + +// This function implements the Veltkamp splitting. Given a floating point +// number x it returns the pair {x_hi, x_lo} such that x_hi + x_lo = x holds +// exactly and that half of the significant of x fits in x_hi. +// This is Algorithm 3 from Jean-Michel Muller, "Elementary Functions", +// 3rd edition, Birkh\"auser, 2016. +template +EIGEN_STRONG_INLINE +void veltkamp_splitting(const Packet& x, Packet& x_hi, Packet& x_lo) { + typedef typename unpacket_traits::type Scalar; + EIGEN_CONSTEXPR int shift = (NumTraits::digits() + 1) / 2; + const Scalar shift_scale = Scalar(uint64_t(1) << shift); // Scalar constructor not necessarily constexpr. + const Packet gamma = pmul(pset1(shift_scale + Scalar(1)), x); + Packet rho = psub(x, gamma); + x_hi = padd(rho, gamma); + x_lo = psub(x, x_hi); +} + +// This function implements Dekker's algorithm for products x * y. +// Given floating point numbers {x, y} computes the pair +// {p_hi, p_lo} such that x * y = p_hi + p_lo holds exactly and +// p_hi = fl(x * y). +template +EIGEN_STRONG_INLINE +void twoprod(const Packet& x, const Packet& y, + Packet& p_hi, Packet& p_lo) { + Packet x_hi, x_lo, y_hi, y_lo; + veltkamp_splitting(x, x_hi, x_lo); + veltkamp_splitting(y, y_hi, y_lo); + + p_hi = pmul(x, y); + p_lo = pmadd(x_hi, y_hi, pnegate(p_hi)); + p_lo = pmadd(x_hi, y_lo, p_lo); + p_lo = pmadd(x_lo, y_hi, p_lo); + p_lo = pmadd(x_lo, y_lo, p_lo); +} + +#endif // EIGEN_HAS_SINGLE_INSTRUCTION_MADD + + +// This function implements Dekker's algorithm for the addition +// of two double word numbers represented by {x_hi, x_lo} and {y_hi, y_lo}. +// It returns the result as a pair {s_hi, s_lo} such that +// x_hi + x_lo + y_hi + y_lo = s_hi + s_lo holds exactly. +// This is Algorithm 5 from Jean-Michel Muller, "Elementary Functions", +// 3rd edition, Birkh\"auser, 2016. +template +EIGEN_STRONG_INLINE + void twosum(const Packet& x_hi, const Packet& x_lo, + const Packet& y_hi, const Packet& y_lo, + Packet& s_hi, Packet& s_lo) { + const Packet x_greater_mask = pcmp_lt(pabs(y_hi), pabs(x_hi)); + Packet r_hi_1, r_lo_1; + fast_twosum(x_hi, y_hi,r_hi_1, r_lo_1); + Packet r_hi_2, r_lo_2; + fast_twosum(y_hi, x_hi,r_hi_2, r_lo_2); + const Packet r_hi = pselect(x_greater_mask, r_hi_1, r_hi_2); + + const Packet s1 = padd(padd(y_lo, r_lo_1), x_lo); + const Packet s2 = padd(padd(x_lo, r_lo_2), y_lo); + const Packet s = pselect(x_greater_mask, s1, s2); + + fast_twosum(r_hi, s, s_hi, s_lo); +} + +// This is a version of twosum for double word numbers, +// which assumes that |x_hi| >= |y_hi|. +template +EIGEN_STRONG_INLINE + void fast_twosum(const Packet& x_hi, const Packet& x_lo, + const Packet& y_hi, const Packet& y_lo, + Packet& s_hi, Packet& s_lo) { + Packet r_hi, r_lo; + fast_twosum(x_hi, y_hi, r_hi, r_lo); + const Packet s = padd(padd(y_lo, r_lo), x_lo); + fast_twosum(r_hi, s, s_hi, s_lo); +} + +// This is a version of twosum for adding a floating point number x to +// double word number {y_hi, y_lo} number, with the assumption +// that |x| >= |y_hi|. +template +EIGEN_STRONG_INLINE +void fast_twosum(const Packet& x, + const Packet& y_hi, const Packet& y_lo, + Packet& s_hi, Packet& s_lo) { + Packet r_hi, r_lo; + fast_twosum(x, y_hi, r_hi, r_lo); + const Packet s = padd(y_lo, r_lo); + fast_twosum(r_hi, s, s_hi, s_lo); +} + +// This function implements the multiplication of a double word +// number represented by {x_hi, x_lo} by a floating point number y. +// It returns the result as a pair {p_hi, p_lo} such that +// (x_hi + x_lo) * y = p_hi + p_lo hold with a relative error +// of less than 2*2^{-2p}, where p is the number of significand bit +// in the floating point type. +// This is Algorithm 7 from Jean-Michel Muller, "Elementary Functions", +// 3rd edition, Birkh\"auser, 2016. +template +EIGEN_STRONG_INLINE +void twoprod(const Packet& x_hi, const Packet& x_lo, const Packet& y, + Packet& p_hi, Packet& p_lo) { + Packet c_hi, c_lo1; + twoprod(x_hi, y, c_hi, c_lo1); + const Packet c_lo2 = pmul(x_lo, y); + Packet t_hi, t_lo1; + fast_twosum(c_hi, c_lo2, t_hi, t_lo1); + const Packet t_lo2 = padd(t_lo1, c_lo1); + fast_twosum(t_hi, t_lo2, p_hi, p_lo); +} + +// This function implements the multiplication of two double word +// numbers represented by {x_hi, x_lo} and {y_hi, y_lo}. +// It returns the result as a pair {p_hi, p_lo} such that +// (x_hi + x_lo) * (y_hi + y_lo) = p_hi + p_lo holds with a relative error +// of less than 2*2^{-2p}, where p is the number of significand bit +// in the floating point type. +template +EIGEN_STRONG_INLINE +void twoprod(const Packet& x_hi, const Packet& x_lo, + const Packet& y_hi, const Packet& y_lo, + Packet& p_hi, Packet& p_lo) { + Packet p_hi_hi, p_hi_lo; + twoprod(x_hi, x_lo, y_hi, p_hi_hi, p_hi_lo); + Packet p_lo_hi, p_lo_lo; + twoprod(x_hi, x_lo, y_lo, p_lo_hi, p_lo_lo); + fast_twosum(p_hi_hi, p_hi_lo, p_lo_hi, p_lo_lo, p_hi, p_lo); +} + +// This function computes the reciprocal of a floating point number +// with extra precision and returns the result as a double word. +template +void doubleword_reciprocal(const Packet& x, Packet& recip_hi, Packet& recip_lo) { + typedef typename unpacket_traits::type Scalar; + // 1. Approximate the reciprocal as the reciprocal of the high order element. + Packet approx_recip = prsqrt(x); + approx_recip = pmul(approx_recip, approx_recip); + + // 2. Run one step of Newton-Raphson iteration in double word arithmetic + // to get the bottom half. The NR iteration for reciprocal of 'a' is + // x_{i+1} = x_i * (2 - a * x_i) + + // -a*x_i + Packet t1_hi, t1_lo; + twoprod(pnegate(x), approx_recip, t1_hi, t1_lo); + // 2 - a*x_i + Packet t2_hi, t2_lo; + fast_twosum(pset1(Scalar(2)), t1_hi, t2_hi, t2_lo); + Packet t3_hi, t3_lo; + fast_twosum(t2_hi, padd(t2_lo, t1_lo), t3_hi, t3_lo); + // x_i * (2 - a * x_i) + twoprod(t3_hi, t3_lo, approx_recip, recip_hi, recip_lo); +} + + +// This function computes log2(x) and returns the result as a double word. +template +struct accurate_log2 { + template + EIGEN_STRONG_INLINE + void operator()(const Packet& x, Packet& log2_x_hi, Packet& log2_x_lo) { + log2_x_hi = plog2(x); + log2_x_lo = pzero(x); + } +}; + +// This specialization uses a more accurate algorithm to compute log2(x) for +// floats in [1/sqrt(2);sqrt(2)] with a relative accuracy of ~6.42e-10. +// This additional accuracy is needed to counter the error-magnification +// inherent in multiplying by a potentially large exponent in pow(x,y). +// The minimax polynomial used was calculated using the Sollya tool. +// See sollya.org. +template <> +struct accurate_log2 { + template + EIGEN_STRONG_INLINE + void operator()(const Packet& z, Packet& log2_x_hi, Packet& log2_x_lo) { + // The function log(1+x)/x is approximated in the interval + // [1/sqrt(2)-1;sqrt(2)-1] by a degree 10 polynomial of the form + // Q(x) = (C0 + x * (C1 + x * (C2 + x * (C3 + x * P(x))))), + // where the degree 6 polynomial P(x) is evaluated in single precision, + // while the remaining 4 terms of Q(x), as well as the final multiplication by x + // to reconstruct log(1+x) are evaluated in extra precision using + // double word arithmetic. C0 through C3 are extra precise constants + // stored as double words. + // + // The polynomial coefficients were calculated using Sollya commands: + // > n = 10; + // > f = log2(1+x)/x; + // > interval = [sqrt(0.5)-1;sqrt(2)-1]; + // > p = fpminimax(f,n,[|double,double,double,double,single...|],interval,relative,floating); + + const Packet p6 = pset1( 9.703654795885e-2f); + const Packet p5 = pset1(-0.1690667718648f); + const Packet p4 = pset1( 0.1720575392246f); + const Packet p3 = pset1(-0.1789081543684f); + const Packet p2 = pset1( 0.2050433009862f); + const Packet p1 = pset1(-0.2404672354459f); + const Packet p0 = pset1( 0.2885761857032f); + + const Packet C3_hi = pset1(-0.360674142838f); + const Packet C3_lo = pset1(-6.13283912543e-09f); + const Packet C2_hi = pset1(0.480897903442f); + const Packet C2_lo = pset1(-1.44861207474e-08f); + const Packet C1_hi = pset1(-0.721347510815f); + const Packet C1_lo = pset1(-4.84483164698e-09f); + const Packet C0_hi = pset1(1.44269502163f); + const Packet C0_lo = pset1(2.01711713999e-08f); + const Packet one = pset1(1.0f); + + const Packet x = psub(z, one); + // Evaluate P(x) in working precision. + // We evaluate it in multiple parts to improve instruction level + // parallelism. + Packet x2 = pmul(x,x); + Packet p_even = pmadd(p6, x2, p4); + p_even = pmadd(p_even, x2, p2); + p_even = pmadd(p_even, x2, p0); + Packet p_odd = pmadd(p5, x2, p3); + p_odd = pmadd(p_odd, x2, p1); + Packet p = pmadd(p_odd, x, p_even); + + // Now evaluate the low-order tems of Q(x) in double word precision. + // In the following, due to the alternating signs and the fact that + // |x| < sqrt(2)-1, we can assume that |C*_hi| >= q_i, and use + // fast_twosum instead of the slower twosum. + Packet q_hi, q_lo; + Packet t_hi, t_lo; + // C3 + x * p(x) + twoprod(p, x, t_hi, t_lo); + fast_twosum(C3_hi, C3_lo, t_hi, t_lo, q_hi, q_lo); + // C2 + x * p(x) + twoprod(q_hi, q_lo, x, t_hi, t_lo); + fast_twosum(C2_hi, C2_lo, t_hi, t_lo, q_hi, q_lo); + // C1 + x * p(x) + twoprod(q_hi, q_lo, x, t_hi, t_lo); + fast_twosum(C1_hi, C1_lo, t_hi, t_lo, q_hi, q_lo); + // C0 + x * p(x) + twoprod(q_hi, q_lo, x, t_hi, t_lo); + fast_twosum(C0_hi, C0_lo, t_hi, t_lo, q_hi, q_lo); + + // log(z) ~= x * Q(x) + twoprod(q_hi, q_lo, x, log2_x_hi, log2_x_lo); + } +}; + +// This specialization uses a more accurate algorithm to compute log2(x) for +// floats in [1/sqrt(2);sqrt(2)] with a relative accuracy of ~1.27e-18. +// This additional accuracy is needed to counter the error-magnification +// inherent in multiplying by a potentially large exponent in pow(x,y). +// The minimax polynomial used was calculated using the Sollya tool. +// See sollya.org. + +template <> +struct accurate_log2 { + template + EIGEN_STRONG_INLINE + void operator()(const Packet& x, Packet& log2_x_hi, Packet& log2_x_lo) { + // We use a transformation of variables: + // r = c * (x-1) / (x+1), + // such that + // log2(x) = log2((1 + r/c) / (1 - r/c)) = f(r). + // The function f(r) can be approximated well using an odd polynomial + // of the form + // P(r) = ((Q(r^2) * r^2 + C) * r^2 + 1) * r, + // For the implementation of log2 here, Q is of degree 6 with + // coefficient represented in working precision (double), while C is a + // constant represented in extra precision as a double word to achieve + // full accuracy. + // + // The polynomial coefficients were computed by the Sollya script: + // + // c = 2 / log(2); + // trans = c * (x-1)/(x+1); + // itrans = (1+x/c)/(1-x/c); + // interval=[trans(sqrt(0.5)); trans(sqrt(2))]; + // print(interval); + // f = log2(itrans(x)); + // p=fpminimax(f,[|1,3,5,7,9,11,13,15,17|],[|1,DD,double...|],interval,relative,floating); + const Packet q12 = pset1(2.87074255468000586e-9); + const Packet q10 = pset1(2.38957980901884082e-8); + const Packet q8 = pset1(2.31032094540014656e-7); + const Packet q6 = pset1(2.27279857398537278e-6); + const Packet q4 = pset1(2.31271023278625638e-5); + const Packet q2 = pset1(2.47556738444535513e-4); + const Packet q0 = pset1(2.88543873228900172e-3); + const Packet C_hi = pset1(0.0400377511598501157); + const Packet C_lo = pset1(-4.77726582251425391e-19); + const Packet one = pset1(1.0); + + const Packet cst_2_log2e_hi = pset1(2.88539008177792677); + const Packet cst_2_log2e_lo = pset1(4.07660016854549667e-17); + // c * (x - 1) + Packet num_hi, num_lo; + twoprod(cst_2_log2e_hi, cst_2_log2e_lo, psub(x, one), num_hi, num_lo); + // TODO(rmlarsen): Investigate if using the division algorithm by + // Muller et al. is faster/more accurate. + // 1 / (x + 1) + Packet denom_hi, denom_lo; + doubleword_reciprocal(padd(x, one), denom_hi, denom_lo); + // r = c * (x-1) / (x+1), + Packet r_hi, r_lo; + twoprod(num_hi, num_lo, denom_hi, denom_lo, r_hi, r_lo); + // r2 = r * r + Packet r2_hi, r2_lo; + twoprod(r_hi, r_lo, r_hi, r_lo, r2_hi, r2_lo); + // r4 = r2 * r2 + Packet r4_hi, r4_lo; + twoprod(r2_hi, r2_lo, r2_hi, r2_lo, r4_hi, r4_lo); + + // Evaluate Q(r^2) in working precision. We evaluate it in two parts + // (even and odd in r^2) to improve instruction level parallelism. + Packet q_even = pmadd(q12, r4_hi, q8); + Packet q_odd = pmadd(q10, r4_hi, q6); + q_even = pmadd(q_even, r4_hi, q4); + q_odd = pmadd(q_odd, r4_hi, q2); + q_even = pmadd(q_even, r4_hi, q0); + Packet q = pmadd(q_odd, r2_hi, q_even); + + // Now evaluate the low order terms of P(x) in double word precision. + // In the following, due to the increasing magnitude of the coefficients + // and r being constrained to [-0.5, 0.5] we can use fast_twosum instead + // of the slower twosum. + // Q(r^2) * r^2 + Packet p_hi, p_lo; + twoprod(r2_hi, r2_lo, q, p_hi, p_lo); + // Q(r^2) * r^2 + C + Packet p1_hi, p1_lo; + fast_twosum(C_hi, C_lo, p_hi, p_lo, p1_hi, p1_lo); + // (Q(r^2) * r^2 + C) * r^2 + Packet p2_hi, p2_lo; + twoprod(r2_hi, r2_lo, p1_hi, p1_lo, p2_hi, p2_lo); + // ((Q(r^2) * r^2 + C) * r^2 + 1) + Packet p3_hi, p3_lo; + fast_twosum(one, p2_hi, p2_lo, p3_hi, p3_lo); + + // log(z) ~= ((Q(r^2) * r^2 + C) * r^2 + 1) * r + twoprod(p3_hi, p3_lo, r_hi, r_lo, log2_x_hi, log2_x_lo); + } +}; + +// This function computes exp2(x) (i.e. 2**x). +template +struct fast_accurate_exp2 { + template + EIGEN_STRONG_INLINE + Packet operator()(const Packet& x) { + // TODO(rmlarsen): Add a pexp2 packetop. + return pexp(pmul(pset1(Scalar(EIGEN_LN2)), x)); + } +}; + +// This specialization uses a faster algorithm to compute exp2(x) for floats +// in [-0.5;0.5] with a relative accuracy of 1 ulp. +// The minimax polynomial used was calculated using the Sollya tool. +// See sollya.org. +template <> +struct fast_accurate_exp2 { + template + EIGEN_STRONG_INLINE + Packet operator()(const Packet& x) { + // This function approximates exp2(x) by a degree 6 polynomial of the form + // Q(x) = 1 + x * (C + x * P(x)), where the degree 4 polynomial P(x) is evaluated in + // single precision, and the remaining steps are evaluated with extra precision using + // double word arithmetic. C is an extra precise constant stored as a double word. + // + // The polynomial coefficients were calculated using Sollya commands: + // > n = 6; + // > f = 2^x; + // > interval = [-0.5;0.5]; + // > p = fpminimax(f,n,[|1,double,single...|],interval,relative,floating); + + const Packet p4 = pset1(1.539513905e-4f); + const Packet p3 = pset1(1.340007293e-3f); + const Packet p2 = pset1(9.618283249e-3f); + const Packet p1 = pset1(5.550328270e-2f); + const Packet p0 = pset1(0.2402264923f); + + const Packet C_hi = pset1(0.6931471825f); + const Packet C_lo = pset1(2.36836577e-08f); + const Packet one = pset1(1.0f); + + // Evaluate P(x) in working precision. + // We evaluate even and odd parts of the polynomial separately + // to gain some instruction level parallelism. + Packet x2 = pmul(x,x); + Packet p_even = pmadd(p4, x2, p2); + Packet p_odd = pmadd(p3, x2, p1); + p_even = pmadd(p_even, x2, p0); + Packet p = pmadd(p_odd, x, p_even); + + // Evaluate the remaining terms of Q(x) with extra precision using + // double word arithmetic. + Packet p_hi, p_lo; + // x * p(x) + twoprod(p, x, p_hi, p_lo); + // C + x * p(x) + Packet q1_hi, q1_lo; + twosum(p_hi, p_lo, C_hi, C_lo, q1_hi, q1_lo); + // x * (C + x * p(x)) + Packet q2_hi, q2_lo; + twoprod(q1_hi, q1_lo, x, q2_hi, q2_lo); + // 1 + x * (C + x * p(x)) + Packet q3_hi, q3_lo; + // Since |q2_hi| <= sqrt(2)-1 < 1, we can use fast_twosum + // for adding it to unity here. + fast_twosum(one, q2_hi, q3_hi, q3_lo); + return padd(q3_hi, padd(q2_lo, q3_lo)); + } +}; + +// in [-0.5;0.5] with a relative accuracy of 1 ulp. +// The minimax polynomial used was calculated using the Sollya tool. +// See sollya.org. +template <> +struct fast_accurate_exp2 { + template + EIGEN_STRONG_INLINE + Packet operator()(const Packet& x) { + // This function approximates exp2(x) by a degree 10 polynomial of the form + // Q(x) = 1 + x * (C + x * P(x)), where the degree 8 polynomial P(x) is evaluated in + // single precision, and the remaining steps are evaluated with extra precision using + // double word arithmetic. C is an extra precise constant stored as a double word. + // + // The polynomial coefficients were calculated using Sollya commands: + // > n = 11; + // > f = 2^x; + // > interval = [-0.5;0.5]; + // > p = fpminimax(f,n,[|1,DD,double...|],interval,relative,floating); + + const Packet p9 = pset1(4.431642109085495276e-10); + const Packet p8 = pset1(7.073829923303358410e-9); + const Packet p7 = pset1(1.017822306737031311e-7); + const Packet p6 = pset1(1.321543498017646657e-6); + const Packet p5 = pset1(1.525273342728892877e-5); + const Packet p4 = pset1(1.540353045780084423e-4); + const Packet p3 = pset1(1.333355814685869807e-3); + const Packet p2 = pset1(9.618129107593478832e-3); + const Packet p1 = pset1(5.550410866481961247e-2); + const Packet p0 = pset1(0.240226506959101332); + const Packet C_hi = pset1(0.693147180559945286); + const Packet C_lo = pset1(4.81927865669806721e-17); + const Packet one = pset1(1.0); + + // Evaluate P(x) in working precision. + // We evaluate even and odd parts of the polynomial separately + // to gain some instruction level parallelism. + Packet x2 = pmul(x,x); + Packet p_even = pmadd(p8, x2, p6); + Packet p_odd = pmadd(p9, x2, p7); + p_even = pmadd(p_even, x2, p4); + p_odd = pmadd(p_odd, x2, p5); + p_even = pmadd(p_even, x2, p2); + p_odd = pmadd(p_odd, x2, p3); + p_even = pmadd(p_even, x2, p0); + p_odd = pmadd(p_odd, x2, p1); + Packet p = pmadd(p_odd, x, p_even); + + // Evaluate the remaining terms of Q(x) with extra precision using + // double word arithmetic. + Packet p_hi, p_lo; + // x * p(x) + twoprod(p, x, p_hi, p_lo); + // C + x * p(x) + Packet q1_hi, q1_lo; + twosum(p_hi, p_lo, C_hi, C_lo, q1_hi, q1_lo); + // x * (C + x * p(x)) + Packet q2_hi, q2_lo; + twoprod(q1_hi, q1_lo, x, q2_hi, q2_lo); + // 1 + x * (C + x * p(x)) + Packet q3_hi, q3_lo; + // Since |q2_hi| <= sqrt(2)-1 < 1, we can use fast_twosum + // for adding it to unity here. + fast_twosum(one, q2_hi, q3_hi, q3_lo); + return padd(q3_hi, padd(q2_lo, q3_lo)); + } +}; + +// This function implements the non-trivial case of pow(x,y) where x is +// positive and y is (possibly) non-integer. +// Formally, pow(x,y) = exp2(y * log2(x)), where exp2(x) is shorthand for 2^x. +// TODO(rmlarsen): We should probably add this as a packet up 'ppow', to make it +// easier to specialize or turn off for specific types and/or backends.x +template +EIGEN_STRONG_INLINE Packet generic_pow_impl(const Packet& x, const Packet& y) { + typedef typename unpacket_traits::type Scalar; + // Split x into exponent e_x and mantissa m_x. + Packet e_x; + Packet m_x = pfrexp(x, e_x); + + // Adjust m_x to lie in [1/sqrt(2):sqrt(2)] to minimize absolute error in log2(m_x). + EIGEN_CONSTEXPR Scalar sqrt_half = Scalar(0.70710678118654752440); + const Packet m_x_scale_mask = pcmp_lt(m_x, pset1(sqrt_half)); + m_x = pselect(m_x_scale_mask, pmul(pset1(Scalar(2)), m_x), m_x); + e_x = pselect(m_x_scale_mask, psub(e_x, pset1(Scalar(1))), e_x); + + // Compute log2(m_x) with 6 extra bits of accuracy. + Packet rx_hi, rx_lo; + accurate_log2()(m_x, rx_hi, rx_lo); + + // Compute the two terms {y * e_x, y * r_x} in f = y * log2(x) with doubled + // precision using double word arithmetic. + Packet f1_hi, f1_lo, f2_hi, f2_lo; + twoprod(e_x, y, f1_hi, f1_lo); + twoprod(rx_hi, rx_lo, y, f2_hi, f2_lo); + // Sum the two terms in f using double word arithmetic. We know + // that |e_x| > |log2(m_x)|, except for the case where e_x==0. + // This means that we can use fast_twosum(f1,f2). + // In the case e_x == 0, e_x * y = f1 = 0, so we don't lose any + // accuracy by violating the assumption of fast_twosum, because + // it's a no-op. + Packet f_hi, f_lo; + fast_twosum(f1_hi, f1_lo, f2_hi, f2_lo, f_hi, f_lo); + + // Split f into integer and fractional parts. + Packet n_z, r_z; + absolute_split(f_hi, n_z, r_z); + r_z = padd(r_z, f_lo); + Packet n_r; + absolute_split(r_z, n_r, r_z); + n_z = padd(n_z, n_r); + + // We now have an accurate split of f = n_z + r_z and can compute + // x^y = 2**{n_z + r_z) = exp2(r_z) * 2**{n_z}. + // Since r_z is in [-0.5;0.5], we compute the first factor to high accuracy + // using a specialized algorithm. Multiplication by the second factor can + // be done exactly using pldexp(), since it is an integer power of 2. + const Packet e_r = fast_accurate_exp2()(r_z); + return pldexp(e_r, n_z); +} + +// Generic implementation of pow(x,y). +template +EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS +EIGEN_UNUSED +Packet generic_pow(const Packet& x, const Packet& y) { + typedef typename unpacket_traits::type Scalar; + + const Packet cst_pos_inf = pset1(NumTraits::infinity()); + const Packet cst_zero = pset1(Scalar(0)); + const Packet cst_one = pset1(Scalar(1)); + const Packet cst_nan = pset1(NumTraits::quiet_NaN()); + + const Packet abs_x = pabs(x); + // Predicates for sign and magnitude of x. + const Packet x_is_zero = pcmp_eq(x, cst_zero); + const Packet x_is_neg = pcmp_lt(x, cst_zero); + const Packet abs_x_is_inf = pcmp_eq(abs_x, cst_pos_inf); + const Packet abs_x_is_one = pcmp_eq(abs_x, cst_one); + const Packet abs_x_is_gt_one = pcmp_lt(cst_one, abs_x); + const Packet abs_x_is_lt_one = pcmp_lt(abs_x, cst_one); + const Packet x_is_one = pandnot(abs_x_is_one, x_is_neg); + const Packet x_is_neg_one = pand(abs_x_is_one, x_is_neg); + const Packet x_is_nan = pandnot(ptrue(x), pcmp_eq(x, x)); + + // Predicates for sign and magnitude of y. + const Packet y_is_one = pcmp_eq(y, cst_one); + const Packet y_is_zero = pcmp_eq(y, cst_zero); + const Packet y_is_neg = pcmp_lt(y, cst_zero); + const Packet y_is_pos = pandnot(ptrue(y), por(y_is_zero, y_is_neg)); + const Packet y_is_nan = pandnot(ptrue(y), pcmp_eq(y, y)); + const Packet abs_y_is_inf = pcmp_eq(pabs(y), cst_pos_inf); + EIGEN_CONSTEXPR Scalar huge_exponent = + (NumTraits::max_exponent() * Scalar(EIGEN_LN2)) / + NumTraits::epsilon(); + const Packet abs_y_is_huge = pcmp_le(pset1(huge_exponent), pabs(y)); + + // Predicates for whether y is integer and/or even. + const Packet y_is_int = pcmp_eq(pfloor(y), y); + const Packet y_div_2 = pmul(y, pset1(Scalar(0.5))); + const Packet y_is_even = pcmp_eq(pround(y_div_2), y_div_2); + + // Predicates encoding special cases for the value of pow(x,y) + const Packet invalid_negative_x = pandnot(pandnot(pandnot(x_is_neg, abs_x_is_inf), + y_is_int), + abs_y_is_inf); + const Packet pow_is_one = por(por(x_is_one, y_is_zero), + pand(x_is_neg_one, + por(abs_y_is_inf, pandnot(y_is_even, invalid_negative_x)))); + const Packet pow_is_nan = por(invalid_negative_x, por(x_is_nan, y_is_nan)); + const Packet pow_is_zero = por(por(por(pand(x_is_zero, y_is_pos), + pand(abs_x_is_inf, y_is_neg)), + pand(pand(abs_x_is_lt_one, abs_y_is_huge), + y_is_pos)), + pand(pand(abs_x_is_gt_one, abs_y_is_huge), + y_is_neg)); + const Packet pow_is_inf = por(por(por(pand(x_is_zero, y_is_neg), + pand(abs_x_is_inf, y_is_pos)), + pand(pand(abs_x_is_lt_one, abs_y_is_huge), + y_is_neg)), + pand(pand(abs_x_is_gt_one, abs_y_is_huge), + y_is_pos)); + + // General computation of pow(x,y) for positive x or negative x and integer y. + const Packet negate_pow_abs = pandnot(x_is_neg, y_is_even); + const Packet pow_abs = generic_pow_impl(abs_x, y); + return pselect(y_is_one, x, + pselect(pow_is_one, cst_one, + pselect(pow_is_nan, cst_nan, + pselect(pow_is_inf, cst_pos_inf, + pselect(pow_is_zero, cst_zero, + pselect(negate_pow_abs, pnegate(pow_abs), pow_abs)))))); +} + + + +/* polevl (modified for Eigen) + * + * Evaluate polynomial + * + * + * + * SYNOPSIS: + * + * int N; + * Scalar x, y, coef[N+1]; + * + * y = polevl( x, coef); + * + * + * + * DESCRIPTION: + * + * Evaluates polynomial of degree N: + * + * 2 N + * y = C + C x + C x +...+ C x + * 0 1 2 N + * + * Coefficients are stored in reverse order: + * + * coef[0] = C , ..., coef[N] = C . + * N 0 + * + * The function p1evl() assumes that coef[N] = 1.0 and is + * omitted from the array. Its calling arguments are + * otherwise the same as polevl(). + * + * + * The Eigen implementation is templatized. For best speed, store + * coef as a const array (constexpr), e.g. + * + * const double coef[] = {1.0, 2.0, 3.0, ...}; + * + */ +template +struct ppolevl { + static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet run(const Packet& x, const typename unpacket_traits::type coeff[]) { + EIGEN_STATIC_ASSERT((N > 0), YOU_MADE_A_PROGRAMMING_MISTAKE); + return pmadd(ppolevl::run(x, coeff), x, pset1(coeff[N])); + } +}; + +template +struct ppolevl { + static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet run(const Packet& x, const typename unpacket_traits::type coeff[]) { + EIGEN_UNUSED_VARIABLE(x); + return pset1(coeff[0]); + } +}; + +/* chbevl (modified for Eigen) + * + * Evaluate Chebyshev series + * + * + * + * SYNOPSIS: + * + * int N; + * Scalar x, y, coef[N], chebevl(); + * + * y = chbevl( x, coef, N ); + * + * + * + * DESCRIPTION: + * + * Evaluates the series + * + * N-1 + * - ' + * y = > coef[i] T (x/2) + * - i + * i=0 + * + * of Chebyshev polynomials Ti at argument x/2. + * + * Coefficients are stored in reverse order, i.e. the zero + * order term is last in the array. Note N is the number of + * coefficients, not the order. + * + * If coefficients are for the interval a to b, x must + * have been transformed to x -> 2(2x - b - a)/(b-a) before + * entering the routine. This maps x from (a, b) to (-1, 1), + * over which the Chebyshev polynomials are defined. + * + * If the coefficients are for the inverted interval, in + * which (a, b) is mapped to (1/b, 1/a), the transformation + * required is x -> 2(2ab/x - b - a)/(b-a). If b is infinity, + * this becomes x -> 4a/x - 1. + * + * + * + * SPEED: + * + * Taking advantage of the recurrence properties of the + * Chebyshev polynomials, the routine requires one more + * addition per loop than evaluating a nested polynomial of + * the same degree. + * + */ + +template +struct pchebevl { + EIGEN_DEVICE_FUNC + static EIGEN_STRONG_INLINE Packet run(Packet x, const typename unpacket_traits::type coef[]) { + typedef typename unpacket_traits::type Scalar; + Packet b0 = pset1(coef[0]); + Packet b1 = pset1(static_cast(0.f)); + Packet b2; + + for (int i = 1; i < N; i++) { + b2 = b1; + b1 = b0; + b0 = psub(pmadd(x, b1, pset1(coef[i])), b2); + } + + return pmul(pset1(static_cast(0.5f)), psub(b0, b2)); + } +}; + +} // end namespace internal +} // end namespace Eigen + +#endif // EIGEN_ARCH_GENERIC_PACKET_MATH_FUNCTIONS_H diff --git a/Eigen/src/Core/arch/Default/GenericPacketMathFunctionsFwd.h b/Eigen/src/Core/arch/Default/GenericPacketMathFunctionsFwd.h new file mode 100644 index 0000000..177a04e --- /dev/null +++ b/Eigen/src/Core/arch/Default/GenericPacketMathFunctionsFwd.h @@ -0,0 +1,110 @@ +// This file is part of Eigen, a lightweight C++ template library +// for linear algebra. +// +// Copyright (C) 2019 Gael Guennebaud +// +// This Source Code Form is subject to the terms of the Mozilla +// Public License v. 2.0. If a copy of the MPL was not distributed +// with this file, You can obtain one at http://mozilla.org/MPL/2.0/. + +#ifndef EIGEN_ARCH_GENERIC_PACKET_MATH_FUNCTIONS_FWD_H +#define EIGEN_ARCH_GENERIC_PACKET_MATH_FUNCTIONS_FWD_H + +namespace Eigen { +namespace internal { + +// Forward declarations of the generic math functions +// implemented in GenericPacketMathFunctions.h +// This is needed to workaround a circular dependency. + +/*************************************************************************** + * Some generic implementations to be used by implementors +***************************************************************************/ + +/** Default implementation of pfrexp. + * It is expected to be called by implementers of template<> pfrexp. + */ +template EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC +Packet pfrexp_generic(const Packet& a, Packet& exponent); + +// Extracts the biased exponent value from Packet p, and casts the results to +// a floating-point Packet type. Used by pfrexp_generic. Override this if +// there is no unpacket_traits::integer_packet. +template EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC +Packet pfrexp_generic_get_biased_exponent(const Packet& p); + +/** Default implementation of pldexp. + * It is expected to be called by implementers of template<> pldexp. + */ +template EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC +Packet pldexp_generic(const Packet& a, const Packet& exponent); + +/** \internal \returns log(x) for single precision float */ +template +EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS +EIGEN_UNUSED +Packet plog_float(const Packet _x); + +/** \internal \returns log2(x) for single precision float */ +template +EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS +EIGEN_UNUSED +Packet plog2_float(const Packet _x); + +/** \internal \returns log(x) for single precision float */ +template +EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS +EIGEN_UNUSED +Packet plog_double(const Packet _x); + +/** \internal \returns log2(x) for single precision float */ +template +EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS +EIGEN_UNUSED +Packet plog2_double(const Packet _x); + +/** \internal \returns log(1 + x) */ +template +Packet generic_plog1p(const Packet& x); + +/** \internal \returns exp(x)-1 */ +template +Packet generic_expm1(const Packet& x); + +/** \internal \returns exp(x) for single precision float */ +template +EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS +EIGEN_UNUSED +Packet pexp_float(const Packet _x); + +/** \internal \returns exp(x) for double precision real numbers */ +template +EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS +EIGEN_UNUSED +Packet pexp_double(const Packet _x); + +/** \internal \returns sin(x) for single precision float */ +template +EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS +EIGEN_UNUSED +Packet psin_float(const Packet& x); + +/** \internal \returns cos(x) for single precision float */ +template +EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS +EIGEN_UNUSED +Packet pcos_float(const Packet& x); + +/** \internal \returns sqrt(x) for complex types */ +template +EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS +EIGEN_UNUSED +Packet psqrt_complex(const Packet& a); + +template struct ppolevl; + + +} // end namespace internal +} // end namespace Eigen + +#endif // EIGEN_ARCH_GENERIC_PACKET_MATH_FUNCTIONS_FWD_H diff --git a/Eigen/src/Core/arch/Default/Half.h b/Eigen/src/Core/arch/Default/Half.h new file mode 100644 index 0000000..9f8e8cc --- /dev/null +++ b/Eigen/src/Core/arch/Default/Half.h @@ -0,0 +1,942 @@ +// This file is part of Eigen, a lightweight C++ template library +// for linear algebra. +// +// This Source Code Form is subject to the terms of the Mozilla +// Public License v. 2.0. If a copy of the MPL was not distributed +// with this file, You can obtain one at http://mozilla.org/MPL/2.0/. +// +// The conversion routines are Copyright (c) Fabian Giesen, 2016. +// The original license follows: +// +// Copyright (c) Fabian Giesen, 2016 +// All rights reserved. +// Redistribution and use in source and binary forms, with or without +// modification, are permitted. +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +// HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + + +// Standard 16-bit float type, mostly useful for GPUs. Defines a new +// type Eigen::half (inheriting either from CUDA's or HIP's __half struct) with +// operator overloads such that it behaves basically as an arithmetic +// type. It will be quite slow on CPUs (so it is recommended to stay +// in fp32 for CPUs, except for simple parameter conversions, I/O +// to disk and the likes), but fast on GPUs. + + +#ifndef EIGEN_HALF_H +#define EIGEN_HALF_H + +#include + +#if defined(EIGEN_HAS_GPU_FP16) || defined(EIGEN_HAS_ARM64_FP16_SCALAR_ARITHMETIC) +// When compiling with GPU support, the "__half_raw" base class as well as +// some other routines are defined in the GPU compiler header files +// (cuda_fp16.h, hip_fp16.h), and they are not tagged constexpr +// As a consequence, we get compile failures when compiling Eigen with +// GPU support. Hence the need to disable EIGEN_CONSTEXPR when building +// Eigen with GPU support + #pragma push_macro("EIGEN_CONSTEXPR") + #undef EIGEN_CONSTEXPR + #define EIGEN_CONSTEXPR +#endif + +#define F16_PACKET_FUNCTION(PACKET_F, PACKET_F16, METHOD) \ + template <> \ + EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC EIGEN_UNUSED \ + PACKET_F16 METHOD(const PACKET_F16& _x) { \ + return float2half(METHOD(half2float(_x))); \ + } + +namespace Eigen { + +struct half; + +namespace half_impl { + +// We want to use the __half_raw struct from the HIP header file only during the device compile phase. +// This is required because of a quirk in the way TensorFlow GPU builds are done. +// When compiling TensorFlow source code with GPU support, files that +// * contain GPU kernels (i.e. *.cu.cc files) are compiled via hipcc +// * do not contain GPU kernels ( i.e. *.cc files) are compiled via gcc (typically) +// +// Tensorflow uses the Eigen::half type as its FP16 type, and there are functions that +// * are defined in a file that gets compiled via hipcc AND +// * have Eigen::half as a pass-by-value argument AND +// * are called in a file that gets compiled via gcc +// +// In the scenario described above the caller and callee will see different versions +// of the Eigen::half base class __half_raw, and they will be compiled by different compilers +// +// There appears to be an ABI mismatch between gcc and clang (which is called by hipcc) that results in +// the callee getting corrupted values for the Eigen::half argument. +// +// Making the host side compile phase of hipcc use the same Eigen::half impl, as the gcc compile, resolves +// this error, and hence the following convoluted #if condition +#if !defined(EIGEN_HAS_GPU_FP16) || !defined(EIGEN_GPU_COMPILE_PHASE) +// Make our own __half_raw definition that is similar to CUDA's. +struct __half_raw { +#if (defined(EIGEN_HAS_GPU_FP16) && !defined(EIGEN_GPU_COMPILE_PHASE)) + // Eigen::half can be used as the datatype for shared memory declarations (in Eigen and TF) + // The element type for shared memory cannot have non-trivial constructors + // and hence the following special casing (which skips the zero-initilization). + // Note that this check gets done even in the host compilation phase, and + // hence the need for this + EIGEN_DEVICE_FUNC __half_raw() {} +#else + EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR __half_raw() : x(0) {} +#endif +#if defined(EIGEN_HAS_ARM64_FP16_SCALAR_ARITHMETIC) + explicit EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR __half_raw(numext::uint16_t raw) : x(numext::bit_cast<__fp16>(raw)) { + } + __fp16 x; +#else + explicit EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR __half_raw(numext::uint16_t raw) : x(raw) {} + numext::uint16_t x; +#endif +}; + +#elif defined(EIGEN_HAS_HIP_FP16) + // Nothing to do here + // HIP fp16 header file has a definition for __half_raw +#elif defined(EIGEN_HAS_CUDA_FP16) + #if EIGEN_CUDA_SDK_VER < 90000 + // In CUDA < 9.0, __half is the equivalent of CUDA 9's __half_raw + typedef __half __half_raw; + #endif // defined(EIGEN_HAS_CUDA_FP16) +#elif defined(SYCL_DEVICE_ONLY) + typedef cl::sycl::half __half_raw; +#endif + +EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR __half_raw raw_uint16_to_half(numext::uint16_t x); +EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC __half_raw float_to_half_rtne(float ff); +EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC float half_to_float(__half_raw h); + +struct half_base : public __half_raw { + EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR half_base() {} + EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR half_base(const __half_raw& h) : __half_raw(h) {} + +#if defined(EIGEN_HAS_GPU_FP16) + #if defined(EIGEN_HAS_HIP_FP16) + EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR half_base(const __half& h) { x = __half_as_ushort(h); } + #elif defined(EIGEN_HAS_CUDA_FP16) + #if EIGEN_CUDA_SDK_VER >= 90000 + EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR half_base(const __half& h) : __half_raw(*(__half_raw*)&h) {} + #endif + #endif +#endif +}; + +} // namespace half_impl + +// Class definition. +struct half : public half_impl::half_base { + + // Writing this out as separate #if-else blocks to make the code easier to follow + // The same applies to most #if-else blocks in this file +#if !defined(EIGEN_HAS_GPU_FP16) || !defined(EIGEN_GPU_COMPILE_PHASE) + // Use the same base class for the following two scenarios + // * when compiling without GPU support enabled + // * during host compile phase when compiling with GPU support enabled + typedef half_impl::__half_raw __half_raw; +#elif defined(EIGEN_HAS_HIP_FP16) + // Nothing to do here + // HIP fp16 header file has a definition for __half_raw +#elif defined(EIGEN_HAS_CUDA_FP16) + // Note that EIGEN_CUDA_SDK_VER is set to 0 even when compiling with HIP, so + // (EIGEN_CUDA_SDK_VER < 90000) is true even for HIP! So keeping this within + // #if defined(EIGEN_HAS_CUDA_FP16) is needed + #if defined(EIGEN_CUDA_SDK_VER) && EIGEN_CUDA_SDK_VER < 90000 + typedef half_impl::__half_raw __half_raw; + #endif +#endif + + EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR half() {} + + EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR half(const __half_raw& h) : half_impl::half_base(h) {} + +#if defined(EIGEN_HAS_GPU_FP16) + #if defined(EIGEN_HAS_HIP_FP16) + EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR half(const __half& h) : half_impl::half_base(h) {} + #elif defined(EIGEN_HAS_CUDA_FP16) + #if defined(EIGEN_CUDA_SDK_VER) && EIGEN_CUDA_SDK_VER >= 90000 + EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR half(const __half& h) : half_impl::half_base(h) {} + #endif + #endif +#endif + + + explicit EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR half(bool b) + : half_impl::half_base(half_impl::raw_uint16_to_half(b ? 0x3c00 : 0)) {} + template + explicit EIGEN_DEVICE_FUNC half(T val) + : half_impl::half_base(half_impl::float_to_half_rtne(static_cast(val))) {} + explicit EIGEN_DEVICE_FUNC half(float f) + : half_impl::half_base(half_impl::float_to_half_rtne(f)) {} + + // Following the convention of numpy, converting between complex and + // float will lead to loss of imag value. + template + explicit EIGEN_DEVICE_FUNC half(std::complex c) + : half_impl::half_base(half_impl::float_to_half_rtne(static_cast(c.real()))) {} + + EIGEN_DEVICE_FUNC operator float() const { // NOLINT: Allow implicit conversion to float, because it is lossless. + return half_impl::half_to_float(*this); + } + +#if defined(EIGEN_HAS_GPU_FP16) && !defined(EIGEN_GPU_COMPILE_PHASE) + EIGEN_DEVICE_FUNC operator __half() const { + ::__half_raw hr; + hr.x = x; + return __half(hr); + } +#endif +}; + +} // end namespace Eigen + +namespace std { +template<> +struct numeric_limits { + static const bool is_specialized = true; + static const bool is_signed = true; + static const bool is_integer = false; + static const bool is_exact = false; + static const bool has_infinity = true; + static const bool has_quiet_NaN = true; + static const bool has_signaling_NaN = true; + static const float_denorm_style has_denorm = denorm_present; + static const bool has_denorm_loss = false; + static const std::float_round_style round_style = std::round_to_nearest; + static const bool is_iec559 = false; + static const bool is_bounded = false; + static const bool is_modulo = false; + static const int digits = 11; + static const int digits10 = 3; // according to http://half.sourceforge.net/structstd_1_1numeric__limits_3_01half__float_1_1half_01_4.html + static const int max_digits10 = 5; // according to http://half.sourceforge.net/structstd_1_1numeric__limits_3_01half__float_1_1half_01_4.html + static const int radix = 2; + static const int min_exponent = -13; + static const int min_exponent10 = -4; + static const int max_exponent = 16; + static const int max_exponent10 = 4; + static const bool traps = true; + static const bool tinyness_before = false; + + static Eigen::half (min)() { return Eigen::half_impl::raw_uint16_to_half(0x400); } + static Eigen::half lowest() { return Eigen::half_impl::raw_uint16_to_half(0xfbff); } + static Eigen::half (max)() { return Eigen::half_impl::raw_uint16_to_half(0x7bff); } + static Eigen::half epsilon() { return Eigen::half_impl::raw_uint16_to_half(0x0800); } + static Eigen::half round_error() { return Eigen::half(0.5); } + static Eigen::half infinity() { return Eigen::half_impl::raw_uint16_to_half(0x7c00); } + static Eigen::half quiet_NaN() { return Eigen::half_impl::raw_uint16_to_half(0x7e00); } + static Eigen::half signaling_NaN() { return Eigen::half_impl::raw_uint16_to_half(0x7d00); } + static Eigen::half denorm_min() { return Eigen::half_impl::raw_uint16_to_half(0x1); } +}; + +// If std::numeric_limits is specialized, should also specialize +// std::numeric_limits, std::numeric_limits, and +// std::numeric_limits +// https://stackoverflow.com/a/16519653/ +template<> +struct numeric_limits : numeric_limits {}; +template<> +struct numeric_limits : numeric_limits {}; +template<> +struct numeric_limits : numeric_limits {}; +} // end namespace std + +namespace Eigen { + +namespace half_impl { + +#if (defined(EIGEN_HAS_CUDA_FP16) && defined(EIGEN_CUDA_ARCH) && \ + EIGEN_CUDA_ARCH >= 530) || \ + (defined(EIGEN_HAS_HIP_FP16) && defined(HIP_DEVICE_COMPILE)) +// Note: We deliberatly do *not* define this to 1 even if we have Arm's native +// fp16 type since GPU halfs are rather different from native CPU halfs. +// TODO: Rename to something like EIGEN_HAS_NATIVE_GPU_FP16 +#define EIGEN_HAS_NATIVE_FP16 +#endif + +// Intrinsics for native fp16 support. Note that on current hardware, +// these are no faster than fp32 arithmetic (you need to use the half2 +// versions to get the ALU speed increased), but you do save the +// conversion steps back and forth. + +#if defined(EIGEN_HAS_NATIVE_FP16) +EIGEN_STRONG_INLINE __device__ half operator + (const half& a, const half& b) { +#if defined(EIGEN_CUDA_SDK_VER) && EIGEN_CUDA_SDK_VER >= 90000 + return __hadd(::__half(a), ::__half(b)); +#else + return __hadd(a, b); +#endif +} +EIGEN_STRONG_INLINE __device__ half operator * (const half& a, const half& b) { + return __hmul(a, b); +} +EIGEN_STRONG_INLINE __device__ half operator - (const half& a, const half& b) { + return __hsub(a, b); +} +EIGEN_STRONG_INLINE __device__ half operator / (const half& a, const half& b) { +#if defined(EIGEN_CUDA_SDK_VER) && EIGEN_CUDA_SDK_VER >= 90000 + return __hdiv(a, b); +#else + float num = __half2float(a); + float denom = __half2float(b); + return __float2half(num / denom); +#endif +} +EIGEN_STRONG_INLINE __device__ half operator - (const half& a) { + return __hneg(a); +} +EIGEN_STRONG_INLINE __device__ half& operator += (half& a, const half& b) { + a = a + b; + return a; +} +EIGEN_STRONG_INLINE __device__ half& operator *= (half& a, const half& b) { + a = a * b; + return a; +} +EIGEN_STRONG_INLINE __device__ half& operator -= (half& a, const half& b) { + a = a - b; + return a; +} +EIGEN_STRONG_INLINE __device__ half& operator /= (half& a, const half& b) { + a = a / b; + return a; +} +EIGEN_STRONG_INLINE __device__ bool operator == (const half& a, const half& b) { + return __heq(a, b); +} +EIGEN_STRONG_INLINE __device__ bool operator != (const half& a, const half& b) { + return __hne(a, b); +} +EIGEN_STRONG_INLINE __device__ bool operator < (const half& a, const half& b) { + return __hlt(a, b); +} +EIGEN_STRONG_INLINE __device__ bool operator <= (const half& a, const half& b) { + return __hle(a, b); +} +EIGEN_STRONG_INLINE __device__ bool operator > (const half& a, const half& b) { + return __hgt(a, b); +} +EIGEN_STRONG_INLINE __device__ bool operator >= (const half& a, const half& b) { + return __hge(a, b); +} +#endif + +#if defined(EIGEN_HAS_ARM64_FP16_SCALAR_ARITHMETIC) +EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC half operator + (const half& a, const half& b) { + return half(vaddh_f16(a.x, b.x)); +} +EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC half operator * (const half& a, const half& b) { + return half(vmulh_f16(a.x, b.x)); +} +EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC half operator - (const half& a, const half& b) { + return half(vsubh_f16(a.x, b.x)); +} +EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC half operator / (const half& a, const half& b) { + return half(vdivh_f16(a.x, b.x)); +} +EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC half operator - (const half& a) { + return half(vnegh_f16(a.x)); +} +EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC half& operator += (half& a, const half& b) { + a = half(vaddh_f16(a.x, b.x)); + return a; +} +EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC half& operator *= (half& a, const half& b) { + a = half(vmulh_f16(a.x, b.x)); + return a; +} +EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC half& operator -= (half& a, const half& b) { + a = half(vsubh_f16(a.x, b.x)); + return a; +} +EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC half& operator /= (half& a, const half& b) { + a = half(vdivh_f16(a.x, b.x)); + return a; +} +EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bool operator == (const half& a, const half& b) { + return vceqh_f16(a.x, b.x); +} +EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bool operator != (const half& a, const half& b) { + return !vceqh_f16(a.x, b.x); +} +EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bool operator < (const half& a, const half& b) { + return vclth_f16(a.x, b.x); +} +EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bool operator <= (const half& a, const half& b) { + return vcleh_f16(a.x, b.x); +} +EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bool operator > (const half& a, const half& b) { + return vcgth_f16(a.x, b.x); +} +EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bool operator >= (const half& a, const half& b) { + return vcgeh_f16(a.x, b.x); +} +// We need to distinguish ‘clang as the CUDA compiler’ from ‘clang as the host compiler, +// invoked by NVCC’ (e.g. on MacOS). The former needs to see both host and device implementation +// of the functions, while the latter can only deal with one of them. +#elif !defined(EIGEN_HAS_NATIVE_FP16) || (EIGEN_COMP_CLANG && !EIGEN_COMP_NVCC) // Emulate support for half floats + +#if EIGEN_COMP_CLANG && defined(EIGEN_CUDACC) +// We need to provide emulated *host-side* FP16 operators for clang. +#pragma push_macro("EIGEN_DEVICE_FUNC") +#undef EIGEN_DEVICE_FUNC +#if defined(EIGEN_HAS_CUDA_FP16) && defined(EIGEN_HAS_NATIVE_FP16) +#define EIGEN_DEVICE_FUNC __host__ +#else // both host and device need emulated ops. +#define EIGEN_DEVICE_FUNC __host__ __device__ +#endif +#endif + +// Definitions for CPUs and older HIP+CUDA, mostly working through conversion +// to/from fp32. +EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC half operator + (const half& a, const half& b) { + return half(float(a) + float(b)); +} +EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC half operator * (const half& a, const half& b) { + return half(float(a) * float(b)); +} +EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC half operator - (const half& a, const half& b) { + return half(float(a) - float(b)); +} +EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC half operator / (const half& a, const half& b) { + return half(float(a) / float(b)); +} +EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC half operator - (const half& a) { + half result; + result.x = a.x ^ 0x8000; + return result; +} +EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC half& operator += (half& a, const half& b) { + a = half(float(a) + float(b)); + return a; +} +EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC half& operator *= (half& a, const half& b) { + a = half(float(a) * float(b)); + return a; +} +EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC half& operator -= (half& a, const half& b) { + a = half(float(a) - float(b)); + return a; +} +EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC half& operator /= (half& a, const half& b) { + a = half(float(a) / float(b)); + return a; +} +EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bool operator == (const half& a, const half& b) { + return numext::equal_strict(float(a),float(b)); +} +EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bool operator != (const half& a, const half& b) { + return numext::not_equal_strict(float(a), float(b)); +} +EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bool operator < (const half& a, const half& b) { + return float(a) < float(b); +} +EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bool operator <= (const half& a, const half& b) { + return float(a) <= float(b); +} +EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bool operator > (const half& a, const half& b) { + return float(a) > float(b); +} +EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bool operator >= (const half& a, const half& b) { + return float(a) >= float(b); +} + +#if defined(__clang__) && defined(__CUDA__) +#pragma pop_macro("EIGEN_DEVICE_FUNC") +#endif +#endif // Emulate support for half floats + +// Division by an index. Do it in full float precision to avoid accuracy +// issues in converting the denominator to half. +EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC half operator / (const half& a, Index b) { + return half(static_cast(a) / static_cast(b)); +} + +EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC half operator++(half& a) { + a += half(1); + return a; +} + +EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC half operator--(half& a) { + a -= half(1); + return a; +} + +EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC half operator++(half& a, int) { + half original_value = a; + ++a; + return original_value; +} + +EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC half operator--(half& a, int) { + half original_value = a; + --a; + return original_value; +} + +// Conversion routines, including fallbacks for the host or older CUDA. +// Note that newer Intel CPUs (Haswell or newer) have vectorized versions of +// these in hardware. If we need more performance on older/other CPUs, they are +// also possible to vectorize directly. + +EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR __half_raw raw_uint16_to_half(numext::uint16_t x) { + // We cannot simply do a "return __half_raw(x)" here, because __half_raw is union type + // in the hip_fp16 header file, and that will trigger a compile error + // On the other hand, having anything but a return statement also triggers a compile error + // because this is constexpr function. + // Fortunately, since we need to disable EIGEN_CONSTEXPR for GPU anyway, we can get out + // of this catch22 by having separate bodies for GPU / non GPU +#if defined(EIGEN_HAS_GPU_FP16) + __half_raw h; + h.x = x; + return h; +#else + return __half_raw(x); +#endif +} + +EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC numext::uint16_t raw_half_as_uint16(const __half_raw& h) { + // HIP/CUDA/Default have a member 'x' of type uint16_t. + // For ARM64 native half, the member 'x' is of type __fp16, so we need to bit-cast. + // For SYCL, cl::sycl::half is _Float16, so cast directly. +#if defined(EIGEN_HAS_ARM64_FP16_SCALAR_ARITHMETIC) + return numext::bit_cast(h.x); +#elif defined(SYCL_DEVICE_ONLY) + return numext::bit_cast(h); +#else + return h.x; +#endif +} + +union float32_bits { + unsigned int u; + float f; +}; + +EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC __half_raw float_to_half_rtne(float ff) { +#if (defined(EIGEN_HAS_CUDA_FP16) && defined(EIGEN_CUDA_ARCH) && EIGEN_CUDA_ARCH >= 300) || \ + (defined(EIGEN_HAS_HIP_FP16) && defined(EIGEN_HIP_DEVICE_COMPILE)) + __half tmp_ff = __float2half(ff); + return *(__half_raw*)&tmp_ff; + +#elif defined(EIGEN_HAS_FP16_C) + __half_raw h; + h.x = _cvtss_sh(ff, 0); + return h; + +#elif defined(EIGEN_HAS_ARM64_FP16_SCALAR_ARITHMETIC) + __half_raw h; + h.x = static_cast<__fp16>(ff); + return h; + +#else + float32_bits f; f.f = ff; + + const float32_bits f32infty = { 255 << 23 }; + const float32_bits f16max = { (127 + 16) << 23 }; + const float32_bits denorm_magic = { ((127 - 15) + (23 - 10) + 1) << 23 }; + unsigned int sign_mask = 0x80000000u; + __half_raw o; + o.x = static_cast(0x0u); + + unsigned int sign = f.u & sign_mask; + f.u ^= sign; + + // NOTE all the integer compares in this function can be safely + // compiled into signed compares since all operands are below + // 0x80000000. Important if you want fast straight SSE2 code + // (since there's no unsigned PCMPGTD). + + if (f.u >= f16max.u) { // result is Inf or NaN (all exponent bits set) + o.x = (f.u > f32infty.u) ? 0x7e00 : 0x7c00; // NaN->qNaN and Inf->Inf + } else { // (De)normalized number or zero + if (f.u < (113 << 23)) { // resulting FP16 is subnormal or zero + // use a magic value to align our 10 mantissa bits at the bottom of + // the float. as long as FP addition is round-to-nearest-even this + // just works. + f.f += denorm_magic.f; + + // and one integer subtract of the bias later, we have our final float! + o.x = static_cast(f.u - denorm_magic.u); + } else { + unsigned int mant_odd = (f.u >> 13) & 1; // resulting mantissa is odd + + // update exponent, rounding bias part 1 + // Equivalent to `f.u += ((unsigned int)(15 - 127) << 23) + 0xfff`, but + // without arithmetic overflow. + f.u += 0xc8000fffU; + // rounding bias part 2 + f.u += mant_odd; + // take the bits! + o.x = static_cast(f.u >> 13); + } + } + + o.x |= static_cast(sign >> 16); + return o; +#endif +} + +EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC float half_to_float(__half_raw h) { +#if (defined(EIGEN_HAS_CUDA_FP16) && defined(EIGEN_CUDA_ARCH) && EIGEN_CUDA_ARCH >= 300) || \ + (defined(EIGEN_HAS_HIP_FP16) && defined(EIGEN_HIP_DEVICE_COMPILE)) + return __half2float(h); +#elif defined(EIGEN_HAS_FP16_C) + return _cvtsh_ss(h.x); +#elif defined(EIGEN_HAS_ARM64_FP16_SCALAR_ARITHMETIC) + return static_cast(h.x); +#else + const float32_bits magic = { 113 << 23 }; + const unsigned int shifted_exp = 0x7c00 << 13; // exponent mask after shift + float32_bits o; + + o.u = (h.x & 0x7fff) << 13; // exponent/mantissa bits + unsigned int exp = shifted_exp & o.u; // just the exponent + o.u += (127 - 15) << 23; // exponent adjust + + // handle exponent special cases + if (exp == shifted_exp) { // Inf/NaN? + o.u += (128 - 16) << 23; // extra exp adjust + } else if (exp == 0) { // Zero/Denormal? + o.u += 1 << 23; // extra exp adjust + o.f -= magic.f; // renormalize + } + + o.u |= (h.x & 0x8000) << 16; // sign bit + return o.f; +#endif +} + +// --- standard functions --- + +EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bool (isinf)(const half& a) { +#ifdef EIGEN_HAS_ARM64_FP16_SCALAR_ARITHMETIC + return (numext::bit_cast(a.x) & 0x7fff) == 0x7c00; +#else + return (a.x & 0x7fff) == 0x7c00; +#endif +} +EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bool (isnan)(const half& a) { +#if (defined(EIGEN_HAS_CUDA_FP16) && defined(EIGEN_CUDA_ARCH) && EIGEN_CUDA_ARCH >= 530) || \ + (defined(EIGEN_HAS_HIP_FP16) && defined(EIGEN_HIP_DEVICE_COMPILE)) + return __hisnan(a); +#elif defined(EIGEN_HAS_ARM64_FP16_SCALAR_ARITHMETIC) + return (numext::bit_cast(a.x) & 0x7fff) > 0x7c00; +#else + return (a.x & 0x7fff) > 0x7c00; +#endif +} +EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bool (isfinite)(const half& a) { + return !(isinf EIGEN_NOT_A_MACRO (a)) && !(isnan EIGEN_NOT_A_MACRO (a)); +} + +EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC half abs(const half& a) { +#if defined(EIGEN_HAS_ARM64_FP16_SCALAR_ARITHMETIC) + return half(vabsh_f16(a.x)); +#else + half result; + result.x = a.x & 0x7FFF; + return result; +#endif +} +EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC half exp(const half& a) { +#if (EIGEN_CUDA_SDK_VER >= 80000 && defined EIGEN_CUDA_ARCH && EIGEN_CUDA_ARCH >= 530) || \ + defined(EIGEN_HIP_DEVICE_COMPILE) + return half(hexp(a)); +#else + return half(::expf(float(a))); +#endif +} +EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC half expm1(const half& a) { + return half(numext::expm1(float(a))); +} +EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC half log(const half& a) { +#if (defined(EIGEN_HAS_CUDA_FP16) && EIGEN_CUDA_SDK_VER >= 80000 && defined(EIGEN_CUDA_ARCH) && EIGEN_CUDA_ARCH >= 530) || \ + (defined(EIGEN_HAS_HIP_FP16) && defined(EIGEN_HIP_DEVICE_COMPILE)) + return half(::hlog(a)); +#else + return half(::logf(float(a))); +#endif +} +EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC half log1p(const half& a) { + return half(numext::log1p(float(a))); +} +EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC half log10(const half& a) { + return half(::log10f(float(a))); +} +EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC half log2(const half& a) { + return half(static_cast(EIGEN_LOG2E) * ::logf(float(a))); +} + +EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC half sqrt(const half& a) { +#if (EIGEN_CUDA_SDK_VER >= 80000 && defined EIGEN_CUDA_ARCH && EIGEN_CUDA_ARCH >= 530) || \ + defined(EIGEN_HIP_DEVICE_COMPILE) + return half(hsqrt(a)); +#else + return half(::sqrtf(float(a))); +#endif +} +EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC half pow(const half& a, const half& b) { + return half(::powf(float(a), float(b))); +} +EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC half sin(const half& a) { + return half(::sinf(float(a))); +} +EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC half cos(const half& a) { + return half(::cosf(float(a))); +} +EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC half tan(const half& a) { + return half(::tanf(float(a))); +} +EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC half tanh(const half& a) { + return half(::tanhf(float(a))); +} +EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC half asin(const half& a) { + return half(::asinf(float(a))); +} +EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC half acos(const half& a) { + return half(::acosf(float(a))); +} +EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC half floor(const half& a) { +#if (EIGEN_CUDA_SDK_VER >= 80000 && defined EIGEN_CUDA_ARCH && EIGEN_CUDA_ARCH >= 300) || \ + defined(EIGEN_HIP_DEVICE_COMPILE) + return half(hfloor(a)); +#else + return half(::floorf(float(a))); +#endif +} +EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC half ceil(const half& a) { +#if (EIGEN_CUDA_SDK_VER >= 80000 && defined EIGEN_CUDA_ARCH && EIGEN_CUDA_ARCH >= 300) || \ + defined(EIGEN_HIP_DEVICE_COMPILE) + return half(hceil(a)); +#else + return half(::ceilf(float(a))); +#endif +} +EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC half rint(const half& a) { + return half(::rintf(float(a))); +} +EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC half round(const half& a) { + return half(::roundf(float(a))); +} +EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC half fmod(const half& a, const half& b) { + return half(::fmodf(float(a), float(b))); +} + +EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC half (min)(const half& a, const half& b) { +#if (defined(EIGEN_HAS_CUDA_FP16) && defined(EIGEN_CUDA_ARCH) && EIGEN_CUDA_ARCH >= 530) || \ + (defined(EIGEN_HAS_HIP_FP16) && defined(EIGEN_HIP_DEVICE_COMPILE)) + return __hlt(b, a) ? b : a; +#else + const float f1 = static_cast(a); + const float f2 = static_cast(b); + return f2 < f1 ? b : a; +#endif +} +EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC half (max)(const half& a, const half& b) { +#if (defined(EIGEN_HAS_CUDA_FP16) && defined(EIGEN_CUDA_ARCH) && EIGEN_CUDA_ARCH >= 530) || \ + (defined(EIGEN_HAS_HIP_FP16) && defined(EIGEN_HIP_DEVICE_COMPILE)) + return __hlt(a, b) ? b : a; +#else + const float f1 = static_cast(a); + const float f2 = static_cast(b); + return f1 < f2 ? b : a; +#endif +} + +#ifndef EIGEN_NO_IO +EIGEN_ALWAYS_INLINE std::ostream& operator << (std::ostream& os, const half& v) { + os << static_cast(v); + return os; +} +#endif + +} // end namespace half_impl + +// import Eigen::half_impl::half into Eigen namespace +// using half_impl::half; + +namespace internal { + +template<> +struct random_default_impl +{ + static inline half run(const half& x, const half& y) + { + return x + (y-x) * half(float(std::rand()) / float(RAND_MAX)); + } + static inline half run() + { + return run(half(-1.f), half(1.f)); + } +}; + +template<> struct is_arithmetic { enum { value = true }; }; + +} // end namespace internal + +template<> struct NumTraits + : GenericNumTraits +{ + enum { + IsSigned = true, + IsInteger = false, + IsComplex = false, + RequireInitialization = false + }; + + EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR static EIGEN_STRONG_INLINE Eigen::half epsilon() { + return half_impl::raw_uint16_to_half(0x0800); + } + EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR static EIGEN_STRONG_INLINE Eigen::half dummy_precision() { + return half_impl::raw_uint16_to_half(0x211f); // Eigen::half(1e-2f); + } + EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR static EIGEN_STRONG_INLINE Eigen::half highest() { + return half_impl::raw_uint16_to_half(0x7bff); + } + EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR static EIGEN_STRONG_INLINE Eigen::half lowest() { + return half_impl::raw_uint16_to_half(0xfbff); + } + EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR static EIGEN_STRONG_INLINE Eigen::half infinity() { + return half_impl::raw_uint16_to_half(0x7c00); + } + EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR static EIGEN_STRONG_INLINE Eigen::half quiet_NaN() { + return half_impl::raw_uint16_to_half(0x7e00); + } +}; + +} // end namespace Eigen + +#if defined(EIGEN_HAS_GPU_FP16) || defined(EIGEN_HAS_ARM64_FP16_SCALAR_ARITHMETIC) + #pragma pop_macro("EIGEN_CONSTEXPR") +#endif + +namespace Eigen { +namespace numext { + +#if defined(EIGEN_GPU_COMPILE_PHASE) + +template <> +EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE bool(isnan)(const Eigen::half& h) { + return (half_impl::isnan)(h); +} + +template <> +EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE bool(isinf)(const Eigen::half& h) { + return (half_impl::isinf)(h); +} + +template <> +EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE bool(isfinite)(const Eigen::half& h) { + return (half_impl::isfinite)(h); +} + +#endif + +template <> +EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC Eigen::half bit_cast(const uint16_t& src) { + return Eigen::half(Eigen::half_impl::raw_uint16_to_half(src)); +} + +template <> +EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC uint16_t bit_cast(const Eigen::half& src) { + return Eigen::half_impl::raw_half_as_uint16(src); +} + +} // namespace numext +} // namespace Eigen + +// Add the missing shfl* intrinsics. +// The __shfl* functions are only valid on HIP or _CUDA_ARCH_ >= 300. +// CUDA defines them for (__CUDA_ARCH__ >= 300 || !defined(__CUDA_ARCH__)) +// +// HIP and CUDA prior to SDK 9.0 define +// __shfl, __shfl_up, __shfl_down, __shfl_xor for int and float +// CUDA since 9.0 deprecates those and instead defines +// __shfl_sync, __shfl_up_sync, __shfl_down_sync, __shfl_xor_sync, +// with native support for __half and __nv_bfloat16 +// +// Note that the following are __device__ - only functions. +#if (defined(EIGEN_CUDACC) && (!defined(EIGEN_CUDA_ARCH) || EIGEN_CUDA_ARCH >= 300)) \ + || defined(EIGEN_HIPCC) + +#if defined(EIGEN_HAS_CUDA_FP16) && EIGEN_CUDA_SDK_VER >= 90000 + +__device__ EIGEN_STRONG_INLINE Eigen::half __shfl_sync(unsigned mask, Eigen::half var, int srcLane, int width=warpSize) { + const __half h = var; + return static_cast(__shfl_sync(mask, h, srcLane, width)); +} + +__device__ EIGEN_STRONG_INLINE Eigen::half __shfl_up_sync(unsigned mask, Eigen::half var, unsigned int delta, int width=warpSize) { + const __half h = var; + return static_cast(__shfl_up_sync(mask, h, delta, width)); +} + +__device__ EIGEN_STRONG_INLINE Eigen::half __shfl_down_sync(unsigned mask, Eigen::half var, unsigned int delta, int width=warpSize) { + const __half h = var; + return static_cast(__shfl_down_sync(mask, h, delta, width)); +} + +__device__ EIGEN_STRONG_INLINE Eigen::half __shfl_xor_sync(unsigned mask, Eigen::half var, int laneMask, int width=warpSize) { + const __half h = var; + return static_cast(__shfl_xor_sync(mask, h, laneMask, width)); +} + +#else // HIP or CUDA SDK < 9.0 + +__device__ EIGEN_STRONG_INLINE Eigen::half __shfl(Eigen::half var, int srcLane, int width=warpSize) { + const int ivar = static_cast(Eigen::numext::bit_cast(var)); + return Eigen::numext::bit_cast(static_cast(__shfl(ivar, srcLane, width))); +} + +__device__ EIGEN_STRONG_INLINE Eigen::half __shfl_up(Eigen::half var, unsigned int delta, int width=warpSize) { + const int ivar = static_cast(Eigen::numext::bit_cast(var)); + return Eigen::numext::bit_cast(static_cast(__shfl_up(ivar, delta, width))); +} + +__device__ EIGEN_STRONG_INLINE Eigen::half __shfl_down(Eigen::half var, unsigned int delta, int width=warpSize) { + const int ivar = static_cast(Eigen::numext::bit_cast(var)); + return Eigen::numext::bit_cast(static_cast(__shfl_down(ivar, delta, width))); +} + +__device__ EIGEN_STRONG_INLINE Eigen::half __shfl_xor(Eigen::half var, int laneMask, int width=warpSize) { + const int ivar = static_cast(Eigen::numext::bit_cast(var)); + return Eigen::numext::bit_cast(static_cast(__shfl_xor(ivar, laneMask, width))); +} + +#endif // HIP vs CUDA +#endif // __shfl* + +// ldg() has an overload for __half_raw, but we also need one for Eigen::half. +#if (defined(EIGEN_CUDACC) && (!defined(EIGEN_CUDA_ARCH) || EIGEN_CUDA_ARCH >= 350)) \ + || defined(EIGEN_HIPCC) +EIGEN_STRONG_INLINE __device__ Eigen::half __ldg(const Eigen::half* ptr) { + return Eigen::half_impl::raw_uint16_to_half(__ldg(reinterpret_cast(ptr))); +} +#endif // __ldg + +#if EIGEN_HAS_STD_HASH +namespace std { +template <> +struct hash { + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE std::size_t operator()(const Eigen::half& a) const { + return static_cast(Eigen::numext::bit_cast(a)); + } +}; +} // end namespace std +#endif + +#endif // EIGEN_HALF_H diff --git a/Eigen/src/Core/arch/Default/Settings.h b/Eigen/src/Core/arch/Default/Settings.h new file mode 100644 index 0000000..a5c3ada --- /dev/null +++ b/Eigen/src/Core/arch/Default/Settings.h @@ -0,0 +1,49 @@ +// This file is part of Eigen, a lightweight C++ template library +// for linear algebra. +// +// Copyright (C) 2008-2010 Gael Guennebaud +// Copyright (C) 2006-2008 Benoit Jacob +// +// This Source Code Form is subject to the terms of the Mozilla +// Public License v. 2.0. If a copy of the MPL was not distributed +// with this file, You can obtain one at http://mozilla.org/MPL/2.0/. + + +/* All the parameters defined in this file can be specialized in the + * architecture specific files, and/or by the user. + * More to come... */ + +#ifndef EIGEN_DEFAULT_SETTINGS_H +#define EIGEN_DEFAULT_SETTINGS_H + +/** Defines the maximal loop size to enable meta unrolling of loops. + * Note that the value here is expressed in Eigen's own notion of "number of FLOPS", + * it does not correspond to the number of iterations or the number of instructions + */ +#ifndef EIGEN_UNROLLING_LIMIT +#define EIGEN_UNROLLING_LIMIT 110 +#endif + +/** Defines the threshold between a "small" and a "large" matrix. + * This threshold is mainly used to select the proper product implementation. + */ +#ifndef EIGEN_CACHEFRIENDLY_PRODUCT_THRESHOLD +#define EIGEN_CACHEFRIENDLY_PRODUCT_THRESHOLD 8 +#endif + +/** Defines the maximal width of the blocks used in the triangular product and solver + * for vectors (level 2 blas xTRMV and xTRSV). The default is 8. + */ +#ifndef EIGEN_TUNE_TRIANGULAR_PANEL_WIDTH +#define EIGEN_TUNE_TRIANGULAR_PANEL_WIDTH 8 +#endif + + +/** Defines the default number of registers available for that architecture. + * Currently it must be 8 or 16. Other values will fail. + */ +#ifndef EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS +#define EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS 8 +#endif + +#endif // EIGEN_DEFAULT_SETTINGS_H diff --git a/Eigen/src/Core/arch/Default/TypeCasting.h b/Eigen/src/Core/arch/Default/TypeCasting.h new file mode 100644 index 0000000..fb8183b --- /dev/null +++ b/Eigen/src/Core/arch/Default/TypeCasting.h @@ -0,0 +1,120 @@ +// This file is part of Eigen, a lightweight C++ template library +// for linear algebra. +// +// Copyright (C) 2016 Benoit Steiner +// Copyright (C) 2019 Rasmus Munk Larsen +// +// This Source Code Form is subject to the terms of the Mozilla +// Public License v. 2.0. If a copy of the MPL was not distributed +// with this file, You can obtain one at http://mozilla.org/MPL/2.0/. + +#ifndef EIGEN_GENERIC_TYPE_CASTING_H +#define EIGEN_GENERIC_TYPE_CASTING_H + +namespace Eigen { + +namespace internal { + +template<> +struct scalar_cast_op { + EIGEN_EMPTY_STRUCT_CTOR(scalar_cast_op) + typedef Eigen::half result_type; + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Eigen::half operator() (const float& a) const { + #if (defined(EIGEN_HAS_CUDA_FP16) && defined(EIGEN_CUDA_ARCH) && EIGEN_CUDA_ARCH >= 300) || \ + (defined(EIGEN_HAS_HIP_FP16) && defined(EIGEN_HIP_DEVICE_COMPILE)) + return __float2half(a); + #else + return Eigen::half(a); + #endif + } +}; + +template<> +struct functor_traits > +{ enum { Cost = NumTraits::AddCost, PacketAccess = false }; }; + + +template<> +struct scalar_cast_op { + EIGEN_EMPTY_STRUCT_CTOR(scalar_cast_op) + typedef Eigen::half result_type; + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Eigen::half operator() (const int& a) const { + #if (defined(EIGEN_HAS_CUDA_FP16) && defined(EIGEN_CUDA_ARCH) && EIGEN_CUDA_ARCH >= 300) || \ + (defined(EIGEN_HAS_HIP_FP16) && defined(EIGEN_HIP_DEVICE_COMPILE)) + return __float2half(static_cast(a)); + #else + return Eigen::half(static_cast(a)); + #endif + } +}; + +template<> +struct functor_traits > +{ enum { Cost = NumTraits::AddCost, PacketAccess = false }; }; + + +template<> +struct scalar_cast_op { + EIGEN_EMPTY_STRUCT_CTOR(scalar_cast_op) + typedef float result_type; + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE float operator() (const Eigen::half& a) const { + #if (defined(EIGEN_HAS_CUDA_FP16) && defined(EIGEN_CUDA_ARCH) && EIGEN_CUDA_ARCH >= 300) || \ + (defined(EIGEN_HAS_HIP_FP16) && defined(EIGEN_HIP_DEVICE_COMPILE)) + return __half2float(a); + #else + return static_cast(a); + #endif + } +}; + +template<> +struct functor_traits > +{ enum { Cost = NumTraits::AddCost, PacketAccess = false }; }; + + +template<> +struct scalar_cast_op { + EIGEN_EMPTY_STRUCT_CTOR(scalar_cast_op) + typedef Eigen::bfloat16 result_type; + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Eigen::bfloat16 operator() (const float& a) const { + return Eigen::bfloat16(a); + } +}; + +template<> +struct functor_traits > +{ enum { Cost = NumTraits::AddCost, PacketAccess = false }; }; + + +template<> +struct scalar_cast_op { + EIGEN_EMPTY_STRUCT_CTOR(scalar_cast_op) + typedef Eigen::bfloat16 result_type; + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Eigen::bfloat16 operator() (const int& a) const { + return Eigen::bfloat16(static_cast(a)); + } +}; + +template<> +struct functor_traits > +{ enum { Cost = NumTraits::AddCost, PacketAccess = false }; }; + + +template<> +struct scalar_cast_op { + EIGEN_EMPTY_STRUCT_CTOR(scalar_cast_op) + typedef float result_type; + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE float operator() (const Eigen::bfloat16& a) const { + return static_cast(a); + } +}; + +template<> +struct functor_traits > +{ enum { Cost = NumTraits::AddCost, PacketAccess = false }; }; + + +} +} + +#endif // EIGEN_GENERIC_TYPE_CASTING_H diff --git a/Eigen/src/Core/arch/GPU/MathFunctions.h b/Eigen/src/Core/arch/GPU/MathFunctions.h new file mode 100644 index 0000000..d2b3a25 --- /dev/null +++ b/Eigen/src/Core/arch/GPU/MathFunctions.h @@ -0,0 +1,103 @@ +// This file is part of Eigen, a lightweight C++ template library +// for linear algebra. +// +// Copyright (C) 2014 Benoit Steiner +// +// This Source Code Form is subject to the terms of the Mozilla +// Public License v. 2.0. If a copy of the MPL was not distributed +// with this file, You can obtain one at http://mozilla.org/MPL/2.0/. + +#ifndef EIGEN_MATH_FUNCTIONS_GPU_H +#define EIGEN_MATH_FUNCTIONS_GPU_H + +namespace Eigen { + +namespace internal { + +// Make sure this is only available when targeting a GPU: we don't want to +// introduce conflicts between these packet_traits definitions and the ones +// we'll use on the host side (SSE, AVX, ...) +#if defined(EIGEN_GPUCC) && defined(EIGEN_USE_GPU) +template<> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE +float4 plog(const float4& a) +{ + return make_float4(logf(a.x), logf(a.y), logf(a.z), logf(a.w)); +} + +template<> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE +double2 plog(const double2& a) +{ + using ::log; + return make_double2(log(a.x), log(a.y)); +} + +template<> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE +float4 plog1p(const float4& a) +{ + return make_float4(log1pf(a.x), log1pf(a.y), log1pf(a.z), log1pf(a.w)); +} + +template<> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE +double2 plog1p(const double2& a) +{ + return make_double2(log1p(a.x), log1p(a.y)); +} + +template<> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE +float4 pexp(const float4& a) +{ + return make_float4(expf(a.x), expf(a.y), expf(a.z), expf(a.w)); +} + +template<> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE +double2 pexp(const double2& a) +{ + using ::exp; + return make_double2(exp(a.x), exp(a.y)); +} + +template<> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE +float4 pexpm1(const float4& a) +{ + return make_float4(expm1f(a.x), expm1f(a.y), expm1f(a.z), expm1f(a.w)); +} + +template<> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE +double2 pexpm1(const double2& a) +{ + return make_double2(expm1(a.x), expm1(a.y)); +} + +template<> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE +float4 psqrt(const float4& a) +{ + return make_float4(sqrtf(a.x), sqrtf(a.y), sqrtf(a.z), sqrtf(a.w)); +} + +template<> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE +double2 psqrt(const double2& a) +{ + using ::sqrt; + return make_double2(sqrt(a.x), sqrt(a.y)); +} + +template<> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE +float4 prsqrt(const float4& a) +{ + return make_float4(rsqrtf(a.x), rsqrtf(a.y), rsqrtf(a.z), rsqrtf(a.w)); +} + +template<> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE +double2 prsqrt(const double2& a) +{ + return make_double2(rsqrt(a.x), rsqrt(a.y)); +} + + +#endif + +} // end namespace internal + +} // end namespace Eigen + +#endif // EIGEN_MATH_FUNCTIONS_GPU_H diff --git a/Eigen/src/Core/arch/GPU/PacketMath.h b/Eigen/src/Core/arch/GPU/PacketMath.h new file mode 100644 index 0000000..689110d --- /dev/null +++ b/Eigen/src/Core/arch/GPU/PacketMath.h @@ -0,0 +1,1685 @@ +// This file is part of Eigen, a lightweight C++ template library +// for linear algebra. +// +// Copyright (C) 2014 Benoit Steiner +// +// This Source Code Form is subject to the terms of the Mozilla +// Public License v. 2.0. If a copy of the MPL was not distributed +// with this file, You can obtain one at http://mozilla.org/MPL/2.0/. + +#ifndef EIGEN_PACKET_MATH_GPU_H +#define EIGEN_PACKET_MATH_GPU_H + +namespace Eigen { + +namespace internal { + +// Read-only data cached load available. +#if defined(EIGEN_HIP_DEVICE_COMPILE) || (defined(EIGEN_CUDA_ARCH) && EIGEN_CUDA_ARCH >= 350) +#define EIGEN_GPU_HAS_LDG 1 +#endif + +// FP16 math available. +#if (defined(EIGEN_CUDA_ARCH) && EIGEN_CUDA_ARCH >= 530) +#define EIGEN_CUDA_HAS_FP16_ARITHMETIC 1 +#endif + +#if defined(EIGEN_HIP_DEVICE_COMPILE) || defined(EIGEN_CUDA_HAS_FP16_ARITHMETIC) +#define EIGEN_GPU_HAS_FP16_ARITHMETIC 1 +#endif + +// Make sure this is only available when targeting a GPU: we don't want to +// introduce conflicts between these packet_traits definitions and the ones +// we'll use on the host side (SSE, AVX, ...) +#if defined(EIGEN_GPUCC) && defined(EIGEN_USE_GPU) + +template<> struct is_arithmetic { enum { value = true }; }; +template<> struct is_arithmetic { enum { value = true }; }; + +template<> struct packet_traits : default_packet_traits +{ + typedef float4 type; + typedef float4 half; + enum { + Vectorizable = 1, + AlignedOnScalar = 1, + size=4, + HasHalfPacket = 0, + + HasDiv = 1, + HasSin = 0, + HasCos = 0, + HasLog = 1, + HasExp = 1, + HasSqrt = 1, + HasRsqrt = 1, + HasLGamma = 1, + HasDiGamma = 1, + HasZeta = 1, + HasPolygamma = 1, + HasErf = 1, + HasErfc = 1, + HasNdtri = 1, + HasBessel = 1, + HasIGamma = 1, + HasIGammaDerA = 1, + HasGammaSampleDerAlpha = 1, + HasIGammac = 1, + HasBetaInc = 1, + + HasBlend = 0, + HasFloor = 1, + }; +}; + +template<> struct packet_traits : default_packet_traits +{ + typedef double2 type; + typedef double2 half; + enum { + Vectorizable = 1, + AlignedOnScalar = 1, + size=2, + HasHalfPacket = 0, + + HasDiv = 1, + HasLog = 1, + HasExp = 1, + HasSqrt = 1, + HasRsqrt = 1, + HasLGamma = 1, + HasDiGamma = 1, + HasZeta = 1, + HasPolygamma = 1, + HasErf = 1, + HasErfc = 1, + HasNdtri = 1, + HasBessel = 1, + HasIGamma = 1, + HasIGammaDerA = 1, + HasGammaSampleDerAlpha = 1, + HasIGammac = 1, + HasBetaInc = 1, + + HasBlend = 0, + HasFloor = 1, + }; +}; + + +template<> struct unpacket_traits { typedef float type; enum {size=4, alignment=Aligned16, vectorizable=true, masked_load_available=false, masked_store_available=false}; typedef float4 half; }; +template<> struct unpacket_traits { typedef double type; enum {size=2, alignment=Aligned16, vectorizable=true, masked_load_available=false, masked_store_available=false}; typedef double2 half; }; + +template<> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE float4 pset1(const float& from) { + return make_float4(from, from, from, from); +} +template<> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE double2 pset1(const double& from) { + return make_double2(from, from); +} + +// We need to distinguish ‘clang as the CUDA compiler’ from ‘clang as the host compiler, +// invoked by NVCC’ (e.g. on MacOS). The former needs to see both host and device implementation +// of the functions, while the latter can only deal with one of them. +#if defined(EIGEN_CUDA_ARCH) || defined(EIGEN_HIPCC) || (defined(EIGEN_CUDACC) && EIGEN_COMP_CLANG && !EIGEN_COMP_NVCC) +namespace { + +EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE float bitwise_and(const float& a, + const float& b) { + return __int_as_float(__float_as_int(a) & __float_as_int(b)); +} +EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE double bitwise_and(const double& a, + const double& b) { + return __longlong_as_double(__double_as_longlong(a) & + __double_as_longlong(b)); +} + +EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE float bitwise_or(const float& a, + const float& b) { + return __int_as_float(__float_as_int(a) | __float_as_int(b)); +} +EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE double bitwise_or(const double& a, + const double& b) { + return __longlong_as_double(__double_as_longlong(a) | + __double_as_longlong(b)); +} + +EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE float bitwise_xor(const float& a, + const float& b) { + return __int_as_float(__float_as_int(a) ^ __float_as_int(b)); +} +EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE double bitwise_xor(const double& a, + const double& b) { + return __longlong_as_double(__double_as_longlong(a) ^ + __double_as_longlong(b)); +} + +EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE float bitwise_andnot(const float& a, + const float& b) { + return __int_as_float(__float_as_int(a) & ~__float_as_int(b)); +} +EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE double bitwise_andnot(const double& a, + const double& b) { + return __longlong_as_double(__double_as_longlong(a) & + ~__double_as_longlong(b)); +} +EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE float eq_mask(const float& a, + const float& b) { + return __int_as_float(a == b ? 0xffffffffu : 0u); +} +EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE double eq_mask(const double& a, + const double& b) { + return __longlong_as_double(a == b ? 0xffffffffffffffffull : 0ull); +} + +EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE float lt_mask(const float& a, + const float& b) { + return __int_as_float(a < b ? 0xffffffffu : 0u); +} +EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE double lt_mask(const double& a, + const double& b) { + return __longlong_as_double(a < b ? 0xffffffffffffffffull : 0ull); +} + +} // namespace + +template <> +EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE float4 pand(const float4& a, + const float4& b) { + return make_float4(bitwise_and(a.x, b.x), bitwise_and(a.y, b.y), + bitwise_and(a.z, b.z), bitwise_and(a.w, b.w)); +} +template <> +EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE double2 pand(const double2& a, + const double2& b) { + return make_double2(bitwise_and(a.x, b.x), bitwise_and(a.y, b.y)); +} + +template <> +EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE float4 por(const float4& a, + const float4& b) { + return make_float4(bitwise_or(a.x, b.x), bitwise_or(a.y, b.y), + bitwise_or(a.z, b.z), bitwise_or(a.w, b.w)); +} +template <> +EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE double2 por(const double2& a, + const double2& b) { + return make_double2(bitwise_or(a.x, b.x), bitwise_or(a.y, b.y)); +} + +template <> +EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE float4 pxor(const float4& a, + const float4& b) { + return make_float4(bitwise_xor(a.x, b.x), bitwise_xor(a.y, b.y), + bitwise_xor(a.z, b.z), bitwise_xor(a.w, b.w)); +} +template <> +EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE double2 pxor(const double2& a, + const double2& b) { + return make_double2(bitwise_xor(a.x, b.x), bitwise_xor(a.y, b.y)); +} + +template <> +EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE float4 pandnot(const float4& a, + const float4& b) { + return make_float4(bitwise_andnot(a.x, b.x), bitwise_andnot(a.y, b.y), + bitwise_andnot(a.z, b.z), bitwise_andnot(a.w, b.w)); +} +template <> +EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE double2 +pandnot(const double2& a, const double2& b) { + return make_double2(bitwise_andnot(a.x, b.x), bitwise_andnot(a.y, b.y)); +} + +template <> +EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE float4 pcmp_eq(const float4& a, + const float4& b) { + return make_float4(eq_mask(a.x, b.x), eq_mask(a.y, b.y), eq_mask(a.z, b.z), + eq_mask(a.w, b.w)); +} +template <> +EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE float4 pcmp_lt(const float4& a, + const float4& b) { + return make_float4(lt_mask(a.x, b.x), lt_mask(a.y, b.y), lt_mask(a.z, b.z), + lt_mask(a.w, b.w)); +} +template <> +EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE double2 +pcmp_eq(const double2& a, const double2& b) { + return make_double2(eq_mask(a.x, b.x), eq_mask(a.y, b.y)); +} +template <> +EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE double2 +pcmp_lt(const double2& a, const double2& b) { + return make_double2(lt_mask(a.x, b.x), lt_mask(a.y, b.y)); +} +#endif // defined(EIGEN_CUDA_ARCH) || defined(EIGEN_HIPCC) || (defined(EIGEN_CUDACC) && EIGEN_COMP_CLANG && !EIGEN_COMP_NVCC) + +template<> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE float4 plset(const float& a) { + return make_float4(a, a+1, a+2, a+3); +} +template<> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE double2 plset(const double& a) { + return make_double2(a, a+1); +} + +template<> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE float4 padd(const float4& a, const float4& b) { + return make_float4(a.x+b.x, a.y+b.y, a.z+b.z, a.w+b.w); +} +template<> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE double2 padd(const double2& a, const double2& b) { + return make_double2(a.x+b.x, a.y+b.y); +} + +template<> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE float4 psub(const float4& a, const float4& b) { + return make_float4(a.x-b.x, a.y-b.y, a.z-b.z, a.w-b.w); +} +template<> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE double2 psub(const double2& a, const double2& b) { + return make_double2(a.x-b.x, a.y-b.y); +} + +template<> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE float4 pnegate(const float4& a) { + return make_float4(-a.x, -a.y, -a.z, -a.w); +} +template<> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE double2 pnegate(const double2& a) { + return make_double2(-a.x, -a.y); +} + +template<> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE float4 pconj(const float4& a) { return a; } +template<> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE double2 pconj(const double2& a) { return a; } + +template<> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE float4 pmul(const float4& a, const float4& b) { + return make_float4(a.x*b.x, a.y*b.y, a.z*b.z, a.w*b.w); +} +template<> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE double2 pmul(const double2& a, const double2& b) { + return make_double2(a.x*b.x, a.y*b.y); +} + +template<> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE float4 pdiv(const float4& a, const float4& b) { + return make_float4(a.x/b.x, a.y/b.y, a.z/b.z, a.w/b.w); +} +template<> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE double2 pdiv(const double2& a, const double2& b) { + return make_double2(a.x/b.x, a.y/b.y); +} + +template<> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE float4 pmin(const float4& a, const float4& b) { + return make_float4(fminf(a.x, b.x), fminf(a.y, b.y), fminf(a.z, b.z), fminf(a.w, b.w)); +} +template<> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE double2 pmin(const double2& a, const double2& b) { + return make_double2(fmin(a.x, b.x), fmin(a.y, b.y)); +} + +template<> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE float4 pmax(const float4& a, const float4& b) { + return make_float4(fmaxf(a.x, b.x), fmaxf(a.y, b.y), fmaxf(a.z, b.z), fmaxf(a.w, b.w)); +} +template<> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE double2 pmax(const double2& a, const double2& b) { + return make_double2(fmax(a.x, b.x), fmax(a.y, b.y)); +} + +template<> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE float4 pload(const float* from) { + return *reinterpret_cast(from); +} + +template<> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE double2 pload(const double* from) { + return *reinterpret_cast(from); +} + +template<> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE float4 ploadu(const float* from) { + return make_float4(from[0], from[1], from[2], from[3]); +} +template<> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE double2 ploadu(const double* from) { + return make_double2(from[0], from[1]); +} + +template<> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE float4 ploaddup(const float* from) { + return make_float4(from[0], from[0], from[1], from[1]); +} +template<> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE double2 ploaddup(const double* from) { + return make_double2(from[0], from[0]); +} + +template<> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void pstore(float* to, const float4& from) { + *reinterpret_cast(to) = from; +} + +template<> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void pstore(double* to, const double2& from) { + *reinterpret_cast(to) = from; +} + +template<> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void pstoreu(float* to, const float4& from) { + to[0] = from.x; + to[1] = from.y; + to[2] = from.z; + to[3] = from.w; +} + +template<> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void pstoreu(double* to, const double2& from) { + to[0] = from.x; + to[1] = from.y; +} + +template<> +EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE float4 ploadt_ro(const float* from) { +#if defined(EIGEN_GPU_HAS_LDG) + return __ldg((const float4*)from); +#else + return make_float4(from[0], from[1], from[2], from[3]); +#endif +} +template<> +EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE double2 ploadt_ro(const double* from) { +#if defined(EIGEN_GPU_HAS_LDG) + return __ldg((const double2*)from); +#else + return make_double2(from[0], from[1]); +#endif +} + +template<> +EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE float4 ploadt_ro(const float* from) { +#if defined(EIGEN_GPU_HAS_LDG) + return make_float4(__ldg(from+0), __ldg(from+1), __ldg(from+2), __ldg(from+3)); +#else + return make_float4(from[0], from[1], from[2], from[3]); +#endif +} +template<> +EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE double2 ploadt_ro(const double* from) { +#if defined(EIGEN_GPU_HAS_LDG) + return make_double2(__ldg(from+0), __ldg(from+1)); +#else + return make_double2(from[0], from[1]); +#endif +} + +template<> EIGEN_DEVICE_FUNC inline float4 pgather(const float* from, Index stride) { + return make_float4(from[0*stride], from[1*stride], from[2*stride], from[3*stride]); +} + +template<> EIGEN_DEVICE_FUNC inline double2 pgather(const double* from, Index stride) { + return make_double2(from[0*stride], from[1*stride]); +} + +template<> EIGEN_DEVICE_FUNC inline void pscatter(float* to, const float4& from, Index stride) { + to[stride*0] = from.x; + to[stride*1] = from.y; + to[stride*2] = from.z; + to[stride*3] = from.w; +} +template<> EIGEN_DEVICE_FUNC inline void pscatter(double* to, const double2& from, Index stride) { + to[stride*0] = from.x; + to[stride*1] = from.y; +} + +template<> EIGEN_DEVICE_FUNC inline float pfirst(const float4& a) { + return a.x; +} +template<> EIGEN_DEVICE_FUNC inline double pfirst(const double2& a) { + return a.x; +} + +template<> EIGEN_DEVICE_FUNC inline float predux(const float4& a) { + return a.x + a.y + a.z + a.w; +} +template<> EIGEN_DEVICE_FUNC inline double predux(const double2& a) { + return a.x + a.y; +} + +template<> EIGEN_DEVICE_FUNC inline float predux_max(const float4& a) { + return fmaxf(fmaxf(a.x, a.y), fmaxf(a.z, a.w)); +} +template<> EIGEN_DEVICE_FUNC inline double predux_max(const double2& a) { + return fmax(a.x, a.y); +} + +template<> EIGEN_DEVICE_FUNC inline float predux_min(const float4& a) { + return fminf(fminf(a.x, a.y), fminf(a.z, a.w)); +} +template<> EIGEN_DEVICE_FUNC inline double predux_min(const double2& a) { + return fmin(a.x, a.y); +} + +template<> EIGEN_DEVICE_FUNC inline float predux_mul(const float4& a) { + return a.x * a.y * a.z * a.w; +} +template<> EIGEN_DEVICE_FUNC inline double predux_mul(const double2& a) { + return a.x * a.y; +} + +template<> EIGEN_DEVICE_FUNC inline float4 pabs(const float4& a) { + return make_float4(fabsf(a.x), fabsf(a.y), fabsf(a.z), fabsf(a.w)); +} +template<> EIGEN_DEVICE_FUNC inline double2 pabs(const double2& a) { + return make_double2(fabs(a.x), fabs(a.y)); +} + +template<> EIGEN_DEVICE_FUNC inline float4 pfloor(const float4& a) { + return make_float4(floorf(a.x), floorf(a.y), floorf(a.z), floorf(a.w)); +} +template<> EIGEN_DEVICE_FUNC inline double2 pfloor(const double2& a) { + return make_double2(floor(a.x), floor(a.y)); +} + +EIGEN_DEVICE_FUNC inline void +ptranspose(PacketBlock& kernel) { + float tmp = kernel.packet[0].y; + kernel.packet[0].y = kernel.packet[1].x; + kernel.packet[1].x = tmp; + + tmp = kernel.packet[0].z; + kernel.packet[0].z = kernel.packet[2].x; + kernel.packet[2].x = tmp; + + tmp = kernel.packet[0].w; + kernel.packet[0].w = kernel.packet[3].x; + kernel.packet[3].x = tmp; + + tmp = kernel.packet[1].z; + kernel.packet[1].z = kernel.packet[2].y; + kernel.packet[2].y = tmp; + + tmp = kernel.packet[1].w; + kernel.packet[1].w = kernel.packet[3].y; + kernel.packet[3].y = tmp; + + tmp = kernel.packet[2].w; + kernel.packet[2].w = kernel.packet[3].z; + kernel.packet[3].z = tmp; +} + +EIGEN_DEVICE_FUNC inline void +ptranspose(PacketBlock& kernel) { + double tmp = kernel.packet[0].y; + kernel.packet[0].y = kernel.packet[1].x; + kernel.packet[1].x = tmp; +} + +#endif // defined(EIGEN_GPUCC) && defined(EIGEN_USE_GPU) + +// Packet4h2 must be defined in the macro without EIGEN_CUDA_ARCH, meaning +// its corresponding packet_traits must be visible on host. +#if defined(EIGEN_HAS_CUDA_FP16) || defined(EIGEN_HAS_HIP_FP16) + +typedef ulonglong2 Packet4h2; +template<> struct unpacket_traits { typedef Eigen::half type; enum {size=8, alignment=Aligned16, vectorizable=true, masked_load_available=false, masked_store_available=false}; typedef Packet4h2 half; }; +template<> struct is_arithmetic { enum { value = true }; }; + +template<> struct unpacket_traits { typedef Eigen::half type; enum {size=2, alignment=Aligned16, vectorizable=true, masked_load_available=false, masked_store_available=false}; typedef half2 half; }; +template<> struct is_arithmetic { enum { value = true }; }; + +template<> struct packet_traits : default_packet_traits +{ + typedef Packet4h2 type; + typedef Packet4h2 half; + enum { + Vectorizable = 1, + AlignedOnScalar = 1, + size=8, + HasHalfPacket = 0, + HasAdd = 1, + HasSub = 1, + HasMul = 1, + HasDiv = 1, + HasSqrt = 1, + HasRsqrt = 1, + HasExp = 1, + HasExpm1 = 1, + HasLog = 1, + HasLog1p = 1 + }; +}; + +namespace { +// This is equivalent to make_half2, which is undocumented and doesn't seem to always exist. +EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE half2 combine_half(const __half& a, const __half& b) { +#if defined(EIGEN_GPU_COMPILE_PHASE) + return __halves2half2(a, b); +#else + // Round-about way since __halves2half2 is a __device__ function. + return __floats2half2_rn(__half2float(a), __half2float(b)); +#endif +} + +EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE __half get_half2_low(const half2& a) { +#if defined(EIGEN_GPU_COMPILE_PHASE) + return __low2half(a); +#else + return __float2half(__low2float(a)); +#endif +} + +EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE __half get_half2_high(const half2& a) { +#if defined(EIGEN_GPU_COMPILE_PHASE) + return __high2half(a); +#else + return __float2half(__high2float(a)); +#endif +} +} // namespace + +template<> +EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE half2 pset1(const Eigen::half& from) { +#if defined(EIGEN_GPU_COMPILE_PHASE) + return __half2half2(from); +#else + const float f = __half2float(from); + return __floats2half2_rn(f, f); +#endif +} + +template <> +EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet4h2 +pset1(const Eigen::half& from) { + Packet4h2 r; + half2* p_alias = reinterpret_cast(&r); + p_alias[0] = pset1(from); + p_alias[1] = pset1(from); + p_alias[2] = pset1(from); + p_alias[3] = pset1(from); + return r; +} + +// We now need this visible on both host and device. +// #if defined(EIGEN_CUDA_ARCH) || defined(EIGEN_HIPCC) || (defined(EIGEN_CUDACC) && EIGEN_COMP_CLANG && !EIGEN_COMP_NVCC) +namespace { + +EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE half2 pload(const Eigen::half* from) { + return *reinterpret_cast(from); +} + +EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE half2 ploadu(const Eigen::half* from) { + return combine_half(from[0], from[1]); +} + +EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE half2 ploaddup(const Eigen::half* from) { + return combine_half(from[0], from[0]); +} + +EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void pstore(Eigen::half* to, + const half2& from) { + *reinterpret_cast(to) = from; +} + +EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void pstoreu(Eigen::half* to, + const half2& from) { + to[0] = get_half2_low(from); + to[1] = get_half2_high(from); +} + + +EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE half2 ploadt_ro_aligned( + const Eigen::half* from) { +#if defined(EIGEN_GPU_HAS_LDG) + // Input is guaranteed to be properly aligned. + return __ldg(reinterpret_cast(from)); +#else + return combine_half(*(from+0), *(from+1)); +#endif +} + +EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE half2 ploadt_ro_unaligned( + const Eigen::half* from) { +#if defined(EIGEN_GPU_HAS_LDG) + return __halves2half2(__ldg(from+0), __ldg(from+1)); +#else + return combine_half(*(from+0), *(from+1)); +#endif +} + +EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE half2 pgather(const Eigen::half* from, + Index stride) { + return combine_half(from[0*stride], from[1*stride]); +} + +EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void pscatter( + Eigen::half* to, const half2& from, Index stride) { + to[stride*0] = get_half2_low(from); + to[stride*1] = get_half2_high(from); +} + +EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Eigen::half pfirst(const half2& a) { + return get_half2_low(a); +} + +EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE half2 pabs(const half2& a) { + half a1 = get_half2_low(a); + half a2 = get_half2_high(a); + half result1 = half_impl::raw_uint16_to_half(a1.x & 0x7FFF); + half result2 = half_impl::raw_uint16_to_half(a2.x & 0x7FFF); + return combine_half(result1, result2); +} + +EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE half2 ptrue(const half2& /*a*/) { + half true_half = half_impl::raw_uint16_to_half(0xffffu); + return pset1(true_half); +} + +EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE half2 pzero(const half2& /*a*/) { + half false_half = half_impl::raw_uint16_to_half(0x0000u); + return pset1(false_half); +} + +EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void +ptranspose(PacketBlock& kernel) { + __half a1 = get_half2_low(kernel.packet[0]); + __half a2 = get_half2_high(kernel.packet[0]); + __half b1 = get_half2_low(kernel.packet[1]); + __half b2 = get_half2_high(kernel.packet[1]); + kernel.packet[0] = combine_half(a1, b1); + kernel.packet[1] = combine_half(a2, b2); +} + +EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE half2 plset(const Eigen::half& a) { +#if defined(EIGEN_GPU_HAS_FP16_ARITHMETIC) + return __halves2half2(a, __hadd(a, __float2half(1.0f))); +#else + float f = __half2float(a) + 1.0f; + return combine_half(a, __float2half(f)); +#endif +} + +EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE half2 pselect(const half2& mask, + const half2& a, + const half2& b) { + half mask_low = get_half2_low(mask); + half mask_high = get_half2_high(mask); + half result_low = mask_low == half(0) ? get_half2_low(b) : get_half2_low(a); + half result_high = mask_high == half(0) ? get_half2_high(b) : get_half2_high(a); + return combine_half(result_low, result_high); +} + +EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE half2 pcmp_eq(const half2& a, + const half2& b) { + half true_half = half_impl::raw_uint16_to_half(0xffffu); + half false_half = half_impl::raw_uint16_to_half(0x0000u); + half a1 = get_half2_low(a); + half a2 = get_half2_high(a); + half b1 = get_half2_low(b); + half b2 = get_half2_high(b); + half eq1 = __half2float(a1) == __half2float(b1) ? true_half : false_half; + half eq2 = __half2float(a2) == __half2float(b2) ? true_half : false_half; + return combine_half(eq1, eq2); +} + +EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE half2 pcmp_lt(const half2& a, + const half2& b) { + half true_half = half_impl::raw_uint16_to_half(0xffffu); + half false_half = half_impl::raw_uint16_to_half(0x0000u); + half a1 = get_half2_low(a); + half a2 = get_half2_high(a); + half b1 = get_half2_low(b); + half b2 = get_half2_high(b); + half eq1 = __half2float(a1) < __half2float(b1) ? true_half : false_half; + half eq2 = __half2float(a2) < __half2float(b2) ? true_half : false_half; + return combine_half(eq1, eq2); +} + +EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE half2 pand(const half2& a, + const half2& b) { + half a1 = get_half2_low(a); + half a2 = get_half2_high(a); + half b1 = get_half2_low(b); + half b2 = get_half2_high(b); + half result1 = half_impl::raw_uint16_to_half(a1.x & b1.x); + half result2 = half_impl::raw_uint16_to_half(a2.x & b2.x); + return combine_half(result1, result2); +} + +EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE half2 por(const half2& a, + const half2& b) { + half a1 = get_half2_low(a); + half a2 = get_half2_high(a); + half b1 = get_half2_low(b); + half b2 = get_half2_high(b); + half result1 = half_impl::raw_uint16_to_half(a1.x | b1.x); + half result2 = half_impl::raw_uint16_to_half(a2.x | b2.x); + return combine_half(result1, result2); +} + +EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE half2 pxor(const half2& a, + const half2& b) { + half a1 = get_half2_low(a); + half a2 = get_half2_high(a); + half b1 = get_half2_low(b); + half b2 = get_half2_high(b); + half result1 = half_impl::raw_uint16_to_half(a1.x ^ b1.x); + half result2 = half_impl::raw_uint16_to_half(a2.x ^ b2.x); + return combine_half(result1, result2); +} + +EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE half2 pandnot(const half2& a, + const half2& b) { + half a1 = get_half2_low(a); + half a2 = get_half2_high(a); + half b1 = get_half2_low(b); + half b2 = get_half2_high(b); + half result1 = half_impl::raw_uint16_to_half(a1.x & ~b1.x); + half result2 = half_impl::raw_uint16_to_half(a2.x & ~b2.x); + return combine_half(result1, result2); +} + +EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE half2 padd(const half2& a, + const half2& b) { +#if defined(EIGEN_GPU_HAS_FP16_ARITHMETIC) + return __hadd2(a, b); +#else + float a1 = __low2float(a); + float a2 = __high2float(a); + float b1 = __low2float(b); + float b2 = __high2float(b); + float r1 = a1 + b1; + float r2 = a2 + b2; + return __floats2half2_rn(r1, r2); +#endif +} + +EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE half2 psub(const half2& a, + const half2& b) { +#if defined(EIGEN_GPU_HAS_FP16_ARITHMETIC) + return __hsub2(a, b); +#else + float a1 = __low2float(a); + float a2 = __high2float(a); + float b1 = __low2float(b); + float b2 = __high2float(b); + float r1 = a1 - b1; + float r2 = a2 - b2; + return __floats2half2_rn(r1, r2); +#endif +} + +EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE half2 pnegate(const half2& a) { +#if defined(EIGEN_GPU_HAS_FP16_ARITHMETIC) + return __hneg2(a); +#else + float a1 = __low2float(a); + float a2 = __high2float(a); + return __floats2half2_rn(-a1, -a2); +#endif +} + +EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE half2 pconj(const half2& a) { return a; } + +EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE half2 pmul(const half2& a, + const half2& b) { +#if defined(EIGEN_GPU_HAS_FP16_ARITHMETIC) + return __hmul2(a, b); +#else + float a1 = __low2float(a); + float a2 = __high2float(a); + float b1 = __low2float(b); + float b2 = __high2float(b); + float r1 = a1 * b1; + float r2 = a2 * b2; + return __floats2half2_rn(r1, r2); +#endif +} + +EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE half2 pmadd(const half2& a, + const half2& b, + const half2& c) { +#if defined(EIGEN_GPU_HAS_FP16_ARITHMETIC) + return __hfma2(a, b, c); +#else + float a1 = __low2float(a); + float a2 = __high2float(a); + float b1 = __low2float(b); + float b2 = __high2float(b); + float c1 = __low2float(c); + float c2 = __high2float(c); + float r1 = a1 * b1 + c1; + float r2 = a2 * b2 + c2; + return __floats2half2_rn(r1, r2); +#endif +} + +EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE half2 pdiv(const half2& a, + const half2& b) { +#if defined(EIGEN_GPU_HAS_FP16_ARITHMETIC) + return __h2div(a, b); +#else + float a1 = __low2float(a); + float a2 = __high2float(a); + float b1 = __low2float(b); + float b2 = __high2float(b); + float r1 = a1 / b1; + float r2 = a2 / b2; + return __floats2half2_rn(r1, r2); +#endif +} + +EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE half2 pmin(const half2& a, + const half2& b) { + float a1 = __low2float(a); + float a2 = __high2float(a); + float b1 = __low2float(b); + float b2 = __high2float(b); + __half r1 = a1 < b1 ? get_half2_low(a) : get_half2_low(b); + __half r2 = a2 < b2 ? get_half2_high(a) : get_half2_high(b); + return combine_half(r1, r2); +} + +EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE half2 pmax(const half2& a, + const half2& b) { + float a1 = __low2float(a); + float a2 = __high2float(a); + float b1 = __low2float(b); + float b2 = __high2float(b); + __half r1 = a1 > b1 ? get_half2_low(a) : get_half2_low(b); + __half r2 = a2 > b2 ? get_half2_high(a) : get_half2_high(b); + return combine_half(r1, r2); +} + +EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Eigen::half predux(const half2& a) { +#if defined(EIGEN_GPU_HAS_FP16_ARITHMETIC) + return __hadd(__low2half(a), __high2half(a)); +#else + float a1 = __low2float(a); + float a2 = __high2float(a); + return Eigen::half(__float2half(a1 + a2)); +#endif +} + +EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Eigen::half predux_max(const half2& a) { +#if defined(EIGEN_GPU_HAS_FP16_ARITHMETIC) + __half first = __low2half(a); + __half second = __high2half(a); + return __hgt(first, second) ? first : second; +#else + float a1 = __low2float(a); + float a2 = __high2float(a); + return a1 > a2 ? get_half2_low(a) : get_half2_high(a); +#endif +} + +EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Eigen::half predux_min(const half2& a) { +#if defined(EIGEN_GPU_HAS_FP16_ARITHMETIC) + __half first = __low2half(a); + __half second = __high2half(a); + return __hlt(first, second) ? first : second; +#else + float a1 = __low2float(a); + float a2 = __high2float(a); + return a1 < a2 ? get_half2_low(a) : get_half2_high(a); +#endif +} + +EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Eigen::half predux_mul(const half2& a) { +#if defined(EIGEN_GPU_HAS_FP16_ARITHMETIC) + return __hmul(__low2half(a), __high2half(a)); +#else + float a1 = __low2float(a); + float a2 = __high2float(a); + return Eigen::half(__float2half(a1 * a2)); +#endif +} + +EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE half2 plog1p(const half2& a) { + float a1 = __low2float(a); + float a2 = __high2float(a); + float r1 = log1pf(a1); + float r2 = log1pf(a2); + return __floats2half2_rn(r1, r2); +} + +EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE half2 pexpm1(const half2& a) { + float a1 = __low2float(a); + float a2 = __high2float(a); + float r1 = expm1f(a1); + float r2 = expm1f(a2); + return __floats2half2_rn(r1, r2); +} + +#if (EIGEN_CUDA_SDK_VER >= 80000 && defined(EIGEN_CUDA_HAS_FP16_ARITHMETIC)) || \ + defined(EIGEN_HIP_DEVICE_COMPILE) + +EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE +half2 plog(const half2& a) { + return h2log(a); +} + + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE +half2 pexp(const half2& a) { + return h2exp(a); +} + + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE +half2 psqrt(const half2& a) { + return h2sqrt(a); +} + + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE +half2 prsqrt(const half2& a) { + return h2rsqrt(a); +} + +#else + +EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE half2 plog(const half2& a) { + float a1 = __low2float(a); + float a2 = __high2float(a); + float r1 = logf(a1); + float r2 = logf(a2); + return __floats2half2_rn(r1, r2); +} + +EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE half2 pexp(const half2& a) { + float a1 = __low2float(a); + float a2 = __high2float(a); + float r1 = expf(a1); + float r2 = expf(a2); + return __floats2half2_rn(r1, r2); +} + +EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE half2 psqrt(const half2& a) { + float a1 = __low2float(a); + float a2 = __high2float(a); + float r1 = sqrtf(a1); + float r2 = sqrtf(a2); + return __floats2half2_rn(r1, r2); +} + +EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE half2 prsqrt(const half2& a) { + float a1 = __low2float(a); + float a2 = __high2float(a); + float r1 = rsqrtf(a1); + float r2 = rsqrtf(a2); + return __floats2half2_rn(r1, r2); +} +#endif +} // namespace + +template <> +EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet4h2 +pload(const Eigen::half* from) { + return *reinterpret_cast(from); +} + +// unaligned load; +template <> +EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet4h2 +ploadu(const Eigen::half* from) { + Packet4h2 r; + half2* p_alias = reinterpret_cast(&r); + p_alias[0] = ploadu(from + 0); + p_alias[1] = ploadu(from + 2); + p_alias[2] = ploadu(from + 4); + p_alias[3] = ploadu(from + 6); + return r; +} + +template <> +EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet4h2 +ploaddup(const Eigen::half* from) { + Packet4h2 r; + half2* p_alias = reinterpret_cast(&r); + p_alias[0] = ploaddup(from + 0); + p_alias[1] = ploaddup(from + 1); + p_alias[2] = ploaddup(from + 2); + p_alias[3] = ploaddup(from + 3); + return r; +} + +template <> +EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void pstore( + Eigen::half* to, const Packet4h2& from) { + *reinterpret_cast(to) = from; +} + +template <> +EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void pstoreu( + Eigen::half* to, const Packet4h2& from) { + const half2* from_alias = reinterpret_cast(&from); + pstoreu(to + 0,from_alias[0]); + pstoreu(to + 2,from_alias[1]); + pstoreu(to + 4,from_alias[2]); + pstoreu(to + 6,from_alias[3]); +} + +template <> +EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Packet4h2 +ploadt_ro(const Eigen::half* from) { +#if defined(EIGEN_GPU_HAS_LDG) + Packet4h2 r; + r = __ldg(reinterpret_cast(from)); + return r; +#else + Packet4h2 r; + half2* r_alias = reinterpret_cast(&r); + r_alias[0] = ploadt_ro_aligned(from + 0); + r_alias[1] = ploadt_ro_aligned(from + 2); + r_alias[2] = ploadt_ro_aligned(from + 4); + r_alias[3] = ploadt_ro_aligned(from + 6); + return r; +#endif +} + +template <> +EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Packet4h2 +ploadt_ro(const Eigen::half* from) { + Packet4h2 r; + half2* r_alias = reinterpret_cast(&r); + r_alias[0] = ploadt_ro_unaligned(from + 0); + r_alias[1] = ploadt_ro_unaligned(from + 2); + r_alias[2] = ploadt_ro_unaligned(from + 4); + r_alias[3] = ploadt_ro_unaligned(from + 6); + return r; +} + +template <> +EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet4h2 +pgather(const Eigen::half* from, Index stride) { + Packet4h2 r; + half2* p_alias = reinterpret_cast(&r); + p_alias[0] = combine_half(from[0 * stride], from[1 * stride]); + p_alias[1] = combine_half(from[2 * stride], from[3 * stride]); + p_alias[2] = combine_half(from[4 * stride], from[5 * stride]); + p_alias[3] = combine_half(from[6 * stride], from[7 * stride]); + return r; +} + +template <> +EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void pscatter( + Eigen::half* to, const Packet4h2& from, Index stride) { + const half2* from_alias = reinterpret_cast(&from); + pscatter(to + stride * 0, from_alias[0], stride); + pscatter(to + stride * 2, from_alias[1], stride); + pscatter(to + stride * 4, from_alias[2], stride); + pscatter(to + stride * 6, from_alias[3], stride); +} + +template <> +EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Eigen::half pfirst( + const Packet4h2& a) { + return pfirst(*(reinterpret_cast(&a))); +} + +template <> +EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet4h2 pabs( + const Packet4h2& a) { + Packet4h2 r; + half2* p_alias = reinterpret_cast(&r); + const half2* a_alias = reinterpret_cast(&a); + p_alias[0] = pabs(a_alias[0]); + p_alias[1] = pabs(a_alias[1]); + p_alias[2] = pabs(a_alias[2]); + p_alias[3] = pabs(a_alias[3]); + return r; +} + +template <> +EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet4h2 ptrue( + const Packet4h2& /*a*/) { + half true_half = half_impl::raw_uint16_to_half(0xffffu); + return pset1(true_half); +} + +template <> +EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet4h2 pzero(const Packet4h2& /*a*/) { + half false_half = half_impl::raw_uint16_to_half(0x0000u); + return pset1(false_half); +} + +EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void ptranspose_double( + double* d_row0, double* d_row1, double* d_row2, double* d_row3, + double* d_row4, double* d_row5, double* d_row6, double* d_row7) { + double d_tmp; + d_tmp = d_row0[1]; + d_row0[1] = d_row4[0]; + d_row4[0] = d_tmp; + + d_tmp = d_row1[1]; + d_row1[1] = d_row5[0]; + d_row5[0] = d_tmp; + + d_tmp = d_row2[1]; + d_row2[1] = d_row6[0]; + d_row6[0] = d_tmp; + + d_tmp = d_row3[1]; + d_row3[1] = d_row7[0]; + d_row7[0] = d_tmp; +} + +EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void ptranspose_half2( + half2* f_row0, half2* f_row1, half2* f_row2, half2* f_row3) { + half2 f_tmp; + f_tmp = f_row0[1]; + f_row0[1] = f_row2[0]; + f_row2[0] = f_tmp; + + f_tmp = f_row1[1]; + f_row1[1] = f_row3[0]; + f_row3[0] = f_tmp; +} + +EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void +ptranspose_half(half2& f0, half2& f1) { + __half a1 = get_half2_low(f0); + __half a2 = get_half2_high(f0); + __half b1 = get_half2_low(f1); + __half b2 = get_half2_high(f1); + f0 = combine_half(a1, b1); + f1 = combine_half(a2, b2); +} + +EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void +ptranspose(PacketBlock& kernel) { + double* d_row0 = reinterpret_cast(&kernel.packet[0]); + double* d_row1 = reinterpret_cast(&kernel.packet[1]); + double* d_row2 = reinterpret_cast(&kernel.packet[2]); + double* d_row3 = reinterpret_cast(&kernel.packet[3]); + double* d_row4 = reinterpret_cast(&kernel.packet[4]); + double* d_row5 = reinterpret_cast(&kernel.packet[5]); + double* d_row6 = reinterpret_cast(&kernel.packet[6]); + double* d_row7 = reinterpret_cast(&kernel.packet[7]); + ptranspose_double(d_row0, d_row1, d_row2, d_row3, + d_row4, d_row5, d_row6, d_row7); + + + half2* f_row0 = reinterpret_cast(d_row0); + half2* f_row1 = reinterpret_cast(d_row1); + half2* f_row2 = reinterpret_cast(d_row2); + half2* f_row3 = reinterpret_cast(d_row3); + ptranspose_half2(f_row0, f_row1, f_row2, f_row3); + ptranspose_half(f_row0[0], f_row1[0]); + ptranspose_half(f_row0[1], f_row1[1]); + ptranspose_half(f_row2[0], f_row3[0]); + ptranspose_half(f_row2[1], f_row3[1]); + + f_row0 = reinterpret_cast(d_row0 + 1); + f_row1 = reinterpret_cast(d_row1 + 1); + f_row2 = reinterpret_cast(d_row2 + 1); + f_row3 = reinterpret_cast(d_row3 + 1); + ptranspose_half2(f_row0, f_row1, f_row2, f_row3); + ptranspose_half(f_row0[0], f_row1[0]); + ptranspose_half(f_row0[1], f_row1[1]); + ptranspose_half(f_row2[0], f_row3[0]); + ptranspose_half(f_row2[1], f_row3[1]); + + f_row0 = reinterpret_cast(d_row4); + f_row1 = reinterpret_cast(d_row5); + f_row2 = reinterpret_cast(d_row6); + f_row3 = reinterpret_cast(d_row7); + ptranspose_half2(f_row0, f_row1, f_row2, f_row3); + ptranspose_half(f_row0[0], f_row1[0]); + ptranspose_half(f_row0[1], f_row1[1]); + ptranspose_half(f_row2[0], f_row3[0]); + ptranspose_half(f_row2[1], f_row3[1]); + + f_row0 = reinterpret_cast(d_row4 + 1); + f_row1 = reinterpret_cast(d_row5 + 1); + f_row2 = reinterpret_cast(d_row6 + 1); + f_row3 = reinterpret_cast(d_row7 + 1); + ptranspose_half2(f_row0, f_row1, f_row2, f_row3); + ptranspose_half(f_row0[0], f_row1[0]); + ptranspose_half(f_row0[1], f_row1[1]); + ptranspose_half(f_row2[0], f_row3[0]); + ptranspose_half(f_row2[1], f_row3[1]); + +} + +template <> +EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet4h2 +plset(const Eigen::half& a) { +#if defined(EIGEN_HIP_DEVICE_COMPILE) + + Packet4h2 r; + half2* p_alias = reinterpret_cast(&r); + p_alias[0] = __halves2half2(a, __hadd(a, __float2half(1.0f))); + p_alias[1] = __halves2half2(__hadd(a, __float2half(2.0f)), + __hadd(a, __float2half(3.0f))); + p_alias[2] = __halves2half2(__hadd(a, __float2half(4.0f)), + __hadd(a, __float2half(5.0f))); + p_alias[3] = __halves2half2(__hadd(a, __float2half(6.0f)), + __hadd(a, __float2half(7.0f))); + return r; +#elif defined(EIGEN_CUDA_HAS_FP16_ARITHMETIC) + Packet4h2 r; + half2* r_alias = reinterpret_cast(&r); + + half2 b = pset1(a); + half2 c; + half2 half_offset0 = __halves2half2(__float2half(0.0f),__float2half(2.0f)); + half2 half_offset1 = __halves2half2(__float2half(4.0f),__float2half(6.0f)); + + c = __hadd2(b, half_offset0); + r_alias[0] = plset(__low2half(c)); + r_alias[1] = plset(__high2half(c)); + + c = __hadd2(b, half_offset1); + r_alias[2] = plset(__low2half(c)); + r_alias[3] = plset(__high2half(c)); + + return r; + +#else + float f = __half2float(a); + Packet4h2 r; + half2* p_alias = reinterpret_cast(&r); + p_alias[0] = combine_half(a, __float2half(f + 1.0f)); + p_alias[1] = combine_half(__float2half(f + 2.0f), __float2half(f + 3.0f)); + p_alias[2] = combine_half(__float2half(f + 4.0f), __float2half(f + 5.0f)); + p_alias[3] = combine_half(__float2half(f + 6.0f), __float2half(f + 7.0f)); + return r; +#endif +} + +template <> +EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet4h2 +pselect(const Packet4h2& mask, const Packet4h2& a, + const Packet4h2& b) { + Packet4h2 r; + half2* r_alias = reinterpret_cast(&r); + const half2* mask_alias = reinterpret_cast(&mask); + const half2* a_alias = reinterpret_cast(&a); + const half2* b_alias = reinterpret_cast(&b); + r_alias[0] = pselect(mask_alias[0], a_alias[0], b_alias[0]); + r_alias[1] = pselect(mask_alias[1], a_alias[1], b_alias[1]); + r_alias[2] = pselect(mask_alias[2], a_alias[2], b_alias[2]); + r_alias[3] = pselect(mask_alias[3], a_alias[3], b_alias[3]); + return r; +} + +template <> +EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet4h2 +pcmp_eq(const Packet4h2& a, const Packet4h2& b) { + Packet4h2 r; + half2* r_alias = reinterpret_cast(&r); + const half2* a_alias = reinterpret_cast(&a); + const half2* b_alias = reinterpret_cast(&b); + r_alias[0] = pcmp_eq(a_alias[0], b_alias[0]); + r_alias[1] = pcmp_eq(a_alias[1], b_alias[1]); + r_alias[2] = pcmp_eq(a_alias[2], b_alias[2]); + r_alias[3] = pcmp_eq(a_alias[3], b_alias[3]); + return r; +} + +template <> +EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet4h2 pand( + const Packet4h2& a, const Packet4h2& b) { + Packet4h2 r; + half2* r_alias = reinterpret_cast(&r); + const half2* a_alias = reinterpret_cast(&a); + const half2* b_alias = reinterpret_cast(&b); + r_alias[0] = pand(a_alias[0], b_alias[0]); + r_alias[1] = pand(a_alias[1], b_alias[1]); + r_alias[2] = pand(a_alias[2], b_alias[2]); + r_alias[3] = pand(a_alias[3], b_alias[3]); + return r; +} + +template <> +EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet4h2 por( + const Packet4h2& a, const Packet4h2& b) { + Packet4h2 r; + half2* r_alias = reinterpret_cast(&r); + const half2* a_alias = reinterpret_cast(&a); + const half2* b_alias = reinterpret_cast(&b); + r_alias[0] = por(a_alias[0], b_alias[0]); + r_alias[1] = por(a_alias[1], b_alias[1]); + r_alias[2] = por(a_alias[2], b_alias[2]); + r_alias[3] = por(a_alias[3], b_alias[3]); + return r; +} + +template <> +EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet4h2 pxor( + const Packet4h2& a, const Packet4h2& b) { + Packet4h2 r; + half2* r_alias = reinterpret_cast(&r); + const half2* a_alias = reinterpret_cast(&a); + const half2* b_alias = reinterpret_cast(&b); + r_alias[0] = pxor(a_alias[0], b_alias[0]); + r_alias[1] = pxor(a_alias[1], b_alias[1]); + r_alias[2] = pxor(a_alias[2], b_alias[2]); + r_alias[3] = pxor(a_alias[3], b_alias[3]); + return r; +} + +template <> +EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet4h2 +pandnot(const Packet4h2& a, const Packet4h2& b) { + Packet4h2 r; + half2* r_alias = reinterpret_cast(&r); + const half2* a_alias = reinterpret_cast(&a); + const half2* b_alias = reinterpret_cast(&b); + r_alias[0] = pandnot(a_alias[0], b_alias[0]); + r_alias[1] = pandnot(a_alias[1], b_alias[1]); + r_alias[2] = pandnot(a_alias[2], b_alias[2]); + r_alias[3] = pandnot(a_alias[3], b_alias[3]); + return r; +} + +template <> +EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet4h2 padd( + const Packet4h2& a, const Packet4h2& b) { + Packet4h2 r; + half2* r_alias = reinterpret_cast(&r); + const half2* a_alias = reinterpret_cast(&a); + const half2* b_alias = reinterpret_cast(&b); + r_alias[0] = padd(a_alias[0], b_alias[0]); + r_alias[1] = padd(a_alias[1], b_alias[1]); + r_alias[2] = padd(a_alias[2], b_alias[2]); + r_alias[3] = padd(a_alias[3], b_alias[3]); + return r; +} + +template <> +EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet4h2 psub( + const Packet4h2& a, const Packet4h2& b) { + Packet4h2 r; + half2* r_alias = reinterpret_cast(&r); + const half2* a_alias = reinterpret_cast(&a); + const half2* b_alias = reinterpret_cast(&b); + r_alias[0] = psub(a_alias[0], b_alias[0]); + r_alias[1] = psub(a_alias[1], b_alias[1]); + r_alias[2] = psub(a_alias[2], b_alias[2]); + r_alias[3] = psub(a_alias[3], b_alias[3]); + return r; +} + +template <> +EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet4h2 pnegate(const Packet4h2& a) { + Packet4h2 r; + half2* r_alias = reinterpret_cast(&r); + const half2* a_alias = reinterpret_cast(&a); + r_alias[0] = pnegate(a_alias[0]); + r_alias[1] = pnegate(a_alias[1]); + r_alias[2] = pnegate(a_alias[2]); + r_alias[3] = pnegate(a_alias[3]); + return r; +} + +template <> +EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet4h2 pconj(const Packet4h2& a) { + return a; +} + +template <> +EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet4h2 pmul( + const Packet4h2& a, const Packet4h2& b) { + Packet4h2 r; + half2* r_alias = reinterpret_cast(&r); + const half2* a_alias = reinterpret_cast(&a); + const half2* b_alias = reinterpret_cast(&b); + r_alias[0] = pmul(a_alias[0], b_alias[0]); + r_alias[1] = pmul(a_alias[1], b_alias[1]); + r_alias[2] = pmul(a_alias[2], b_alias[2]); + r_alias[3] = pmul(a_alias[3], b_alias[3]); + return r; +} + +template <> +EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet4h2 pmadd( + const Packet4h2& a, const Packet4h2& b, const Packet4h2& c) { + Packet4h2 r; + half2* r_alias = reinterpret_cast(&r); + const half2* a_alias = reinterpret_cast(&a); + const half2* b_alias = reinterpret_cast(&b); + const half2* c_alias = reinterpret_cast(&c); + r_alias[0] = pmadd(a_alias[0], b_alias[0], c_alias[0]); + r_alias[1] = pmadd(a_alias[1], b_alias[1], c_alias[1]); + r_alias[2] = pmadd(a_alias[2], b_alias[2], c_alias[2]); + r_alias[3] = pmadd(a_alias[3], b_alias[3], c_alias[3]); + return r; +} + +template <> +EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet4h2 pdiv( + const Packet4h2& a, const Packet4h2& b) { + Packet4h2 r; + half2* r_alias = reinterpret_cast(&r); + const half2* a_alias = reinterpret_cast(&a); + const half2* b_alias = reinterpret_cast(&b); + r_alias[0] = pdiv(a_alias[0], b_alias[0]); + r_alias[1] = pdiv(a_alias[1], b_alias[1]); + r_alias[2] = pdiv(a_alias[2], b_alias[2]); + r_alias[3] = pdiv(a_alias[3], b_alias[3]); + return r; +} + +template <> +EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet4h2 pmin( + const Packet4h2& a, const Packet4h2& b) { + Packet4h2 r; + half2* r_alias = reinterpret_cast(&r); + const half2* a_alias = reinterpret_cast(&a); + const half2* b_alias = reinterpret_cast(&b); + r_alias[0] = pmin(a_alias[0], b_alias[0]); + r_alias[1] = pmin(a_alias[1], b_alias[1]); + r_alias[2] = pmin(a_alias[2], b_alias[2]); + r_alias[3] = pmin(a_alias[3], b_alias[3]); + return r; +} + +template <> +EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet4h2 pmax( + const Packet4h2& a, const Packet4h2& b) { + Packet4h2 r; + half2* r_alias = reinterpret_cast(&r); + const half2* a_alias = reinterpret_cast(&a); + const half2* b_alias = reinterpret_cast(&b); + r_alias[0] = pmax(a_alias[0], b_alias[0]); + r_alias[1] = pmax(a_alias[1], b_alias[1]); + r_alias[2] = pmax(a_alias[2], b_alias[2]); + r_alias[3] = pmax(a_alias[3], b_alias[3]); + return r; +} + +template <> +EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Eigen::half predux( + const Packet4h2& a) { + const half2* a_alias = reinterpret_cast(&a); + + return predux(a_alias[0]) + predux(a_alias[1]) + + predux(a_alias[2]) + predux(a_alias[3]); +} + +template <> +EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Eigen::half predux_max( + const Packet4h2& a) { + const half2* a_alias = reinterpret_cast(&a); + half2 m0 = combine_half(predux_max(a_alias[0]), + predux_max(a_alias[1])); + half2 m1 = combine_half(predux_max(a_alias[2]), + predux_max(a_alias[3])); + __half first = predux_max(m0); + __half second = predux_max(m1); +#if defined(EIGEN_CUDA_HAS_FP16_ARITHMETIC) + return (__hgt(first, second) ? first : second); +#else + float ffirst = __half2float(first); + float fsecond = __half2float(second); + return (ffirst > fsecond)? first: second; +#endif +} + +template <> +EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Eigen::half predux_min( + const Packet4h2& a) { + const half2* a_alias = reinterpret_cast(&a); + half2 m0 = combine_half(predux_min(a_alias[0]), + predux_min(a_alias[1])); + half2 m1 = combine_half(predux_min(a_alias[2]), + predux_min(a_alias[3])); + __half first = predux_min(m0); + __half second = predux_min(m1); +#if defined(EIGEN_CUDA_HAS_FP16_ARITHMETIC) + return (__hlt(first, second) ? first : second); +#else + float ffirst = __half2float(first); + float fsecond = __half2float(second); + return (ffirst < fsecond)? first: second; +#endif +} + +// likely overflow/underflow +template <> +EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Eigen::half predux_mul( + const Packet4h2& a) { + const half2* a_alias = reinterpret_cast(&a); + return predux_mul(pmul(pmul(a_alias[0], a_alias[1]), + pmul(a_alias[2], a_alias[3]))); +} + +template <> +EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet4h2 +plog1p(const Packet4h2& a) { + Packet4h2 r; + half2* r_alias = reinterpret_cast(&r); + const half2* a_alias = reinterpret_cast(&a); + r_alias[0] = plog1p(a_alias[0]); + r_alias[1] = plog1p(a_alias[1]); + r_alias[2] = plog1p(a_alias[2]); + r_alias[3] = plog1p(a_alias[3]); + return r; +} + +template <> +EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet4h2 +pexpm1(const Packet4h2& a) { + Packet4h2 r; + half2* r_alias = reinterpret_cast(&r); + const half2* a_alias = reinterpret_cast(&a); + r_alias[0] = pexpm1(a_alias[0]); + r_alias[1] = pexpm1(a_alias[1]); + r_alias[2] = pexpm1(a_alias[2]); + r_alias[3] = pexpm1(a_alias[3]); + return r; +} + +template <> +EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet4h2 plog(const Packet4h2& a) { + Packet4h2 r; + half2* r_alias = reinterpret_cast(&r); + const half2* a_alias = reinterpret_cast(&a); + r_alias[0] = plog(a_alias[0]); + r_alias[1] = plog(a_alias[1]); + r_alias[2] = plog(a_alias[2]); + r_alias[3] = plog(a_alias[3]); + return r; +} + +template <> +EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet4h2 pexp(const Packet4h2& a) { + Packet4h2 r; + half2* r_alias = reinterpret_cast(&r); + const half2* a_alias = reinterpret_cast(&a); + r_alias[0] = pexp(a_alias[0]); + r_alias[1] = pexp(a_alias[1]); + r_alias[2] = pexp(a_alias[2]); + r_alias[3] = pexp(a_alias[3]); + return r; +} + +template <> +EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet4h2 psqrt(const Packet4h2& a) { + Packet4h2 r; + half2* r_alias = reinterpret_cast(&r); + const half2* a_alias = reinterpret_cast(&a); + r_alias[0] = psqrt(a_alias[0]); + r_alias[1] = psqrt(a_alias[1]); + r_alias[2] = psqrt(a_alias[2]); + r_alias[3] = psqrt(a_alias[3]); + return r; +} + +template <> +EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet4h2 +prsqrt(const Packet4h2& a) { + Packet4h2 r; + half2* r_alias = reinterpret_cast(&r); + const half2* a_alias = reinterpret_cast(&a); + r_alias[0] = prsqrt(a_alias[0]); + r_alias[1] = prsqrt(a_alias[1]); + r_alias[2] = prsqrt(a_alias[2]); + r_alias[3] = prsqrt(a_alias[3]); + return r; +} + +// The following specialized padd, pmul, pdiv, pmin, pmax, pset1 are needed for +// the implementation of GPU half reduction. +template<> +EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE half2 padd(const half2& a, + const half2& b) { +#if defined(EIGEN_GPU_HAS_FP16_ARITHMETIC) + return __hadd2(a, b); +#else + float a1 = __low2float(a); + float a2 = __high2float(a); + float b1 = __low2float(b); + float b2 = __high2float(b); + float r1 = a1 + b1; + float r2 = a2 + b2; + return __floats2half2_rn(r1, r2); +#endif +} + +template<> +EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE half2 pmul(const half2& a, + const half2& b) { +#if defined(EIGEN_GPU_HAS_FP16_ARITHMETIC) + return __hmul2(a, b); +#else + float a1 = __low2float(a); + float a2 = __high2float(a); + float b1 = __low2float(b); + float b2 = __high2float(b); + float r1 = a1 * b1; + float r2 = a2 * b2; + return __floats2half2_rn(r1, r2); +#endif +} + +template<> +EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE half2 pdiv(const half2& a, + const half2& b) { +#if defined(EIGEN_GPU_HAS_FP16_ARITHMETIC) + return __h2div(a, b); +#else + float a1 = __low2float(a); + float a2 = __high2float(a); + float b1 = __low2float(b); + float b2 = __high2float(b); + float r1 = a1 / b1; + float r2 = a2 / b2; + return __floats2half2_rn(r1, r2); +#endif +} + +template<> +EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE half2 pmin(const half2& a, + const half2& b) { + float a1 = __low2float(a); + float a2 = __high2float(a); + float b1 = __low2float(b); + float b2 = __high2float(b); + __half r1 = a1 < b1 ? get_half2_low(a) : get_half2_low(b); + __half r2 = a2 < b2 ? get_half2_high(a) : get_half2_high(b); + return combine_half(r1, r2); +} + +template<> +EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE half2 pmax(const half2& a, + const half2& b) { + float a1 = __low2float(a); + float a2 = __high2float(a); + float b1 = __low2float(b); + float b2 = __high2float(b); + __half r1 = a1 > b1 ? get_half2_low(a) : get_half2_low(b); + __half r2 = a2 > b2 ? get_half2_high(a) : get_half2_high(b); + return combine_half(r1, r2); +} + +// #endif // defined(EIGEN_CUDA_ARCH) || defined(EIGEN_HIPCC) || (defined(EIGEN_CUDACC) && EIGEN_COMP_CLANG && !EIGEN_COMP_NVCC) + +#endif // defined(EIGEN_HAS_CUDA_FP16) || defined(EIGEN_HAS_HIP_FP16) + +#undef EIGEN_GPU_HAS_LDG +#undef EIGEN_CUDA_HAS_FP16_ARITHMETIC +#undef EIGEN_GPU_HAS_FP16_ARITHMETIC + +} // end namespace internal + +} // end namespace Eigen + + +#endif // EIGEN_PACKET_MATH_GPU_H diff --git a/Eigen/src/Core/arch/GPU/TypeCasting.h b/Eigen/src/Core/arch/GPU/TypeCasting.h new file mode 100644 index 0000000..7545462 --- /dev/null +++ b/Eigen/src/Core/arch/GPU/TypeCasting.h @@ -0,0 +1,80 @@ +// This file is part of Eigen, a lightweight C++ template library +// for linear algebra. +// +// Copyright (C) 2016 Benoit Steiner +// +// This Source Code Form is subject to the terms of the Mozilla +// Public License v. 2.0. If a copy of the MPL was not distributed +// with this file, You can obtain one at http://mozilla.org/MPL/2.0/. + +#ifndef EIGEN_TYPE_CASTING_GPU_H +#define EIGEN_TYPE_CASTING_GPU_H + +namespace Eigen { + +namespace internal { + +#if (defined(EIGEN_HAS_CUDA_FP16) && defined(EIGEN_CUDA_ARCH) && EIGEN_CUDA_ARCH >= 300) || \ + (defined(EIGEN_HAS_HIP_FP16) && defined(EIGEN_HIP_DEVICE_COMPILE)) + + +template <> +struct type_casting_traits { + enum { + VectorizedCast = 1, + SrcCoeffRatio = 1, + TgtCoeffRatio = 2 + }; +}; + +template<> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE float4 pcast(const half2& a, const half2& b) { + float2 r1 = __half22float2(a); + float2 r2 = __half22float2(b); + return make_float4(r1.x, r1.y, r2.x, r2.y); +} + + +template<> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet4h2 pcast(const float4& a, const float4& b) { + Packet4h2 r; + half2* r_alias=reinterpret_cast(&r); + r_alias[0]=__floats2half2_rn(a.x,a.y); + r_alias[1]=__floats2half2_rn(a.z,a.w); + r_alias[2]=__floats2half2_rn(b.x,b.y); + r_alias[3]=__floats2half2_rn(b.z,b.w); + return r; +} + +template <> +struct type_casting_traits { + enum { + VectorizedCast = 1, + SrcCoeffRatio = 2, + TgtCoeffRatio = 1 + }; +}; + +template<> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE float4 pcast(const Packet4h2& a) { + // Simply discard the second half of the input + float4 r; + const half2* a_alias=reinterpret_cast(&a); + float2 r1 = __half22float2(a_alias[0]); + float2 r2 = __half22float2(a_alias[1]); + r.x=static_cast(r1.x); + r.y=static_cast(r1.y); + r.z=static_cast(r2.x); + r.w=static_cast(r2.y); + return r; +} + +template<> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE half2 pcast(const float4& a) { + // Simply discard the second half of the input + return __floats2half2_rn(a.x, a.y); +} + +#endif + +} // end namespace internal + +} // end namespace Eigen + +#endif // EIGEN_TYPE_CASTING_GPU_H diff --git a/Eigen/src/Core/arch/HIP/hcc/math_constants.h b/Eigen/src/Core/arch/HIP/hcc/math_constants.h new file mode 100644 index 0000000..25375a0 --- /dev/null +++ b/Eigen/src/Core/arch/HIP/hcc/math_constants.h @@ -0,0 +1,23 @@ +/* + * math_constants.h - + * HIP equivalent of the CUDA header of the same name + */ + +#ifndef __MATH_CONSTANTS_H__ +#define __MATH_CONSTANTS_H__ + +/* single precision constants */ + +#define HIPRT_INF_F __int_as_float(0x7f800000) +#define HIPRT_NAN_F __int_as_float(0x7fffffff) +#define HIPRT_MIN_DENORM_F __int_as_float(0x00000001) +#define HIPRT_MAX_NORMAL_F __int_as_float(0x7f7fffff) +#define HIPRT_NEG_ZERO_F __int_as_float(0x80000000) +#define HIPRT_ZERO_F 0.0f +#define HIPRT_ONE_F 1.0f + +/* double precision constants */ +#define HIPRT_INF __hiloint2double(0x7ff00000, 0x00000000) +#define HIPRT_NAN __hiloint2double(0xfff80000, 0x00000000) + +#endif diff --git a/Eigen/src/Core/arch/MSA/Complex.h b/Eigen/src/Core/arch/MSA/Complex.h new file mode 100644 index 0000000..53dacfa --- /dev/null +++ b/Eigen/src/Core/arch/MSA/Complex.h @@ -0,0 +1,648 @@ +// This file is part of Eigen, a lightweight C++ template library +// for linear algebra. +// +// Copyright (C) 2018 Wave Computing, Inc. +// Written by: +// Chris Larsen +// Alexey Frunze (afrunze@wavecomp.com) +// +// This Source Code Form is subject to the terms of the Mozilla +// Public License v. 2.0. If a copy of the MPL was not distributed +// with this file, You can obtain one at http://mozilla.org/MPL/2.0/. + +#ifndef EIGEN_COMPLEX_MSA_H +#define EIGEN_COMPLEX_MSA_H + +#include + +namespace Eigen { + +namespace internal { + +//---------- float ---------- +struct Packet2cf { + EIGEN_STRONG_INLINE Packet2cf() { + } + EIGEN_STRONG_INLINE explicit Packet2cf(const std::complex& a, + const std::complex& b) { + Packet4f t = { std::real(a), std::imag(a), std::real(b), std::imag(b) }; + v = t; + } + EIGEN_STRONG_INLINE explicit Packet2cf(const Packet4f& a) : v(a) { + } + EIGEN_STRONG_INLINE Packet2cf(const Packet2cf& a) : v(a.v) { + } + EIGEN_STRONG_INLINE Packet2cf& operator=(const Packet2cf& b) { + v = b.v; + return *this; + } + EIGEN_STRONG_INLINE Packet2cf conjugate(void) const { + return Packet2cf((Packet4f)__builtin_msa_bnegi_d((v2u64)v, 63)); + } + EIGEN_STRONG_INLINE Packet2cf& operator*=(const Packet2cf& b) { + Packet4f v1, v2; + + // Get the real values of a | a1_re | a1_re | a2_re | a2_re | + v1 = (Packet4f)__builtin_msa_ilvev_w((v4i32)v, (v4i32)v); + // Get the imag values of a | a1_im | a1_im | a2_im | a2_im | + v2 = (Packet4f)__builtin_msa_ilvod_w((v4i32)v, (v4i32)v); + // Multiply the real a with b + v1 = pmul(v1, b.v); + // Multiply the imag a with b + v2 = pmul(v2, b.v); + // Conjugate v2 + v2 = Packet2cf(v2).conjugate().v; + // Swap real/imag elements in v2. + v2 = (Packet4f)__builtin_msa_shf_w((v4i32)v2, EIGEN_MSA_SHF_I8(1, 0, 3, 2)); + // Add and return the result + v = padd(v1, v2); + return *this; + } + EIGEN_STRONG_INLINE Packet2cf operator*(const Packet2cf& b) const { + return Packet2cf(*this) *= b; + } + EIGEN_STRONG_INLINE Packet2cf& operator+=(const Packet2cf& b) { + v = padd(v, b.v); + return *this; + } + EIGEN_STRONG_INLINE Packet2cf operator+(const Packet2cf& b) const { + return Packet2cf(*this) += b; + } + EIGEN_STRONG_INLINE Packet2cf& operator-=(const Packet2cf& b) { + v = psub(v, b.v); + return *this; + } + EIGEN_STRONG_INLINE Packet2cf operator-(const Packet2cf& b) const { + return Packet2cf(*this) -= b; + } + EIGEN_STRONG_INLINE Packet2cf& operator/=(const Packet2cf& b) { + *this *= b.conjugate(); + Packet4f s = pmul(b.v, b.v); + s = padd(s, (Packet4f)__builtin_msa_shf_w((v4i32)s, EIGEN_MSA_SHF_I8(1, 0, 3, 2))); + v = pdiv(v, s); + return *this; + } + EIGEN_STRONG_INLINE Packet2cf operator/(const Packet2cf& b) const { + return Packet2cf(*this) /= b; + } + EIGEN_STRONG_INLINE Packet2cf operator-(void) const { + return Packet2cf(pnegate(v)); + } + + Packet4f v; +}; + +inline std::ostream& operator<<(std::ostream& os, const Packet2cf& value) { + os << "[ (" << value.v[0] << ", " << value.v[1] + << "i)," + " (" + << value.v[2] << ", " << value.v[3] << "i) ]"; + return os; +} + +template <> +struct packet_traits > : default_packet_traits { + typedef Packet2cf type; + typedef Packet2cf half; + enum { + Vectorizable = 1, + AlignedOnScalar = 1, + size = 2, + HasHalfPacket = 0, + + HasAdd = 1, + HasSub = 1, + HasMul = 1, + HasDiv = 1, + HasNegate = 1, + HasAbs = 0, + HasAbs2 = 0, + HasMin = 0, + HasMax = 0, + HasSetLinear = 0, + HasBlend = 1 + }; +}; + +template <> +struct unpacket_traits { + typedef std::complex type; + enum { size = 2, alignment = Aligned16, vectorizable=true, masked_load_available=false, masked_store_available=false }; + typedef Packet2cf half; +}; + +template <> +EIGEN_STRONG_INLINE Packet2cf pset1(const std::complex& from) { + EIGEN_MSA_DEBUG; + + float f0 = from.real(), f1 = from.imag(); + Packet4f v0 = { f0, f0, f0, f0 }; + Packet4f v1 = { f1, f1, f1, f1 }; + return Packet2cf((Packet4f)__builtin_msa_ilvr_w((Packet4i)v1, (Packet4i)v0)); +} + +template <> +EIGEN_STRONG_INLINE Packet2cf padd(const Packet2cf& a, const Packet2cf& b) { + EIGEN_MSA_DEBUG; + + return a + b; +} + +template <> +EIGEN_STRONG_INLINE Packet2cf psub(const Packet2cf& a, const Packet2cf& b) { + EIGEN_MSA_DEBUG; + + return a - b; +} + +template <> +EIGEN_STRONG_INLINE Packet2cf pnegate(const Packet2cf& a) { + EIGEN_MSA_DEBUG; + + return -a; +} + +template <> +EIGEN_STRONG_INLINE Packet2cf pconj(const Packet2cf& a) { + EIGEN_MSA_DEBUG; + + return a.conjugate(); +} + +template <> +EIGEN_STRONG_INLINE Packet2cf pmul(const Packet2cf& a, const Packet2cf& b) { + EIGEN_MSA_DEBUG; + + return a * b; +} + +template <> +EIGEN_STRONG_INLINE Packet2cf pand(const Packet2cf& a, const Packet2cf& b) { + EIGEN_MSA_DEBUG; + + return Packet2cf(pand(a.v, b.v)); +} + +template <> +EIGEN_STRONG_INLINE Packet2cf por(const Packet2cf& a, const Packet2cf& b) { + EIGEN_MSA_DEBUG; + + return Packet2cf(por(a.v, b.v)); +} + +template <> +EIGEN_STRONG_INLINE Packet2cf pxor(const Packet2cf& a, const Packet2cf& b) { + EIGEN_MSA_DEBUG; + + return Packet2cf(pxor(a.v, b.v)); +} + +template <> +EIGEN_STRONG_INLINE Packet2cf pandnot(const Packet2cf& a, const Packet2cf& b) { + EIGEN_MSA_DEBUG; + + return Packet2cf(pandnot(a.v, b.v)); +} + +template <> +EIGEN_STRONG_INLINE Packet2cf pload(const std::complex* from) { + EIGEN_MSA_DEBUG; + + EIGEN_DEBUG_ALIGNED_LOAD return Packet2cf(pload((const float*)from)); +} + +template <> +EIGEN_STRONG_INLINE Packet2cf ploadu(const std::complex* from) { + EIGEN_MSA_DEBUG; + + EIGEN_DEBUG_UNALIGNED_LOAD return Packet2cf(ploadu((const float*)from)); +} + +template <> +EIGEN_STRONG_INLINE Packet2cf ploaddup(const std::complex* from) { + EIGEN_MSA_DEBUG; + + return pset1(*from); +} + +template <> +EIGEN_STRONG_INLINE void pstore >(std::complex* to, + const Packet2cf& from) { + EIGEN_MSA_DEBUG; + + EIGEN_DEBUG_ALIGNED_STORE pstore((float*)to, from.v); +} + +template <> +EIGEN_STRONG_INLINE void pstoreu >(std::complex* to, + const Packet2cf& from) { + EIGEN_MSA_DEBUG; + + EIGEN_DEBUG_UNALIGNED_STORE pstoreu((float*)to, from.v); +} + +template <> +EIGEN_DEVICE_FUNC inline Packet2cf pgather, Packet2cf>( + const std::complex* from, Index stride) { + EIGEN_MSA_DEBUG; + + return Packet2cf(from[0 * stride], from[1 * stride]); +} + +template <> +EIGEN_DEVICE_FUNC inline void pscatter, Packet2cf>(std::complex* to, + const Packet2cf& from, + Index stride) { + EIGEN_MSA_DEBUG; + + *to = std::complex(from.v[0], from.v[1]); + to += stride; + *to = std::complex(from.v[2], from.v[3]); +} + +template <> +EIGEN_STRONG_INLINE void prefetch >(const std::complex* addr) { + EIGEN_MSA_DEBUG; + + prefetch(reinterpret_cast(addr)); +} + +template <> +EIGEN_STRONG_INLINE std::complex pfirst(const Packet2cf& a) { + EIGEN_MSA_DEBUG; + + return std::complex(a.v[0], a.v[1]); +} + +template <> +EIGEN_STRONG_INLINE Packet2cf preverse(const Packet2cf& a) { + EIGEN_MSA_DEBUG; + + return Packet2cf((Packet4f)__builtin_msa_shf_w((v4i32)a.v, EIGEN_MSA_SHF_I8(2, 3, 0, 1))); +} + +template <> +EIGEN_STRONG_INLINE Packet2cf pcplxflip(const Packet2cf& a) { + EIGEN_MSA_DEBUG; + + return Packet2cf((Packet4f)__builtin_msa_shf_w((v4i32)a.v, EIGEN_MSA_SHF_I8(1, 0, 3, 2))); +} + +template <> +EIGEN_STRONG_INLINE std::complex predux(const Packet2cf& a) { + EIGEN_MSA_DEBUG; + + Packet4f value = (Packet4f)preverse((Packet2d)a.v); + value += a.v; + return std::complex(value[0], value[1]); +} + +template <> +EIGEN_STRONG_INLINE std::complex predux_mul(const Packet2cf& a) { + EIGEN_MSA_DEBUG; + + return std::complex((a.v[0] * a.v[2]) - (a.v[1] * a.v[3]), + (a.v[0] * a.v[3]) + (a.v[1] * a.v[2])); +} + +EIGEN_MAKE_CONJ_HELPER_CPLX_REAL(Packet2cf, Packet4f) + +template <> +EIGEN_STRONG_INLINE Packet2cf pdiv(const Packet2cf& a, const Packet2cf& b) { + EIGEN_MSA_DEBUG; + + return a / b; +} + +inline std::ostream& operator<<(std::ostream& os, const PacketBlock& value) { + os << "[ " << value.packet[0] << ", " << std::endl << " " << value.packet[1] << " ]"; + return os; +} + +EIGEN_DEVICE_FUNC inline void ptranspose(PacketBlock& kernel) { + EIGEN_MSA_DEBUG; + + Packet4f tmp = + (Packet4f)__builtin_msa_ilvl_d((v2i64)kernel.packet[1].v, (v2i64)kernel.packet[0].v); + kernel.packet[0].v = + (Packet4f)__builtin_msa_ilvr_d((v2i64)kernel.packet[1].v, (v2i64)kernel.packet[0].v); + kernel.packet[1].v = tmp; +} + +template <> +EIGEN_STRONG_INLINE Packet2cf pblend(const Selector<2>& ifPacket, const Packet2cf& thenPacket, + const Packet2cf& elsePacket) { + return (Packet2cf)(Packet4f)pblend(ifPacket, (Packet2d)thenPacket.v, + (Packet2d)elsePacket.v); +} + +//---------- double ---------- + +struct Packet1cd { + EIGEN_STRONG_INLINE Packet1cd() { + } + EIGEN_STRONG_INLINE explicit Packet1cd(const std::complex& a) { + v[0] = std::real(a); + v[1] = std::imag(a); + } + EIGEN_STRONG_INLINE explicit Packet1cd(const Packet2d& a) : v(a) { + } + EIGEN_STRONG_INLINE Packet1cd(const Packet1cd& a) : v(a.v) { + } + EIGEN_STRONG_INLINE Packet1cd& operator=(const Packet1cd& b) { + v = b.v; + return *this; + } + EIGEN_STRONG_INLINE Packet1cd conjugate(void) const { + static const v2u64 p2ul_CONJ_XOR = { 0x0, 0x8000000000000000 }; + return (Packet1cd)pxor(v, (Packet2d)p2ul_CONJ_XOR); + } + EIGEN_STRONG_INLINE Packet1cd& operator*=(const Packet1cd& b) { + Packet2d v1, v2; + + // Get the real values of a | a1_re | a1_re + v1 = (Packet2d)__builtin_msa_ilvev_d((v2i64)v, (v2i64)v); + // Get the imag values of a | a1_im | a1_im + v2 = (Packet2d)__builtin_msa_ilvod_d((v2i64)v, (v2i64)v); + // Multiply the real a with b + v1 = pmul(v1, b.v); + // Multiply the imag a with b + v2 = pmul(v2, b.v); + // Conjugate v2 + v2 = Packet1cd(v2).conjugate().v; + // Swap real/imag elements in v2. + v2 = (Packet2d)__builtin_msa_shf_w((v4i32)v2, EIGEN_MSA_SHF_I8(2, 3, 0, 1)); + // Add and return the result + v = padd(v1, v2); + return *this; + } + EIGEN_STRONG_INLINE Packet1cd operator*(const Packet1cd& b) const { + return Packet1cd(*this) *= b; + } + EIGEN_STRONG_INLINE Packet1cd& operator+=(const Packet1cd& b) { + v = padd(v, b.v); + return *this; + } + EIGEN_STRONG_INLINE Packet1cd operator+(const Packet1cd& b) const { + return Packet1cd(*this) += b; + } + EIGEN_STRONG_INLINE Packet1cd& operator-=(const Packet1cd& b) { + v = psub(v, b.v); + return *this; + } + EIGEN_STRONG_INLINE Packet1cd operator-(const Packet1cd& b) const { + return Packet1cd(*this) -= b; + } + EIGEN_STRONG_INLINE Packet1cd& operator/=(const Packet1cd& b) { + *this *= b.conjugate(); + Packet2d s = pmul(b.v, b.v); + s = padd(s, preverse(s)); + v = pdiv(v, s); + return *this; + } + EIGEN_STRONG_INLINE Packet1cd operator/(const Packet1cd& b) const { + return Packet1cd(*this) /= b; + } + EIGEN_STRONG_INLINE Packet1cd operator-(void) const { + return Packet1cd(pnegate(v)); + } + + Packet2d v; +}; + +inline std::ostream& operator<<(std::ostream& os, const Packet1cd& value) { + os << "[ (" << value.v[0] << ", " << value.v[1] << "i) ]"; + return os; +} + +template <> +struct packet_traits > : default_packet_traits { + typedef Packet1cd type; + typedef Packet1cd half; + enum { + Vectorizable = 1, + AlignedOnScalar = 0, + size = 1, + HasHalfPacket = 0, + + HasAdd = 1, + HasSub = 1, + HasMul = 1, + HasDiv = 1, + HasNegate = 1, + HasAbs = 0, + HasAbs2 = 0, + HasMin = 0, + HasMax = 0, + HasSetLinear = 0 + }; +}; + +template <> +struct unpacket_traits { + typedef std::complex type; + enum { size = 1, alignment = Aligned16, vectorizable=true, masked_load_available=false, masked_store_available=false }; + typedef Packet1cd half; +}; + +template <> +EIGEN_STRONG_INLINE Packet1cd pload(const std::complex* from) { + EIGEN_MSA_DEBUG; + + EIGEN_DEBUG_ALIGNED_LOAD return Packet1cd(pload((const double*)from)); +} + +template <> +EIGEN_STRONG_INLINE Packet1cd ploadu(const std::complex* from) { + EIGEN_MSA_DEBUG; + + EIGEN_DEBUG_UNALIGNED_LOAD return Packet1cd(ploadu((const double*)from)); +} + +template <> +EIGEN_STRONG_INLINE Packet1cd pset1(const std::complex& from) { + EIGEN_MSA_DEBUG; + + return Packet1cd(from); +} + +template <> +EIGEN_STRONG_INLINE Packet1cd padd(const Packet1cd& a, const Packet1cd& b) { + EIGEN_MSA_DEBUG; + + return a + b; +} + +template <> +EIGEN_STRONG_INLINE Packet1cd psub(const Packet1cd& a, const Packet1cd& b) { + EIGEN_MSA_DEBUG; + + return a - b; +} + +template <> +EIGEN_STRONG_INLINE Packet1cd pnegate(const Packet1cd& a) { + EIGEN_MSA_DEBUG; + + return -a; +} + +template <> +EIGEN_STRONG_INLINE Packet1cd pconj(const Packet1cd& a) { + EIGEN_MSA_DEBUG; + + return a.conjugate(); +} + +template <> +EIGEN_STRONG_INLINE Packet1cd pmul(const Packet1cd& a, const Packet1cd& b) { + EIGEN_MSA_DEBUG; + + return a * b; +} + +template <> +EIGEN_STRONG_INLINE Packet1cd pand(const Packet1cd& a, const Packet1cd& b) { + EIGEN_MSA_DEBUG; + + return Packet1cd(pand(a.v, b.v)); +} + +template <> +EIGEN_STRONG_INLINE Packet1cd por(const Packet1cd& a, const Packet1cd& b) { + EIGEN_MSA_DEBUG; + + return Packet1cd(por(a.v, b.v)); +} + +template <> +EIGEN_STRONG_INLINE Packet1cd pxor(const Packet1cd& a, const Packet1cd& b) { + EIGEN_MSA_DEBUG; + + return Packet1cd(pxor(a.v, b.v)); +} + +template <> +EIGEN_STRONG_INLINE Packet1cd pandnot(const Packet1cd& a, const Packet1cd& b) { + EIGEN_MSA_DEBUG; + + return Packet1cd(pandnot(a.v, b.v)); +} + +template <> +EIGEN_STRONG_INLINE Packet1cd ploaddup(const std::complex* from) { + EIGEN_MSA_DEBUG; + + return pset1(*from); +} + +template <> +EIGEN_STRONG_INLINE void pstore >(std::complex* to, + const Packet1cd& from) { + EIGEN_MSA_DEBUG; + + EIGEN_DEBUG_ALIGNED_STORE pstore((double*)to, from.v); +} + +template <> +EIGEN_STRONG_INLINE void pstoreu >(std::complex* to, + const Packet1cd& from) { + EIGEN_MSA_DEBUG; + + EIGEN_DEBUG_UNALIGNED_STORE pstoreu((double*)to, from.v); +} + +template <> +EIGEN_STRONG_INLINE void prefetch >(const std::complex* addr) { + EIGEN_MSA_DEBUG; + + prefetch(reinterpret_cast(addr)); +} + +template <> +EIGEN_DEVICE_FUNC inline Packet1cd pgather, Packet1cd>( + const std::complex* from, Index stride __attribute__((unused))) { + EIGEN_MSA_DEBUG; + + Packet1cd res; + res.v[0] = std::real(from[0]); + res.v[1] = std::imag(from[0]); + return res; +} + +template <> +EIGEN_DEVICE_FUNC inline void pscatter, Packet1cd>(std::complex* to, + const Packet1cd& from, + Index stride + __attribute__((unused))) { + EIGEN_MSA_DEBUG; + + pstore(to, from); +} + +template <> +EIGEN_STRONG_INLINE std::complex pfirst(const Packet1cd& a) { + EIGEN_MSA_DEBUG; + + return std::complex(a.v[0], a.v[1]); +} + +template <> +EIGEN_STRONG_INLINE Packet1cd preverse(const Packet1cd& a) { + EIGEN_MSA_DEBUG; + + return a; +} + +template <> +EIGEN_STRONG_INLINE std::complex predux(const Packet1cd& a) { + EIGEN_MSA_DEBUG; + + return pfirst(a); +} + +template <> +EIGEN_STRONG_INLINE std::complex predux_mul(const Packet1cd& a) { + EIGEN_MSA_DEBUG; + + return pfirst(a); +} + +EIGEN_MAKE_CONJ_HELPER_CPLX_REAL(Packet1cd, Packet2d) + +template <> +EIGEN_STRONG_INLINE Packet1cd pdiv(const Packet1cd& a, const Packet1cd& b) { + EIGEN_MSA_DEBUG; + + return a / b; +} + +EIGEN_STRONG_INLINE Packet1cd pcplxflip /**/ (const Packet1cd& x) { + EIGEN_MSA_DEBUG; + + return Packet1cd(preverse(Packet2d(x.v))); +} + +inline std::ostream& operator<<(std::ostream& os, const PacketBlock& value) { + os << "[ " << value.packet[0] << ", " << std::endl << " " << value.packet[1] << " ]"; + return os; +} + +EIGEN_STRONG_INLINE void ptranspose(PacketBlock& kernel) { + EIGEN_MSA_DEBUG; + + Packet2d v1, v2; + + v1 = (Packet2d)__builtin_msa_ilvev_d((v2i64)kernel.packet[0].v, (v2i64)kernel.packet[1].v); + // Get the imag values of a + v2 = (Packet2d)__builtin_msa_ilvod_d((v2i64)kernel.packet[0].v, (v2i64)kernel.packet[1].v); + + kernel.packet[0].v = v1; + kernel.packet[1].v = v2; +} + +} // end namespace internal + +} // end namespace Eigen + +#endif // EIGEN_COMPLEX_MSA_H diff --git a/Eigen/src/Core/arch/MSA/MathFunctions.h b/Eigen/src/Core/arch/MSA/MathFunctions.h new file mode 100644 index 0000000..f5181b9 --- /dev/null +++ b/Eigen/src/Core/arch/MSA/MathFunctions.h @@ -0,0 +1,387 @@ +// This file is part of Eigen, a lightweight C++ template library +// for linear algebra. +// +// Copyright (C) 2007 Julien Pommier +// Copyright (C) 2014 Pedro Gonnet (pedro.gonnet@gmail.com) +// Copyright (C) 2016 Gael Guennebaud +// +// Copyright (C) 2018 Wave Computing, Inc. +// Written by: +// Chris Larsen +// Alexey Frunze (afrunze@wavecomp.com) +// +// This Source Code Form is subject to the terms of the Mozilla +// Public License v. 2.0. If a copy of the MPL was not distributed +// with this file, You can obtain one at http://mozilla.org/MPL/2.0/. + +/* The sin, cos, exp, and log functions of this file come from + * Julien Pommier's sse math library: http://gruntthepeon.free.fr/ssemath/ + */ + +/* The tanh function of this file is an adaptation of + * template T generic_fast_tanh_float(const T&) + * from MathFunctionsImpl.h. + */ + +#ifndef EIGEN_MATH_FUNCTIONS_MSA_H +#define EIGEN_MATH_FUNCTIONS_MSA_H + +namespace Eigen { + +namespace internal { + +template <> +EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS EIGEN_UNUSED Packet4f +plog(const Packet4f& _x) { + static _EIGEN_DECLARE_CONST_Packet4f(cephes_SQRTHF, 0.707106781186547524f); + static _EIGEN_DECLARE_CONST_Packet4f(cephes_log_p0, 7.0376836292e-2f); + static _EIGEN_DECLARE_CONST_Packet4f(cephes_log_p1, -1.1514610310e-1f); + static _EIGEN_DECLARE_CONST_Packet4f(cephes_log_p2, 1.1676998740e-1f); + static _EIGEN_DECLARE_CONST_Packet4f(cephes_log_p3, -1.2420140846e-1f); + static _EIGEN_DECLARE_CONST_Packet4f(cephes_log_p4, +1.4249322787e-1f); + static _EIGEN_DECLARE_CONST_Packet4f(cephes_log_p5, -1.6668057665e-1f); + static _EIGEN_DECLARE_CONST_Packet4f(cephes_log_p6, +2.0000714765e-1f); + static _EIGEN_DECLARE_CONST_Packet4f(cephes_log_p7, -2.4999993993e-1f); + static _EIGEN_DECLARE_CONST_Packet4f(cephes_log_p8, +3.3333331174e-1f); + static _EIGEN_DECLARE_CONST_Packet4f(cephes_log_q1, -2.12194440e-4f); + static _EIGEN_DECLARE_CONST_Packet4f(cephes_log_q2, 0.693359375f); + static _EIGEN_DECLARE_CONST_Packet4f(half, 0.5f); + static _EIGEN_DECLARE_CONST_Packet4f(1, 1.0f); + + // Convert negative argument into NAN (quiet negative, to be specific). + Packet4f zero = (Packet4f)__builtin_msa_ldi_w(0); + Packet4i neg_mask = __builtin_msa_fclt_w(_x, zero); + Packet4i zero_mask = __builtin_msa_fceq_w(_x, zero); + Packet4f non_neg_x_or_nan = padd(_x, (Packet4f)neg_mask); // Add 0.0 or NAN. + Packet4f x = non_neg_x_or_nan; + + // Extract exponent from x = mantissa * 2**exponent, where 1.0 <= mantissa < 2.0. + // N.B. the exponent is one less of what frexpf() would return. + Packet4i e_int = __builtin_msa_ftint_s_w(__builtin_msa_flog2_w(x)); + // Multiply x by 2**(-exponent-1) to get 0.5 <= x < 1.0 as from frexpf(). + x = __builtin_msa_fexp2_w(x, (Packet4i)__builtin_msa_nori_b((v16u8)e_int, 0)); + + /* + if (x < SQRTHF) { + x = x + x - 1.0; + } else { + e += 1; + x = x - 1.0; + } + */ + Packet4f xx = padd(x, x); + Packet4i ge_mask = __builtin_msa_fcle_w(p4f_cephes_SQRTHF, x); + e_int = psub(e_int, ge_mask); + x = (Packet4f)__builtin_msa_bsel_v((v16u8)ge_mask, (v16u8)xx, (v16u8)x); + x = psub(x, p4f_1); + Packet4f e = __builtin_msa_ffint_s_w(e_int); + + Packet4f x2 = pmul(x, x); + Packet4f x3 = pmul(x2, x); + + Packet4f y, y1, y2; + y = pmadd(p4f_cephes_log_p0, x, p4f_cephes_log_p1); + y1 = pmadd(p4f_cephes_log_p3, x, p4f_cephes_log_p4); + y2 = pmadd(p4f_cephes_log_p6, x, p4f_cephes_log_p7); + y = pmadd(y, x, p4f_cephes_log_p2); + y1 = pmadd(y1, x, p4f_cephes_log_p5); + y2 = pmadd(y2, x, p4f_cephes_log_p8); + y = pmadd(y, x3, y1); + y = pmadd(y, x3, y2); + y = pmul(y, x3); + + y = pmadd(e, p4f_cephes_log_q1, y); + x = __builtin_msa_fmsub_w(x, x2, p4f_half); + x = padd(x, y); + x = pmadd(e, p4f_cephes_log_q2, x); + + // x is now the logarithm result candidate. We still need to handle the + // extreme arguments of zero and positive infinity, though. + // N.B. if the argument is +INFINITY, x is NAN because the polynomial terms + // contain infinities of both signs (see the coefficients and code above). + // INFINITY - INFINITY is NAN. + + // If the argument is +INFINITY, make it the new result candidate. + // To achieve that we choose the smaller of the result candidate and the + // argument. + // This is correct for all finite pairs of values (the logarithm is smaller + // than the argument). + // This is also correct in the special case when the argument is +INFINITY + // and the result candidate is NAN. This is because the fmin.df instruction + // prefers non-NANs to NANs. + x = __builtin_msa_fmin_w(x, non_neg_x_or_nan); + + // If the argument is zero (including -0.0), the result becomes -INFINITY. + Packet4i neg_infs = __builtin_msa_slli_w(zero_mask, 23); + x = (Packet4f)__builtin_msa_bsel_v((v16u8)zero_mask, (v16u8)x, (v16u8)neg_infs); + + return x; +} + +template <> +EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS EIGEN_UNUSED Packet4f +pexp(const Packet4f& _x) { + // Limiting single-precision pexp's argument to [-128, +128] lets pexp + // reach 0 and INFINITY naturally. + static _EIGEN_DECLARE_CONST_Packet4f(exp_lo, -128.0f); + static _EIGEN_DECLARE_CONST_Packet4f(exp_hi, +128.0f); + static _EIGEN_DECLARE_CONST_Packet4f(cephes_LOG2EF, 1.44269504088896341f); + static _EIGEN_DECLARE_CONST_Packet4f(cephes_exp_C1, 0.693359375f); + static _EIGEN_DECLARE_CONST_Packet4f(cephes_exp_C2, -2.12194440e-4f); + static _EIGEN_DECLARE_CONST_Packet4f(cephes_exp_p0, 1.9875691500e-4f); + static _EIGEN_DECLARE_CONST_Packet4f(cephes_exp_p1, 1.3981999507e-3f); + static _EIGEN_DECLARE_CONST_Packet4f(cephes_exp_p2, 8.3334519073e-3f); + static _EIGEN_DECLARE_CONST_Packet4f(cephes_exp_p3, 4.1665795894e-2f); + static _EIGEN_DECLARE_CONST_Packet4f(cephes_exp_p4, 1.6666665459e-1f); + static _EIGEN_DECLARE_CONST_Packet4f(cephes_exp_p5, 5.0000001201e-1f); + static _EIGEN_DECLARE_CONST_Packet4f(half, 0.5f); + static _EIGEN_DECLARE_CONST_Packet4f(1, 1.0f); + + Packet4f x = _x; + + // Clamp x. + x = (Packet4f)__builtin_msa_bsel_v((v16u8)__builtin_msa_fclt_w(x, p4f_exp_lo), (v16u8)x, + (v16u8)p4f_exp_lo); + x = (Packet4f)__builtin_msa_bsel_v((v16u8)__builtin_msa_fclt_w(p4f_exp_hi, x), (v16u8)x, + (v16u8)p4f_exp_hi); + + // Round to nearest integer by adding 0.5 (with x's sign) and truncating. + Packet4f x2_add = (Packet4f)__builtin_msa_binsli_w((v4u32)p4f_half, (v4u32)x, 0); + Packet4f x2 = pmadd(x, p4f_cephes_LOG2EF, x2_add); + Packet4i x2_int = __builtin_msa_ftrunc_s_w(x2); + Packet4f x2_int_f = __builtin_msa_ffint_s_w(x2_int); + + x = __builtin_msa_fmsub_w(x, x2_int_f, p4f_cephes_exp_C1); + x = __builtin_msa_fmsub_w(x, x2_int_f, p4f_cephes_exp_C2); + + Packet4f z = pmul(x, x); + + Packet4f y = p4f_cephes_exp_p0; + y = pmadd(y, x, p4f_cephes_exp_p1); + y = pmadd(y, x, p4f_cephes_exp_p2); + y = pmadd(y, x, p4f_cephes_exp_p3); + y = pmadd(y, x, p4f_cephes_exp_p4); + y = pmadd(y, x, p4f_cephes_exp_p5); + y = pmadd(y, z, x); + y = padd(y, p4f_1); + + // y *= 2**exponent. + y = __builtin_msa_fexp2_w(y, x2_int); + + return y; +} + +template <> +EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS EIGEN_UNUSED Packet4f +ptanh(const Packet4f& _x) { + static _EIGEN_DECLARE_CONST_Packet4f(tanh_tiny, 1e-4f); + static _EIGEN_DECLARE_CONST_Packet4f(tanh_hi, 9.0f); + // The monomial coefficients of the numerator polynomial (odd). + static _EIGEN_DECLARE_CONST_Packet4f(alpha_1, 4.89352455891786e-3f); + static _EIGEN_DECLARE_CONST_Packet4f(alpha_3, 6.37261928875436e-4f); + static _EIGEN_DECLARE_CONST_Packet4f(alpha_5, 1.48572235717979e-5f); + static _EIGEN_DECLARE_CONST_Packet4f(alpha_7, 5.12229709037114e-8f); + static _EIGEN_DECLARE_CONST_Packet4f(alpha_9, -8.60467152213735e-11f); + static _EIGEN_DECLARE_CONST_Packet4f(alpha_11, 2.00018790482477e-13f); + static _EIGEN_DECLARE_CONST_Packet4f(alpha_13, -2.76076847742355e-16f); + // The monomial coefficients of the denominator polynomial (even). + static _EIGEN_DECLARE_CONST_Packet4f(beta_0, 4.89352518554385e-3f); + static _EIGEN_DECLARE_CONST_Packet4f(beta_2, 2.26843463243900e-3f); + static _EIGEN_DECLARE_CONST_Packet4f(beta_4, 1.18534705686654e-4f); + static _EIGEN_DECLARE_CONST_Packet4f(beta_6, 1.19825839466702e-6f); + + Packet4f x = pabs(_x); + Packet4i tiny_mask = __builtin_msa_fclt_w(x, p4f_tanh_tiny); + + // Clamp the inputs to the range [-9, 9] since anything outside + // this range is -/+1.0f in single-precision. + x = (Packet4f)__builtin_msa_bsel_v((v16u8)__builtin_msa_fclt_w(p4f_tanh_hi, x), (v16u8)x, + (v16u8)p4f_tanh_hi); + + // Since the polynomials are odd/even, we need x**2. + Packet4f x2 = pmul(x, x); + + // Evaluate the numerator polynomial p. + Packet4f p = pmadd(x2, p4f_alpha_13, p4f_alpha_11); + p = pmadd(x2, p, p4f_alpha_9); + p = pmadd(x2, p, p4f_alpha_7); + p = pmadd(x2, p, p4f_alpha_5); + p = pmadd(x2, p, p4f_alpha_3); + p = pmadd(x2, p, p4f_alpha_1); + p = pmul(x, p); + + // Evaluate the denominator polynomial q. + Packet4f q = pmadd(x2, p4f_beta_6, p4f_beta_4); + q = pmadd(x2, q, p4f_beta_2); + q = pmadd(x2, q, p4f_beta_0); + + // Divide the numerator by the denominator. + p = pdiv(p, q); + + // Reinstate the sign. + p = (Packet4f)__builtin_msa_binsli_w((v4u32)p, (v4u32)_x, 0); + + // When the argument is very small in magnitude it's more accurate to just return it. + p = (Packet4f)__builtin_msa_bsel_v((v16u8)tiny_mask, (v16u8)p, (v16u8)_x); + + return p; +} + +template +Packet4f psincos_inner_msa_float(const Packet4f& _x) { + static _EIGEN_DECLARE_CONST_Packet4f(sincos_max_arg, 13176795.0f); // Approx. (2**24) / (4/Pi). + static _EIGEN_DECLARE_CONST_Packet4f(minus_cephes_DP1, -0.78515625f); + static _EIGEN_DECLARE_CONST_Packet4f(minus_cephes_DP2, -2.4187564849853515625e-4f); + static _EIGEN_DECLARE_CONST_Packet4f(minus_cephes_DP3, -3.77489497744594108e-8f); + static _EIGEN_DECLARE_CONST_Packet4f(sincof_p0, -1.9515295891e-4f); + static _EIGEN_DECLARE_CONST_Packet4f(sincof_p1, 8.3321608736e-3f); + static _EIGEN_DECLARE_CONST_Packet4f(sincof_p2, -1.6666654611e-1f); + static _EIGEN_DECLARE_CONST_Packet4f(coscof_p0, 2.443315711809948e-5f); + static _EIGEN_DECLARE_CONST_Packet4f(coscof_p1, -1.388731625493765e-3f); + static _EIGEN_DECLARE_CONST_Packet4f(coscof_p2, 4.166664568298827e-2f); + static _EIGEN_DECLARE_CONST_Packet4f(cephes_FOPI, 1.27323954473516f); // 4/Pi. + static _EIGEN_DECLARE_CONST_Packet4f(half, 0.5f); + static _EIGEN_DECLARE_CONST_Packet4f(1, 1.0f); + + Packet4f x = pabs(_x); + + // Translate infinite arguments into NANs. + Packet4f zero_or_nan_if_inf = psub(_x, _x); + x = padd(x, zero_or_nan_if_inf); + // Prevent sin/cos from generating values larger than 1.0 in magnitude + // for very large arguments by setting x to 0.0. + Packet4i small_or_nan_mask = __builtin_msa_fcult_w(x, p4f_sincos_max_arg); + x = pand(x, (Packet4f)small_or_nan_mask); + + // Scale x by 4/Pi to find x's octant. + Packet4f y = pmul(x, p4f_cephes_FOPI); + // Get the octant. We'll reduce x by this number of octants or by one more than it. + Packet4i y_int = __builtin_msa_ftrunc_s_w(y); + // x's from even-numbered octants will translate to octant 0: [0, +Pi/4]. + // x's from odd-numbered octants will translate to octant -1: [-Pi/4, 0]. + // Adjustment for odd-numbered octants: octant = (octant + 1) & (~1). + Packet4i y_int1 = __builtin_msa_addvi_w(y_int, 1); + Packet4i y_int2 = (Packet4i)__builtin_msa_bclri_w((Packet4ui)y_int1, 0); // bclri = bit-clear + y = __builtin_msa_ffint_s_w(y_int2); + + // Compute the sign to apply to the polynomial. + Packet4i sign_mask = sine ? pxor(__builtin_msa_slli_w(y_int1, 29), (Packet4i)_x) + : __builtin_msa_slli_w(__builtin_msa_addvi_w(y_int, 3), 29); + + // Get the polynomial selection mask. + // We'll calculate both (sin and cos) polynomials and then select from the two. + Packet4i poly_mask = __builtin_msa_ceqi_w(__builtin_msa_slli_w(y_int2, 30), 0); + + // Reduce x by y octants to get: -Pi/4 <= x <= +Pi/4. + // The magic pass: "Extended precision modular arithmetic" + // x = ((x - y * DP1) - y * DP2) - y * DP3 + Packet4f tmp1 = pmul(y, p4f_minus_cephes_DP1); + Packet4f tmp2 = pmul(y, p4f_minus_cephes_DP2); + Packet4f tmp3 = pmul(y, p4f_minus_cephes_DP3); + x = padd(x, tmp1); + x = padd(x, tmp2); + x = padd(x, tmp3); + + // Evaluate the cos(x) polynomial. + y = p4f_coscof_p0; + Packet4f z = pmul(x, x); + y = pmadd(y, z, p4f_coscof_p1); + y = pmadd(y, z, p4f_coscof_p2); + y = pmul(y, z); + y = pmul(y, z); + y = __builtin_msa_fmsub_w(y, z, p4f_half); + y = padd(y, p4f_1); + + // Evaluate the sin(x) polynomial. + Packet4f y2 = p4f_sincof_p0; + y2 = pmadd(y2, z, p4f_sincof_p1); + y2 = pmadd(y2, z, p4f_sincof_p2); + y2 = pmul(y2, z); + y2 = pmadd(y2, x, x); + + // Select the correct result from the two polynomials. + y = sine ? (Packet4f)__builtin_msa_bsel_v((v16u8)poly_mask, (v16u8)y, (v16u8)y2) + : (Packet4f)__builtin_msa_bsel_v((v16u8)poly_mask, (v16u8)y2, (v16u8)y); + + // Update the sign. + sign_mask = pxor(sign_mask, (Packet4i)y); + y = (Packet4f)__builtin_msa_binsli_w((v4u32)y, (v4u32)sign_mask, 0); // binsli = bit-insert-left + return y; +} + +template <> +EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS EIGEN_UNUSED Packet4f +psin(const Packet4f& x) { + return psincos_inner_msa_float(x); +} + +template <> +EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS EIGEN_UNUSED Packet4f +pcos(const Packet4f& x) { + return psincos_inner_msa_float(x); +} + +template <> +EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS EIGEN_UNUSED Packet2d +pexp(const Packet2d& _x) { + // Limiting double-precision pexp's argument to [-1024, +1024] lets pexp + // reach 0 and INFINITY naturally. + static _EIGEN_DECLARE_CONST_Packet2d(exp_lo, -1024.0); + static _EIGEN_DECLARE_CONST_Packet2d(exp_hi, +1024.0); + static _EIGEN_DECLARE_CONST_Packet2d(cephes_LOG2EF, 1.4426950408889634073599); + static _EIGEN_DECLARE_CONST_Packet2d(cephes_exp_C1, 0.693145751953125); + static _EIGEN_DECLARE_CONST_Packet2d(cephes_exp_C2, 1.42860682030941723212e-6); + static _EIGEN_DECLARE_CONST_Packet2d(cephes_exp_p0, 1.26177193074810590878e-4); + static _EIGEN_DECLARE_CONST_Packet2d(cephes_exp_p1, 3.02994407707441961300e-2); + static _EIGEN_DECLARE_CONST_Packet2d(cephes_exp_p2, 9.99999999999999999910e-1); + static _EIGEN_DECLARE_CONST_Packet2d(cephes_exp_q0, 3.00198505138664455042e-6); + static _EIGEN_DECLARE_CONST_Packet2d(cephes_exp_q1, 2.52448340349684104192e-3); + static _EIGEN_DECLARE_CONST_Packet2d(cephes_exp_q2, 2.27265548208155028766e-1); + static _EIGEN_DECLARE_CONST_Packet2d(cephes_exp_q3, 2.00000000000000000009e0); + static _EIGEN_DECLARE_CONST_Packet2d(half, 0.5); + static _EIGEN_DECLARE_CONST_Packet2d(1, 1.0); + static _EIGEN_DECLARE_CONST_Packet2d(2, 2.0); + + Packet2d x = _x; + + // Clamp x. + x = (Packet2d)__builtin_msa_bsel_v((v16u8)__builtin_msa_fclt_d(x, p2d_exp_lo), (v16u8)x, + (v16u8)p2d_exp_lo); + x = (Packet2d)__builtin_msa_bsel_v((v16u8)__builtin_msa_fclt_d(p2d_exp_hi, x), (v16u8)x, + (v16u8)p2d_exp_hi); + + // Round to nearest integer by adding 0.5 (with x's sign) and truncating. + Packet2d x2_add = (Packet2d)__builtin_msa_binsli_d((v2u64)p2d_half, (v2u64)x, 0); + Packet2d x2 = pmadd(x, p2d_cephes_LOG2EF, x2_add); + Packet2l x2_long = __builtin_msa_ftrunc_s_d(x2); + Packet2d x2_long_d = __builtin_msa_ffint_s_d(x2_long); + + x = __builtin_msa_fmsub_d(x, x2_long_d, p2d_cephes_exp_C1); + x = __builtin_msa_fmsub_d(x, x2_long_d, p2d_cephes_exp_C2); + + x2 = pmul(x, x); + + Packet2d px = p2d_cephes_exp_p0; + px = pmadd(px, x2, p2d_cephes_exp_p1); + px = pmadd(px, x2, p2d_cephes_exp_p2); + px = pmul(px, x); + + Packet2d qx = p2d_cephes_exp_q0; + qx = pmadd(qx, x2, p2d_cephes_exp_q1); + qx = pmadd(qx, x2, p2d_cephes_exp_q2); + qx = pmadd(qx, x2, p2d_cephes_exp_q3); + + x = pdiv(px, psub(qx, px)); + x = pmadd(p2d_2, x, p2d_1); + + // x *= 2**exponent. + x = __builtin_msa_fexp2_d(x, x2_long); + + return x; +} + +} // end namespace internal + +} // end namespace Eigen + +#endif // EIGEN_MATH_FUNCTIONS_MSA_H diff --git a/Eigen/src/Core/arch/MSA/PacketMath.h b/Eigen/src/Core/arch/MSA/PacketMath.h new file mode 100644 index 0000000..afe8f33 --- /dev/null +++ b/Eigen/src/Core/arch/MSA/PacketMath.h @@ -0,0 +1,1233 @@ +// This file is part of Eigen, a lightweight C++ template library +// for linear algebra. +// +// Copyright (C) 2018 Wave Computing, Inc. +// Written by: +// Chris Larsen +// Alexey Frunze (afrunze@wavecomp.com) +// +// This Source Code Form is subject to the terms of the Mozilla +// Public License v. 2.0. If a copy of the MPL was not distributed +// with this file, You can obtain one at http://mozilla.org/MPL/2.0/. + +#ifndef EIGEN_PACKET_MATH_MSA_H +#define EIGEN_PACKET_MATH_MSA_H + +#include +#include + +namespace Eigen { + +namespace internal { + +#ifndef EIGEN_CACHEFRIENDLY_PRODUCT_THRESHOLD +#define EIGEN_CACHEFRIENDLY_PRODUCT_THRESHOLD 8 +#endif + +#ifndef EIGEN_HAS_SINGLE_INSTRUCTION_MADD +#define EIGEN_HAS_SINGLE_INSTRUCTION_MADD +#endif + +#ifndef EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS +#define EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS 32 +#endif + +#if 0 +#define EIGEN_MSA_DEBUG \ + static bool firstTime = true; \ + do { \ + if (firstTime) { \ + std::cout << __FILE__ << ':' << __LINE__ << ':' << __FUNCTION__ << std::endl; \ + firstTime = false; \ + } \ + } while (0) +#else +#define EIGEN_MSA_DEBUG +#endif + +#define EIGEN_MSA_SHF_I8(a, b, c, d) (((d) << 6) | ((c) << 4) | ((b) << 2) | (a)) + +typedef v4f32 Packet4f; +typedef v4i32 Packet4i; +typedef v4u32 Packet4ui; + +#define _EIGEN_DECLARE_CONST_Packet4f(NAME, X) const Packet4f p4f_##NAME = { X, X, X, X } +#define _EIGEN_DECLARE_CONST_Packet4i(NAME, X) const Packet4i p4i_##NAME = { X, X, X, X } +#define _EIGEN_DECLARE_CONST_Packet4ui(NAME, X) const Packet4ui p4ui_##NAME = { X, X, X, X } + +inline std::ostream& operator<<(std::ostream& os, const Packet4f& value) { + os << "[ " << value[0] << ", " << value[1] << ", " << value[2] << ", " << value[3] << " ]"; + return os; +} + +inline std::ostream& operator<<(std::ostream& os, const Packet4i& value) { + os << "[ " << value[0] << ", " << value[1] << ", " << value[2] << ", " << value[3] << " ]"; + return os; +} + +inline std::ostream& operator<<(std::ostream& os, const Packet4ui& value) { + os << "[ " << value[0] << ", " << value[1] << ", " << value[2] << ", " << value[3] << " ]"; + return os; +} + +template <> +struct packet_traits : default_packet_traits { + typedef Packet4f type; + typedef Packet4f half; // Packet2f intrinsics not implemented yet + enum { + Vectorizable = 1, + AlignedOnScalar = 1, + size = 4, + HasHalfPacket = 0, // Packet2f intrinsics not implemented yet + // FIXME check the Has* + HasDiv = 1, + HasSin = EIGEN_FAST_MATH, + HasCos = EIGEN_FAST_MATH, + HasTanh = EIGEN_FAST_MATH, + HasErf = EIGEN_FAST_MATH, + HasLog = 1, + HasExp = 1, + HasSqrt = 1, + HasRsqrt = 1, + HasRound = 1, + HasFloor = 1, + HasCeil = 1, + HasBlend = 1 + }; +}; + +template <> +struct packet_traits : default_packet_traits { + typedef Packet4i type; + typedef Packet4i half; // Packet2i intrinsics not implemented yet + enum { + Vectorizable = 1, + AlignedOnScalar = 1, + size = 4, + HasHalfPacket = 0, // Packet2i intrinsics not implemented yet + // FIXME check the Has* + HasDiv = 1, + HasBlend = 1 + }; +}; + +template <> +struct unpacket_traits { + typedef float type; + enum { size = 4, alignment = Aligned16, vectorizable=true, masked_load_available=false, masked_store_available=false }; + typedef Packet4f half; +}; + +template <> +struct unpacket_traits { + typedef int32_t type; + enum { size = 4, alignment = Aligned16, vectorizable=true, masked_load_available=false, masked_store_available=false }; + typedef Packet4i half; +}; + +template <> +EIGEN_STRONG_INLINE Packet4f pset1(const float& from) { + EIGEN_MSA_DEBUG; + + Packet4f v = { from, from, from, from }; + return v; +} + +template <> +EIGEN_STRONG_INLINE Packet4i pset1(const int32_t& from) { + EIGEN_MSA_DEBUG; + + return __builtin_msa_fill_w(from); +} + +template <> +EIGEN_STRONG_INLINE Packet4f pload1(const float* from) { + EIGEN_MSA_DEBUG; + + float f = *from; + Packet4f v = { f, f, f, f }; + return v; +} + +template <> +EIGEN_STRONG_INLINE Packet4i pload1(const int32_t* from) { + EIGEN_MSA_DEBUG; + + return __builtin_msa_fill_w(*from); +} + +template <> +EIGEN_STRONG_INLINE Packet4f padd(const Packet4f& a, const Packet4f& b) { + EIGEN_MSA_DEBUG; + + return __builtin_msa_fadd_w(a, b); +} + +template <> +EIGEN_STRONG_INLINE Packet4i padd(const Packet4i& a, const Packet4i& b) { + EIGEN_MSA_DEBUG; + + return __builtin_msa_addv_w(a, b); +} + +template <> +EIGEN_STRONG_INLINE Packet4f plset(const float& a) { + EIGEN_MSA_DEBUG; + + static const Packet4f countdown = { 0.0f, 1.0f, 2.0f, 3.0f }; + return padd(pset1(a), countdown); +} + +template <> +EIGEN_STRONG_INLINE Packet4i plset(const int32_t& a) { + EIGEN_MSA_DEBUG; + + static const Packet4i countdown = { 0, 1, 2, 3 }; + return padd(pset1(a), countdown); +} + +template <> +EIGEN_STRONG_INLINE Packet4f psub(const Packet4f& a, const Packet4f& b) { + EIGEN_MSA_DEBUG; + + return __builtin_msa_fsub_w(a, b); +} + +template <> +EIGEN_STRONG_INLINE Packet4i psub(const Packet4i& a, const Packet4i& b) { + EIGEN_MSA_DEBUG; + + return __builtin_msa_subv_w(a, b); +} + +template <> +EIGEN_STRONG_INLINE Packet4f pnegate(const Packet4f& a) { + EIGEN_MSA_DEBUG; + + return (Packet4f)__builtin_msa_bnegi_w((v4u32)a, 31); +} + +template <> +EIGEN_STRONG_INLINE Packet4i pnegate(const Packet4i& a) { + EIGEN_MSA_DEBUG; + + return __builtin_msa_addvi_w((v4i32)__builtin_msa_nori_b((v16u8)a, 0), 1); +} + +template <> +EIGEN_STRONG_INLINE Packet4f pconj(const Packet4f& a) { + EIGEN_MSA_DEBUG; + + return a; +} + +template <> +EIGEN_STRONG_INLINE Packet4i pconj(const Packet4i& a) { + EIGEN_MSA_DEBUG; + + return a; +} + +template <> +EIGEN_STRONG_INLINE Packet4f pmul(const Packet4f& a, const Packet4f& b) { + EIGEN_MSA_DEBUG; + + return __builtin_msa_fmul_w(a, b); +} + +template <> +EIGEN_STRONG_INLINE Packet4i pmul(const Packet4i& a, const Packet4i& b) { + EIGEN_MSA_DEBUG; + + return __builtin_msa_mulv_w(a, b); +} + +template <> +EIGEN_STRONG_INLINE Packet4f pdiv(const Packet4f& a, const Packet4f& b) { + EIGEN_MSA_DEBUG; + + return __builtin_msa_fdiv_w(a, b); +} + +template <> +EIGEN_STRONG_INLINE Packet4i pdiv(const Packet4i& a, const Packet4i& b) { + EIGEN_MSA_DEBUG; + + return __builtin_msa_div_s_w(a, b); +} + +template <> +EIGEN_STRONG_INLINE Packet4f pmadd(const Packet4f& a, const Packet4f& b, const Packet4f& c) { + EIGEN_MSA_DEBUG; + + return __builtin_msa_fmadd_w(c, a, b); +} + +template <> +EIGEN_STRONG_INLINE Packet4i pmadd(const Packet4i& a, const Packet4i& b, const Packet4i& c) { + EIGEN_MSA_DEBUG; + + // Use "asm" construct to avoid __builtin_msa_maddv_w GNU C bug. + Packet4i value = c; + __asm__("maddv.w %w[value], %w[a], %w[b]\n" + // Outputs + : [value] "+f"(value) + // Inputs + : [a] "f"(a), [b] "f"(b)); + return value; +} + +template <> +EIGEN_STRONG_INLINE Packet4f pand(const Packet4f& a, const Packet4f& b) { + EIGEN_MSA_DEBUG; + + return (Packet4f)__builtin_msa_and_v((v16u8)a, (v16u8)b); +} + +template <> +EIGEN_STRONG_INLINE Packet4i pand(const Packet4i& a, const Packet4i& b) { + EIGEN_MSA_DEBUG; + + return (Packet4i)__builtin_msa_and_v((v16u8)a, (v16u8)b); +} + +template <> +EIGEN_STRONG_INLINE Packet4f por(const Packet4f& a, const Packet4f& b) { + EIGEN_MSA_DEBUG; + + return (Packet4f)__builtin_msa_or_v((v16u8)a, (v16u8)b); +} + +template <> +EIGEN_STRONG_INLINE Packet4i por(const Packet4i& a, const Packet4i& b) { + EIGEN_MSA_DEBUG; + + return (Packet4i)__builtin_msa_or_v((v16u8)a, (v16u8)b); +} + +template <> +EIGEN_STRONG_INLINE Packet4f pxor(const Packet4f& a, const Packet4f& b) { + EIGEN_MSA_DEBUG; + + return (Packet4f)__builtin_msa_xor_v((v16u8)a, (v16u8)b); +} + +template <> +EIGEN_STRONG_INLINE Packet4i pxor(const Packet4i& a, const Packet4i& b) { + EIGEN_MSA_DEBUG; + + return (Packet4i)__builtin_msa_xor_v((v16u8)a, (v16u8)b); +} + +template <> +EIGEN_STRONG_INLINE Packet4f pandnot(const Packet4f& a, const Packet4f& b) { + EIGEN_MSA_DEBUG; + + return pand(a, (Packet4f)__builtin_msa_xori_b((v16u8)b, 255)); +} + +template <> +EIGEN_STRONG_INLINE Packet4i pandnot(const Packet4i& a, const Packet4i& b) { + EIGEN_MSA_DEBUG; + + return pand(a, (Packet4i)__builtin_msa_xori_b((v16u8)b, 255)); +} + +template <> +EIGEN_STRONG_INLINE Packet4f pmin(const Packet4f& a, const Packet4f& b) { + EIGEN_MSA_DEBUG; + +#if EIGEN_FAST_MATH + // This prefers numbers to NaNs. + return __builtin_msa_fmin_w(a, b); +#else + // This prefers NaNs to numbers. + Packet4i aNaN = __builtin_msa_fcun_w(a, a); + Packet4i aMinOrNaN = por(__builtin_msa_fclt_w(a, b), aNaN); + return (Packet4f)__builtin_msa_bsel_v((v16u8)aMinOrNaN, (v16u8)b, (v16u8)a); +#endif +} + +template <> +EIGEN_STRONG_INLINE Packet4i pmin(const Packet4i& a, const Packet4i& b) { + EIGEN_MSA_DEBUG; + + return __builtin_msa_min_s_w(a, b); +} + +template <> +EIGEN_STRONG_INLINE Packet4f pmax(const Packet4f& a, const Packet4f& b) { + EIGEN_MSA_DEBUG; + +#if EIGEN_FAST_MATH + // This prefers numbers to NaNs. + return __builtin_msa_fmax_w(a, b); +#else + // This prefers NaNs to numbers. + Packet4i aNaN = __builtin_msa_fcun_w(a, a); + Packet4i aMaxOrNaN = por(__builtin_msa_fclt_w(b, a), aNaN); + return (Packet4f)__builtin_msa_bsel_v((v16u8)aMaxOrNaN, (v16u8)b, (v16u8)a); +#endif +} + +template <> +EIGEN_STRONG_INLINE Packet4i pmax(const Packet4i& a, const Packet4i& b) { + EIGEN_MSA_DEBUG; + + return __builtin_msa_max_s_w(a, b); +} + +template <> +EIGEN_STRONG_INLINE Packet4f pload(const float* from) { + EIGEN_MSA_DEBUG; + + EIGEN_DEBUG_ALIGNED_LOAD return (Packet4f)__builtin_msa_ld_w(const_cast(from), 0); +} + +template <> +EIGEN_STRONG_INLINE Packet4i pload(const int32_t* from) { + EIGEN_MSA_DEBUG; + + EIGEN_DEBUG_ALIGNED_LOAD return __builtin_msa_ld_w(const_cast(from), 0); +} + +template <> +EIGEN_STRONG_INLINE Packet4f ploadu(const float* from) { + EIGEN_MSA_DEBUG; + + EIGEN_DEBUG_UNALIGNED_LOAD return (Packet4f)__builtin_msa_ld_w(const_cast(from), 0); +} + +template <> +EIGEN_STRONG_INLINE Packet4i ploadu(const int32_t* from) { + EIGEN_MSA_DEBUG; + + EIGEN_DEBUG_UNALIGNED_LOAD return (Packet4i)__builtin_msa_ld_w(const_cast(from), 0); +} + +template <> +EIGEN_STRONG_INLINE Packet4f ploaddup(const float* from) { + EIGEN_MSA_DEBUG; + + float f0 = from[0], f1 = from[1]; + Packet4f v0 = { f0, f0, f0, f0 }; + Packet4f v1 = { f1, f1, f1, f1 }; + return (Packet4f)__builtin_msa_ilvr_d((v2i64)v1, (v2i64)v0); +} + +template <> +EIGEN_STRONG_INLINE Packet4i ploaddup(const int32_t* from) { + EIGEN_MSA_DEBUG; + + int32_t i0 = from[0], i1 = from[1]; + Packet4i v0 = { i0, i0, i0, i0 }; + Packet4i v1 = { i1, i1, i1, i1 }; + return (Packet4i)__builtin_msa_ilvr_d((v2i64)v1, (v2i64)v0); +} + +template <> +EIGEN_STRONG_INLINE void pstore(float* to, const Packet4f& from) { + EIGEN_MSA_DEBUG; + + EIGEN_DEBUG_ALIGNED_STORE __builtin_msa_st_w((Packet4i)from, to, 0); +} + +template <> +EIGEN_STRONG_INLINE void pstore(int32_t* to, const Packet4i& from) { + EIGEN_MSA_DEBUG; + + EIGEN_DEBUG_ALIGNED_STORE __builtin_msa_st_w(from, to, 0); +} + +template <> +EIGEN_STRONG_INLINE void pstoreu(float* to, const Packet4f& from) { + EIGEN_MSA_DEBUG; + + EIGEN_DEBUG_UNALIGNED_STORE __builtin_msa_st_w((Packet4i)from, to, 0); +} + +template <> +EIGEN_STRONG_INLINE void pstoreu(int32_t* to, const Packet4i& from) { + EIGEN_MSA_DEBUG; + + EIGEN_DEBUG_UNALIGNED_STORE __builtin_msa_st_w(from, to, 0); +} + +template <> +EIGEN_DEVICE_FUNC inline Packet4f pgather(const float* from, Index stride) { + EIGEN_MSA_DEBUG; + + float f = *from; + Packet4f v = { f, f, f, f }; + v[1] = from[stride]; + v[2] = from[2 * stride]; + v[3] = from[3 * stride]; + return v; +} + +template <> +EIGEN_DEVICE_FUNC inline Packet4i pgather(const int32_t* from, Index stride) { + EIGEN_MSA_DEBUG; + + int32_t i = *from; + Packet4i v = { i, i, i, i }; + v[1] = from[stride]; + v[2] = from[2 * stride]; + v[3] = from[3 * stride]; + return v; +} + +template <> +EIGEN_DEVICE_FUNC inline void pscatter(float* to, const Packet4f& from, + Index stride) { + EIGEN_MSA_DEBUG; + + *to = from[0]; + to += stride; + *to = from[1]; + to += stride; + *to = from[2]; + to += stride; + *to = from[3]; +} + +template <> +EIGEN_DEVICE_FUNC inline void pscatter(int32_t* to, const Packet4i& from, + Index stride) { + EIGEN_MSA_DEBUG; + + *to = from[0]; + to += stride; + *to = from[1]; + to += stride; + *to = from[2]; + to += stride; + *to = from[3]; +} + +template <> +EIGEN_STRONG_INLINE void prefetch(const float* addr) { + EIGEN_MSA_DEBUG; + + __builtin_prefetch(addr); +} + +template <> +EIGEN_STRONG_INLINE void prefetch(const int32_t* addr) { + EIGEN_MSA_DEBUG; + + __builtin_prefetch(addr); +} + +template <> +EIGEN_STRONG_INLINE float pfirst(const Packet4f& a) { + EIGEN_MSA_DEBUG; + + return a[0]; +} + +template <> +EIGEN_STRONG_INLINE int32_t pfirst(const Packet4i& a) { + EIGEN_MSA_DEBUG; + + return a[0]; +} + +template <> +EIGEN_STRONG_INLINE Packet4f preverse(const Packet4f& a) { + EIGEN_MSA_DEBUG; + + return (Packet4f)__builtin_msa_shf_w((v4i32)a, EIGEN_MSA_SHF_I8(3, 2, 1, 0)); +} + +template <> +EIGEN_STRONG_INLINE Packet4i preverse(const Packet4i& a) { + EIGEN_MSA_DEBUG; + + return __builtin_msa_shf_w(a, EIGEN_MSA_SHF_I8(3, 2, 1, 0)); +} + +template <> +EIGEN_STRONG_INLINE Packet4f pabs(const Packet4f& a) { + EIGEN_MSA_DEBUG; + + return (Packet4f)__builtin_msa_bclri_w((v4u32)a, 31); +} + +template <> +EIGEN_STRONG_INLINE Packet4i pabs(const Packet4i& a) { + EIGEN_MSA_DEBUG; + + Packet4i zero = __builtin_msa_ldi_w(0); + return __builtin_msa_add_a_w(zero, a); +} + +template <> +EIGEN_STRONG_INLINE float predux(const Packet4f& a) { + EIGEN_MSA_DEBUG; + + Packet4f s = padd(a, (Packet4f)__builtin_msa_shf_w((v4i32)a, EIGEN_MSA_SHF_I8(2, 3, 0, 1))); + s = padd(s, (Packet4f)__builtin_msa_shf_w((v4i32)s, EIGEN_MSA_SHF_I8(1, 0, 3, 2))); + return s[0]; +} + + +template <> +EIGEN_STRONG_INLINE int32_t predux(const Packet4i& a) { + EIGEN_MSA_DEBUG; + + Packet4i s = padd(a, __builtin_msa_shf_w(a, EIGEN_MSA_SHF_I8(2, 3, 0, 1))); + s = padd(s, __builtin_msa_shf_w(s, EIGEN_MSA_SHF_I8(1, 0, 3, 2))); + return s[0]; +} + +// Other reduction functions: +// mul +template <> +EIGEN_STRONG_INLINE float predux_mul(const Packet4f& a) { + EIGEN_MSA_DEBUG; + + Packet4f p = pmul(a, (Packet4f)__builtin_msa_shf_w((v4i32)a, EIGEN_MSA_SHF_I8(2, 3, 0, 1))); + p = pmul(p, (Packet4f)__builtin_msa_shf_w((v4i32)p, EIGEN_MSA_SHF_I8(1, 0, 3, 2))); + return p[0]; +} + +template <> +EIGEN_STRONG_INLINE int32_t predux_mul(const Packet4i& a) { + EIGEN_MSA_DEBUG; + + Packet4i p = pmul(a, __builtin_msa_shf_w(a, EIGEN_MSA_SHF_I8(2, 3, 0, 1))); + p = pmul(p, __builtin_msa_shf_w(p, EIGEN_MSA_SHF_I8(1, 0, 3, 2))); + return p[0]; +} + +// min +template <> +EIGEN_STRONG_INLINE float predux_min(const Packet4f& a) { + EIGEN_MSA_DEBUG; + + // Swap 64-bit halves of a. + Packet4f swapped = (Packet4f)__builtin_msa_shf_w((Packet4i)a, EIGEN_MSA_SHF_I8(2, 3, 0, 1)); +#if !EIGEN_FAST_MATH + // Detect presence of NaNs from pairs a[0]-a[2] and a[1]-a[3] as two 32-bit + // masks of all zeroes/ones in low 64 bits. + v16u8 unord = (v16u8)__builtin_msa_fcun_w(a, swapped); + // Combine the two masks into one: 64 ones if no NaNs, otherwise 64 zeroes. + unord = (v16u8)__builtin_msa_ceqi_d((v2i64)unord, 0); +#endif + // Continue with min computation. + Packet4f v = __builtin_msa_fmin_w(a, swapped); + v = __builtin_msa_fmin_w( + v, (Packet4f)__builtin_msa_shf_w((Packet4i)v, EIGEN_MSA_SHF_I8(1, 0, 3, 2))); +#if !EIGEN_FAST_MATH + // Based on the mask select between v and 4 qNaNs. + v16u8 qnans = (v16u8)__builtin_msa_fill_w(0x7FC00000); + v = (Packet4f)__builtin_msa_bsel_v(unord, qnans, (v16u8)v); +#endif + return v[0]; +} + +template <> +EIGEN_STRONG_INLINE int32_t predux_min(const Packet4i& a) { + EIGEN_MSA_DEBUG; + + Packet4i m = pmin(a, __builtin_msa_shf_w(a, EIGEN_MSA_SHF_I8(2, 3, 0, 1))); + m = pmin(m, __builtin_msa_shf_w(m, EIGEN_MSA_SHF_I8(1, 0, 3, 2))); + return m[0]; +} + +// max +template <> +EIGEN_STRONG_INLINE float predux_max(const Packet4f& a) { + EIGEN_MSA_DEBUG; + + // Swap 64-bit halves of a. + Packet4f swapped = (Packet4f)__builtin_msa_shf_w((Packet4i)a, EIGEN_MSA_SHF_I8(2, 3, 0, 1)); +#if !EIGEN_FAST_MATH + // Detect presence of NaNs from pairs a[0]-a[2] and a[1]-a[3] as two 32-bit + // masks of all zeroes/ones in low 64 bits. + v16u8 unord = (v16u8)__builtin_msa_fcun_w(a, swapped); + // Combine the two masks into one: 64 ones if no NaNs, otherwise 64 zeroes. + unord = (v16u8)__builtin_msa_ceqi_d((v2i64)unord, 0); +#endif + // Continue with max computation. + Packet4f v = __builtin_msa_fmax_w(a, swapped); + v = __builtin_msa_fmax_w( + v, (Packet4f)__builtin_msa_shf_w((Packet4i)v, EIGEN_MSA_SHF_I8(1, 0, 3, 2))); +#if !EIGEN_FAST_MATH + // Based on the mask select between v and 4 qNaNs. + v16u8 qnans = (v16u8)__builtin_msa_fill_w(0x7FC00000); + v = (Packet4f)__builtin_msa_bsel_v(unord, qnans, (v16u8)v); +#endif + return v[0]; +} + +template <> +EIGEN_STRONG_INLINE int32_t predux_max(const Packet4i& a) { + EIGEN_MSA_DEBUG; + + Packet4i m = pmax(a, __builtin_msa_shf_w(a, EIGEN_MSA_SHF_I8(2, 3, 0, 1))); + m = pmax(m, __builtin_msa_shf_w(m, EIGEN_MSA_SHF_I8(1, 0, 3, 2))); + return m[0]; +} + +inline std::ostream& operator<<(std::ostream& os, const PacketBlock& value) { + os << "[ " << value.packet[0] << "," << std::endl + << " " << value.packet[1] << "," << std::endl + << " " << value.packet[2] << "," << std::endl + << " " << value.packet[3] << " ]"; + return os; +} + +EIGEN_DEVICE_FUNC inline void ptranspose(PacketBlock& kernel) { + EIGEN_MSA_DEBUG; + + v4i32 tmp1, tmp2, tmp3, tmp4; + + tmp1 = __builtin_msa_ilvr_w((v4i32)kernel.packet[1], (v4i32)kernel.packet[0]); + tmp2 = __builtin_msa_ilvr_w((v4i32)kernel.packet[3], (v4i32)kernel.packet[2]); + tmp3 = __builtin_msa_ilvl_w((v4i32)kernel.packet[1], (v4i32)kernel.packet[0]); + tmp4 = __builtin_msa_ilvl_w((v4i32)kernel.packet[3], (v4i32)kernel.packet[2]); + + kernel.packet[0] = (Packet4f)__builtin_msa_ilvr_d((v2i64)tmp2, (v2i64)tmp1); + kernel.packet[1] = (Packet4f)__builtin_msa_ilvod_d((v2i64)tmp2, (v2i64)tmp1); + kernel.packet[2] = (Packet4f)__builtin_msa_ilvr_d((v2i64)tmp4, (v2i64)tmp3); + kernel.packet[3] = (Packet4f)__builtin_msa_ilvod_d((v2i64)tmp4, (v2i64)tmp3); +} + +inline std::ostream& operator<<(std::ostream& os, const PacketBlock& value) { + os << "[ " << value.packet[0] << "," << std::endl + << " " << value.packet[1] << "," << std::endl + << " " << value.packet[2] << "," << std::endl + << " " << value.packet[3] << " ]"; + return os; +} + +EIGEN_DEVICE_FUNC inline void ptranspose(PacketBlock& kernel) { + EIGEN_MSA_DEBUG; + + v4i32 tmp1, tmp2, tmp3, tmp4; + + tmp1 = __builtin_msa_ilvr_w(kernel.packet[1], kernel.packet[0]); + tmp2 = __builtin_msa_ilvr_w(kernel.packet[3], kernel.packet[2]); + tmp3 = __builtin_msa_ilvl_w(kernel.packet[1], kernel.packet[0]); + tmp4 = __builtin_msa_ilvl_w(kernel.packet[3], kernel.packet[2]); + + kernel.packet[0] = (Packet4i)__builtin_msa_ilvr_d((v2i64)tmp2, (v2i64)tmp1); + kernel.packet[1] = (Packet4i)__builtin_msa_ilvod_d((v2i64)tmp2, (v2i64)tmp1); + kernel.packet[2] = (Packet4i)__builtin_msa_ilvr_d((v2i64)tmp4, (v2i64)tmp3); + kernel.packet[3] = (Packet4i)__builtin_msa_ilvod_d((v2i64)tmp4, (v2i64)tmp3); +} + +template <> +EIGEN_STRONG_INLINE Packet4f psqrt(const Packet4f& a) { + EIGEN_MSA_DEBUG; + + return __builtin_msa_fsqrt_w(a); +} + +template <> +EIGEN_STRONG_INLINE Packet4f prsqrt(const Packet4f& a) { + EIGEN_MSA_DEBUG; + +#if EIGEN_FAST_MATH + return __builtin_msa_frsqrt_w(a); +#else + Packet4f ones = __builtin_msa_ffint_s_w(__builtin_msa_ldi_w(1)); + return pdiv(ones, psqrt(a)); +#endif +} + +template <> +EIGEN_STRONG_INLINE Packet4f pfloor(const Packet4f& a) { + Packet4f v = a; + int32_t old_mode, new_mode; + asm volatile( + "cfcmsa %[old_mode], $1\n" + "ori %[new_mode], %[old_mode], 3\n" // 3 = round towards -INFINITY. + "ctcmsa $1, %[new_mode]\n" + "frint.w %w[v], %w[v]\n" + "ctcmsa $1, %[old_mode]\n" + : // outputs + [old_mode] "=r"(old_mode), [new_mode] "=r"(new_mode), + [v] "+f"(v) + : // inputs + : // clobbers + ); + return v; +} + +template <> +EIGEN_STRONG_INLINE Packet4f pceil(const Packet4f& a) { + Packet4f v = a; + int32_t old_mode, new_mode; + asm volatile( + "cfcmsa %[old_mode], $1\n" + "ori %[new_mode], %[old_mode], 3\n" + "xori %[new_mode], %[new_mode], 1\n" // 2 = round towards +INFINITY. + "ctcmsa $1, %[new_mode]\n" + "frint.w %w[v], %w[v]\n" + "ctcmsa $1, %[old_mode]\n" + : // outputs + [old_mode] "=r"(old_mode), [new_mode] "=r"(new_mode), + [v] "+f"(v) + : // inputs + : // clobbers + ); + return v; +} + +template <> +EIGEN_STRONG_INLINE Packet4f pround(const Packet4f& a) { + Packet4f v = a; + int32_t old_mode, new_mode; + asm volatile( + "cfcmsa %[old_mode], $1\n" + "ori %[new_mode], %[old_mode], 3\n" + "xori %[new_mode], %[new_mode], 3\n" // 0 = round to nearest, ties to even. + "ctcmsa $1, %[new_mode]\n" + "frint.w %w[v], %w[v]\n" + "ctcmsa $1, %[old_mode]\n" + : // outputs + [old_mode] "=r"(old_mode), [new_mode] "=r"(new_mode), + [v] "+f"(v) + : // inputs + : // clobbers + ); + return v; +} + +template <> +EIGEN_STRONG_INLINE Packet4f pblend(const Selector<4>& ifPacket, const Packet4f& thenPacket, + const Packet4f& elsePacket) { + Packet4ui select = { ifPacket.select[0], ifPacket.select[1], ifPacket.select[2], + ifPacket.select[3] }; + Packet4i mask = __builtin_msa_ceqi_w((Packet4i)select, 0); + return (Packet4f)__builtin_msa_bsel_v((v16u8)mask, (v16u8)thenPacket, (v16u8)elsePacket); +} + +template <> +EIGEN_STRONG_INLINE Packet4i pblend(const Selector<4>& ifPacket, const Packet4i& thenPacket, + const Packet4i& elsePacket) { + Packet4ui select = { ifPacket.select[0], ifPacket.select[1], ifPacket.select[2], + ifPacket.select[3] }; + Packet4i mask = __builtin_msa_ceqi_w((Packet4i)select, 0); + return (Packet4i)__builtin_msa_bsel_v((v16u8)mask, (v16u8)thenPacket, (v16u8)elsePacket); +} + +//---------- double ---------- + +typedef v2f64 Packet2d; +typedef v2i64 Packet2l; +typedef v2u64 Packet2ul; + +#define _EIGEN_DECLARE_CONST_Packet2d(NAME, X) const Packet2d p2d_##NAME = { X, X } +#define _EIGEN_DECLARE_CONST_Packet2l(NAME, X) const Packet2l p2l_##NAME = { X, X } +#define _EIGEN_DECLARE_CONST_Packet2ul(NAME, X) const Packet2ul p2ul_##NAME = { X, X } + +inline std::ostream& operator<<(std::ostream& os, const Packet2d& value) { + os << "[ " << value[0] << ", " << value[1] << " ]"; + return os; +} + +inline std::ostream& operator<<(std::ostream& os, const Packet2l& value) { + os << "[ " << value[0] << ", " << value[1] << " ]"; + return os; +} + +inline std::ostream& operator<<(std::ostream& os, const Packet2ul& value) { + os << "[ " << value[0] << ", " << value[1] << " ]"; + return os; +} + +template <> +struct packet_traits : default_packet_traits { + typedef Packet2d type; + typedef Packet2d half; + enum { + Vectorizable = 1, + AlignedOnScalar = 1, + size = 2, + HasHalfPacket = 0, + // FIXME check the Has* + HasDiv = 1, + HasExp = 1, + HasSqrt = 1, + HasRsqrt = 1, + HasRound = 1, + HasFloor = 1, + HasCeil = 1, + HasBlend = 1 + }; +}; + +template <> +struct unpacket_traits { + typedef double type; + enum { size = 2, alignment = Aligned16, vectorizable=true, masked_load_available=false, masked_store_available=false }; + typedef Packet2d half; +}; + +template <> +EIGEN_STRONG_INLINE Packet2d pset1(const double& from) { + EIGEN_MSA_DEBUG; + + Packet2d value = { from, from }; + return value; +} + +template <> +EIGEN_STRONG_INLINE Packet2d padd(const Packet2d& a, const Packet2d& b) { + EIGEN_MSA_DEBUG; + + return __builtin_msa_fadd_d(a, b); +} + +template <> +EIGEN_STRONG_INLINE Packet2d plset(const double& a) { + EIGEN_MSA_DEBUG; + + static const Packet2d countdown = { 0.0, 1.0 }; + return padd(pset1(a), countdown); +} + +template <> +EIGEN_STRONG_INLINE Packet2d psub(const Packet2d& a, const Packet2d& b) { + EIGEN_MSA_DEBUG; + + return __builtin_msa_fsub_d(a, b); +} + +template <> +EIGEN_STRONG_INLINE Packet2d pnegate(const Packet2d& a) { + EIGEN_MSA_DEBUG; + + return (Packet2d)__builtin_msa_bnegi_d((v2u64)a, 63); +} + +template <> +EIGEN_STRONG_INLINE Packet2d pconj(const Packet2d& a) { + EIGEN_MSA_DEBUG; + + return a; +} + +template <> +EIGEN_STRONG_INLINE Packet2d pmul(const Packet2d& a, const Packet2d& b) { + EIGEN_MSA_DEBUG; + + return __builtin_msa_fmul_d(a, b); +} + +template <> +EIGEN_STRONG_INLINE Packet2d pdiv(const Packet2d& a, const Packet2d& b) { + EIGEN_MSA_DEBUG; + + return __builtin_msa_fdiv_d(a, b); +} + +template <> +EIGEN_STRONG_INLINE Packet2d pmadd(const Packet2d& a, const Packet2d& b, const Packet2d& c) { + EIGEN_MSA_DEBUG; + + return __builtin_msa_fmadd_d(c, a, b); +} + +// Logical Operations are not supported for float, so we have to reinterpret casts using MSA +// intrinsics +template <> +EIGEN_STRONG_INLINE Packet2d pand(const Packet2d& a, const Packet2d& b) { + EIGEN_MSA_DEBUG; + + return (Packet2d)__builtin_msa_and_v((v16u8)a, (v16u8)b); +} + +template <> +EIGEN_STRONG_INLINE Packet2d por(const Packet2d& a, const Packet2d& b) { + EIGEN_MSA_DEBUG; + + return (Packet2d)__builtin_msa_or_v((v16u8)a, (v16u8)b); +} + +template <> +EIGEN_STRONG_INLINE Packet2d pxor(const Packet2d& a, const Packet2d& b) { + EIGEN_MSA_DEBUG; + + return (Packet2d)__builtin_msa_xor_v((v16u8)a, (v16u8)b); +} + +template <> +EIGEN_STRONG_INLINE Packet2d pandnot(const Packet2d& a, const Packet2d& b) { + EIGEN_MSA_DEBUG; + + return pand(a, (Packet2d)__builtin_msa_xori_b((v16u8)b, 255)); +} + +template <> +EIGEN_STRONG_INLINE Packet2d pload(const double* from) { + EIGEN_MSA_DEBUG; + + EIGEN_DEBUG_UNALIGNED_LOAD return (Packet2d)__builtin_msa_ld_d(const_cast(from), 0); +} + +template <> +EIGEN_STRONG_INLINE Packet2d pmin(const Packet2d& a, const Packet2d& b) { + EIGEN_MSA_DEBUG; + +#if EIGEN_FAST_MATH + // This prefers numbers to NaNs. + return __builtin_msa_fmin_d(a, b); +#else + // This prefers NaNs to numbers. + v2i64 aNaN = __builtin_msa_fcun_d(a, a); + v2i64 aMinOrNaN = por(__builtin_msa_fclt_d(a, b), aNaN); + return (Packet2d)__builtin_msa_bsel_v((v16u8)aMinOrNaN, (v16u8)b, (v16u8)a); +#endif +} + +template <> +EIGEN_STRONG_INLINE Packet2d pmax(const Packet2d& a, const Packet2d& b) { + EIGEN_MSA_DEBUG; + +#if EIGEN_FAST_MATH + // This prefers numbers to NaNs. + return __builtin_msa_fmax_d(a, b); +#else + // This prefers NaNs to numbers. + v2i64 aNaN = __builtin_msa_fcun_d(a, a); + v2i64 aMaxOrNaN = por(__builtin_msa_fclt_d(b, a), aNaN); + return (Packet2d)__builtin_msa_bsel_v((v16u8)aMaxOrNaN, (v16u8)b, (v16u8)a); +#endif +} + +template <> +EIGEN_STRONG_INLINE Packet2d ploadu(const double* from) { + EIGEN_MSA_DEBUG; + + EIGEN_DEBUG_UNALIGNED_LOAD return (Packet2d)__builtin_msa_ld_d(const_cast(from), 0); +} + +template <> +EIGEN_STRONG_INLINE Packet2d ploaddup(const double* from) { + EIGEN_MSA_DEBUG; + + Packet2d value = { *from, *from }; + return value; +} + +template <> +EIGEN_STRONG_INLINE void pstore(double* to, const Packet2d& from) { + EIGEN_MSA_DEBUG; + + EIGEN_DEBUG_ALIGNED_STORE __builtin_msa_st_d((v2i64)from, to, 0); +} + +template <> +EIGEN_STRONG_INLINE void pstoreu(double* to, const Packet2d& from) { + EIGEN_MSA_DEBUG; + + EIGEN_DEBUG_UNALIGNED_STORE __builtin_msa_st_d((v2i64)from, to, 0); +} + +template <> +EIGEN_DEVICE_FUNC inline Packet2d pgather(const double* from, Index stride) { + EIGEN_MSA_DEBUG; + + Packet2d value; + value[0] = *from; + from += stride; + value[1] = *from; + return value; +} + +template <> +EIGEN_DEVICE_FUNC inline void pscatter(double* to, const Packet2d& from, + Index stride) { + EIGEN_MSA_DEBUG; + + *to = from[0]; + to += stride; + *to = from[1]; +} + +template <> +EIGEN_STRONG_INLINE void prefetch(const double* addr) { + EIGEN_MSA_DEBUG; + + __builtin_prefetch(addr); +} + +template <> +EIGEN_STRONG_INLINE double pfirst(const Packet2d& a) { + EIGEN_MSA_DEBUG; + + return a[0]; +} + +template <> +EIGEN_STRONG_INLINE Packet2d preverse(const Packet2d& a) { + EIGEN_MSA_DEBUG; + + return (Packet2d)__builtin_msa_shf_w((v4i32)a, EIGEN_MSA_SHF_I8(2, 3, 0, 1)); +} + +template <> +EIGEN_STRONG_INLINE Packet2d pabs(const Packet2d& a) { + EIGEN_MSA_DEBUG; + + return (Packet2d)__builtin_msa_bclri_d((v2u64)a, 63); +} + +template <> +EIGEN_STRONG_INLINE double predux(const Packet2d& a) { + EIGEN_MSA_DEBUG; + + Packet2d s = padd(a, preverse(a)); + return s[0]; +} + +// Other reduction functions: +// mul +template <> +EIGEN_STRONG_INLINE double predux_mul(const Packet2d& a) { + EIGEN_MSA_DEBUG; + + Packet2d p = pmul(a, preverse(a)); + return p[0]; +} + +// min +template <> +EIGEN_STRONG_INLINE double predux_min(const Packet2d& a) { + EIGEN_MSA_DEBUG; + +#if EIGEN_FAST_MATH + Packet2d swapped = (Packet2d)__builtin_msa_shf_w((Packet4i)a, EIGEN_MSA_SHF_I8(2, 3, 0, 1)); + Packet2d v = __builtin_msa_fmin_d(a, swapped); + return v[0]; +#else + double a0 = a[0], a1 = a[1]; + return ((numext::isnan)(a0) || a0 < a1) ? a0 : a1; +#endif +} + +// max +template <> +EIGEN_STRONG_INLINE double predux_max(const Packet2d& a) { + EIGEN_MSA_DEBUG; + +#if EIGEN_FAST_MATH + Packet2d swapped = (Packet2d)__builtin_msa_shf_w((Packet4i)a, EIGEN_MSA_SHF_I8(2, 3, 0, 1)); + Packet2d v = __builtin_msa_fmax_d(a, swapped); + return v[0]; +#else + double a0 = a[0], a1 = a[1]; + return ((numext::isnan)(a0) || a0 > a1) ? a0 : a1; +#endif +} + +template <> +EIGEN_STRONG_INLINE Packet2d psqrt(const Packet2d& a) { + EIGEN_MSA_DEBUG; + + return __builtin_msa_fsqrt_d(a); +} + +template <> +EIGEN_STRONG_INLINE Packet2d prsqrt(const Packet2d& a) { + EIGEN_MSA_DEBUG; + +#if EIGEN_FAST_MATH + return __builtin_msa_frsqrt_d(a); +#else + Packet2d ones = __builtin_msa_ffint_s_d(__builtin_msa_ldi_d(1)); + return pdiv(ones, psqrt(a)); +#endif +} + +inline std::ostream& operator<<(std::ostream& os, const PacketBlock& value) { + os << "[ " << value.packet[0] << "," << std::endl << " " << value.packet[1] << " ]"; + return os; +} + +EIGEN_DEVICE_FUNC inline void ptranspose(PacketBlock& kernel) { + EIGEN_MSA_DEBUG; + + Packet2d trn1 = (Packet2d)__builtin_msa_ilvev_d((v2i64)kernel.packet[1], (v2i64)kernel.packet[0]); + Packet2d trn2 = (Packet2d)__builtin_msa_ilvod_d((v2i64)kernel.packet[1], (v2i64)kernel.packet[0]); + kernel.packet[0] = trn1; + kernel.packet[1] = trn2; +} + +template <> +EIGEN_STRONG_INLINE Packet2d pfloor(const Packet2d& a) { + Packet2d v = a; + int32_t old_mode, new_mode; + asm volatile( + "cfcmsa %[old_mode], $1\n" + "ori %[new_mode], %[old_mode], 3\n" // 3 = round towards -INFINITY. + "ctcmsa $1, %[new_mode]\n" + "frint.d %w[v], %w[v]\n" + "ctcmsa $1, %[old_mode]\n" + : // outputs + [old_mode] "=r"(old_mode), [new_mode] "=r"(new_mode), + [v] "+f"(v) + : // inputs + : // clobbers + ); + return v; +} + +template <> +EIGEN_STRONG_INLINE Packet2d pceil(const Packet2d& a) { + Packet2d v = a; + int32_t old_mode, new_mode; + asm volatile( + "cfcmsa %[old_mode], $1\n" + "ori %[new_mode], %[old_mode], 3\n" + "xori %[new_mode], %[new_mode], 1\n" // 2 = round towards +INFINITY. + "ctcmsa $1, %[new_mode]\n" + "frint.d %w[v], %w[v]\n" + "ctcmsa $1, %[old_mode]\n" + : // outputs + [old_mode] "=r"(old_mode), [new_mode] "=r"(new_mode), + [v] "+f"(v) + : // inputs + : // clobbers + ); + return v; +} + +template <> +EIGEN_STRONG_INLINE Packet2d pround(const Packet2d& a) { + Packet2d v = a; + int32_t old_mode, new_mode; + asm volatile( + "cfcmsa %[old_mode], $1\n" + "ori %[new_mode], %[old_mode], 3\n" + "xori %[new_mode], %[new_mode], 3\n" // 0 = round to nearest, ties to even. + "ctcmsa $1, %[new_mode]\n" + "frint.d %w[v], %w[v]\n" + "ctcmsa $1, %[old_mode]\n" + : // outputs + [old_mode] "=r"(old_mode), [new_mode] "=r"(new_mode), + [v] "+f"(v) + : // inputs + : // clobbers + ); + return v; +} + +template <> +EIGEN_STRONG_INLINE Packet2d pblend(const Selector<2>& ifPacket, const Packet2d& thenPacket, + const Packet2d& elsePacket) { + Packet2ul select = { ifPacket.select[0], ifPacket.select[1] }; + Packet2l mask = __builtin_msa_ceqi_d((Packet2l)select, 0); + return (Packet2d)__builtin_msa_bsel_v((v16u8)mask, (v16u8)thenPacket, (v16u8)elsePacket); +} + +} // end namespace internal + +} // end namespace Eigen + +#endif // EIGEN_PACKET_MATH_MSA_H diff --git a/Eigen/src/Core/arch/NEON/Complex.h b/Eigen/src/Core/arch/NEON/Complex.h new file mode 100644 index 0000000..f40af7f --- /dev/null +++ b/Eigen/src/Core/arch/NEON/Complex.h @@ -0,0 +1,584 @@ +// This file is part of Eigen, a lightweight C++ template library +// for linear algebra. +// +// Copyright (C) 2010 Gael Guennebaud +// Copyright (C) 2010 Konstantinos Margaritis +// +// This Source Code Form is subject to the terms of the Mozilla +// Public License v. 2.0. If a copy of the MPL was not distributed +// with this file, You can obtain one at http://mozilla.org/MPL/2.0/. + +#ifndef EIGEN_COMPLEX_NEON_H +#define EIGEN_COMPLEX_NEON_H + +namespace Eigen { + +namespace internal { + +inline uint32x4_t p4ui_CONJ_XOR() +{ +// See bug 1325, clang fails to call vld1q_u64. +#if EIGEN_COMP_CLANG || EIGEN_COMP_CASTXML + uint32x4_t ret = { 0x00000000, 0x80000000, 0x00000000, 0x80000000 }; + return ret; +#else + static const uint32_t conj_XOR_DATA[] = { 0x00000000, 0x80000000, 0x00000000, 0x80000000 }; + return vld1q_u32( conj_XOR_DATA ); +#endif +} + +inline uint32x2_t p2ui_CONJ_XOR() +{ + static const uint32_t conj_XOR_DATA[] = { 0x00000000, 0x80000000 }; + return vld1_u32( conj_XOR_DATA ); +} + +//---------- float ---------- + +struct Packet1cf +{ + EIGEN_STRONG_INLINE Packet1cf() {} + EIGEN_STRONG_INLINE explicit Packet1cf(const Packet2f& a) : v(a) {} + Packet2f v; +}; +struct Packet2cf +{ + EIGEN_STRONG_INLINE Packet2cf() {} + EIGEN_STRONG_INLINE explicit Packet2cf(const Packet4f& a) : v(a) {} + Packet4f v; +}; + +template<> struct packet_traits > : default_packet_traits +{ + typedef Packet2cf type; + typedef Packet1cf half; + enum + { + Vectorizable = 1, + AlignedOnScalar = 1, + size = 2, + HasHalfPacket = 1, + + HasAdd = 1, + HasSub = 1, + HasMul = 1, + HasDiv = 1, + HasNegate = 1, + HasAbs = 0, + HasAbs2 = 0, + HasMin = 0, + HasMax = 0, + HasSetLinear = 0 + }; +}; + +template<> struct unpacket_traits +{ + typedef std::complex type; + typedef Packet1cf half; + typedef Packet2f as_real; + enum + { + size = 1, + alignment = Aligned16, + vectorizable = true, + masked_load_available = false, + masked_store_available = false + }; +}; +template<> struct unpacket_traits +{ + typedef std::complex type; + typedef Packet1cf half; + typedef Packet4f as_real; + enum + { + size = 2, + alignment = Aligned16, + vectorizable = true, + masked_load_available = false, + masked_store_available = false + }; +}; + +template<> EIGEN_STRONG_INLINE Packet1cf pcast(const float& a) +{ return Packet1cf(vset_lane_f32(a, vdup_n_f32(0.f), 0)); } +template<> EIGEN_STRONG_INLINE Packet2cf pcast(const Packet2f& a) +{ return Packet2cf(vreinterpretq_f32_u64(vmovl_u32(vreinterpret_u32_f32(a)))); } + +template<> EIGEN_STRONG_INLINE Packet1cf pset1(const std::complex& from) +{ return Packet1cf(vld1_f32(reinterpret_cast(&from))); } +template<> EIGEN_STRONG_INLINE Packet2cf pset1(const std::complex& from) +{ + const float32x2_t r64 = vld1_f32(reinterpret_cast(&from)); + return Packet2cf(vcombine_f32(r64, r64)); +} + +template<> EIGEN_STRONG_INLINE Packet1cf padd(const Packet1cf& a, const Packet1cf& b) +{ return Packet1cf(padd(a.v, b.v)); } +template<> EIGEN_STRONG_INLINE Packet2cf padd(const Packet2cf& a, const Packet2cf& b) +{ return Packet2cf(padd(a.v, b.v)); } + +template<> EIGEN_STRONG_INLINE Packet1cf psub(const Packet1cf& a, const Packet1cf& b) +{ return Packet1cf(psub(a.v, b.v)); } +template<> EIGEN_STRONG_INLINE Packet2cf psub(const Packet2cf& a, const Packet2cf& b) +{ return Packet2cf(psub(a.v, b.v)); } + +template<> EIGEN_STRONG_INLINE Packet1cf pnegate(const Packet1cf& a) { return Packet1cf(pnegate(a.v)); } +template<> EIGEN_STRONG_INLINE Packet2cf pnegate(const Packet2cf& a) { return Packet2cf(pnegate(a.v)); } + +template<> EIGEN_STRONG_INLINE Packet1cf pconj(const Packet1cf& a) +{ + const Packet2ui b = vreinterpret_u32_f32(a.v); + return Packet1cf(vreinterpret_f32_u32(veor_u32(b, p2ui_CONJ_XOR()))); +} +template<> EIGEN_STRONG_INLINE Packet2cf pconj(const Packet2cf& a) +{ + const Packet4ui b = vreinterpretq_u32_f32(a.v); + return Packet2cf(vreinterpretq_f32_u32(veorq_u32(b, p4ui_CONJ_XOR()))); +} + +template<> EIGEN_STRONG_INLINE Packet1cf pmul(const Packet1cf& a, const Packet1cf& b) +{ + Packet2f v1, v2; + + // Get the real values of a | a1_re | a1_re | + v1 = vdup_lane_f32(a.v, 0); + // Get the imag values of a | a1_im | a1_im | + v2 = vdup_lane_f32(a.v, 1); + // Multiply the real a with b + v1 = vmul_f32(v1, b.v); + // Multiply the imag a with b + v2 = vmul_f32(v2, b.v); + // Conjugate v2 + v2 = vreinterpret_f32_u32(veor_u32(vreinterpret_u32_f32(v2), p2ui_CONJ_XOR())); + // Swap real/imag elements in v2. + v2 = vrev64_f32(v2); + // Add and return the result + return Packet1cf(vadd_f32(v1, v2)); +} +template<> EIGEN_STRONG_INLINE Packet2cf pmul(const Packet2cf& a, const Packet2cf& b) +{ + Packet4f v1, v2; + + // Get the real values of a | a1_re | a1_re | a2_re | a2_re | + v1 = vcombine_f32(vdup_lane_f32(vget_low_f32(a.v), 0), vdup_lane_f32(vget_high_f32(a.v), 0)); + // Get the imag values of a | a1_im | a1_im | a2_im | a2_im | + v2 = vcombine_f32(vdup_lane_f32(vget_low_f32(a.v), 1), vdup_lane_f32(vget_high_f32(a.v), 1)); + // Multiply the real a with b + v1 = vmulq_f32(v1, b.v); + // Multiply the imag a with b + v2 = vmulq_f32(v2, b.v); + // Conjugate v2 + v2 = vreinterpretq_f32_u32(veorq_u32(vreinterpretq_u32_f32(v2), p4ui_CONJ_XOR())); + // Swap real/imag elements in v2. + v2 = vrev64q_f32(v2); + // Add and return the result + return Packet2cf(vaddq_f32(v1, v2)); +} + +template<> EIGEN_STRONG_INLINE Packet1cf pcmp_eq(const Packet1cf& a, const Packet1cf& b) +{ + // Compare real and imaginary parts of a and b to get the mask vector: + // [re(a[0])==re(b[0]), im(a[0])==im(b[0])] + Packet2f eq = pcmp_eq(a.v, b.v); + // Swap real/imag elements in the mask in to get: + // [im(a[0])==im(b[0]), re(a[0])==re(b[0])] + Packet2f eq_swapped = vrev64_f32(eq); + // Return re(a)==re(b) && im(a)==im(b) by computing bitwise AND of eq and eq_swapped + return Packet1cf(pand(eq, eq_swapped)); +} +template<> EIGEN_STRONG_INLINE Packet2cf pcmp_eq(const Packet2cf& a, const Packet2cf& b) +{ + // Compare real and imaginary parts of a and b to get the mask vector: + // [re(a[0])==re(b[0]), im(a[0])==im(b[0]), re(a[1])==re(b[1]), im(a[1])==im(b[1])] + Packet4f eq = pcmp_eq(a.v, b.v); + // Swap real/imag elements in the mask in to get: + // [im(a[0])==im(b[0]), re(a[0])==re(b[0]), im(a[1])==im(b[1]), re(a[1])==re(b[1])] + Packet4f eq_swapped = vrev64q_f32(eq); + // Return re(a)==re(b) && im(a)==im(b) by computing bitwise AND of eq and eq_swapped + return Packet2cf(pand(eq, eq_swapped)); +} + +template<> EIGEN_STRONG_INLINE Packet1cf pand(const Packet1cf& a, const Packet1cf& b) +{ return Packet1cf(vreinterpret_f32_u32(vand_u32(vreinterpret_u32_f32(a.v), vreinterpret_u32_f32(b.v)))); } +template<> EIGEN_STRONG_INLINE Packet2cf pand(const Packet2cf& a, const Packet2cf& b) +{ return Packet2cf(vreinterpretq_f32_u32(vandq_u32(vreinterpretq_u32_f32(a.v), vreinterpretq_u32_f32(b.v)))); } + +template<> EIGEN_STRONG_INLINE Packet1cf por(const Packet1cf& a, const Packet1cf& b) +{ return Packet1cf(vreinterpret_f32_u32(vorr_u32(vreinterpret_u32_f32(a.v), vreinterpret_u32_f32(b.v)))); } +template<> EIGEN_STRONG_INLINE Packet2cf por(const Packet2cf& a, const Packet2cf& b) +{ return Packet2cf(vreinterpretq_f32_u32(vorrq_u32(vreinterpretq_u32_f32(a.v), vreinterpretq_u32_f32(b.v)))); } + +template<> EIGEN_STRONG_INLINE Packet1cf pxor(const Packet1cf& a, const Packet1cf& b) +{ return Packet1cf(vreinterpret_f32_u32(veor_u32(vreinterpret_u32_f32(a.v), vreinterpret_u32_f32(b.v)))); } +template<> EIGEN_STRONG_INLINE Packet2cf pxor(const Packet2cf& a, const Packet2cf& b) +{ return Packet2cf(vreinterpretq_f32_u32(veorq_u32(vreinterpretq_u32_f32(a.v), vreinterpretq_u32_f32(b.v)))); } + +template<> EIGEN_STRONG_INLINE Packet1cf pandnot(const Packet1cf& a, const Packet1cf& b) +{ return Packet1cf(vreinterpret_f32_u32(vbic_u32(vreinterpret_u32_f32(a.v), vreinterpret_u32_f32(b.v)))); } +template<> EIGEN_STRONG_INLINE Packet2cf pandnot(const Packet2cf& a, const Packet2cf& b) +{ return Packet2cf(vreinterpretq_f32_u32(vbicq_u32(vreinterpretq_u32_f32(a.v), vreinterpretq_u32_f32(b.v)))); } + +template<> EIGEN_STRONG_INLINE Packet1cf pload(const std::complex* from) +{ EIGEN_DEBUG_ALIGNED_LOAD return Packet1cf(pload((const float*)from)); } +template<> EIGEN_STRONG_INLINE Packet2cf pload(const std::complex* from) +{ EIGEN_DEBUG_ALIGNED_LOAD return Packet2cf(pload(reinterpret_cast(from))); } + +template<> EIGEN_STRONG_INLINE Packet1cf ploadu(const std::complex* from) +{ EIGEN_DEBUG_UNALIGNED_LOAD return Packet1cf(ploadu((const float*)from)); } +template<> EIGEN_STRONG_INLINE Packet2cf ploadu(const std::complex* from) +{ EIGEN_DEBUG_UNALIGNED_LOAD return Packet2cf(ploadu(reinterpret_cast(from))); } + +template<> EIGEN_STRONG_INLINE Packet1cf ploaddup(const std::complex* from) +{ return pset1(*from); } +template<> EIGEN_STRONG_INLINE Packet2cf ploaddup(const std::complex* from) +{ return pset1(*from); } + +template<> EIGEN_STRONG_INLINE void pstore >(std::complex *to, const Packet1cf& from) +{ EIGEN_DEBUG_ALIGNED_STORE pstore((float*)to, from.v); } +template<> EIGEN_STRONG_INLINE void pstore >(std::complex *to, const Packet2cf& from) +{ EIGEN_DEBUG_ALIGNED_STORE pstore(reinterpret_cast(to), from.v); } + +template<> EIGEN_STRONG_INLINE void pstoreu >(std::complex *to, const Packet1cf& from) +{ EIGEN_DEBUG_UNALIGNED_STORE pstoreu((float*)to, from.v); } +template<> EIGEN_STRONG_INLINE void pstoreu >(std::complex *to, const Packet2cf& from) +{ EIGEN_DEBUG_UNALIGNED_STORE pstoreu(reinterpret_cast(to), from.v); } + +template<> EIGEN_DEVICE_FUNC inline Packet1cf pgather, Packet1cf>( + const std::complex* from, Index stride) +{ + const Packet2f tmp = vdup_n_f32(std::real(from[0*stride])); + return Packet1cf(vset_lane_f32(std::imag(from[0*stride]), tmp, 1)); +} +template<> EIGEN_DEVICE_FUNC inline Packet2cf pgather, Packet2cf>( + const std::complex* from, Index stride) +{ + Packet4f res = vdupq_n_f32(std::real(from[0*stride])); + res = vsetq_lane_f32(std::imag(from[0*stride]), res, 1); + res = vsetq_lane_f32(std::real(from[1*stride]), res, 2); + res = vsetq_lane_f32(std::imag(from[1*stride]), res, 3); + return Packet2cf(res); +} + +template<> EIGEN_DEVICE_FUNC inline void pscatter, Packet1cf>( + std::complex* to, const Packet1cf& from, Index stride) +{ to[stride*0] = std::complex(vget_lane_f32(from.v, 0), vget_lane_f32(from.v, 1)); } +template<> EIGEN_DEVICE_FUNC inline void pscatter, Packet2cf>( + std::complex* to, const Packet2cf& from, Index stride) +{ + to[stride*0] = std::complex(vgetq_lane_f32(from.v, 0), vgetq_lane_f32(from.v, 1)); + to[stride*1] = std::complex(vgetq_lane_f32(from.v, 2), vgetq_lane_f32(from.v, 3)); +} + +template<> EIGEN_STRONG_INLINE void prefetch >(const std::complex *addr) +{ EIGEN_ARM_PREFETCH(reinterpret_cast(addr)); } + +template<> EIGEN_STRONG_INLINE std::complex pfirst(const Packet1cf& a) +{ + EIGEN_ALIGN16 std::complex x; + vst1_f32(reinterpret_cast(&x), a.v); + return x; +} +template<> EIGEN_STRONG_INLINE std::complex pfirst(const Packet2cf& a) +{ + EIGEN_ALIGN16 std::complex x[2]; + vst1q_f32(reinterpret_cast(x), a.v); + return x[0]; +} + +template<> EIGEN_STRONG_INLINE Packet1cf preverse(const Packet1cf& a) { return a; } +template<> EIGEN_STRONG_INLINE Packet2cf preverse(const Packet2cf& a) +{ return Packet2cf(vcombine_f32(vget_high_f32(a.v), vget_low_f32(a.v))); } + +template<> EIGEN_STRONG_INLINE Packet1cf pcplxflip(const Packet1cf& a) +{ return Packet1cf(vrev64_f32(a.v)); } +template<> EIGEN_STRONG_INLINE Packet2cf pcplxflip(const Packet2cf& a) +{ return Packet2cf(vrev64q_f32(a.v)); } + +template<> EIGEN_STRONG_INLINE std::complex predux(const Packet1cf& a) +{ + std::complex s; + vst1_f32((float *)&s, a.v); + return s; +} +template<> EIGEN_STRONG_INLINE std::complex predux(const Packet2cf& a) +{ + std::complex s; + vst1_f32(reinterpret_cast(&s), vadd_f32(vget_low_f32(a.v), vget_high_f32(a.v))); + return s; +} + +template<> EIGEN_STRONG_INLINE std::complex predux_mul(const Packet1cf& a) +{ + std::complex s; + vst1_f32((float *)&s, a.v); + return s; +} +template<> EIGEN_STRONG_INLINE std::complex predux_mul(const Packet2cf& a) +{ + float32x2_t a1, a2, v1, v2, prod; + std::complex s; + + a1 = vget_low_f32(a.v); + a2 = vget_high_f32(a.v); + // Get the real values of a | a1_re | a1_re | a2_re | a2_re | + v1 = vdup_lane_f32(a1, 0); + // Get the real values of a | a1_im | a1_im | a2_im | a2_im | + v2 = vdup_lane_f32(a1, 1); + // Multiply the real a with b + v1 = vmul_f32(v1, a2); + // Multiply the imag a with b + v2 = vmul_f32(v2, a2); + // Conjugate v2 + v2 = vreinterpret_f32_u32(veor_u32(vreinterpret_u32_f32(v2), p2ui_CONJ_XOR())); + // Swap real/imag elements in v2. + v2 = vrev64_f32(v2); + // Add v1, v2 + prod = vadd_f32(v1, v2); + + vst1_f32(reinterpret_cast(&s), prod); + + return s; +} + +EIGEN_MAKE_CONJ_HELPER_CPLX_REAL(Packet1cf,Packet2f) +EIGEN_MAKE_CONJ_HELPER_CPLX_REAL(Packet2cf,Packet4f) + +template<> EIGEN_STRONG_INLINE Packet1cf pdiv(const Packet1cf& a, const Packet1cf& b) +{ + // TODO optimize it for NEON + Packet1cf res = pmul(a, pconj(b)); + Packet2f s, rev_s; + + // this computes the norm + s = vmul_f32(b.v, b.v); + rev_s = vrev64_f32(s); + + return Packet1cf(pdiv(res.v, vadd_f32(s, rev_s))); +} +template<> EIGEN_STRONG_INLINE Packet2cf pdiv(const Packet2cf& a, const Packet2cf& b) +{ + // TODO optimize it for NEON + Packet2cf res = pmul(a,pconj(b)); + Packet4f s, rev_s; + + // this computes the norm + s = vmulq_f32(b.v, b.v); + rev_s = vrev64q_f32(s); + + return Packet2cf(pdiv(res.v, vaddq_f32(s, rev_s))); +} + +EIGEN_DEVICE_FUNC inline void ptranspose(PacketBlock& /*kernel*/) {} +EIGEN_DEVICE_FUNC inline void ptranspose(PacketBlock& kernel) +{ + Packet4f tmp = vcombine_f32(vget_high_f32(kernel.packet[0].v), vget_high_f32(kernel.packet[1].v)); + kernel.packet[0].v = vcombine_f32(vget_low_f32(kernel.packet[0].v), vget_low_f32(kernel.packet[1].v)); + kernel.packet[1].v = tmp; +} + +template<> EIGEN_STRONG_INLINE Packet1cf psqrt(const Packet1cf& a) { + return psqrt_complex(a); +} + +template<> EIGEN_STRONG_INLINE Packet2cf psqrt(const Packet2cf& a) { + return psqrt_complex(a); +} + +//---------- double ---------- +#if EIGEN_ARCH_ARM64 && !EIGEN_APPLE_DOUBLE_NEON_BUG + +// See bug 1325, clang fails to call vld1q_u64. +#if EIGEN_COMP_CLANG || EIGEN_COMP_CASTXML + static uint64x2_t p2ul_CONJ_XOR = {0x0, 0x8000000000000000}; +#else + const uint64_t p2ul_conj_XOR_DATA[] = { 0x0, 0x8000000000000000 }; + static uint64x2_t p2ul_CONJ_XOR = vld1q_u64( p2ul_conj_XOR_DATA ); +#endif + +struct Packet1cd +{ + EIGEN_STRONG_INLINE Packet1cd() {} + EIGEN_STRONG_INLINE explicit Packet1cd(const Packet2d& a) : v(a) {} + Packet2d v; +}; + +template<> struct packet_traits > : default_packet_traits +{ + typedef Packet1cd type; + typedef Packet1cd half; + enum + { + Vectorizable = 1, + AlignedOnScalar = 0, + size = 1, + HasHalfPacket = 0, + + HasAdd = 1, + HasSub = 1, + HasMul = 1, + HasDiv = 1, + HasNegate = 1, + HasAbs = 0, + HasAbs2 = 0, + HasMin = 0, + HasMax = 0, + HasSetLinear = 0 + }; +}; + +template<> struct unpacket_traits +{ + typedef std::complex type; + typedef Packet1cd half; + typedef Packet2d as_real; + enum + { + size=1, + alignment=Aligned16, + vectorizable=true, + masked_load_available=false, + masked_store_available=false + }; +}; + +template<> EIGEN_STRONG_INLINE Packet1cd pload(const std::complex* from) +{ EIGEN_DEBUG_ALIGNED_LOAD return Packet1cd(pload(reinterpret_cast(from))); } + +template<> EIGEN_STRONG_INLINE Packet1cd ploadu(const std::complex* from) +{ EIGEN_DEBUG_UNALIGNED_LOAD return Packet1cd(ploadu(reinterpret_cast(from))); } + +template<> EIGEN_STRONG_INLINE Packet1cd pset1(const std::complex& from) +{ + /* here we really have to use unaligned loads :( */ + return ploadu(&from); +} + +template<> EIGEN_STRONG_INLINE Packet1cd padd(const Packet1cd& a, const Packet1cd& b) +{ return Packet1cd(padd(a.v, b.v)); } + +template<> EIGEN_STRONG_INLINE Packet1cd psub(const Packet1cd& a, const Packet1cd& b) +{ return Packet1cd(psub(a.v, b.v)); } + +template<> EIGEN_STRONG_INLINE Packet1cd pnegate(const Packet1cd& a) +{ return Packet1cd(pnegate(a.v)); } + +template<> EIGEN_STRONG_INLINE Packet1cd pconj(const Packet1cd& a) +{ return Packet1cd(vreinterpretq_f64_u64(veorq_u64(vreinterpretq_u64_f64(a.v), p2ul_CONJ_XOR))); } + +template<> EIGEN_STRONG_INLINE Packet1cd pmul(const Packet1cd& a, const Packet1cd& b) +{ + Packet2d v1, v2; + + // Get the real values of a + v1 = vdupq_lane_f64(vget_low_f64(a.v), 0); + // Get the imag values of a + v2 = vdupq_lane_f64(vget_high_f64(a.v), 0); + // Multiply the real a with b + v1 = vmulq_f64(v1, b.v); + // Multiply the imag a with b + v2 = vmulq_f64(v2, b.v); + // Conjugate v2 + v2 = vreinterpretq_f64_u64(veorq_u64(vreinterpretq_u64_f64(v2), p2ul_CONJ_XOR)); + // Swap real/imag elements in v2. + v2 = preverse(v2); + // Add and return the result + return Packet1cd(vaddq_f64(v1, v2)); +} + +template<> EIGEN_STRONG_INLINE Packet1cd pcmp_eq(const Packet1cd& a, const Packet1cd& b) +{ + // Compare real and imaginary parts of a and b to get the mask vector: + // [re(a)==re(b), im(a)==im(b)] + Packet2d eq = pcmp_eq(a.v, b.v); + // Swap real/imag elements in the mask in to get: + // [im(a)==im(b), re(a)==re(b)] + Packet2d eq_swapped = vreinterpretq_f64_u32(vrev64q_u32(vreinterpretq_u32_f64(eq))); + // Return re(a)==re(b) & im(a)==im(b) by computing bitwise AND of eq and eq_swapped + return Packet1cd(pand(eq, eq_swapped)); +} + +template<> EIGEN_STRONG_INLINE Packet1cd pand(const Packet1cd& a, const Packet1cd& b) +{ return Packet1cd(vreinterpretq_f64_u64(vandq_u64(vreinterpretq_u64_f64(a.v),vreinterpretq_u64_f64(b.v)))); } + +template<> EIGEN_STRONG_INLINE Packet1cd por(const Packet1cd& a, const Packet1cd& b) +{ return Packet1cd(vreinterpretq_f64_u64(vorrq_u64(vreinterpretq_u64_f64(a.v),vreinterpretq_u64_f64(b.v)))); } + +template<> EIGEN_STRONG_INLINE Packet1cd pxor(const Packet1cd& a, const Packet1cd& b) +{ return Packet1cd(vreinterpretq_f64_u64(veorq_u64(vreinterpretq_u64_f64(a.v),vreinterpretq_u64_f64(b.v)))); } + +template<> EIGEN_STRONG_INLINE Packet1cd pandnot(const Packet1cd& a, const Packet1cd& b) +{ return Packet1cd(vreinterpretq_f64_u64(vbicq_u64(vreinterpretq_u64_f64(a.v),vreinterpretq_u64_f64(b.v)))); } + +template<> EIGEN_STRONG_INLINE Packet1cd ploaddup(const std::complex* from) +{ return pset1(*from); } + +template<> EIGEN_STRONG_INLINE void pstore >(std::complex *to, const Packet1cd& from) +{ EIGEN_DEBUG_ALIGNED_STORE pstore(reinterpret_cast(to), from.v); } + +template<> EIGEN_STRONG_INLINE void pstoreu >(std::complex *to, const Packet1cd& from) +{ EIGEN_DEBUG_UNALIGNED_STORE pstoreu(reinterpret_cast(to), from.v); } + +template<> EIGEN_STRONG_INLINE void prefetch >(const std::complex *addr) +{ EIGEN_ARM_PREFETCH(reinterpret_cast(addr)); } + +template<> EIGEN_DEVICE_FUNC inline Packet1cd pgather, Packet1cd>( + const std::complex* from, Index stride) +{ + Packet2d res = pset1(0.0); + res = vsetq_lane_f64(std::real(from[0*stride]), res, 0); + res = vsetq_lane_f64(std::imag(from[0*stride]), res, 1); + return Packet1cd(res); +} + +template<> EIGEN_DEVICE_FUNC inline void pscatter, Packet1cd>( + std::complex* to, const Packet1cd& from, Index stride) +{ to[stride*0] = std::complex(vgetq_lane_f64(from.v, 0), vgetq_lane_f64(from.v, 1)); } + +template<> EIGEN_STRONG_INLINE std::complex pfirst(const Packet1cd& a) +{ + EIGEN_ALIGN16 std::complex res; + pstore >(&res, a); + return res; +} + +template<> EIGEN_STRONG_INLINE Packet1cd preverse(const Packet1cd& a) { return a; } + +template<> EIGEN_STRONG_INLINE std::complex predux(const Packet1cd& a) { return pfirst(a); } + +template<> EIGEN_STRONG_INLINE std::complex predux_mul(const Packet1cd& a) { return pfirst(a); } + +EIGEN_MAKE_CONJ_HELPER_CPLX_REAL(Packet1cd,Packet2d) + +template<> EIGEN_STRONG_INLINE Packet1cd pdiv(const Packet1cd& a, const Packet1cd& b) +{ + // TODO optimize it for NEON + Packet1cd res = pmul(a,pconj(b)); + Packet2d s = pmul(b.v, b.v); + Packet2d rev_s = preverse(s); + + return Packet1cd(pdiv(res.v, padd(s,rev_s))); +} + +EIGEN_STRONG_INLINE Packet1cd pcplxflip/**/(const Packet1cd& x) +{ return Packet1cd(preverse(Packet2d(x.v))); } + +EIGEN_STRONG_INLINE void ptranspose(PacketBlock& kernel) +{ + Packet2d tmp = vcombine_f64(vget_high_f64(kernel.packet[0].v), vget_high_f64(kernel.packet[1].v)); + kernel.packet[0].v = vcombine_f64(vget_low_f64(kernel.packet[0].v), vget_low_f64(kernel.packet[1].v)); + kernel.packet[1].v = tmp; +} + +template<> EIGEN_STRONG_INLINE Packet1cd psqrt(const Packet1cd& a) { + return psqrt_complex(a); +} + +#endif // EIGEN_ARCH_ARM64 + +} // end namespace internal + +} // end namespace Eigen + +#endif // EIGEN_COMPLEX_NEON_H diff --git a/Eigen/src/Core/arch/NEON/GeneralBlockPanelKernel.h b/Eigen/src/Core/arch/NEON/GeneralBlockPanelKernel.h new file mode 100644 index 0000000..3481f33 --- /dev/null +++ b/Eigen/src/Core/arch/NEON/GeneralBlockPanelKernel.h @@ -0,0 +1,183 @@ +namespace Eigen { +namespace internal { + +#if EIGEN_ARCH_ARM && EIGEN_COMP_CLANG + +// Clang seems to excessively spill registers in the GEBP kernel on 32-bit arm. +// Here we specialize gebp_traits to eliminate these register spills. +// See #2138. +template<> +struct gebp_traits + : gebp_traits +{ + EIGEN_STRONG_INLINE void acc(const AccPacket& c, const ResPacket& alpha, ResPacket& r) const + { + // This volatile inline ASM both acts as a barrier to prevent reordering, + // as well as enforces strict register use. + asm volatile( + "vmla.f32 %q[r], %q[c], %q[alpha]" + : [r] "+w" (r) + : [c] "w" (c), + [alpha] "w" (alpha) + : ); + } + + template + EIGEN_STRONG_INLINE void madd(const Packet4f& a, const Packet4f& b, + Packet4f& c, Packet4f& tmp, + const LaneIdType&) const { + acc(a, b, c); + } + + template + EIGEN_STRONG_INLINE void madd(const Packet4f& a, const QuadPacket& b, + Packet4f& c, Packet4f& tmp, + const LaneIdType& lane) const { + madd(a, b.get(lane), c, tmp, lane); + } +}; + +#endif // EIGEN_ARCH_ARM && EIGEN_COMP_CLANG + +#if EIGEN_ARCH_ARM64 + +template<> +struct gebp_traits + : gebp_traits +{ + typedef float RhsPacket; + typedef float32x4_t RhsPacketx4; + + EIGEN_STRONG_INLINE void loadRhs(const RhsScalar* b, RhsPacket& dest) const + { + dest = *b; + } + + EIGEN_STRONG_INLINE void loadRhs(const RhsScalar* b, RhsPacketx4& dest) const + { + dest = vld1q_f32(b); + } + + EIGEN_STRONG_INLINE void updateRhs(const RhsScalar* b, RhsPacket& dest) const + { + dest = *b; + } + + EIGEN_STRONG_INLINE void updateRhs(const RhsScalar*, RhsPacketx4&) const + {} + + EIGEN_STRONG_INLINE void loadRhsQuad(const RhsScalar* b, RhsPacket& dest) const + { + loadRhs(b,dest); + } + + EIGEN_STRONG_INLINE void madd(const LhsPacket& a, const RhsPacket& b, AccPacket& c, RhsPacket& /*tmp*/, const FixedInt<0>&) const + { + c = vfmaq_n_f32(c, a, b); + } + + // NOTE: Template parameter inference failed when compiled with Android NDK: + // "candidate template ignored: could not match 'FixedInt' against 'Eigen::internal::FixedInt<0>". + + EIGEN_STRONG_INLINE void madd(const LhsPacket& a, const RhsPacketx4& b, AccPacket& c, RhsPacket& /*tmp*/, const FixedInt<0>&) const + { madd_helper<0>(a, b, c); } + EIGEN_STRONG_INLINE void madd(const LhsPacket& a, const RhsPacketx4& b, AccPacket& c, RhsPacket& /*tmp*/, const FixedInt<1>&) const + { madd_helper<1>(a, b, c); } + EIGEN_STRONG_INLINE void madd(const LhsPacket& a, const RhsPacketx4& b, AccPacket& c, RhsPacket& /*tmp*/, const FixedInt<2>&) const + { madd_helper<2>(a, b, c); } + EIGEN_STRONG_INLINE void madd(const LhsPacket& a, const RhsPacketx4& b, AccPacket& c, RhsPacket& /*tmp*/, const FixedInt<3>&) const + { madd_helper<3>(a, b, c); } + + private: + template + EIGEN_STRONG_INLINE void madd_helper(const LhsPacket& a, const RhsPacketx4& b, AccPacket& c) const + { + #if EIGEN_COMP_GNUC_STRICT && !(EIGEN_GNUC_AT_LEAST(9,0)) + // workaround gcc issue https://gcc.gnu.org/bugzilla/show_bug.cgi?id=89101 + // vfmaq_laneq_f32 is implemented through a costly dup + if(LaneID==0) asm("fmla %0.4s, %1.4s, %2.s[0]\n" : "+w" (c) : "w" (a), "w" (b) : ); + else if(LaneID==1) asm("fmla %0.4s, %1.4s, %2.s[1]\n" : "+w" (c) : "w" (a), "w" (b) : ); + else if(LaneID==2) asm("fmla %0.4s, %1.4s, %2.s[2]\n" : "+w" (c) : "w" (a), "w" (b) : ); + else if(LaneID==3) asm("fmla %0.4s, %1.4s, %2.s[3]\n" : "+w" (c) : "w" (a), "w" (b) : ); + #else + c = vfmaq_laneq_f32(c, a, b, LaneID); + #endif + } +}; + + +template<> +struct gebp_traits + : gebp_traits +{ + typedef double RhsPacket; + + struct RhsPacketx4 { + float64x2_t B_0, B_1; + }; + + EIGEN_STRONG_INLINE void loadRhs(const RhsScalar* b, RhsPacket& dest) const + { + dest = *b; + } + + EIGEN_STRONG_INLINE void loadRhs(const RhsScalar* b, RhsPacketx4& dest) const + { + dest.B_0 = vld1q_f64(b); + dest.B_1 = vld1q_f64(b+2); + } + + EIGEN_STRONG_INLINE void updateRhs(const RhsScalar* b, RhsPacket& dest) const + { + loadRhs(b,dest); + } + + EIGEN_STRONG_INLINE void updateRhs(const RhsScalar*, RhsPacketx4&) const + {} + + EIGEN_STRONG_INLINE void loadRhsQuad(const RhsScalar* b, RhsPacket& dest) const + { + loadRhs(b,dest); + } + + EIGEN_STRONG_INLINE void madd(const LhsPacket& a, const RhsPacket& b, AccPacket& c, RhsPacket& /*tmp*/, const FixedInt<0>&) const + { + c = vfmaq_n_f64(c, a, b); + } + + // NOTE: Template parameter inference failed when compiled with Android NDK: + // "candidate template ignored: could not match 'FixedInt' against 'Eigen::internal::FixedInt<0>". + + EIGEN_STRONG_INLINE void madd(const LhsPacket& a, const RhsPacketx4& b, AccPacket& c, RhsPacket& /*tmp*/, const FixedInt<0>&) const + { madd_helper<0>(a, b, c); } + EIGEN_STRONG_INLINE void madd(const LhsPacket& a, const RhsPacketx4& b, AccPacket& c, RhsPacket& /*tmp*/, const FixedInt<1>&) const + { madd_helper<1>(a, b, c); } + EIGEN_STRONG_INLINE void madd(const LhsPacket& a, const RhsPacketx4& b, AccPacket& c, RhsPacket& /*tmp*/, const FixedInt<2>&) const + { madd_helper<2>(a, b, c); } + EIGEN_STRONG_INLINE void madd(const LhsPacket& a, const RhsPacketx4& b, AccPacket& c, RhsPacket& /*tmp*/, const FixedInt<3>&) const + { madd_helper<3>(a, b, c); } + + private: + template + EIGEN_STRONG_INLINE void madd_helper(const LhsPacket& a, const RhsPacketx4& b, AccPacket& c) const + { + #if EIGEN_COMP_GNUC_STRICT && !(EIGEN_GNUC_AT_LEAST(9,0)) + // workaround gcc issue https://gcc.gnu.org/bugzilla/show_bug.cgi?id=89101 + // vfmaq_laneq_f64 is implemented through a costly dup + if(LaneID==0) asm("fmla %0.2d, %1.2d, %2.d[0]\n" : "+w" (c) : "w" (a), "w" (b.B_0) : ); + else if(LaneID==1) asm("fmla %0.2d, %1.2d, %2.d[1]\n" : "+w" (c) : "w" (a), "w" (b.B_0) : ); + else if(LaneID==2) asm("fmla %0.2d, %1.2d, %2.d[0]\n" : "+w" (c) : "w" (a), "w" (b.B_1) : ); + else if(LaneID==3) asm("fmla %0.2d, %1.2d, %2.d[1]\n" : "+w" (c) : "w" (a), "w" (b.B_1) : ); + #else + if(LaneID==0) c = vfmaq_laneq_f64(c, a, b.B_0, 0); + else if(LaneID==1) c = vfmaq_laneq_f64(c, a, b.B_0, 1); + else if(LaneID==2) c = vfmaq_laneq_f64(c, a, b.B_1, 0); + else if(LaneID==3) c = vfmaq_laneq_f64(c, a, b.B_1, 1); + #endif + } +}; + +#endif // EIGEN_ARCH_ARM64 + +} // namespace internal +} // namespace Eigen diff --git a/Eigen/src/Core/arch/NEON/MathFunctions.h b/Eigen/src/Core/arch/NEON/MathFunctions.h new file mode 100644 index 0000000..fa6615a --- /dev/null +++ b/Eigen/src/Core/arch/NEON/MathFunctions.h @@ -0,0 +1,75 @@ +// This file is part of Eigen, a lightweight C++ template library +// for linear algebra. +// +// This Source Code Form is subject to the terms of the Mozilla +// Public License v. 2.0. If a copy of the MPL was not distributed +// with this file, You can obtain one at http://mozilla.org/MPL/2.0/. + +#ifndef EIGEN_MATH_FUNCTIONS_NEON_H +#define EIGEN_MATH_FUNCTIONS_NEON_H + +namespace Eigen { + +namespace internal { + +template<> EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS EIGEN_UNUSED Packet2f pexp(const Packet2f& x) +{ return pexp_float(x); } +template<> EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS EIGEN_UNUSED Packet4f pexp(const Packet4f& x) +{ return pexp_float(x); } + +template<> EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS EIGEN_UNUSED Packet2f plog(const Packet2f& x) +{ return plog_float(x); } +template<> EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS EIGEN_UNUSED Packet4f plog(const Packet4f& x) +{ return plog_float(x); } + +template<> EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS EIGEN_UNUSED Packet2f psin(const Packet2f& x) +{ return psin_float(x); } +template<> EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS EIGEN_UNUSED Packet4f psin(const Packet4f& x) +{ return psin_float(x); } + +template<> EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS EIGEN_UNUSED Packet2f pcos(const Packet2f& x) +{ return pcos_float(x); } +template<> EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS EIGEN_UNUSED Packet4f pcos(const Packet4f& x) +{ return pcos_float(x); } + +// Hyperbolic Tangent function. +template<> EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS EIGEN_UNUSED Packet2f ptanh(const Packet2f& x) +{ return internal::generic_fast_tanh_float(x); } +template<> EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS EIGEN_UNUSED Packet4f ptanh(const Packet4f& x) +{ return internal::generic_fast_tanh_float(x); } + +BF16_PACKET_FUNCTION(Packet4f, Packet4bf, psin) +BF16_PACKET_FUNCTION(Packet4f, Packet4bf, pcos) +BF16_PACKET_FUNCTION(Packet4f, Packet4bf, plog) +BF16_PACKET_FUNCTION(Packet4f, Packet4bf, pexp) +BF16_PACKET_FUNCTION(Packet4f, Packet4bf, ptanh) + +template <> +EIGEN_STRONG_INLINE Packet4bf pfrexp(const Packet4bf& a, Packet4bf& exponent) { + Packet4f fexponent; + const Packet4bf out = F32ToBf16(pfrexp(Bf16ToF32(a), fexponent)); + exponent = F32ToBf16(fexponent); + return out; +} + +template <> +EIGEN_STRONG_INLINE Packet4bf pldexp(const Packet4bf& a, const Packet4bf& exponent) { + return F32ToBf16(pldexp(Bf16ToF32(a), Bf16ToF32(exponent))); +} + +//---------- double ---------- + +#if EIGEN_ARCH_ARM64 && !EIGEN_APPLE_DOUBLE_NEON_BUG +template<> EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS EIGEN_UNUSED Packet2d pexp(const Packet2d& x) +{ return pexp_double(x); } + +template<> EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS EIGEN_UNUSED Packet2d plog(const Packet2d& x) +{ return plog_double(x); } + +#endif + +} // end namespace internal + +} // end namespace Eigen + +#endif // EIGEN_MATH_FUNCTIONS_NEON_H diff --git a/Eigen/src/Core/arch/NEON/PacketMath.h b/Eigen/src/Core/arch/NEON/PacketMath.h new file mode 100644 index 0000000..d2aeef4 --- /dev/null +++ b/Eigen/src/Core/arch/NEON/PacketMath.h @@ -0,0 +1,4587 @@ +// This file is part of Eigen, a lightweight C++ template library +// for linear algebra. +// +// Copyright (C) 2008-2009 Gael Guennebaud +// Copyright (C) 2010 Konstantinos Margaritis +// Heavily based on Gael's SSE version. +// +// This Source Code Form is subject to the terms of the Mozilla +// Public License v. 2.0. If a copy of the MPL was not distributed +// with this file, You can obtain one at http://mozilla.org/MPL/2.0/. + +#ifndef EIGEN_PACKET_MATH_NEON_H +#define EIGEN_PACKET_MATH_NEON_H + +namespace Eigen { + +namespace internal { + +#ifndef EIGEN_CACHEFRIENDLY_PRODUCT_THRESHOLD +#define EIGEN_CACHEFRIENDLY_PRODUCT_THRESHOLD 8 +#endif + +#ifndef EIGEN_HAS_SINGLE_INSTRUCTION_MADD +#define EIGEN_HAS_SINGLE_INSTRUCTION_MADD +#endif + +#ifndef EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS +#if EIGEN_ARCH_ARM64 +#define EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS 32 +#else +#define EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS 16 +#endif +#endif + +#if EIGEN_COMP_MSVC_STRICT + +// In MSVC's arm_neon.h header file, all NEON vector types +// are aliases to the same underlying type __n128. +// We thus have to wrap them to make them different C++ types. +// (See also bug 1428) +typedef eigen_packet_wrapper Packet2f; +typedef eigen_packet_wrapper Packet4f; +typedef eigen_packet_wrapper Packet4c; +typedef eigen_packet_wrapper Packet8c; +typedef eigen_packet_wrapper Packet16c; +typedef eigen_packet_wrapper Packet4uc; +typedef eigen_packet_wrapper Packet8uc; +typedef eigen_packet_wrapper Packet16uc; +typedef eigen_packet_wrapper Packet4s; +typedef eigen_packet_wrapper Packet8s; +typedef eigen_packet_wrapper Packet4us; +typedef eigen_packet_wrapper Packet8us; +typedef eigen_packet_wrapper Packet2i; +typedef eigen_packet_wrapper Packet4i; +typedef eigen_packet_wrapper Packet2ui; +typedef eigen_packet_wrapper Packet4ui; +typedef eigen_packet_wrapper Packet2l; +typedef eigen_packet_wrapper Packet2ul; + +#else + +typedef float32x2_t Packet2f; +typedef float32x4_t Packet4f; +typedef eigen_packet_wrapper Packet4c; +typedef int8x8_t Packet8c; +typedef int8x16_t Packet16c; +typedef eigen_packet_wrapper Packet4uc; +typedef uint8x8_t Packet8uc; +typedef uint8x16_t Packet16uc; +typedef int16x4_t Packet4s; +typedef int16x8_t Packet8s; +typedef uint16x4_t Packet4us; +typedef uint16x8_t Packet8us; +typedef int32x2_t Packet2i; +typedef int32x4_t Packet4i; +typedef uint32x2_t Packet2ui; +typedef uint32x4_t Packet4ui; +typedef int64x2_t Packet2l; +typedef uint64x2_t Packet2ul; + +#endif // EIGEN_COMP_MSVC_STRICT + +EIGEN_STRONG_INLINE Packet4f shuffle1(const Packet4f& m, int mask){ + const float* a = reinterpret_cast(&m); + Packet4f res = {*(a + (mask & 3)), *(a + ((mask >> 2) & 3)), *(a + ((mask >> 4) & 3 )), *(a + ((mask >> 6) & 3))}; + return res; +} + +// fuctionally equivalent to _mm_shuffle_ps in SSE when interleave +// == false (i.e. shuffle(m, n, mask) equals _mm_shuffle_ps(m, n, mask)), +// interleave m and n when interleave == true. Currently used in LU/arch/InverseSize4.h +// to enable a shared implementation for fast inversion of matrices of size 4. +template +EIGEN_STRONG_INLINE Packet4f shuffle2(const Packet4f &m, const Packet4f &n, int mask) +{ + const float* a = reinterpret_cast(&m); + const float* b = reinterpret_cast(&n); + Packet4f res = {*(a + (mask & 3)), *(a + ((mask >> 2) & 3)), *(b + ((mask >> 4) & 3)), *(b + ((mask >> 6) & 3))}; + return res; +} + +template<> +EIGEN_STRONG_INLINE Packet4f shuffle2(const Packet4f &m, const Packet4f &n, int mask) +{ + const float* a = reinterpret_cast(&m); + const float* b = reinterpret_cast(&n); + Packet4f res = {*(a + (mask & 3)), *(b + ((mask >> 2) & 3)), *(a + ((mask >> 4) & 3)), *(b + ((mask >> 6) & 3))}; + return res; +} + +EIGEN_STRONG_INLINE static int eigen_neon_shuffle_mask(int p, int q, int r, int s) {return ((s)<<6|(r)<<4|(q)<<2|(p));} + +EIGEN_STRONG_INLINE Packet4f vec4f_swizzle1(const Packet4f& a, int p, int q, int r, int s) +{ + return shuffle1(a, eigen_neon_shuffle_mask(p, q, r, s)); +} +EIGEN_STRONG_INLINE Packet4f vec4f_swizzle2(const Packet4f& a, const Packet4f& b, int p, int q, int r, int s) +{ + return shuffle2(a,b,eigen_neon_shuffle_mask(p, q, r, s)); +} +EIGEN_STRONG_INLINE Packet4f vec4f_movelh(const Packet4f& a, const Packet4f& b) +{ + return shuffle2(a,b,eigen_neon_shuffle_mask(0, 1, 0, 1)); +} +EIGEN_STRONG_INLINE Packet4f vec4f_movehl(const Packet4f& a, const Packet4f& b) +{ + return shuffle2(b,a,eigen_neon_shuffle_mask(2, 3, 2, 3)); +} +EIGEN_STRONG_INLINE Packet4f vec4f_unpacklo(const Packet4f& a, const Packet4f& b) +{ + return shuffle2(a,b,eigen_neon_shuffle_mask(0, 0, 1, 1)); +} +EIGEN_STRONG_INLINE Packet4f vec4f_unpackhi(const Packet4f& a, const Packet4f& b) +{ + return shuffle2(a,b,eigen_neon_shuffle_mask(2, 2, 3, 3)); +} +#define vec4f_duplane(a, p) \ + vdupq_lane_f32(vget_low_f32(a), p) + +#define _EIGEN_DECLARE_CONST_Packet4f(NAME,X) \ + const Packet4f p4f_##NAME = pset1(X) + +#define _EIGEN_DECLARE_CONST_Packet4f_FROM_INT(NAME,X) \ + const Packet4f p4f_##NAME = vreinterpretq_f32_u32(pset1(X)) + +#define _EIGEN_DECLARE_CONST_Packet4i(NAME,X) \ + const Packet4i p4i_##NAME = pset1(X) + +#if EIGEN_ARCH_ARM64 + // __builtin_prefetch tends to do nothing on ARM64 compilers because the + // prefetch instructions there are too detailed for __builtin_prefetch to map + // meaningfully to them. + #define EIGEN_ARM_PREFETCH(ADDR) __asm__ __volatile__("prfm pldl1keep, [%[addr]]\n" ::[addr] "r"(ADDR) : ); +#elif EIGEN_HAS_BUILTIN(__builtin_prefetch) || EIGEN_COMP_GNUC + #define EIGEN_ARM_PREFETCH(ADDR) __builtin_prefetch(ADDR); +#elif defined __pld + #define EIGEN_ARM_PREFETCH(ADDR) __pld(ADDR) +#elif EIGEN_ARCH_ARM32 + #define EIGEN_ARM_PREFETCH(ADDR) __asm__ __volatile__ ("pld [%[addr]]\n" :: [addr] "r" (ADDR) : ); +#else + // by default no explicit prefetching + #define EIGEN_ARM_PREFETCH(ADDR) +#endif + +template <> +struct packet_traits : default_packet_traits +{ + typedef Packet4f type; + typedef Packet2f half; + enum + { + Vectorizable = 1, + AlignedOnScalar = 1, + size = 4, + HasHalfPacket = 1, + + HasAdd = 1, + HasSub = 1, + HasShift = 1, + HasMul = 1, + HasNegate = 1, + HasAbs = 1, + HasArg = 0, + HasAbs2 = 1, + HasAbsDiff = 1, + HasMin = 1, + HasMax = 1, + HasConj = 1, + HasSetLinear = 0, + HasBlend = 0, + + HasDiv = 1, + HasFloor = 1, + HasCeil = 1, + HasRint = 1, + + HasSin = EIGEN_FAST_MATH, + HasCos = EIGEN_FAST_MATH, + HasLog = 1, + HasExp = 1, + HasSqrt = 1, + HasRsqrt = 1, + HasTanh = EIGEN_FAST_MATH, + HasErf = EIGEN_FAST_MATH, + HasBessel = 0, // Issues with accuracy. + HasNdtri = 0 + }; +}; + +template <> +struct packet_traits : default_packet_traits +{ + typedef Packet16c type; + typedef Packet8c half; + enum + { + Vectorizable = 1, + AlignedOnScalar = 1, + size = 16, + HasHalfPacket = 1, + + HasAdd = 1, + HasSub = 1, + HasShift = 1, + HasMul = 1, + HasNegate = 1, + HasAbs = 1, + HasAbsDiff = 1, + HasArg = 0, + HasAbs2 = 1, + HasMin = 1, + HasMax = 1, + HasConj = 1, + HasSetLinear = 0, + HasBlend = 0 + }; +}; + +template <> +struct packet_traits : default_packet_traits +{ + typedef Packet16uc type; + typedef Packet8uc half; + enum + { + Vectorizable = 1, + AlignedOnScalar = 1, + size = 16, + HasHalfPacket = 1, + + HasAdd = 1, + HasSub = 1, + HasShift = 1, + HasMul = 1, + HasNegate = 0, + HasAbs = 1, + HasAbsDiff = 1, + HasArg = 0, + HasAbs2 = 1, + HasMin = 1, + HasMax = 1, + HasConj = 1, + HasSetLinear = 0, + HasBlend = 0, + + HasSqrt = 1 + }; +}; + +template <> +struct packet_traits : default_packet_traits +{ + typedef Packet8s type; + typedef Packet4s half; + enum + { + Vectorizable = 1, + AlignedOnScalar = 1, + size = 8, + HasHalfPacket = 1, + + HasAdd = 1, + HasSub = 1, + HasShift = 1, + HasMul = 1, + HasNegate = 1, + HasAbs = 1, + HasAbsDiff = 1, + HasArg = 0, + HasAbs2 = 1, + HasMin = 1, + HasMax = 1, + HasConj = 1, + HasSetLinear = 0, + HasBlend = 0 + }; +}; + +template <> +struct packet_traits : default_packet_traits +{ + typedef Packet8us type; + typedef Packet4us half; + enum + { + Vectorizable = 1, + AlignedOnScalar = 1, + size = 8, + HasHalfPacket = 1, + + HasAdd = 1, + HasSub = 1, + HasShift = 1, + HasMul = 1, + HasNegate = 0, + HasAbs = 0, + HasAbsDiff = 1, + HasArg = 0, + HasAbs2 = 1, + HasMin = 1, + HasMax = 1, + HasConj = 1, + HasSetLinear = 0, + HasBlend = 0, + HasSqrt = 1 + }; +}; + +template <> +struct packet_traits : default_packet_traits +{ + typedef Packet4i type; + typedef Packet2i half; + enum + { + Vectorizable = 1, + AlignedOnScalar = 1, + size = 4, + HasHalfPacket = 1, + + HasAdd = 1, + HasSub = 1, + HasShift = 1, + HasMul = 1, + HasNegate = 1, + HasAbs = 1, + HasArg = 0, + HasAbs2 = 1, + HasAbsDiff = 1, + HasMin = 1, + HasMax = 1, + HasConj = 1, + HasSetLinear = 0, + HasBlend = 0 + }; +}; + +template <> +struct packet_traits : default_packet_traits +{ + typedef Packet4ui type; + typedef Packet2ui half; + enum + { + Vectorizable = 1, + AlignedOnScalar = 1, + size = 4, + HasHalfPacket = 1, + + HasAdd = 1, + HasSub = 1, + HasShift = 1, + HasMul = 1, + HasNegate = 0, + HasAbs = 0, + HasArg = 0, + HasAbs2 = 1, + HasAbsDiff = 1, + HasMin = 1, + HasMax = 1, + HasConj = 1, + HasSetLinear = 0, + HasBlend = 0, + + HasSqrt = 1 + }; +}; + +template <> +struct packet_traits : default_packet_traits +{ + typedef Packet2l type; + typedef Packet2l half; + enum + { + Vectorizable = 1, + AlignedOnScalar = 1, + size = 2, + HasHalfPacket = 0, + + HasCmp = 1, + HasAdd = 1, + HasSub = 1, + HasShift = 1, + HasMul = 1, + HasNegate = 1, + HasAbs = 1, + HasArg = 0, + HasAbs2 = 1, + HasAbsDiff = 1, + HasMin = 1, + HasMax = 1, + HasConj = 1, + HasSetLinear = 0, + HasBlend = 0 + }; +}; + +template <> +struct packet_traits : default_packet_traits +{ + typedef Packet2ul type; + typedef Packet2ul half; + enum + { + Vectorizable = 1, + AlignedOnScalar = 1, + size = 2, + HasHalfPacket = 0, + + HasCmp = 1, + HasAdd = 1, + HasSub = 1, + HasShift = 1, + HasMul = 1, + HasNegate = 0, + HasAbs = 0, + HasArg = 0, + HasAbs2 = 1, + HasAbsDiff = 1, + HasMin = 1, + HasMax = 1, + HasConj = 1, + HasSetLinear = 0, + HasBlend = 0 + }; +}; + +#if EIGEN_GNUC_AT_MOST(4, 4) && !EIGEN_COMP_LLVM +// workaround gcc 4.2, 4.3 and 4.4 compilation issue +EIGEN_STRONG_INLINE float32x4_t vld1q_f32(const float* x) { return ::vld1q_f32((const float32_t*)x); } +EIGEN_STRONG_INLINE float32x2_t vld1_f32(const float* x) { return ::vld1_f32 ((const float32_t*)x); } +EIGEN_STRONG_INLINE float32x2_t vld1_dup_f32(const float* x) { return ::vld1_dup_f32 ((const float32_t*)x); } +EIGEN_STRONG_INLINE void vst1q_f32(float* to, float32x4_t from) { ::vst1q_f32((float32_t*)to,from); } +EIGEN_STRONG_INLINE void vst1_f32 (float* to, float32x2_t from) { ::vst1_f32 ((float32_t*)to,from); } +#endif + +template<> struct unpacket_traits +{ + typedef float type; + typedef Packet2f half; + typedef Packet2i integer_packet; + enum + { + size = 2, + alignment = Aligned16, + vectorizable = true, + masked_load_available = false, + masked_store_available = false + }; +}; +template<> struct unpacket_traits +{ + typedef float type; + typedef Packet2f half; + typedef Packet4i integer_packet; + enum + { + size = 4, + alignment = Aligned16, + vectorizable = true, + masked_load_available = false, + masked_store_available = false + }; +}; +template<> struct unpacket_traits +{ + typedef int8_t type; + typedef Packet4c half; + enum + { + size = 4, + alignment = Unaligned, + vectorizable = true, + masked_load_available = false, + masked_store_available = false + }; +}; +template<> struct unpacket_traits +{ + typedef int8_t type; + typedef Packet4c half; + enum + { + size = 8, + alignment = Aligned16, + vectorizable = true, + masked_load_available = false, + masked_store_available = false + }; +}; +template<> struct unpacket_traits +{ + typedef int8_t type; + typedef Packet8c half; + enum + { + size = 16, + alignment = Aligned16, + vectorizable = true, + masked_load_available = false, + masked_store_available = false + }; +}; +template<> struct unpacket_traits +{ + typedef uint8_t type; + typedef Packet4uc half; + enum + { + size = 4, + alignment = Unaligned, + vectorizable = true, + masked_load_available = false, + masked_store_available = false + }; +}; +template<> struct unpacket_traits +{ + typedef uint8_t type; + typedef Packet4uc half; + enum + { + size = 8, + alignment = Aligned16, + vectorizable = true, + masked_load_available = false, + masked_store_available = false + }; +}; +template<> struct unpacket_traits +{ + typedef uint8_t type; + typedef Packet8uc half; + enum + { + size = 16, + alignment = Aligned16, + vectorizable = true, + masked_load_available = false, + masked_store_available = false}; +}; +template<> struct unpacket_traits +{ + typedef int16_t type; + typedef Packet4s half; + enum + { + size = 4, + alignment = Aligned16, + vectorizable = true, + masked_load_available = false, + masked_store_available = false + }; +}; +template<> struct unpacket_traits +{ + typedef int16_t type; + typedef Packet4s half; + enum + { + size = 8, + alignment = Aligned16, + vectorizable = true, + masked_load_available = false, + masked_store_available = false + }; +}; +template<> struct unpacket_traits +{ + typedef uint16_t type; + typedef Packet4us half; + enum + { + size = 4, + alignment = Aligned16, + vectorizable = true, + masked_load_available = false, + masked_store_available = false + }; +}; +template<> struct unpacket_traits +{ + typedef uint16_t type; + typedef Packet4us half; + enum + { + size = 8, + alignment = Aligned16, + vectorizable = true, + masked_load_available = false, + masked_store_available = false + }; +}; +template<> struct unpacket_traits +{ + typedef int32_t type; + typedef Packet2i half; + enum + { + size = 2, + alignment = Aligned16, + vectorizable = true, + masked_load_available = false, + masked_store_available = false + }; +}; +template<> struct unpacket_traits +{ + typedef int32_t type; + typedef Packet2i half; + enum + { + size = 4, + alignment = Aligned16, + vectorizable = true, + masked_load_available = false, + masked_store_available = false + }; +}; +template<> struct unpacket_traits +{ + typedef uint32_t type; + typedef Packet2ui half; + enum + { + size = 2, + alignment = Aligned16, + vectorizable = true, + masked_load_available = false, + masked_store_available = false + }; +}; +template<> struct unpacket_traits +{ + typedef uint32_t type; + typedef Packet2ui half; + enum + { + size = 4, + alignment = Aligned16, + vectorizable = true, + masked_load_available = false, + masked_store_available = false + }; +}; +template<> struct unpacket_traits +{ + typedef int64_t type; + typedef Packet2l half; + enum + { + size = 2, + alignment = Aligned16, + vectorizable = true, + masked_load_available = false, + masked_store_available = false + }; +}; +template<> struct unpacket_traits +{ + typedef uint64_t type; + typedef Packet2ul half; + enum + { + size = 2, + alignment = Aligned16, + vectorizable = true, + masked_load_available = false, + masked_store_available = false + }; +}; + +template<> EIGEN_STRONG_INLINE Packet2f pset1(const float& from) { return vdup_n_f32(from); } +template<> EIGEN_STRONG_INLINE Packet4f pset1(const float& from) { return vdupq_n_f32(from); } +template<> EIGEN_STRONG_INLINE Packet4c pset1(const int8_t& from) +{ return vget_lane_s32(vreinterpret_s32_s8(vdup_n_s8(from)), 0); } +template<> EIGEN_STRONG_INLINE Packet8c pset1(const int8_t& from) { return vdup_n_s8(from); } +template<> EIGEN_STRONG_INLINE Packet16c pset1(const int8_t& from) { return vdupq_n_s8(from); } +template<> EIGEN_STRONG_INLINE Packet4uc pset1(const uint8_t& from) +{ return vget_lane_u32(vreinterpret_u32_u8(vdup_n_u8(from)), 0); } +template<> EIGEN_STRONG_INLINE Packet8uc pset1(const uint8_t& from) { return vdup_n_u8(from); } +template<> EIGEN_STRONG_INLINE Packet16uc pset1(const uint8_t& from) { return vdupq_n_u8(from); } +template<> EIGEN_STRONG_INLINE Packet4s pset1(const int16_t& from) { return vdup_n_s16(from); } +template<> EIGEN_STRONG_INLINE Packet8s pset1(const int16_t& from) { return vdupq_n_s16(from); } +template<> EIGEN_STRONG_INLINE Packet4us pset1(const uint16_t& from) { return vdup_n_u16(from); } +template<> EIGEN_STRONG_INLINE Packet8us pset1(const uint16_t& from) { return vdupq_n_u16(from); } +template<> EIGEN_STRONG_INLINE Packet2i pset1(const int32_t& from) { return vdup_n_s32(from); } +template<> EIGEN_STRONG_INLINE Packet4i pset1(const int32_t& from) { return vdupq_n_s32(from); } +template<> EIGEN_STRONG_INLINE Packet2ui pset1(const uint32_t& from) { return vdup_n_u32(from); } +template<> EIGEN_STRONG_INLINE Packet4ui pset1(const uint32_t& from) { return vdupq_n_u32(from); } +template<> EIGEN_STRONG_INLINE Packet2l pset1(const int64_t& from) { return vdupq_n_s64(from); } +template<> EIGEN_STRONG_INLINE Packet2ul pset1(const uint64_t& from) { return vdupq_n_u64(from); } + +template<> EIGEN_STRONG_INLINE Packet2f pset1frombits(unsigned int from) +{ return vreinterpret_f32_u32(vdup_n_u32(from)); } +template<> EIGEN_STRONG_INLINE Packet4f pset1frombits(unsigned int from) +{ return vreinterpretq_f32_u32(vdupq_n_u32(from)); } + +template<> EIGEN_STRONG_INLINE Packet2f plset(const float& a) +{ + const float c[] = {0.0f,1.0f}; + return vadd_f32(pset1(a), vld1_f32(c)); +} +template<> EIGEN_STRONG_INLINE Packet4f plset(const float& a) +{ + const float c[] = {0.0f,1.0f,2.0f,3.0f}; + return vaddq_f32(pset1(a), vld1q_f32(c)); +} +template<> EIGEN_STRONG_INLINE Packet4c plset(const int8_t& a) +{ return vget_lane_s32(vreinterpret_s32_s8(vadd_s8(vreinterpret_s8_u32(vdup_n_u32(0x03020100)), vdup_n_s8(a))), 0); } +template<> EIGEN_STRONG_INLINE Packet8c plset(const int8_t& a) +{ + const int8_t c[] = {0,1,2,3,4,5,6,7}; + return vadd_s8(pset1(a), vld1_s8(c)); +} +template<> EIGEN_STRONG_INLINE Packet16c plset(const int8_t& a) +{ + const int8_t c[] = {0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15}; + return vaddq_s8(pset1(a), vld1q_s8(c)); +} +template<> EIGEN_STRONG_INLINE Packet4uc plset(const uint8_t& a) +{ return vget_lane_u32(vreinterpret_u32_u8(vadd_u8(vreinterpret_u8_u32(vdup_n_u32(0x03020100)), vdup_n_u8(a))), 0); } +template<> EIGEN_STRONG_INLINE Packet8uc plset(const uint8_t& a) +{ + const uint8_t c[] = {0,1,2,3,4,5,6,7}; + return vadd_u8(pset1(a), vld1_u8(c)); +} +template<> EIGEN_STRONG_INLINE Packet16uc plset(const uint8_t& a) +{ + const uint8_t c[] = {0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15}; + return vaddq_u8(pset1(a), vld1q_u8(c)); +} +template<> EIGEN_STRONG_INLINE Packet4s plset(const int16_t& a) +{ + const int16_t c[] = {0,1,2,3}; + return vadd_s16(pset1(a), vld1_s16(c)); +} +template<> EIGEN_STRONG_INLINE Packet4us plset(const uint16_t& a) +{ + const uint16_t c[] = {0,1,2,3}; + return vadd_u16(pset1(a), vld1_u16(c)); +} +template<> EIGEN_STRONG_INLINE Packet8s plset(const int16_t& a) +{ + const int16_t c[] = {0,1,2,3,4,5,6,7}; + return vaddq_s16(pset1(a), vld1q_s16(c)); +} +template<> EIGEN_STRONG_INLINE Packet8us plset(const uint16_t& a) +{ + const uint16_t c[] = {0,1,2,3,4,5,6,7}; + return vaddq_u16(pset1(a), vld1q_u16(c)); +} +template<> EIGEN_STRONG_INLINE Packet2i plset(const int32_t& a) +{ + const int32_t c[] = {0,1}; + return vadd_s32(pset1(a), vld1_s32(c)); +} +template<> EIGEN_STRONG_INLINE Packet4i plset(const int32_t& a) +{ + const int32_t c[] = {0,1,2,3}; + return vaddq_s32(pset1(a), vld1q_s32(c)); +} +template<> EIGEN_STRONG_INLINE Packet2ui plset(const uint32_t& a) +{ + const uint32_t c[] = {0,1}; + return vadd_u32(pset1(a), vld1_u32(c)); +} +template<> EIGEN_STRONG_INLINE Packet4ui plset(const uint32_t& a) +{ + const uint32_t c[] = {0,1,2,3}; + return vaddq_u32(pset1(a), vld1q_u32(c)); +} +template<> EIGEN_STRONG_INLINE Packet2l plset(const int64_t& a) +{ + const int64_t c[] = {0,1}; + return vaddq_s64(pset1(a), vld1q_s64(c)); +} +template<> EIGEN_STRONG_INLINE Packet2ul plset(const uint64_t& a) +{ + const uint64_t c[] = {0,1}; + return vaddq_u64(pset1(a), vld1q_u64(c)); +} + +template<> EIGEN_STRONG_INLINE Packet2f padd(const Packet2f& a, const Packet2f& b) { return vadd_f32(a,b); } +template<> EIGEN_STRONG_INLINE Packet4f padd(const Packet4f& a, const Packet4f& b) { return vaddq_f32(a,b); } +template<> EIGEN_STRONG_INLINE Packet4c padd(const Packet4c& a, const Packet4c& b) +{ + return vget_lane_s32(vreinterpret_s32_s8(vadd_s8( + vreinterpret_s8_s32(vdup_n_s32(a)), + vreinterpret_s8_s32(vdup_n_s32(b)))), 0); +} +template<> EIGEN_STRONG_INLINE Packet8c padd(const Packet8c& a, const Packet8c& b) { return vadd_s8(a,b); } +template<> EIGEN_STRONG_INLINE Packet16c padd(const Packet16c& a, const Packet16c& b) { return vaddq_s8(a,b); } +template<> EIGEN_STRONG_INLINE Packet4uc padd(const Packet4uc& a, const Packet4uc& b) +{ + return vget_lane_u32(vreinterpret_u32_u8(vadd_u8( + vreinterpret_u8_u32(vdup_n_u32(a)), + vreinterpret_u8_u32(vdup_n_u32(b)))), 0); +} +template<> EIGEN_STRONG_INLINE Packet8uc padd(const Packet8uc& a, const Packet8uc& b) { return vadd_u8(a,b); } +template<> EIGEN_STRONG_INLINE Packet16uc padd(const Packet16uc& a, const Packet16uc& b) { return vaddq_u8(a,b); } +template<> EIGEN_STRONG_INLINE Packet4s padd(const Packet4s& a, const Packet4s& b) { return vadd_s16(a,b); } +template<> EIGEN_STRONG_INLINE Packet8s padd(const Packet8s& a, const Packet8s& b) { return vaddq_s16(a,b); } +template<> EIGEN_STRONG_INLINE Packet4us padd(const Packet4us& a, const Packet4us& b) { return vadd_u16(a,b); } +template<> EIGEN_STRONG_INLINE Packet8us padd(const Packet8us& a, const Packet8us& b) { return vaddq_u16(a,b); } +template<> EIGEN_STRONG_INLINE Packet2i padd(const Packet2i& a, const Packet2i& b) { return vadd_s32(a,b); } +template<> EIGEN_STRONG_INLINE Packet4i padd(const Packet4i& a, const Packet4i& b) { return vaddq_s32(a,b); } +template<> EIGEN_STRONG_INLINE Packet2ui padd(const Packet2ui& a, const Packet2ui& b) { return vadd_u32(a,b); } +template<> EIGEN_STRONG_INLINE Packet4ui padd(const Packet4ui& a, const Packet4ui& b) { return vaddq_u32(a,b); } +template<> EIGEN_STRONG_INLINE Packet2l padd(const Packet2l& a, const Packet2l& b) { return vaddq_s64(a,b); } +template<> EIGEN_STRONG_INLINE Packet2ul padd(const Packet2ul& a, const Packet2ul& b) { return vaddq_u64(a,b); } + +template<> EIGEN_STRONG_INLINE Packet2f psub(const Packet2f& a, const Packet2f& b) { return vsub_f32(a,b); } +template<> EIGEN_STRONG_INLINE Packet4f psub(const Packet4f& a, const Packet4f& b) { return vsubq_f32(a,b); } +template<> EIGEN_STRONG_INLINE Packet4c psub(const Packet4c& a, const Packet4c& b) +{ + return vget_lane_s32(vreinterpret_s32_s8(vsub_s8( + vreinterpret_s8_s32(vdup_n_s32(a)), + vreinterpret_s8_s32(vdup_n_s32(b)))), 0); +} +template<> EIGEN_STRONG_INLINE Packet8c psub(const Packet8c& a, const Packet8c& b) { return vsub_s8(a,b); } +template<> EIGEN_STRONG_INLINE Packet16c psub(const Packet16c& a, const Packet16c& b) { return vsubq_s8(a,b); } +template<> EIGEN_STRONG_INLINE Packet4uc psub(const Packet4uc& a, const Packet4uc& b) +{ + return vget_lane_u32(vreinterpret_u32_u8(vsub_u8( + vreinterpret_u8_u32(vdup_n_u32(a)), + vreinterpret_u8_u32(vdup_n_u32(b)))), 0); +} +template<> EIGEN_STRONG_INLINE Packet8uc psub(const Packet8uc& a, const Packet8uc& b) { return vsub_u8(a,b); } +template<> EIGEN_STRONG_INLINE Packet16uc psub(const Packet16uc& a, const Packet16uc& b) { return vsubq_u8(a,b); } +template<> EIGEN_STRONG_INLINE Packet4s psub(const Packet4s& a, const Packet4s& b) { return vsub_s16(a,b); } +template<> EIGEN_STRONG_INLINE Packet8s psub(const Packet8s& a, const Packet8s& b) { return vsubq_s16(a,b); } +template<> EIGEN_STRONG_INLINE Packet4us psub(const Packet4us& a, const Packet4us& b) { return vsub_u16(a,b); } +template<> EIGEN_STRONG_INLINE Packet8us psub(const Packet8us& a, const Packet8us& b) { return vsubq_u16(a,b); } +template<> EIGEN_STRONG_INLINE Packet2i psub(const Packet2i& a, const Packet2i& b) { return vsub_s32(a,b); } +template<> EIGEN_STRONG_INLINE Packet4i psub(const Packet4i& a, const Packet4i& b) { return vsubq_s32(a,b); } +template<> EIGEN_STRONG_INLINE Packet2ui psub(const Packet2ui& a, const Packet2ui& b) { return vsub_u32(a,b); } +template<> EIGEN_STRONG_INLINE Packet4ui psub(const Packet4ui& a, const Packet4ui& b) { return vsubq_u32(a,b); } +template<> EIGEN_STRONG_INLINE Packet2l psub(const Packet2l& a, const Packet2l& b) { return vsubq_s64(a,b); } +template<> EIGEN_STRONG_INLINE Packet2ul psub(const Packet2ul& a, const Packet2ul& b) { return vsubq_u64(a,b); } + +template<> EIGEN_STRONG_INLINE Packet2f pxor(const Packet2f& a, const Packet2f& b); +template<> EIGEN_STRONG_INLINE Packet2f paddsub(const Packet2f& a, const Packet2f & b) { + Packet2f mask = {numext::bit_cast(0x80000000u), 0.0f}; + return padd(a, pxor(mask, b)); +} +template<> EIGEN_STRONG_INLINE Packet4f pxor(const Packet4f& a, const Packet4f& b); +template<> EIGEN_STRONG_INLINE Packet4f paddsub(const Packet4f& a, const Packet4f& b) { + Packet4f mask = {numext::bit_cast(0x80000000u), 0.0f, numext::bit_cast(0x80000000u), 0.0f}; + return padd(a, pxor(mask, b)); +} + +template<> EIGEN_STRONG_INLINE Packet2f pnegate(const Packet2f& a) { return vneg_f32(a); } +template<> EIGEN_STRONG_INLINE Packet4f pnegate(const Packet4f& a) { return vnegq_f32(a); } +template<> EIGEN_STRONG_INLINE Packet4c pnegate(const Packet4c& a) +{ return vget_lane_s32(vreinterpret_s32_s8(vneg_s8(vreinterpret_s8_s32(vdup_n_s32(a)))), 0); } +template<> EIGEN_STRONG_INLINE Packet8c pnegate(const Packet8c& a) { return vneg_s8(a); } +template<> EIGEN_STRONG_INLINE Packet16c pnegate(const Packet16c& a) { return vnegq_s8(a); } +template<> EIGEN_STRONG_INLINE Packet4s pnegate(const Packet4s& a) { return vneg_s16(a); } +template<> EIGEN_STRONG_INLINE Packet8s pnegate(const Packet8s& a) { return vnegq_s16(a); } +template<> EIGEN_STRONG_INLINE Packet2i pnegate(const Packet2i& a) { return vneg_s32(a); } +template<> EIGEN_STRONG_INLINE Packet4i pnegate(const Packet4i& a) { return vnegq_s32(a); } +template<> EIGEN_STRONG_INLINE Packet2l pnegate(const Packet2l& a) { +#if EIGEN_ARCH_ARM64 + return vnegq_s64(a); +#else + return vcombine_s64( + vdup_n_s64(-vgetq_lane_s64(a, 0)), + vdup_n_s64(-vgetq_lane_s64(a, 1))); +#endif +} + +template<> EIGEN_STRONG_INLINE Packet2f pconj(const Packet2f& a) { return a; } +template<> EIGEN_STRONG_INLINE Packet4f pconj(const Packet4f& a) { return a; } +template<> EIGEN_STRONG_INLINE Packet4c pconj(const Packet4c& a) { return a; } +template<> EIGEN_STRONG_INLINE Packet8c pconj(const Packet8c& a) { return a; } +template<> EIGEN_STRONG_INLINE Packet16c pconj(const Packet16c& a) { return a; } +template<> EIGEN_STRONG_INLINE Packet4uc pconj(const Packet4uc& a) { return a; } +template<> EIGEN_STRONG_INLINE Packet8uc pconj(const Packet8uc& a) { return a; } +template<> EIGEN_STRONG_INLINE Packet16uc pconj(const Packet16uc& a) { return a; } +template<> EIGEN_STRONG_INLINE Packet4s pconj(const Packet4s& a) { return a; } +template<> EIGEN_STRONG_INLINE Packet8s pconj(const Packet8s& a) { return a; } +template<> EIGEN_STRONG_INLINE Packet4us pconj(const Packet4us& a) { return a; } +template<> EIGEN_STRONG_INLINE Packet8us pconj(const Packet8us& a) { return a; } +template<> EIGEN_STRONG_INLINE Packet2i pconj(const Packet2i& a) { return a; } +template<> EIGEN_STRONG_INLINE Packet4i pconj(const Packet4i& a) { return a; } +template<> EIGEN_STRONG_INLINE Packet2ui pconj(const Packet2ui& a) { return a; } +template<> EIGEN_STRONG_INLINE Packet4ui pconj(const Packet4ui& a) { return a; } +template<> EIGEN_STRONG_INLINE Packet2l pconj(const Packet2l& a) { return a; } +template<> EIGEN_STRONG_INLINE Packet2ul pconj(const Packet2ul& a) { return a; } + +template<> EIGEN_STRONG_INLINE Packet2f pmul(const Packet2f& a, const Packet2f& b) { return vmul_f32(a,b); } +template<> EIGEN_STRONG_INLINE Packet4f pmul(const Packet4f& a, const Packet4f& b) { return vmulq_f32(a,b); } +template<> EIGEN_STRONG_INLINE Packet4c pmul(const Packet4c& a, const Packet4c& b) +{ + return vget_lane_s32(vreinterpret_s32_s8(vmul_s8( + vreinterpret_s8_s32(vdup_n_s32(a)), + vreinterpret_s8_s32(vdup_n_s32(b)))), 0); +} +template<> EIGEN_STRONG_INLINE Packet8c pmul(const Packet8c& a, const Packet8c& b) { return vmul_s8(a,b); } +template<> EIGEN_STRONG_INLINE Packet16c pmul(const Packet16c& a, const Packet16c& b) { return vmulq_s8(a,b); } +template<> EIGEN_STRONG_INLINE Packet4uc pmul(const Packet4uc& a, const Packet4uc& b) +{ + return vget_lane_u32(vreinterpret_u32_u8(vmul_u8( + vreinterpret_u8_u32(vdup_n_u32(a)), + vreinterpret_u8_u32(vdup_n_u32(b)))), 0); +} +template<> EIGEN_STRONG_INLINE Packet8uc pmul(const Packet8uc& a, const Packet8uc& b) { return vmul_u8(a,b); } +template<> EIGEN_STRONG_INLINE Packet16uc pmul(const Packet16uc& a, const Packet16uc& b) { return vmulq_u8(a,b); } +template<> EIGEN_STRONG_INLINE Packet4s pmul(const Packet4s& a, const Packet4s& b) { return vmul_s16(a,b); } +template<> EIGEN_STRONG_INLINE Packet8s pmul(const Packet8s& a, const Packet8s& b) { return vmulq_s16(a,b); } +template<> EIGEN_STRONG_INLINE Packet4us pmul(const Packet4us& a, const Packet4us& b) { return vmul_u16(a,b); } +template<> EIGEN_STRONG_INLINE Packet8us pmul(const Packet8us& a, const Packet8us& b) { return vmulq_u16(a,b); } +template<> EIGEN_STRONG_INLINE Packet2i pmul(const Packet2i& a, const Packet2i& b) { return vmul_s32(a,b); } +template<> EIGEN_STRONG_INLINE Packet4i pmul(const Packet4i& a, const Packet4i& b) { return vmulq_s32(a,b); } +template<> EIGEN_STRONG_INLINE Packet2ui pmul(const Packet2ui& a, const Packet2ui& b) { return vmul_u32(a,b); } +template<> EIGEN_STRONG_INLINE Packet4ui pmul(const Packet4ui& a, const Packet4ui& b) { return vmulq_u32(a,b); } +template<> EIGEN_STRONG_INLINE Packet2l pmul(const Packet2l& a, const Packet2l& b) { + return vcombine_s64( + vdup_n_s64(vgetq_lane_s64(a, 0)*vgetq_lane_s64(b, 0)), + vdup_n_s64(vgetq_lane_s64(a, 1)*vgetq_lane_s64(b, 1))); +} +template<> EIGEN_STRONG_INLINE Packet2ul pmul(const Packet2ul& a, const Packet2ul& b) { + return vcombine_u64( + vdup_n_u64(vgetq_lane_u64(a, 0)*vgetq_lane_u64(b, 0)), + vdup_n_u64(vgetq_lane_u64(a, 1)*vgetq_lane_u64(b, 1))); +} + +template<> EIGEN_STRONG_INLINE Packet2f pdiv(const Packet2f& a, const Packet2f& b) +{ +#if EIGEN_ARCH_ARM64 + return vdiv_f32(a,b); +#else + Packet2f inv, restep, div; + + // NEON does not offer a divide instruction, we have to do a reciprocal approximation + // However NEON in contrast to other SIMD engines (AltiVec/SSE), offers + // a reciprocal estimate AND a reciprocal step -which saves a few instructions + // vrecpeq_f32() returns an estimate to 1/b, which we will finetune with + // Newton-Raphson and vrecpsq_f32() + inv = vrecpe_f32(b); + + // This returns a differential, by which we will have to multiply inv to get a better + // approximation of 1/b. + restep = vrecps_f32(b, inv); + inv = vmul_f32(restep, inv); + + // Finally, multiply a by 1/b and get the wanted result of the division. + div = vmul_f32(a, inv); + + return div; +#endif +} +template<> EIGEN_STRONG_INLINE Packet4f pdiv(const Packet4f& a, const Packet4f& b) +{ +#if EIGEN_ARCH_ARM64 + return vdivq_f32(a,b); +#else + Packet4f inv, restep, div; + + // NEON does not offer a divide instruction, we have to do a reciprocal approximation + // However NEON in contrast to other SIMD engines (AltiVec/SSE), offers + // a reciprocal estimate AND a reciprocal step -which saves a few instructions + // vrecpeq_f32() returns an estimate to 1/b, which we will finetune with + // Newton-Raphson and vrecpsq_f32() + inv = vrecpeq_f32(b); + + // This returns a differential, by which we will have to multiply inv to get a better + // approximation of 1/b. + restep = vrecpsq_f32(b, inv); + inv = vmulq_f32(restep, inv); + + // Finally, multiply a by 1/b and get the wanted result of the division. + div = vmulq_f32(a, inv); + + return div; +#endif +} + +template<> EIGEN_STRONG_INLINE Packet4c pdiv(const Packet4c& /*a*/, const Packet4c& /*b*/) +{ + eigen_assert(false && "packet integer division are not supported by NEON"); + return pset1(0); +} +template<> EIGEN_STRONG_INLINE Packet8c pdiv(const Packet8c& /*a*/, const Packet8c& /*b*/) +{ + eigen_assert(false && "packet integer division are not supported by NEON"); + return pset1(0); +} +template<> EIGEN_STRONG_INLINE Packet16c pdiv(const Packet16c& /*a*/, const Packet16c& /*b*/) +{ + eigen_assert(false && "packet integer division are not supported by NEON"); + return pset1(0); +} +template<> EIGEN_STRONG_INLINE Packet4uc pdiv(const Packet4uc& /*a*/, const Packet4uc& /*b*/) +{ + eigen_assert(false && "packet integer division are not supported by NEON"); + return pset1(0); +} +template<> EIGEN_STRONG_INLINE Packet8uc pdiv(const Packet8uc& /*a*/, const Packet8uc& /*b*/) +{ + eigen_assert(false && "packet integer division are not supported by NEON"); + return pset1(0); +} +template<> EIGEN_STRONG_INLINE Packet16uc pdiv(const Packet16uc& /*a*/, const Packet16uc& /*b*/) +{ + eigen_assert(false && "packet integer division are not supported by NEON"); + return pset1(0); +} +template<> EIGEN_STRONG_INLINE Packet4s pdiv(const Packet4s& /*a*/, const Packet4s& /*b*/) +{ + eigen_assert(false && "packet integer division are not supported by NEON"); + return pset1(0); +} +template<> EIGEN_STRONG_INLINE Packet8s pdiv(const Packet8s& /*a*/, const Packet8s& /*b*/) +{ + eigen_assert(false && "packet integer division are not supported by NEON"); + return pset1(0); +} +template<> EIGEN_STRONG_INLINE Packet4us pdiv(const Packet4us& /*a*/, const Packet4us& /*b*/) +{ + eigen_assert(false && "packet integer division are not supported by NEON"); + return pset1(0); +} +template<> EIGEN_STRONG_INLINE Packet8us pdiv(const Packet8us& /*a*/, const Packet8us& /*b*/) +{ + eigen_assert(false && "packet integer division are not supported by NEON"); + return pset1(0); +} +template<> EIGEN_STRONG_INLINE Packet2i pdiv(const Packet2i& /*a*/, const Packet2i& /*b*/) +{ + eigen_assert(false && "packet integer division are not supported by NEON"); + return pset1(0); +} +template<> EIGEN_STRONG_INLINE Packet4i pdiv(const Packet4i& /*a*/, const Packet4i& /*b*/) +{ + eigen_assert(false && "packet integer division are not supported by NEON"); + return pset1(0); +} +template<> EIGEN_STRONG_INLINE Packet2ui pdiv(const Packet2ui& /*a*/, const Packet2ui& /*b*/) +{ + eigen_assert(false && "packet integer division are not supported by NEON"); + return pset1(0); +} +template<> EIGEN_STRONG_INLINE Packet4ui pdiv(const Packet4ui& /*a*/, const Packet4ui& /*b*/) +{ + eigen_assert(false && "packet integer division are not supported by NEON"); + return pset1(0); +} +template<> EIGEN_STRONG_INLINE Packet2l pdiv(const Packet2l& /*a*/, const Packet2l& /*b*/) +{ + eigen_assert(false && "packet integer division are not supported by NEON"); + return pset1(0LL); +} +template<> EIGEN_STRONG_INLINE Packet2ul pdiv(const Packet2ul& /*a*/, const Packet2ul& /*b*/) +{ + eigen_assert(false && "packet integer division are not supported by NEON"); + return pset1(0ULL); +} + + +#ifdef __ARM_FEATURE_FMA +template<> EIGEN_STRONG_INLINE Packet4f pmadd(const Packet4f& a, const Packet4f& b, const Packet4f& c) +{ return vfmaq_f32(c,a,b); } +template<> EIGEN_STRONG_INLINE Packet2f pmadd(const Packet2f& a, const Packet2f& b, const Packet2f& c) +{ return vfma_f32(c,a,b); } +#else +template<> EIGEN_STRONG_INLINE Packet4f pmadd(const Packet4f& a, const Packet4f& b, const Packet4f& c) +{ + return vmlaq_f32(c,a,b); +} +template<> EIGEN_STRONG_INLINE Packet2f pmadd(const Packet2f& a, const Packet2f& b, const Packet2f& c) +{ + return vmla_f32(c,a,b); +} +#endif + +// No FMA instruction for int, so use MLA unconditionally. +template<> EIGEN_STRONG_INLINE Packet4c pmadd(const Packet4c& a, const Packet4c& b, const Packet4c& c) +{ + return vget_lane_s32(vreinterpret_s32_s8(vmla_s8( + vreinterpret_s8_s32(vdup_n_s32(c)), + vreinterpret_s8_s32(vdup_n_s32(a)), + vreinterpret_s8_s32(vdup_n_s32(b)))), 0); +} +template<> EIGEN_STRONG_INLINE Packet8c pmadd(const Packet8c& a, const Packet8c& b, const Packet8c& c) +{ return vmla_s8(c,a,b); } +template<> EIGEN_STRONG_INLINE Packet16c pmadd(const Packet16c& a, const Packet16c& b, const Packet16c& c) +{ return vmlaq_s8(c,a,b); } +template<> EIGEN_STRONG_INLINE Packet4uc pmadd(const Packet4uc& a, const Packet4uc& b, const Packet4uc& c) +{ + return vget_lane_u32(vreinterpret_u32_u8(vmla_u8( + vreinterpret_u8_u32(vdup_n_u32(c)), + vreinterpret_u8_u32(vdup_n_u32(a)), + vreinterpret_u8_u32(vdup_n_u32(b)))), 0); +} +template<> EIGEN_STRONG_INLINE Packet8uc pmadd(const Packet8uc& a, const Packet8uc& b, const Packet8uc& c) +{ return vmla_u8(c,a,b); } +template<> EIGEN_STRONG_INLINE Packet16uc pmadd(const Packet16uc& a, const Packet16uc& b, const Packet16uc& c) +{ return vmlaq_u8(c,a,b); } +template<> EIGEN_STRONG_INLINE Packet4s pmadd(const Packet4s& a, const Packet4s& b, const Packet4s& c) +{ return vmla_s16(c,a,b); } +template<> EIGEN_STRONG_INLINE Packet8s pmadd(const Packet8s& a, const Packet8s& b, const Packet8s& c) +{ return vmlaq_s16(c,a,b); } +template<> EIGEN_STRONG_INLINE Packet4us pmadd(const Packet4us& a, const Packet4us& b, const Packet4us& c) +{ return vmla_u16(c,a,b); } +template<> EIGEN_STRONG_INLINE Packet8us pmadd(const Packet8us& a, const Packet8us& b, const Packet8us& c) +{ return vmlaq_u16(c,a,b); } +template<> EIGEN_STRONG_INLINE Packet2i pmadd(const Packet2i& a, const Packet2i& b, const Packet2i& c) +{ return vmla_s32(c,a,b); } +template<> EIGEN_STRONG_INLINE Packet4i pmadd(const Packet4i& a, const Packet4i& b, const Packet4i& c) +{ return vmlaq_s32(c,a,b); } +template<> EIGEN_STRONG_INLINE Packet2ui pmadd(const Packet2ui& a, const Packet2ui& b, const Packet2ui& c) +{ return vmla_u32(c,a,b); } +template<> EIGEN_STRONG_INLINE Packet4ui pmadd(const Packet4ui& a, const Packet4ui& b, const Packet4ui& c) +{ return vmlaq_u32(c,a,b); } + +template<> EIGEN_STRONG_INLINE Packet2f pabsdiff(const Packet2f& a, const Packet2f& b) +{ return vabd_f32(a,b); } +template<> EIGEN_STRONG_INLINE Packet4f pabsdiff(const Packet4f& a, const Packet4f& b) +{ return vabdq_f32(a,b); } +template<> EIGEN_STRONG_INLINE Packet4c pabsdiff(const Packet4c& a, const Packet4c& b) +{ + return vget_lane_s32(vreinterpret_s32_s8(vabd_s8( + vreinterpret_s8_s32(vdup_n_s32(a)), + vreinterpret_s8_s32(vdup_n_s32(b)))), 0); +} +template<> EIGEN_STRONG_INLINE Packet8c pabsdiff(const Packet8c& a, const Packet8c& b) +{ return vabd_s8(a,b); } +template<> EIGEN_STRONG_INLINE Packet16c pabsdiff(const Packet16c& a, const Packet16c& b) +{ return vabdq_s8(a,b); } +template<> EIGEN_STRONG_INLINE Packet4uc pabsdiff(const Packet4uc& a, const Packet4uc& b) +{ + return vget_lane_u32(vreinterpret_u32_u8(vabd_u8( + vreinterpret_u8_u32(vdup_n_u32(a)), + vreinterpret_u8_u32(vdup_n_u32(b)))), 0); +} +template<> EIGEN_STRONG_INLINE Packet8uc pabsdiff(const Packet8uc& a, const Packet8uc& b) +{ return vabd_u8(a,b); } +template<> EIGEN_STRONG_INLINE Packet16uc pabsdiff(const Packet16uc& a, const Packet16uc& b) +{ return vabdq_u8(a,b); } +template<> EIGEN_STRONG_INLINE Packet4s pabsdiff(const Packet4s& a, const Packet4s& b) +{ return vabd_s16(a,b); } +template<> EIGEN_STRONG_INLINE Packet8s pabsdiff(const Packet8s& a, const Packet8s& b) +{ return vabdq_s16(a,b); } +template<> EIGEN_STRONG_INLINE Packet4us pabsdiff(const Packet4us& a, const Packet4us& b) +{ return vabd_u16(a,b); } +template<> EIGEN_STRONG_INLINE Packet8us pabsdiff(const Packet8us& a, const Packet8us& b) +{ return vabdq_u16(a,b); } +template<> EIGEN_STRONG_INLINE Packet2i pabsdiff(const Packet2i& a, const Packet2i& b) +{ return vabd_s32(a,b); } +template<> EIGEN_STRONG_INLINE Packet4i pabsdiff(const Packet4i& a, const Packet4i& b) +{ return vabdq_s32(a,b); } +template<> EIGEN_STRONG_INLINE Packet2ui pabsdiff(const Packet2ui& a, const Packet2ui& b) +{ return vabd_u32(a,b); } +template<> EIGEN_STRONG_INLINE Packet4ui pabsdiff(const Packet4ui& a, const Packet4ui& b) +{ return vabdq_u32(a,b); } + +template<> EIGEN_STRONG_INLINE Packet2f pmin(const Packet2f& a, const Packet2f& b) { return vmin_f32(a,b); } +template<> EIGEN_STRONG_INLINE Packet4f pmin(const Packet4f& a, const Packet4f& b) { return vminq_f32(a,b); } + +#ifdef __ARM_FEATURE_NUMERIC_MAXMIN +// numeric max and min are only available if ARM_FEATURE_NUMERIC_MAXMIN is defined (which can only be the case for Armv8 systems). +template<> EIGEN_STRONG_INLINE Packet4f pmin(const Packet4f& a, const Packet4f& b) { return vminnmq_f32(a, b); } +template<> EIGEN_STRONG_INLINE Packet2f pmin(const Packet2f& a, const Packet2f& b) { return vminnm_f32(a, b); } +#endif + +template<> EIGEN_STRONG_INLINE Packet4f pmin(const Packet4f& a, const Packet4f& b) { return pmin(a, b); } + +template<> EIGEN_STRONG_INLINE Packet2f pmin(const Packet2f& a, const Packet2f& b) { return pmin(a, b); } + +template<> EIGEN_STRONG_INLINE Packet4c pmin(const Packet4c& a, const Packet4c& b) +{ + return vget_lane_s32(vreinterpret_s32_s8(vmin_s8( + vreinterpret_s8_s32(vdup_n_s32(a)), + vreinterpret_s8_s32(vdup_n_s32(b)))), 0); +} +template<> EIGEN_STRONG_INLINE Packet8c pmin(const Packet8c& a, const Packet8c& b) { return vmin_s8(a,b); } +template<> EIGEN_STRONG_INLINE Packet16c pmin(const Packet16c& a, const Packet16c& b) { return vminq_s8(a,b); } +template<> EIGEN_STRONG_INLINE Packet4uc pmin(const Packet4uc& a, const Packet4uc& b) +{ + return vget_lane_u32(vreinterpret_u32_u8(vmin_u8( + vreinterpret_u8_u32(vdup_n_u32(a)), + vreinterpret_u8_u32(vdup_n_u32(b)))), 0); +} +template<> EIGEN_STRONG_INLINE Packet8uc pmin(const Packet8uc& a, const Packet8uc& b) { return vmin_u8(a,b); } +template<> EIGEN_STRONG_INLINE Packet16uc pmin(const Packet16uc& a, const Packet16uc& b) { return vminq_u8(a,b); } +template<> EIGEN_STRONG_INLINE Packet4s pmin(const Packet4s& a, const Packet4s& b) { return vmin_s16(a,b); } +template<> EIGEN_STRONG_INLINE Packet8s pmin(const Packet8s& a, const Packet8s& b) { return vminq_s16(a,b); } +template<> EIGEN_STRONG_INLINE Packet4us pmin(const Packet4us& a, const Packet4us& b) { return vmin_u16(a,b); } +template<> EIGEN_STRONG_INLINE Packet8us pmin(const Packet8us& a, const Packet8us& b) { return vminq_u16(a,b); } +template<> EIGEN_STRONG_INLINE Packet2i pmin(const Packet2i& a, const Packet2i& b) { return vmin_s32(a,b); } +template<> EIGEN_STRONG_INLINE Packet4i pmin(const Packet4i& a, const Packet4i& b) { return vminq_s32(a,b); } +template<> EIGEN_STRONG_INLINE Packet2ui pmin(const Packet2ui& a, const Packet2ui& b) { return vmin_u32(a,b); } +template<> EIGEN_STRONG_INLINE Packet4ui pmin(const Packet4ui& a, const Packet4ui& b) { return vminq_u32(a,b); } +template<> EIGEN_STRONG_INLINE Packet2l pmin(const Packet2l& a, const Packet2l& b) { + return vcombine_s64( + vdup_n_s64((std::min)(vgetq_lane_s64(a, 0), vgetq_lane_s64(b, 0))), + vdup_n_s64((std::min)(vgetq_lane_s64(a, 1), vgetq_lane_s64(b, 1)))); +} +template<> EIGEN_STRONG_INLINE Packet2ul pmin(const Packet2ul& a, const Packet2ul& b) { + return vcombine_u64( + vdup_n_u64((std::min)(vgetq_lane_u64(a, 0), vgetq_lane_u64(b, 0))), + vdup_n_u64((std::min)(vgetq_lane_u64(a, 1), vgetq_lane_u64(b, 1)))); +} + +template<> EIGEN_STRONG_INLINE Packet2f pmax(const Packet2f& a, const Packet2f& b) { return vmax_f32(a,b); } +template<> EIGEN_STRONG_INLINE Packet4f pmax(const Packet4f& a, const Packet4f& b) { return vmaxq_f32(a,b); } + +#ifdef __ARM_FEATURE_NUMERIC_MAXMIN +// numeric max and min are only available if ARM_FEATURE_NUMERIC_MAXMIN is defined (which can only be the case for Armv8 systems). +template<> EIGEN_STRONG_INLINE Packet4f pmax(const Packet4f& a, const Packet4f& b) { return vmaxnmq_f32(a, b); } +template<> EIGEN_STRONG_INLINE Packet2f pmax(const Packet2f& a, const Packet2f& b) { return vmaxnm_f32(a, b); } +#endif + +template<> EIGEN_STRONG_INLINE Packet4f pmax(const Packet4f& a, const Packet4f& b) { return pmax(a, b); } + +template<> EIGEN_STRONG_INLINE Packet2f pmax(const Packet2f& a, const Packet2f& b) { return pmax(a, b); } + +template<> EIGEN_STRONG_INLINE Packet4c pmax(const Packet4c& a, const Packet4c& b) +{ + return vget_lane_s32(vreinterpret_s32_s8(vmax_s8( + vreinterpret_s8_s32(vdup_n_s32(a)), + vreinterpret_s8_s32(vdup_n_s32(b)))), 0); +} +template<> EIGEN_STRONG_INLINE Packet8c pmax(const Packet8c& a, const Packet8c& b) { return vmax_s8(a,b); } +template<> EIGEN_STRONG_INLINE Packet16c pmax(const Packet16c& a, const Packet16c& b) { return vmaxq_s8(a,b); } +template<> EIGEN_STRONG_INLINE Packet4uc pmax(const Packet4uc& a, const Packet4uc& b) +{ + return vget_lane_u32(vreinterpret_u32_u8(vmax_u8( + vreinterpret_u8_u32(vdup_n_u32(a)), + vreinterpret_u8_u32(vdup_n_u32(b)))), 0); +} +template<> EIGEN_STRONG_INLINE Packet8uc pmax(const Packet8uc& a, const Packet8uc& b) { return vmax_u8(a,b); } +template<> EIGEN_STRONG_INLINE Packet16uc pmax(const Packet16uc& a, const Packet16uc& b) { return vmaxq_u8(a,b); } +template<> EIGEN_STRONG_INLINE Packet4s pmax(const Packet4s& a, const Packet4s& b) { return vmax_s16(a,b); } +template<> EIGEN_STRONG_INLINE Packet8s pmax(const Packet8s& a, const Packet8s& b) { return vmaxq_s16(a,b); } +template<> EIGEN_STRONG_INLINE Packet4us pmax(const Packet4us& a, const Packet4us& b) { return vmax_u16(a,b); } +template<> EIGEN_STRONG_INLINE Packet8us pmax(const Packet8us& a, const Packet8us& b) { return vmaxq_u16(a,b); } +template<> EIGEN_STRONG_INLINE Packet2i pmax(const Packet2i& a, const Packet2i& b) { return vmax_s32(a,b); } +template<> EIGEN_STRONG_INLINE Packet4i pmax(const Packet4i& a, const Packet4i& b) { return vmaxq_s32(a,b); } +template<> EIGEN_STRONG_INLINE Packet2ui pmax(const Packet2ui& a, const Packet2ui& b) { return vmax_u32(a,b); } +template<> EIGEN_STRONG_INLINE Packet4ui pmax(const Packet4ui& a, const Packet4ui& b) { return vmaxq_u32(a,b); } +template<> EIGEN_STRONG_INLINE Packet2l pmax(const Packet2l& a, const Packet2l& b) { + return vcombine_s64( + vdup_n_s64((std::max)(vgetq_lane_s64(a, 0), vgetq_lane_s64(b, 0))), + vdup_n_s64((std::max)(vgetq_lane_s64(a, 1), vgetq_lane_s64(b, 1)))); +} +template<> EIGEN_STRONG_INLINE Packet2ul pmax(const Packet2ul& a, const Packet2ul& b) { + return vcombine_u64( + vdup_n_u64((std::max)(vgetq_lane_u64(a, 0), vgetq_lane_u64(b, 0))), + vdup_n_u64((std::max)(vgetq_lane_u64(a, 1), vgetq_lane_u64(b, 1)))); +} + +template<> EIGEN_STRONG_INLINE Packet2f pcmp_le(const Packet2f& a, const Packet2f& b) +{ return vreinterpret_f32_u32(vcle_f32(a,b)); } +template<> EIGEN_STRONG_INLINE Packet4f pcmp_le(const Packet4f& a, const Packet4f& b) +{ return vreinterpretq_f32_u32(vcleq_f32(a,b)); } +template<> EIGEN_STRONG_INLINE Packet4c pcmp_le(const Packet4c& a, const Packet4c& b) +{ + return vget_lane_s32(vreinterpret_s32_u8(vcle_s8( + vreinterpret_s8_s32(vdup_n_s32(a)), + vreinterpret_s8_s32(vdup_n_s32(b)))), 0); +} +template<> EIGEN_STRONG_INLINE Packet8c pcmp_le(const Packet8c& a, const Packet8c& b) +{ return vreinterpret_s8_u8(vcle_s8(a,b)); } +template<> EIGEN_STRONG_INLINE Packet16c pcmp_le(const Packet16c& a, const Packet16c& b) +{ return vreinterpretq_s8_u8(vcleq_s8(a,b)); } +template<> EIGEN_STRONG_INLINE Packet4uc pcmp_le(const Packet4uc& a, const Packet4uc& b) +{ + return vget_lane_u32(vreinterpret_u32_u8(vcle_u8( + vreinterpret_u8_u32(vdup_n_u32(a)), + vreinterpret_u8_u32(vdup_n_u32(b)))), 0); +} +template<> EIGEN_STRONG_INLINE Packet8uc pcmp_le(const Packet8uc& a, const Packet8uc& b) +{ return vcle_u8(a,b); } +template<> EIGEN_STRONG_INLINE Packet16uc pcmp_le(const Packet16uc& a, const Packet16uc& b) +{ return vcleq_u8(a,b); } +template<> EIGEN_STRONG_INLINE Packet4s pcmp_le(const Packet4s& a, const Packet4s& b) +{ return vreinterpret_s16_u16(vcle_s16(a,b)); } +template<> EIGEN_STRONG_INLINE Packet8s pcmp_le(const Packet8s& a, const Packet8s& b) +{ return vreinterpretq_s16_u16(vcleq_s16(a,b)); } +template<> EIGEN_STRONG_INLINE Packet4us pcmp_le(const Packet4us& a, const Packet4us& b) +{ return vcle_u16(a,b); } +template<> EIGEN_STRONG_INLINE Packet8us pcmp_le(const Packet8us& a, const Packet8us& b) +{ return vcleq_u16(a,b); } +template<> EIGEN_STRONG_INLINE Packet2i pcmp_le(const Packet2i& a, const Packet2i& b) +{ return vreinterpret_s32_u32(vcle_s32(a,b)); } +template<> EIGEN_STRONG_INLINE Packet4i pcmp_le(const Packet4i& a, const Packet4i& b) +{ return vreinterpretq_s32_u32(vcleq_s32(a,b)); } +template<> EIGEN_STRONG_INLINE Packet2ui pcmp_le(const Packet2ui& a, const Packet2ui& b) +{ return vcle_u32(a,b); } +template<> EIGEN_STRONG_INLINE Packet4ui pcmp_le(const Packet4ui& a, const Packet4ui& b) +{ return vcleq_u32(a,b); } +template<> EIGEN_STRONG_INLINE Packet2l pcmp_le(const Packet2l& a, const Packet2l& b) +{ +#if EIGEN_ARCH_ARM64 + return vreinterpretq_s64_u64(vcleq_s64(a,b)); +#else + return vcombine_s64( + vdup_n_s64(vgetq_lane_s64(a, 0) <= vgetq_lane_s64(b, 0) ? numext::int64_t(-1) : 0), + vdup_n_s64(vgetq_lane_s64(a, 1) <= vgetq_lane_s64(b, 1) ? numext::int64_t(-1) : 0)); +#endif +} +template<> EIGEN_STRONG_INLINE Packet2ul pcmp_le(const Packet2ul& a, const Packet2ul& b) +{ +#if EIGEN_ARCH_ARM64 + return vcleq_u64(a,b); +#else + return vcombine_u64( + vdup_n_u64(vgetq_lane_u64(a, 0) <= vgetq_lane_u64(b, 0) ? numext::uint64_t(-1) : 0), + vdup_n_u64(vgetq_lane_u64(a, 1) <= vgetq_lane_u64(b, 1) ? numext::uint64_t(-1) : 0)); +#endif +} + +template<> EIGEN_STRONG_INLINE Packet2f pcmp_lt(const Packet2f& a, const Packet2f& b) +{ return vreinterpret_f32_u32(vclt_f32(a,b)); } +template<> EIGEN_STRONG_INLINE Packet4f pcmp_lt(const Packet4f& a, const Packet4f& b) +{ return vreinterpretq_f32_u32(vcltq_f32(a,b)); } +template<> EIGEN_STRONG_INLINE Packet4c pcmp_lt(const Packet4c& a, const Packet4c& b) +{ + return vget_lane_s32(vreinterpret_s32_u8(vclt_s8( + vreinterpret_s8_s32(vdup_n_s32(a)), + vreinterpret_s8_s32(vdup_n_s32(b)))), 0); +} +template<> EIGEN_STRONG_INLINE Packet8c pcmp_lt(const Packet8c& a, const Packet8c& b) +{ return vreinterpret_s8_u8(vclt_s8(a,b)); } +template<> EIGEN_STRONG_INLINE Packet16c pcmp_lt(const Packet16c& a, const Packet16c& b) +{ return vreinterpretq_s8_u8(vcltq_s8(a,b)); } +template<> EIGEN_STRONG_INLINE Packet4uc pcmp_lt(const Packet4uc& a, const Packet4uc& b) +{ + return vget_lane_u32(vreinterpret_u32_u8(vclt_u8( + vreinterpret_u8_u32(vdup_n_u32(a)), + vreinterpret_u8_u32(vdup_n_u32(b)))), 0); +} +template<> EIGEN_STRONG_INLINE Packet8uc pcmp_lt(const Packet8uc& a, const Packet8uc& b) +{ return vclt_u8(a,b); } +template<> EIGEN_STRONG_INLINE Packet16uc pcmp_lt(const Packet16uc& a, const Packet16uc& b) +{ return vcltq_u8(a,b); } +template<> EIGEN_STRONG_INLINE Packet4s pcmp_lt(const Packet4s& a, const Packet4s& b) +{ return vreinterpret_s16_u16(vclt_s16(a,b)); } +template<> EIGEN_STRONG_INLINE Packet8s pcmp_lt(const Packet8s& a, const Packet8s& b) +{ return vreinterpretq_s16_u16(vcltq_s16(a,b)); } +template<> EIGEN_STRONG_INLINE Packet4us pcmp_lt(const Packet4us& a, const Packet4us& b) +{ return vclt_u16(a,b); } +template<> EIGEN_STRONG_INLINE Packet8us pcmp_lt(const Packet8us& a, const Packet8us& b) +{ return vcltq_u16(a,b); } +template<> EIGEN_STRONG_INLINE Packet2i pcmp_lt(const Packet2i& a, const Packet2i& b) +{ return vreinterpret_s32_u32(vclt_s32(a,b)); } +template<> EIGEN_STRONG_INLINE Packet4i pcmp_lt(const Packet4i& a, const Packet4i& b) +{ return vreinterpretq_s32_u32(vcltq_s32(a,b)); } +template<> EIGEN_STRONG_INLINE Packet2ui pcmp_lt(const Packet2ui& a, const Packet2ui& b) +{ return vclt_u32(a,b); } +template<> EIGEN_STRONG_INLINE Packet4ui pcmp_lt(const Packet4ui& a, const Packet4ui& b) +{ return vcltq_u32(a,b); } +template<> EIGEN_STRONG_INLINE Packet2l pcmp_lt(const Packet2l& a, const Packet2l& b) +{ +#if EIGEN_ARCH_ARM64 + return vreinterpretq_s64_u64(vcltq_s64(a,b)); +#else + return vcombine_s64( + vdup_n_s64(vgetq_lane_s64(a, 0) < vgetq_lane_s64(b, 0) ? numext::int64_t(-1) : 0), + vdup_n_s64(vgetq_lane_s64(a, 1) < vgetq_lane_s64(b, 1) ? numext::int64_t(-1) : 0)); +#endif +} +template<> EIGEN_STRONG_INLINE Packet2ul pcmp_lt(const Packet2ul& a, const Packet2ul& b) +{ +#if EIGEN_ARCH_ARM64 + return vcltq_u64(a,b); +#else + return vcombine_u64( + vdup_n_u64(vgetq_lane_u64(a, 0) < vgetq_lane_u64(b, 0) ? numext::uint64_t(-1) : 0), + vdup_n_u64(vgetq_lane_u64(a, 1) < vgetq_lane_u64(b, 1) ? numext::uint64_t(-1) : 0)); +#endif +} + +template<> EIGEN_STRONG_INLINE Packet2f pcmp_eq(const Packet2f& a, const Packet2f& b) +{ return vreinterpret_f32_u32(vceq_f32(a,b)); } +template<> EIGEN_STRONG_INLINE Packet4f pcmp_eq(const Packet4f& a, const Packet4f& b) +{ return vreinterpretq_f32_u32(vceqq_f32(a,b)); } +template<> EIGEN_STRONG_INLINE Packet4c pcmp_eq(const Packet4c& a, const Packet4c& b) +{ + return vget_lane_s32(vreinterpret_s32_u8(vceq_s8( + vreinterpret_s8_s32(vdup_n_s32(a)), + vreinterpret_s8_s32(vdup_n_s32(b)))), 0); +} +template<> EIGEN_STRONG_INLINE Packet8c pcmp_eq(const Packet8c& a, const Packet8c& b) +{ return vreinterpret_s8_u8(vceq_s8(a,b)); } +template<> EIGEN_STRONG_INLINE Packet16c pcmp_eq(const Packet16c& a, const Packet16c& b) +{ return vreinterpretq_s8_u8(vceqq_s8(a,b)); } +template<> EIGEN_STRONG_INLINE Packet4uc pcmp_eq(const Packet4uc& a, const Packet4uc& b) +{ + return vget_lane_u32(vreinterpret_u32_u8(vceq_u8( + vreinterpret_u8_u32(vdup_n_u32(a)), + vreinterpret_u8_u32(vdup_n_u32(b)))), 0); +} +template<> EIGEN_STRONG_INLINE Packet8uc pcmp_eq(const Packet8uc& a, const Packet8uc& b) +{ return vceq_u8(a,b); } +template<> EIGEN_STRONG_INLINE Packet16uc pcmp_eq(const Packet16uc& a, const Packet16uc& b) +{ return vceqq_u8(a,b); } +template<> EIGEN_STRONG_INLINE Packet4s pcmp_eq(const Packet4s& a, const Packet4s& b) +{ return vreinterpret_s16_u16(vceq_s16(a,b)); } +template<> EIGEN_STRONG_INLINE Packet8s pcmp_eq(const Packet8s& a, const Packet8s& b) +{ return vreinterpretq_s16_u16(vceqq_s16(a,b)); } +template<> EIGEN_STRONG_INLINE Packet4us pcmp_eq(const Packet4us& a, const Packet4us& b) +{ return vceq_u16(a,b); } +template<> EIGEN_STRONG_INLINE Packet8us pcmp_eq(const Packet8us& a, const Packet8us& b) +{ return vceqq_u16(a,b); } +template<> EIGEN_STRONG_INLINE Packet2i pcmp_eq(const Packet2i& a, const Packet2i& b) +{ return vreinterpret_s32_u32(vceq_s32(a,b)); } +template<> EIGEN_STRONG_INLINE Packet4i pcmp_eq(const Packet4i& a, const Packet4i& b) +{ return vreinterpretq_s32_u32(vceqq_s32(a,b)); } +template<> EIGEN_STRONG_INLINE Packet2ui pcmp_eq(const Packet2ui& a, const Packet2ui& b) +{ return vceq_u32(a,b); } +template<> EIGEN_STRONG_INLINE Packet4ui pcmp_eq(const Packet4ui& a, const Packet4ui& b) +{ return vceqq_u32(a,b); } +template<> EIGEN_STRONG_INLINE Packet2l pcmp_eq(const Packet2l& a, const Packet2l& b) +{ +#if EIGEN_ARCH_ARM64 + return vreinterpretq_s64_u64(vceqq_s64(a,b)); +#else + return vcombine_s64( + vdup_n_s64(vgetq_lane_s64(a, 0) == vgetq_lane_s64(b, 0) ? numext::int64_t(-1) : 0), + vdup_n_s64(vgetq_lane_s64(a, 1) == vgetq_lane_s64(b, 1) ? numext::int64_t(-1) : 0)); +#endif +} +template<> EIGEN_STRONG_INLINE Packet2ul pcmp_eq(const Packet2ul& a, const Packet2ul& b) +{ +#if EIGEN_ARCH_ARM64 + return vceqq_u64(a,b); +#else + return vcombine_u64( + vdup_n_u64(vgetq_lane_u64(a, 0) == vgetq_lane_u64(b, 0) ? numext::uint64_t(-1) : 0), + vdup_n_u64(vgetq_lane_u64(a, 1) == vgetq_lane_u64(b, 1) ? numext::uint64_t(-1) : 0)); +#endif +} + +template<> EIGEN_STRONG_INLINE Packet2f pcmp_lt_or_nan(const Packet2f& a, const Packet2f& b) +{ return vreinterpret_f32_u32(vmvn_u32(vcge_f32(a,b))); } +template<> EIGEN_STRONG_INLINE Packet4f pcmp_lt_or_nan(const Packet4f& a, const Packet4f& b) +{ return vreinterpretq_f32_u32(vmvnq_u32(vcgeq_f32(a,b))); } + +// Logical Operations are not supported for float, so we have to reinterpret casts using NEON intrinsics +template<> EIGEN_STRONG_INLINE Packet2f pand(const Packet2f& a, const Packet2f& b) +{ return vreinterpret_f32_u32(vand_u32(vreinterpret_u32_f32(a),vreinterpret_u32_f32(b))); } +template<> EIGEN_STRONG_INLINE Packet4f pand(const Packet4f& a, const Packet4f& b) +{ return vreinterpretq_f32_u32(vandq_u32(vreinterpretq_u32_f32(a),vreinterpretq_u32_f32(b))); } +template<> EIGEN_STRONG_INLINE Packet4c pand(const Packet4c& a, const Packet4c& b) +{ return a & b; } +template<> EIGEN_STRONG_INLINE Packet8c pand(const Packet8c& a, const Packet8c& b) +{ return vand_s8(a,b); } +template<> EIGEN_STRONG_INLINE Packet16c pand(const Packet16c& a, const Packet16c& b) +{ return vandq_s8(a,b); } +template<> EIGEN_STRONG_INLINE Packet4uc pand(const Packet4uc& a, const Packet4uc& b) +{ return a & b; } +template<> EIGEN_STRONG_INLINE Packet8uc pand(const Packet8uc& a, const Packet8uc& b) +{ return vand_u8(a,b); } +template<> EIGEN_STRONG_INLINE Packet16uc pand(const Packet16uc& a, const Packet16uc& b) +{ return vandq_u8(a,b); } +template<> EIGEN_STRONG_INLINE Packet4s pand(const Packet4s& a, const Packet4s& b) { return vand_s16(a,b); } +template<> EIGEN_STRONG_INLINE Packet8s pand(const Packet8s& a, const Packet8s& b) { return vandq_s16(a,b); } +template<> EIGEN_STRONG_INLINE Packet4us pand(const Packet4us& a, const Packet4us& b) +{ return vand_u16(a,b); } +template<> EIGEN_STRONG_INLINE Packet8us pand(const Packet8us& a, const Packet8us& b) +{ return vandq_u16(a,b); } +template<> EIGEN_STRONG_INLINE Packet2i pand(const Packet2i& a, const Packet2i& b) { return vand_s32(a,b); } +template<> EIGEN_STRONG_INLINE Packet4i pand(const Packet4i& a, const Packet4i& b) { return vandq_s32(a,b); } +template<> EIGEN_STRONG_INLINE Packet2ui pand(const Packet2ui& a, const Packet2ui& b) +{ return vand_u32(a,b); } +template<> EIGEN_STRONG_INLINE Packet4ui pand(const Packet4ui& a, const Packet4ui& b) +{ return vandq_u32(a,b); } +template<> EIGEN_STRONG_INLINE Packet2l pand(const Packet2l& a, const Packet2l& b) { return vandq_s64(a,b); } +template<> EIGEN_STRONG_INLINE Packet2ul pand(const Packet2ul& a, const Packet2ul& b) +{ return vandq_u64(a,b); } + +template<> EIGEN_STRONG_INLINE Packet2f por(const Packet2f& a, const Packet2f& b) +{ return vreinterpret_f32_u32(vorr_u32(vreinterpret_u32_f32(a),vreinterpret_u32_f32(b))); } +template<> EIGEN_STRONG_INLINE Packet4f por(const Packet4f& a, const Packet4f& b) +{ return vreinterpretq_f32_u32(vorrq_u32(vreinterpretq_u32_f32(a),vreinterpretq_u32_f32(b))); } +template<> EIGEN_STRONG_INLINE Packet4c por(const Packet4c& a, const Packet4c& b) +{ return a | b; } +template<> EIGEN_STRONG_INLINE Packet8c por(const Packet8c& a, const Packet8c& b) { return vorr_s8(a,b); } +template<> EIGEN_STRONG_INLINE Packet16c por(const Packet16c& a, const Packet16c& b) +{ return vorrq_s8(a,b); } +template<> EIGEN_STRONG_INLINE Packet4uc por(const Packet4uc& a, const Packet4uc& b) +{ return a | b; } +template<> EIGEN_STRONG_INLINE Packet8uc por(const Packet8uc& a, const Packet8uc& b) +{ return vorr_u8(a,b); } +template<> EIGEN_STRONG_INLINE Packet16uc por(const Packet16uc& a, const Packet16uc& b) +{ return vorrq_u8(a,b); } +template<> EIGEN_STRONG_INLINE Packet4s por(const Packet4s& a, const Packet4s& b) +{ return vorr_s16(a,b); } +template<> EIGEN_STRONG_INLINE Packet8s por(const Packet8s& a, const Packet8s& b) +{ return vorrq_s16(a,b); } +template<> EIGEN_STRONG_INLINE Packet4us por(const Packet4us& a, const Packet4us& b) +{ return vorr_u16(a,b); } +template<> EIGEN_STRONG_INLINE Packet8us por(const Packet8us& a, const Packet8us& b) +{ return vorrq_u16(a,b); } +template<> EIGEN_STRONG_INLINE Packet2i por(const Packet2i& a, const Packet2i& b) { return vorr_s32(a,b); } +template<> EIGEN_STRONG_INLINE Packet4i por(const Packet4i& a, const Packet4i& b) { return vorrq_s32(a,b); } +template<> EIGEN_STRONG_INLINE Packet2ui por(const Packet2ui& a, const Packet2ui& b) +{ return vorr_u32(a,b); } +template<> EIGEN_STRONG_INLINE Packet4ui por(const Packet4ui& a, const Packet4ui& b) +{ return vorrq_u32(a,b); } +template<> EIGEN_STRONG_INLINE Packet2l por(const Packet2l& a, const Packet2l& b) +{ return vorrq_s64(a,b); } +template<> EIGEN_STRONG_INLINE Packet2ul por(const Packet2ul& a, const Packet2ul& b) +{ return vorrq_u64(a,b); } + +template<> EIGEN_STRONG_INLINE Packet2f pxor(const Packet2f& a, const Packet2f& b) +{ return vreinterpret_f32_u32(veor_u32(vreinterpret_u32_f32(a),vreinterpret_u32_f32(b))); } +template<> EIGEN_STRONG_INLINE Packet4f pxor(const Packet4f& a, const Packet4f& b) +{ return vreinterpretq_f32_u32(veorq_u32(vreinterpretq_u32_f32(a),vreinterpretq_u32_f32(b))); } +template<> EIGEN_STRONG_INLINE Packet4c pxor(const Packet4c& a, const Packet4c& b) +{ return a ^ b; } +template<> EIGEN_STRONG_INLINE Packet8c pxor(const Packet8c& a, const Packet8c& b) +{ return veor_s8(a,b); } +template<> EIGEN_STRONG_INLINE Packet16c pxor(const Packet16c& a, const Packet16c& b) +{ return veorq_s8(a,b); } +template<> EIGEN_STRONG_INLINE Packet4uc pxor(const Packet4uc& a, const Packet4uc& b) +{ return a ^ b; } +template<> EIGEN_STRONG_INLINE Packet8uc pxor(const Packet8uc& a, const Packet8uc& b) +{ return veor_u8(a,b); } +template<> EIGEN_STRONG_INLINE Packet16uc pxor(const Packet16uc& a, const Packet16uc& b) +{ return veorq_u8(a,b); } +template<> EIGEN_STRONG_INLINE Packet4s pxor(const Packet4s& a, const Packet4s& b) { return veor_s16(a,b); } +template<> EIGEN_STRONG_INLINE Packet8s pxor(const Packet8s& a, const Packet8s& b) { return veorq_s16(a,b); } +template<> EIGEN_STRONG_INLINE Packet4us pxor(const Packet4us& a, const Packet4us& b) +{ return veor_u16(a,b); } +template<> EIGEN_STRONG_INLINE Packet8us pxor(const Packet8us& a, const Packet8us& b) +{ return veorq_u16(a,b); } +template<> EIGEN_STRONG_INLINE Packet2i pxor(const Packet2i& a, const Packet2i& b) { return veor_s32(a,b); } +template<> EIGEN_STRONG_INLINE Packet4i pxor(const Packet4i& a, const Packet4i& b) { return veorq_s32(a,b); } +template<> EIGEN_STRONG_INLINE Packet2ui pxor(const Packet2ui& a, const Packet2ui& b) +{ return veor_u32(a,b); } +template<> EIGEN_STRONG_INLINE Packet4ui pxor(const Packet4ui& a, const Packet4ui& b) +{ return veorq_u32(a,b); } +template<> EIGEN_STRONG_INLINE Packet2l pxor(const Packet2l& a, const Packet2l& b) +{ return veorq_s64(a,b); } +template<> EIGEN_STRONG_INLINE Packet2ul pxor(const Packet2ul& a, const Packet2ul& b) +{ return veorq_u64(a,b); } + +template<> EIGEN_STRONG_INLINE Packet2f pandnot(const Packet2f& a, const Packet2f& b) +{ return vreinterpret_f32_u32(vbic_u32(vreinterpret_u32_f32(a),vreinterpret_u32_f32(b))); } +template<> EIGEN_STRONG_INLINE Packet4f pandnot(const Packet4f& a, const Packet4f& b) +{ return vreinterpretq_f32_u32(vbicq_u32(vreinterpretq_u32_f32(a),vreinterpretq_u32_f32(b))); } +template<> EIGEN_STRONG_INLINE Packet4c pandnot(const Packet4c& a, const Packet4c& b) +{ return a & ~b; } +template<> EIGEN_STRONG_INLINE Packet8c pandnot(const Packet8c& a, const Packet8c& b) { return vbic_s8(a,b); } +template<> EIGEN_STRONG_INLINE Packet16c pandnot(const Packet16c& a, const Packet16c& b) { return vbicq_s8(a,b); } +template<> EIGEN_STRONG_INLINE Packet4uc pandnot(const Packet4uc& a, const Packet4uc& b) +{ return a & ~b; } +template<> EIGEN_STRONG_INLINE Packet8uc pandnot(const Packet8uc& a, const Packet8uc& b) +{ return vbic_u8(a,b); } +template<> EIGEN_STRONG_INLINE Packet16uc pandnot(const Packet16uc& a, const Packet16uc& b) +{ return vbicq_u8(a,b); } +template<> EIGEN_STRONG_INLINE Packet4s pandnot(const Packet4s& a, const Packet4s& b) +{ return vbic_s16(a,b); } +template<> EIGEN_STRONG_INLINE Packet8s pandnot(const Packet8s& a, const Packet8s& b) +{ return vbicq_s16(a,b); } +template<> EIGEN_STRONG_INLINE Packet4us pandnot(const Packet4us& a, const Packet4us& b) +{ return vbic_u16(a,b); } +template<> EIGEN_STRONG_INLINE Packet8us pandnot(const Packet8us& a, const Packet8us& b) +{ return vbicq_u16(a,b); } +template<> EIGEN_STRONG_INLINE Packet2i pandnot(const Packet2i& a, const Packet2i& b) +{ return vbic_s32(a,b); } +template<> EIGEN_STRONG_INLINE Packet4i pandnot(const Packet4i& a, const Packet4i& b) +{ return vbicq_s32(a,b); } +template<> EIGEN_STRONG_INLINE Packet2ui pandnot(const Packet2ui& a, const Packet2ui& b) +{ return vbic_u32(a,b); } +template<> EIGEN_STRONG_INLINE Packet4ui pandnot(const Packet4ui& a, const Packet4ui& b) +{ return vbicq_u32(a,b); } +template<> EIGEN_STRONG_INLINE Packet2l pandnot(const Packet2l& a, const Packet2l& b) +{ return vbicq_s64(a,b); } +template<> EIGEN_STRONG_INLINE Packet2ul pandnot(const Packet2ul& a, const Packet2ul& b) +{ return vbicq_u64(a,b); } + + +template EIGEN_STRONG_INLINE Packet4c parithmetic_shift_right(Packet4c& a) +{ return vget_lane_s32(vreinterpret_s32_s8(vshr_n_s8(vreinterpret_s8_s32(vdup_n_s32(a)), N)), 0); } +template EIGEN_STRONG_INLINE Packet8c parithmetic_shift_right(Packet8c a) { return vshr_n_s8(a,N); } +template EIGEN_STRONG_INLINE Packet16c parithmetic_shift_right(Packet16c a) { return vshrq_n_s8(a,N); } +template EIGEN_STRONG_INLINE Packet4uc parithmetic_shift_right(Packet4uc& a) +{ return vget_lane_u32(vreinterpret_u32_u8(vshr_n_u8(vreinterpret_u8_u32(vdup_n_u32(a)), N)), 0); } +template EIGEN_STRONG_INLINE Packet8uc parithmetic_shift_right(Packet8uc a) { return vshr_n_u8(a,N); } +template EIGEN_STRONG_INLINE Packet16uc parithmetic_shift_right(Packet16uc a) { return vshrq_n_u8(a,N); } +template EIGEN_STRONG_INLINE Packet4s parithmetic_shift_right(Packet4s a) { return vshr_n_s16(a,N); } +template EIGEN_STRONG_INLINE Packet8s parithmetic_shift_right(Packet8s a) { return vshrq_n_s16(a,N); } +template EIGEN_STRONG_INLINE Packet4us parithmetic_shift_right(Packet4us a) { return vshr_n_u16(a,N); } +template EIGEN_STRONG_INLINE Packet8us parithmetic_shift_right(Packet8us a) { return vshrq_n_u16(a,N); } +template EIGEN_STRONG_INLINE Packet2i parithmetic_shift_right(Packet2i a) { return vshr_n_s32(a,N); } +template EIGEN_STRONG_INLINE Packet4i parithmetic_shift_right(Packet4i a) { return vshrq_n_s32(a,N); } +template EIGEN_STRONG_INLINE Packet2ui parithmetic_shift_right(Packet2ui a) { return vshr_n_u32(a,N); } +template EIGEN_STRONG_INLINE Packet4ui parithmetic_shift_right(Packet4ui a) { return vshrq_n_u32(a,N); } +template EIGEN_STRONG_INLINE Packet2l parithmetic_shift_right(Packet2l a) { return vshrq_n_s64(a,N); } +template EIGEN_STRONG_INLINE Packet2ul parithmetic_shift_right(Packet2ul a) { return vshrq_n_u64(a,N); } + +template EIGEN_STRONG_INLINE Packet4c plogical_shift_right(Packet4c& a) +{ return vget_lane_s32(vreinterpret_s32_u8(vshr_n_u8(vreinterpret_u8_s32(vdup_n_s32(a)), N)), 0); } +template EIGEN_STRONG_INLINE Packet8c plogical_shift_right(Packet8c a) +{ return vreinterpret_s8_u8(vshr_n_u8(vreinterpret_u8_s8(a),N)); } +template EIGEN_STRONG_INLINE Packet16c plogical_shift_right(Packet16c a) +{ return vreinterpretq_s8_u8(vshrq_n_u8(vreinterpretq_u8_s8(a),N)); } +template EIGEN_STRONG_INLINE Packet4uc plogical_shift_right(Packet4uc& a) +{ return vget_lane_u32(vreinterpret_u32_s8(vshr_n_s8(vreinterpret_s8_u32(vdup_n_u32(a)), N)), 0); } +template EIGEN_STRONG_INLINE Packet8uc plogical_shift_right(Packet8uc a) { return vshr_n_u8(a,N); } +template EIGEN_STRONG_INLINE Packet16uc plogical_shift_right(Packet16uc a) { return vshrq_n_u8(a,N); } +template EIGEN_STRONG_INLINE Packet4s plogical_shift_right(Packet4s a) +{ return vreinterpret_s16_u16(vshr_n_u16(vreinterpret_u16_s16(a),N)); } +template EIGEN_STRONG_INLINE Packet8s plogical_shift_right(Packet8s a) +{ return vreinterpretq_s16_u16(vshrq_n_u16(vreinterpretq_u16_s16(a),N)); } +template EIGEN_STRONG_INLINE Packet4us plogical_shift_right(Packet4us a) { return vshr_n_u16(a,N); } +template EIGEN_STRONG_INLINE Packet8us plogical_shift_right(Packet8us a) { return vshrq_n_u16(a,N); } +template EIGEN_STRONG_INLINE Packet2i plogical_shift_right(Packet2i a) +{ return vreinterpret_s32_u32(vshr_n_u32(vreinterpret_u32_s32(a),N)); } +template EIGEN_STRONG_INLINE Packet4i plogical_shift_right(Packet4i a) +{ return vreinterpretq_s32_u32(vshrq_n_u32(vreinterpretq_u32_s32(a),N)); } +template EIGEN_STRONG_INLINE Packet2ui plogical_shift_right(Packet2ui a) { return vshr_n_u32(a,N); } +template EIGEN_STRONG_INLINE Packet4ui plogical_shift_right(Packet4ui a) { return vshrq_n_u32(a,N); } +template EIGEN_STRONG_INLINE Packet2l plogical_shift_right(Packet2l a) +{ return vreinterpretq_s64_u64(vshrq_n_u64(vreinterpretq_u64_s64(a),N)); } +template EIGEN_STRONG_INLINE Packet2ul plogical_shift_right(Packet2ul a) { return vshrq_n_u64(a,N); } + +template EIGEN_STRONG_INLINE Packet4c plogical_shift_left(Packet4c& a) +{ return vget_lane_s32(vreinterpret_s32_s8(vshl_n_s8(vreinterpret_s8_s32(vdup_n_s32(a)), N)), 0); } +template EIGEN_STRONG_INLINE Packet8c plogical_shift_left(Packet8c a) { return vshl_n_s8(a,N); } +template EIGEN_STRONG_INLINE Packet16c plogical_shift_left(Packet16c a) { return vshlq_n_s8(a,N); } +template EIGEN_STRONG_INLINE Packet4uc plogical_shift_left(Packet4uc& a) +{ return vget_lane_u32(vreinterpret_u32_u8(vshl_n_u8(vreinterpret_u8_u32(vdup_n_u32(a)), N)), 0); } +template EIGEN_STRONG_INLINE Packet8uc plogical_shift_left(Packet8uc a) { return vshl_n_u8(a,N); } +template EIGEN_STRONG_INLINE Packet16uc plogical_shift_left(Packet16uc a) { return vshlq_n_u8(a,N); } +template EIGEN_STRONG_INLINE Packet4s plogical_shift_left(Packet4s a) { return vshl_n_s16(a,N); } +template EIGEN_STRONG_INLINE Packet8s plogical_shift_left(Packet8s a) { return vshlq_n_s16(a,N); } +template EIGEN_STRONG_INLINE Packet4us plogical_shift_left(Packet4us a) { return vshl_n_u16(a,N); } +template EIGEN_STRONG_INLINE Packet8us plogical_shift_left(Packet8us a) { return vshlq_n_u16(a,N); } +template EIGEN_STRONG_INLINE Packet2i plogical_shift_left(Packet2i a) { return vshl_n_s32(a,N); } +template EIGEN_STRONG_INLINE Packet4i plogical_shift_left(Packet4i a) { return vshlq_n_s32(a,N); } +template EIGEN_STRONG_INLINE Packet2ui plogical_shift_left(Packet2ui a) { return vshl_n_u32(a,N); } +template EIGEN_STRONG_INLINE Packet4ui plogical_shift_left(Packet4ui a) { return vshlq_n_u32(a,N); } +template EIGEN_STRONG_INLINE Packet2l plogical_shift_left(Packet2l a) { return vshlq_n_s64(a,N); } +template EIGEN_STRONG_INLINE Packet2ul plogical_shift_left(Packet2ul a) { return vshlq_n_u64(a,N); } + +template<> EIGEN_STRONG_INLINE Packet2f pload(const float* from) +{ EIGEN_DEBUG_ALIGNED_LOAD return vld1_f32(from); } +template<> EIGEN_STRONG_INLINE Packet4f pload(const float* from) +{ EIGEN_DEBUG_ALIGNED_LOAD return vld1q_f32(from); } +template<> EIGEN_STRONG_INLINE Packet4c pload(const int8_t* from) +{ + Packet4c res; + memcpy(&res, from, sizeof(Packet4c)); + return res; +} +template<> EIGEN_STRONG_INLINE Packet8c pload(const int8_t* from) +{ EIGEN_DEBUG_ALIGNED_LOAD return vld1_s8(from); } +template<> EIGEN_STRONG_INLINE Packet16c pload(const int8_t* from) +{ EIGEN_DEBUG_ALIGNED_LOAD return vld1q_s8(from); } +template<> EIGEN_STRONG_INLINE Packet4uc pload(const uint8_t* from) +{ + Packet4uc res; + memcpy(&res, from, sizeof(Packet4uc)); + return res; +} +template<> EIGEN_STRONG_INLINE Packet8uc pload(const uint8_t* from) +{ EIGEN_DEBUG_ALIGNED_LOAD return vld1_u8(from); } +template<> EIGEN_STRONG_INLINE Packet16uc pload(const uint8_t* from) +{ EIGEN_DEBUG_ALIGNED_LOAD return vld1q_u8(from); } +template<> EIGEN_STRONG_INLINE Packet4s pload(const int16_t* from) +{ EIGEN_DEBUG_ALIGNED_LOAD return vld1_s16(from); } +template<> EIGEN_STRONG_INLINE Packet8s pload(const int16_t* from) +{ EIGEN_DEBUG_ALIGNED_LOAD return vld1q_s16(from); } +template<> EIGEN_STRONG_INLINE Packet4us pload(const uint16_t* from) +{ EIGEN_DEBUG_ALIGNED_LOAD return vld1_u16(from); } +template<> EIGEN_STRONG_INLINE Packet8us pload(const uint16_t* from) +{ EIGEN_DEBUG_ALIGNED_LOAD return vld1q_u16(from); } +template<> EIGEN_STRONG_INLINE Packet2i pload(const int32_t* from) +{ EIGEN_DEBUG_ALIGNED_LOAD return vld1_s32(from); } +template<> EIGEN_STRONG_INLINE Packet4i pload(const int32_t* from) +{ EIGEN_DEBUG_ALIGNED_LOAD return vld1q_s32(from); } +template<> EIGEN_STRONG_INLINE Packet2ui pload(const uint32_t* from) +{ EIGEN_DEBUG_ALIGNED_LOAD return vld1_u32(from); } +template<> EIGEN_STRONG_INLINE Packet4ui pload(const uint32_t* from) +{ EIGEN_DEBUG_ALIGNED_LOAD return vld1q_u32(from); } +template<> EIGEN_STRONG_INLINE Packet2l pload(const int64_t* from) +{ EIGEN_DEBUG_ALIGNED_LOAD return vld1q_s64(from); } +template<> EIGEN_STRONG_INLINE Packet2ul pload(const uint64_t* from) +{ EIGEN_DEBUG_ALIGNED_LOAD return vld1q_u64(from); } + +template<> EIGEN_STRONG_INLINE Packet2f ploadu(const float* from) +{ EIGEN_DEBUG_UNALIGNED_LOAD return vld1_f32(from); } +template<> EIGEN_STRONG_INLINE Packet4f ploadu(const float* from) +{ EIGEN_DEBUG_UNALIGNED_LOAD return vld1q_f32(from); } +template<> EIGEN_STRONG_INLINE Packet4c ploadu(const int8_t* from) +{ + Packet4c res; + memcpy(&res, from, sizeof(Packet4c)); + return res; +} +template<> EIGEN_STRONG_INLINE Packet8c ploadu(const int8_t* from) +{ EIGEN_DEBUG_UNALIGNED_LOAD return vld1_s8(from); } +template<> EIGEN_STRONG_INLINE Packet16c ploadu(const int8_t* from) +{ EIGEN_DEBUG_UNALIGNED_LOAD return vld1q_s8(from); } +template<> EIGEN_STRONG_INLINE Packet4uc ploadu(const uint8_t* from) +{ + Packet4uc res; + memcpy(&res, from, sizeof(Packet4uc)); + return res; +} +template<> EIGEN_STRONG_INLINE Packet8uc ploadu(const uint8_t* from) +{ EIGEN_DEBUG_UNALIGNED_LOAD return vld1_u8(from); } +template<> EIGEN_STRONG_INLINE Packet16uc ploadu(const uint8_t* from) +{ EIGEN_DEBUG_UNALIGNED_LOAD return vld1q_u8(from); } +template<> EIGEN_STRONG_INLINE Packet4s ploadu(const int16_t* from) +{ EIGEN_DEBUG_UNALIGNED_LOAD return vld1_s16(from); } +template<> EIGEN_STRONG_INLINE Packet8s ploadu(const int16_t* from) +{ EIGEN_DEBUG_UNALIGNED_LOAD return vld1q_s16(from); } +template<> EIGEN_STRONG_INLINE Packet4us ploadu(const uint16_t* from) +{ EIGEN_DEBUG_UNALIGNED_LOAD return vld1_u16(from); } +template<> EIGEN_STRONG_INLINE Packet8us ploadu(const uint16_t* from) +{ EIGEN_DEBUG_UNALIGNED_LOAD return vld1q_u16(from); } +template<> EIGEN_STRONG_INLINE Packet2i ploadu(const int32_t* from) +{ EIGEN_DEBUG_UNALIGNED_LOAD return vld1_s32(from); } +template<> EIGEN_STRONG_INLINE Packet4i ploadu(const int32_t* from) +{ EIGEN_DEBUG_UNALIGNED_LOAD return vld1q_s32(from); } +template<> EIGEN_STRONG_INLINE Packet2ui ploadu(const uint32_t* from) +{ EIGEN_DEBUG_UNALIGNED_LOAD return vld1_u32(from); } +template<> EIGEN_STRONG_INLINE Packet4ui ploadu(const uint32_t* from) +{ EIGEN_DEBUG_UNALIGNED_LOAD return vld1q_u32(from); } +template<> EIGEN_STRONG_INLINE Packet2l ploadu(const int64_t* from) +{ EIGEN_DEBUG_UNALIGNED_LOAD return vld1q_s64(from); } +template<> EIGEN_STRONG_INLINE Packet2ul ploadu(const uint64_t* from) +{ EIGEN_DEBUG_UNALIGNED_LOAD return vld1q_u64(from); } + +template<> EIGEN_STRONG_INLINE Packet2f ploaddup(const float* from) +{ return vld1_dup_f32(from); } +template<> EIGEN_STRONG_INLINE Packet4f ploaddup(const float* from) +{ return vcombine_f32(vld1_dup_f32(from), vld1_dup_f32(from+1)); } +template<> EIGEN_STRONG_INLINE Packet4c ploaddup(const int8_t* from) +{ + const int8x8_t a = vreinterpret_s8_s32(vdup_n_s32(pload(from))); + return vget_lane_s32(vreinterpret_s32_s8(vzip_s8(a,a).val[0]), 0); +} +template<> EIGEN_STRONG_INLINE Packet8c ploaddup(const int8_t* from) +{ + const int8x8_t a = vld1_s8(from); + return vzip_s8(a,a).val[0]; +} +template<> EIGEN_STRONG_INLINE Packet16c ploaddup(const int8_t* from) +{ + const int8x8_t a = vld1_s8(from); + const int8x8x2_t b = vzip_s8(a,a); + return vcombine_s8(b.val[0], b.val[1]); +} +template<> EIGEN_STRONG_INLINE Packet4uc ploaddup(const uint8_t* from) +{ + const uint8x8_t a = vreinterpret_u8_u32(vdup_n_u32(pload(from))); + return vget_lane_u32(vreinterpret_u32_u8(vzip_u8(a,a).val[0]), 0); +} +template<> EIGEN_STRONG_INLINE Packet8uc ploaddup(const uint8_t* from) +{ + const uint8x8_t a = vld1_u8(from); + return vzip_u8(a,a).val[0]; +} +template<> EIGEN_STRONG_INLINE Packet16uc ploaddup(const uint8_t* from) +{ + const uint8x8_t a = vld1_u8(from); + const uint8x8x2_t b = vzip_u8(a,a); + return vcombine_u8(b.val[0], b.val[1]); +} +template<> EIGEN_STRONG_INLINE Packet4s ploaddup(const int16_t* from) +{ + return vreinterpret_s16_u32(vzip_u32(vreinterpret_u32_s16(vld1_dup_s16(from)), + vreinterpret_u32_s16(vld1_dup_s16(from+1))).val[0]); +} +template<> EIGEN_STRONG_INLINE Packet8s ploaddup(const int16_t* from) +{ + const int16x4_t a = vld1_s16(from); + const int16x4x2_t b = vzip_s16(a,a); + return vcombine_s16(b.val[0], b.val[1]); +} +template<> EIGEN_STRONG_INLINE Packet4us ploaddup(const uint16_t* from) +{ + return vreinterpret_u16_u32(vzip_u32(vreinterpret_u32_u16(vld1_dup_u16(from)), + vreinterpret_u32_u16(vld1_dup_u16(from+1))).val[0]); +} +template<> EIGEN_STRONG_INLINE Packet8us ploaddup(const uint16_t* from) +{ + const uint16x4_t a = vld1_u16(from); + const uint16x4x2_t b = vzip_u16(a,a); + return vcombine_u16(b.val[0], b.val[1]); +} +template<> EIGEN_STRONG_INLINE Packet2i ploaddup(const int32_t* from) +{ return vld1_dup_s32(from); } +template<> EIGEN_STRONG_INLINE Packet4i ploaddup(const int32_t* from) +{ return vcombine_s32(vld1_dup_s32(from), vld1_dup_s32(from+1)); } +template<> EIGEN_STRONG_INLINE Packet2ui ploaddup(const uint32_t* from) +{ return vld1_dup_u32(from); } +template<> EIGEN_STRONG_INLINE Packet4ui ploaddup(const uint32_t* from) +{ return vcombine_u32(vld1_dup_u32(from), vld1_dup_u32(from+1)); } +template<> EIGEN_STRONG_INLINE Packet2l ploaddup(const int64_t* from) +{ return vld1q_dup_s64(from); } +template<> EIGEN_STRONG_INLINE Packet2ul ploaddup(const uint64_t* from) +{ return vld1q_dup_u64(from); } + +template<> EIGEN_STRONG_INLINE Packet4f ploadquad(const float* from) { return vld1q_dup_f32(from); } +template<> EIGEN_STRONG_INLINE Packet4c ploadquad(const int8_t* from) +{ return vget_lane_s32(vreinterpret_s32_s8(vld1_dup_s8(from)), 0); } +template<> EIGEN_STRONG_INLINE Packet8c ploadquad(const int8_t* from) +{ + return vreinterpret_s8_u32(vzip_u32( + vreinterpret_u32_s8(vld1_dup_s8(from)), + vreinterpret_u32_s8(vld1_dup_s8(from+1))).val[0]); +} +template<> EIGEN_STRONG_INLINE Packet16c ploadquad(const int8_t* from) +{ + const int8x8_t a = vreinterpret_s8_u32(vzip_u32( + vreinterpret_u32_s8(vld1_dup_s8(from)), + vreinterpret_u32_s8(vld1_dup_s8(from+1))).val[0]); + const int8x8_t b = vreinterpret_s8_u32(vzip_u32( + vreinterpret_u32_s8(vld1_dup_s8(from+2)), + vreinterpret_u32_s8(vld1_dup_s8(from+3))).val[0]); + return vcombine_s8(a,b); +} +template<> EIGEN_STRONG_INLINE Packet4uc ploadquad(const uint8_t* from) +{ return vget_lane_u32(vreinterpret_u32_u8(vld1_dup_u8(from)), 0); } +template<> EIGEN_STRONG_INLINE Packet8uc ploadquad(const uint8_t* from) +{ + return vreinterpret_u8_u32(vzip_u32( + vreinterpret_u32_u8(vld1_dup_u8(from)), + vreinterpret_u32_u8(vld1_dup_u8(from+1))).val[0]); +} +template<> EIGEN_STRONG_INLINE Packet16uc ploadquad(const uint8_t* from) +{ + const uint8x8_t a = vreinterpret_u8_u32(vzip_u32( + vreinterpret_u32_u8(vld1_dup_u8(from)), + vreinterpret_u32_u8(vld1_dup_u8(from+1))).val[0]); + const uint8x8_t b = vreinterpret_u8_u32(vzip_u32( + vreinterpret_u32_u8(vld1_dup_u8(from+2)), + vreinterpret_u32_u8(vld1_dup_u8(from+3))).val[0]); + return vcombine_u8(a,b); +} +template<> EIGEN_STRONG_INLINE Packet8s ploadquad(const int16_t* from) +{ return vcombine_s16(vld1_dup_s16(from), vld1_dup_s16(from+1)); } +template<> EIGEN_STRONG_INLINE Packet8us ploadquad(const uint16_t* from) +{ return vcombine_u16(vld1_dup_u16(from), vld1_dup_u16(from+1)); } +template<> EIGEN_STRONG_INLINE Packet4i ploadquad(const int32_t* from) { return vld1q_dup_s32(from); } +template<> EIGEN_STRONG_INLINE Packet4ui ploadquad(const uint32_t* from) { return vld1q_dup_u32(from); } + +template<> EIGEN_STRONG_INLINE void pstore(float* to, const Packet2f& from) +{ EIGEN_DEBUG_ALIGNED_STORE vst1_f32(to,from); } +template<> EIGEN_STRONG_INLINE void pstore(float* to, const Packet4f& from) +{ EIGEN_DEBUG_ALIGNED_STORE vst1q_f32(to,from); } +template<> EIGEN_STRONG_INLINE void pstore(int8_t* to, const Packet4c& from) +{ memcpy(to, &from, sizeof(from)); } +template<> EIGEN_STRONG_INLINE void pstore(int8_t* to, const Packet8c& from) +{ EIGEN_DEBUG_ALIGNED_STORE vst1_s8(to,from); } +template<> EIGEN_STRONG_INLINE void pstore(int8_t* to, const Packet16c& from) +{ EIGEN_DEBUG_ALIGNED_STORE vst1q_s8(to,from); } +template<> EIGEN_STRONG_INLINE void pstore(uint8_t* to, const Packet4uc& from) +{ memcpy(to, &from, sizeof(from)); } +template<> EIGEN_STRONG_INLINE void pstore(uint8_t* to, const Packet8uc& from) +{ EIGEN_DEBUG_ALIGNED_STORE vst1_u8(to,from); } +template<> EIGEN_STRONG_INLINE void pstore(uint8_t* to, const Packet16uc& from) +{ EIGEN_DEBUG_ALIGNED_STORE vst1q_u8(to,from); } +template<> EIGEN_STRONG_INLINE void pstore(int16_t* to, const Packet4s& from) +{ EIGEN_DEBUG_ALIGNED_STORE vst1_s16(to,from); } +template<> EIGEN_STRONG_INLINE void pstore(int16_t* to, const Packet8s& from) +{ EIGEN_DEBUG_ALIGNED_STORE vst1q_s16(to,from); } +template<> EIGEN_STRONG_INLINE void pstore(uint16_t* to, const Packet4us& from) +{ EIGEN_DEBUG_ALIGNED_STORE vst1_u16(to,from); } +template<> EIGEN_STRONG_INLINE void pstore(uint16_t* to, const Packet8us& from) +{ EIGEN_DEBUG_ALIGNED_STORE vst1q_u16(to,from); } +template<> EIGEN_STRONG_INLINE void pstore(int32_t* to, const Packet2i& from) +{ EIGEN_DEBUG_ALIGNED_STORE vst1_s32(to,from); } +template<> EIGEN_STRONG_INLINE void pstore(int32_t* to, const Packet4i& from) +{ EIGEN_DEBUG_ALIGNED_STORE vst1q_s32(to,from); } +template<> EIGEN_STRONG_INLINE void pstore(uint32_t* to, const Packet2ui& from) +{ EIGEN_DEBUG_ALIGNED_STORE vst1_u32(to,from); } +template<> EIGEN_STRONG_INLINE void pstore(uint32_t* to, const Packet4ui& from) +{ EIGEN_DEBUG_ALIGNED_STORE vst1q_u32(to,from); } +template<> EIGEN_STRONG_INLINE void pstore(int64_t* to, const Packet2l& from) +{ EIGEN_DEBUG_ALIGNED_STORE vst1q_s64(to,from); } +template<> EIGEN_STRONG_INLINE void pstore(uint64_t* to, const Packet2ul& from) +{ EIGEN_DEBUG_ALIGNED_STORE vst1q_u64(to,from); } + +template<> EIGEN_STRONG_INLINE void pstoreu(float* to, const Packet2f& from) +{ EIGEN_DEBUG_UNALIGNED_STORE vst1_f32(to,from); } +template<> EIGEN_STRONG_INLINE void pstoreu(float* to, const Packet4f& from) +{ EIGEN_DEBUG_UNALIGNED_STORE vst1q_f32(to,from); } +template<> EIGEN_STRONG_INLINE void pstoreu(int8_t* to, const Packet4c& from) +{ memcpy(to, &from, sizeof(from)); } +template<> EIGEN_STRONG_INLINE void pstoreu(int8_t* to, const Packet8c& from) +{ EIGEN_DEBUG_UNALIGNED_STORE vst1_s8(to,from); } +template<> EIGEN_STRONG_INLINE void pstoreu(int8_t* to, const Packet16c& from) +{ EIGEN_DEBUG_UNALIGNED_STORE vst1q_s8(to,from); } +template<> EIGEN_STRONG_INLINE void pstoreu(uint8_t* to, const Packet4uc& from) +{ memcpy(to, &from, sizeof(from)); } +template<> EIGEN_STRONG_INLINE void pstoreu(uint8_t* to, const Packet8uc& from) +{ EIGEN_DEBUG_UNALIGNED_STORE vst1_u8(to,from); } +template<> EIGEN_STRONG_INLINE void pstoreu(uint8_t* to, const Packet16uc& from) +{ EIGEN_DEBUG_UNALIGNED_STORE vst1q_u8(to,from); } +template<> EIGEN_STRONG_INLINE void pstoreu(int16_t* to, const Packet4s& from) +{ EIGEN_DEBUG_UNALIGNED_STORE vst1_s16(to,from); } +template<> EIGEN_STRONG_INLINE void pstoreu(int16_t* to, const Packet8s& from) +{ EIGEN_DEBUG_UNALIGNED_STORE vst1q_s16(to,from); } +template<> EIGEN_STRONG_INLINE void pstoreu(uint16_t* to, const Packet4us& from) +{ EIGEN_DEBUG_UNALIGNED_STORE vst1_u16(to,from); } +template<> EIGEN_STRONG_INLINE void pstoreu(uint16_t* to, const Packet8us& from) +{ EIGEN_DEBUG_UNALIGNED_STORE vst1q_u16(to,from); } +template<> EIGEN_STRONG_INLINE void pstoreu(int32_t* to, const Packet2i& from) +{ EIGEN_DEBUG_UNALIGNED_STORE vst1_s32(to,from); } +template<> EIGEN_STRONG_INLINE void pstoreu(int32_t* to, const Packet4i& from) +{ EIGEN_DEBUG_UNALIGNED_STORE vst1q_s32(to,from); } +template<> EIGEN_STRONG_INLINE void pstoreu(uint32_t* to, const Packet2ui& from) +{ EIGEN_DEBUG_UNALIGNED_STORE vst1_u32(to,from); } +template<> EIGEN_STRONG_INLINE void pstoreu(uint32_t* to, const Packet4ui& from) +{ EIGEN_DEBUG_UNALIGNED_STORE vst1q_u32(to,from); } +template<> EIGEN_STRONG_INLINE void pstoreu(int64_t* to, const Packet2l& from) +{ EIGEN_DEBUG_UNALIGNED_STORE vst1q_s64(to,from); } +template<> EIGEN_STRONG_INLINE void pstoreu(uint64_t* to, const Packet2ul& from) +{ EIGEN_DEBUG_UNALIGNED_STORE vst1q_u64(to,from); } + +template<> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet2f pgather(const float* from, Index stride) +{ + Packet2f res = vld1_dup_f32(from); + res = vld1_lane_f32(from + 1*stride, res, 1); + return res; +} +template<> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet4f pgather(const float* from, Index stride) +{ + Packet4f res = vld1q_dup_f32(from); + res = vld1q_lane_f32(from + 1*stride, res, 1); + res = vld1q_lane_f32(from + 2*stride, res, 2); + res = vld1q_lane_f32(from + 3*stride, res, 3); + return res; +} +template<> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet4c pgather(const int8_t* from, Index stride) +{ + Packet4c res; + for (int i = 0; i != 4; i++) + reinterpret_cast(&res)[i] = *(from + i * stride); + return res; +} +template<> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet8c pgather(const int8_t* from, Index stride) +{ + Packet8c res = vld1_dup_s8(from); + res = vld1_lane_s8(from + 1*stride, res, 1); + res = vld1_lane_s8(from + 2*stride, res, 2); + res = vld1_lane_s8(from + 3*stride, res, 3); + res = vld1_lane_s8(from + 4*stride, res, 4); + res = vld1_lane_s8(from + 5*stride, res, 5); + res = vld1_lane_s8(from + 6*stride, res, 6); + res = vld1_lane_s8(from + 7*stride, res, 7); + return res; +} +template<> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet16c pgather(const int8_t* from, Index stride) +{ + Packet16c res = vld1q_dup_s8(from); + res = vld1q_lane_s8(from + 1*stride, res, 1); + res = vld1q_lane_s8(from + 2*stride, res, 2); + res = vld1q_lane_s8(from + 3*stride, res, 3); + res = vld1q_lane_s8(from + 4*stride, res, 4); + res = vld1q_lane_s8(from + 5*stride, res, 5); + res = vld1q_lane_s8(from + 6*stride, res, 6); + res = vld1q_lane_s8(from + 7*stride, res, 7); + res = vld1q_lane_s8(from + 8*stride, res, 8); + res = vld1q_lane_s8(from + 9*stride, res, 9); + res = vld1q_lane_s8(from + 10*stride, res, 10); + res = vld1q_lane_s8(from + 11*stride, res, 11); + res = vld1q_lane_s8(from + 12*stride, res, 12); + res = vld1q_lane_s8(from + 13*stride, res, 13); + res = vld1q_lane_s8(from + 14*stride, res, 14); + res = vld1q_lane_s8(from + 15*stride, res, 15); + return res; +} +template<> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet4uc pgather(const uint8_t* from, Index stride) +{ + Packet4uc res; + for (int i = 0; i != 4; i++) + reinterpret_cast(&res)[i] = *(from + i * stride); + return res; +} +template<> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet8uc pgather(const uint8_t* from, Index stride) +{ + Packet8uc res = vld1_dup_u8(from); + res = vld1_lane_u8(from + 1*stride, res, 1); + res = vld1_lane_u8(from + 2*stride, res, 2); + res = vld1_lane_u8(from + 3*stride, res, 3); + res = vld1_lane_u8(from + 4*stride, res, 4); + res = vld1_lane_u8(from + 5*stride, res, 5); + res = vld1_lane_u8(from + 6*stride, res, 6); + res = vld1_lane_u8(from + 7*stride, res, 7); + return res; +} +template<> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet16uc pgather(const uint8_t* from, Index stride) +{ + Packet16uc res = vld1q_dup_u8(from); + res = vld1q_lane_u8(from + 1*stride, res, 1); + res = vld1q_lane_u8(from + 2*stride, res, 2); + res = vld1q_lane_u8(from + 3*stride, res, 3); + res = vld1q_lane_u8(from + 4*stride, res, 4); + res = vld1q_lane_u8(from + 5*stride, res, 5); + res = vld1q_lane_u8(from + 6*stride, res, 6); + res = vld1q_lane_u8(from + 7*stride, res, 7); + res = vld1q_lane_u8(from + 8*stride, res, 8); + res = vld1q_lane_u8(from + 9*stride, res, 9); + res = vld1q_lane_u8(from + 10*stride, res, 10); + res = vld1q_lane_u8(from + 11*stride, res, 11); + res = vld1q_lane_u8(from + 12*stride, res, 12); + res = vld1q_lane_u8(from + 13*stride, res, 13); + res = vld1q_lane_u8(from + 14*stride, res, 14); + res = vld1q_lane_u8(from + 15*stride, res, 15); + return res; +} +template<> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet4s pgather(const int16_t* from, Index stride) +{ + Packet4s res = vld1_dup_s16(from); + res = vld1_lane_s16(from + 1*stride, res, 1); + res = vld1_lane_s16(from + 2*stride, res, 2); + res = vld1_lane_s16(from + 3*stride, res, 3); + return res; +} +template<> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet8s pgather(const int16_t* from, Index stride) +{ + Packet8s res = vld1q_dup_s16(from); + res = vld1q_lane_s16(from + 1*stride, res, 1); + res = vld1q_lane_s16(from + 2*stride, res, 2); + res = vld1q_lane_s16(from + 3*stride, res, 3); + res = vld1q_lane_s16(from + 4*stride, res, 4); + res = vld1q_lane_s16(from + 5*stride, res, 5); + res = vld1q_lane_s16(from + 6*stride, res, 6); + res = vld1q_lane_s16(from + 7*stride, res, 7); + return res; +} +template<> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet4us pgather(const uint16_t* from, Index stride) +{ + Packet4us res = vld1_dup_u16(from); + res = vld1_lane_u16(from + 1*stride, res, 1); + res = vld1_lane_u16(from + 2*stride, res, 2); + res = vld1_lane_u16(from + 3*stride, res, 3); + return res; +} +template<> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet8us pgather(const uint16_t* from, Index stride) +{ + Packet8us res = vld1q_dup_u16(from); + res = vld1q_lane_u16(from + 1*stride, res, 1); + res = vld1q_lane_u16(from + 2*stride, res, 2); + res = vld1q_lane_u16(from + 3*stride, res, 3); + res = vld1q_lane_u16(from + 4*stride, res, 4); + res = vld1q_lane_u16(from + 5*stride, res, 5); + res = vld1q_lane_u16(from + 6*stride, res, 6); + res = vld1q_lane_u16(from + 7*stride, res, 7); + return res; +} +template<> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet2i pgather(const int32_t* from, Index stride) +{ + Packet2i res = vld1_dup_s32(from); + res = vld1_lane_s32(from + 1*stride, res, 1); + return res; +} +template<> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet4i pgather(const int32_t* from, Index stride) +{ + Packet4i res = vld1q_dup_s32(from); + res = vld1q_lane_s32(from + 1*stride, res, 1); + res = vld1q_lane_s32(from + 2*stride, res, 2); + res = vld1q_lane_s32(from + 3*stride, res, 3); + return res; +} +template<> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet2ui pgather(const uint32_t* from, Index stride) +{ + Packet2ui res = vld1_dup_u32(from); + res = vld1_lane_u32(from + 1*stride, res, 1); + return res; +} +template<> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet4ui pgather(const uint32_t* from, Index stride) +{ + Packet4ui res = vld1q_dup_u32(from); + res = vld1q_lane_u32(from + 1*stride, res, 1); + res = vld1q_lane_u32(from + 2*stride, res, 2); + res = vld1q_lane_u32(from + 3*stride, res, 3); + return res; +} +template<> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet2l pgather(const int64_t* from, Index stride) +{ + Packet2l res = vld1q_dup_s64(from); + res = vld1q_lane_s64(from + 1*stride, res, 1); + return res; +} +template<> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet2ul pgather(const uint64_t* from, Index stride) +{ + Packet2ul res = vld1q_dup_u64(from); + res = vld1q_lane_u64(from + 1*stride, res, 1); + return res; +} + +template<> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void pscatter(float* to, const Packet2f& from, Index stride) +{ + vst1_lane_f32(to + stride*0, from, 0); + vst1_lane_f32(to + stride*1, from, 1); +} +template<> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void pscatter(float* to, const Packet4f& from, Index stride) +{ + vst1q_lane_f32(to + stride*0, from, 0); + vst1q_lane_f32(to + stride*1, from, 1); + vst1q_lane_f32(to + stride*2, from, 2); + vst1q_lane_f32(to + stride*3, from, 3); +} +template<> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void pscatter(int8_t* to, const Packet4c& from, Index stride) +{ + for (int i = 0; i != 4; i++) + *(to + i * stride) = reinterpret_cast(&from)[i]; +} +template<> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void pscatter(int8_t* to, const Packet8c& from, Index stride) +{ + vst1_lane_s8(to + stride*0, from, 0); + vst1_lane_s8(to + stride*1, from, 1); + vst1_lane_s8(to + stride*2, from, 2); + vst1_lane_s8(to + stride*3, from, 3); + vst1_lane_s8(to + stride*4, from, 4); + vst1_lane_s8(to + stride*5, from, 5); + vst1_lane_s8(to + stride*6, from, 6); + vst1_lane_s8(to + stride*7, from, 7); +} +template<> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void pscatter(int8_t* to, const Packet16c& from, Index stride) +{ + vst1q_lane_s8(to + stride*0, from, 0); + vst1q_lane_s8(to + stride*1, from, 1); + vst1q_lane_s8(to + stride*2, from, 2); + vst1q_lane_s8(to + stride*3, from, 3); + vst1q_lane_s8(to + stride*4, from, 4); + vst1q_lane_s8(to + stride*5, from, 5); + vst1q_lane_s8(to + stride*6, from, 6); + vst1q_lane_s8(to + stride*7, from, 7); + vst1q_lane_s8(to + stride*8, from, 8); + vst1q_lane_s8(to + stride*9, from, 9); + vst1q_lane_s8(to + stride*10, from, 10); + vst1q_lane_s8(to + stride*11, from, 11); + vst1q_lane_s8(to + stride*12, from, 12); + vst1q_lane_s8(to + stride*13, from, 13); + vst1q_lane_s8(to + stride*14, from, 14); + vst1q_lane_s8(to + stride*15, from, 15); +} +template<> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void pscatter(uint8_t* to, const Packet4uc& from, Index stride) +{ + for (int i = 0; i != 4; i++) + *(to + i * stride) = reinterpret_cast(&from)[i]; +} +template<> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void pscatter(uint8_t* to, const Packet8uc& from, Index stride) +{ + vst1_lane_u8(to + stride*0, from, 0); + vst1_lane_u8(to + stride*1, from, 1); + vst1_lane_u8(to + stride*2, from, 2); + vst1_lane_u8(to + stride*3, from, 3); + vst1_lane_u8(to + stride*4, from, 4); + vst1_lane_u8(to + stride*5, from, 5); + vst1_lane_u8(to + stride*6, from, 6); + vst1_lane_u8(to + stride*7, from, 7); +} +template<> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void pscatter(uint8_t* to, const Packet16uc& from, Index stride) +{ + vst1q_lane_u8(to + stride*0, from, 0); + vst1q_lane_u8(to + stride*1, from, 1); + vst1q_lane_u8(to + stride*2, from, 2); + vst1q_lane_u8(to + stride*3, from, 3); + vst1q_lane_u8(to + stride*4, from, 4); + vst1q_lane_u8(to + stride*5, from, 5); + vst1q_lane_u8(to + stride*6, from, 6); + vst1q_lane_u8(to + stride*7, from, 7); + vst1q_lane_u8(to + stride*8, from, 8); + vst1q_lane_u8(to + stride*9, from, 9); + vst1q_lane_u8(to + stride*10, from, 10); + vst1q_lane_u8(to + stride*11, from, 11); + vst1q_lane_u8(to + stride*12, from, 12); + vst1q_lane_u8(to + stride*13, from, 13); + vst1q_lane_u8(to + stride*14, from, 14); + vst1q_lane_u8(to + stride*15, from, 15); +} +template<> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void pscatter(int16_t* to, const Packet4s& from, Index stride) +{ + vst1_lane_s16(to + stride*0, from, 0); + vst1_lane_s16(to + stride*1, from, 1); + vst1_lane_s16(to + stride*2, from, 2); + vst1_lane_s16(to + stride*3, from, 3); +} +template<> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void pscatter(int16_t* to, const Packet8s& from, Index stride) +{ + vst1q_lane_s16(to + stride*0, from, 0); + vst1q_lane_s16(to + stride*1, from, 1); + vst1q_lane_s16(to + stride*2, from, 2); + vst1q_lane_s16(to + stride*3, from, 3); + vst1q_lane_s16(to + stride*4, from, 4); + vst1q_lane_s16(to + stride*5, from, 5); + vst1q_lane_s16(to + stride*6, from, 6); + vst1q_lane_s16(to + stride*7, from, 7); +} +template<> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void pscatter(uint16_t* to, const Packet4us& from, Index stride) +{ + vst1_lane_u16(to + stride*0, from, 0); + vst1_lane_u16(to + stride*1, from, 1); + vst1_lane_u16(to + stride*2, from, 2); + vst1_lane_u16(to + stride*3, from, 3); +} +template<> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void pscatter(uint16_t* to, const Packet8us& from, Index stride) +{ + vst1q_lane_u16(to + stride*0, from, 0); + vst1q_lane_u16(to + stride*1, from, 1); + vst1q_lane_u16(to + stride*2, from, 2); + vst1q_lane_u16(to + stride*3, from, 3); + vst1q_lane_u16(to + stride*4, from, 4); + vst1q_lane_u16(to + stride*5, from, 5); + vst1q_lane_u16(to + stride*6, from, 6); + vst1q_lane_u16(to + stride*7, from, 7); +} +template<> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void pscatter(int32_t* to, const Packet2i& from, Index stride) +{ + vst1_lane_s32(to + stride*0, from, 0); + vst1_lane_s32(to + stride*1, from, 1); +} +template<> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void pscatter(int32_t* to, const Packet4i& from, Index stride) +{ + vst1q_lane_s32(to + stride*0, from, 0); + vst1q_lane_s32(to + stride*1, from, 1); + vst1q_lane_s32(to + stride*2, from, 2); + vst1q_lane_s32(to + stride*3, from, 3); +} +template<> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void pscatter(uint32_t* to, const Packet2ui& from, Index stride) +{ + vst1_lane_u32(to + stride*0, from, 0); + vst1_lane_u32(to + stride*1, from, 1); +} +template<> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void pscatter(uint32_t* to, const Packet4ui& from, Index stride) +{ + vst1q_lane_u32(to + stride*0, from, 0); + vst1q_lane_u32(to + stride*1, from, 1); + vst1q_lane_u32(to + stride*2, from, 2); + vst1q_lane_u32(to + stride*3, from, 3); +} +template<> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void pscatter(int64_t* to, const Packet2l& from, Index stride) +{ + vst1q_lane_s64(to + stride*0, from, 0); + vst1q_lane_s64(to + stride*1, from, 1); +} +template<> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void pscatter(uint64_t* to, const Packet2ul& from, Index stride) +{ + vst1q_lane_u64(to + stride*0, from, 0); + vst1q_lane_u64(to + stride*1, from, 1); +} + +template<> EIGEN_STRONG_INLINE void prefetch(const float* addr) { EIGEN_ARM_PREFETCH(addr); } +template<> EIGEN_STRONG_INLINE void prefetch(const int8_t* addr) { EIGEN_ARM_PREFETCH(addr); } +template<> EIGEN_STRONG_INLINE void prefetch(const uint8_t* addr) { EIGEN_ARM_PREFETCH(addr); } +template<> EIGEN_STRONG_INLINE void prefetch(const int16_t* addr) { EIGEN_ARM_PREFETCH(addr); } +template<> EIGEN_STRONG_INLINE void prefetch(const uint16_t* addr) { EIGEN_ARM_PREFETCH(addr); } +template<> EIGEN_STRONG_INLINE void prefetch(const int32_t* addr) { EIGEN_ARM_PREFETCH(addr); } +template<> EIGEN_STRONG_INLINE void prefetch(const uint32_t* addr) { EIGEN_ARM_PREFETCH(addr); } +template<> EIGEN_STRONG_INLINE void prefetch(const int64_t* addr) { EIGEN_ARM_PREFETCH(addr); } +template<> EIGEN_STRONG_INLINE void prefetch(const uint64_t* addr) { EIGEN_ARM_PREFETCH(addr); } + +template<> EIGEN_STRONG_INLINE float pfirst(const Packet2f& a) { return vget_lane_f32(a,0); } +template<> EIGEN_STRONG_INLINE float pfirst(const Packet4f& a) { return vgetq_lane_f32(a,0); } +template<> EIGEN_STRONG_INLINE int8_t pfirst(const Packet4c& a) { return static_cast(a & 0xff); } +template<> EIGEN_STRONG_INLINE int8_t pfirst(const Packet8c& a) { return vget_lane_s8(a,0); } +template<> EIGEN_STRONG_INLINE int8_t pfirst(const Packet16c& a) { return vgetq_lane_s8(a,0); } +template<> EIGEN_STRONG_INLINE uint8_t pfirst(const Packet4uc& a) { return static_cast(a & 0xff); } +template<> EIGEN_STRONG_INLINE uint8_t pfirst(const Packet8uc& a) { return vget_lane_u8(a,0); } +template<> EIGEN_STRONG_INLINE uint8_t pfirst(const Packet16uc& a) { return vgetq_lane_u8(a,0); } +template<> EIGEN_STRONG_INLINE int16_t pfirst(const Packet4s& a) { return vget_lane_s16(a,0); } +template<> EIGEN_STRONG_INLINE int16_t pfirst(const Packet8s& a) { return vgetq_lane_s16(a,0); } +template<> EIGEN_STRONG_INLINE uint16_t pfirst(const Packet4us& a) { return vget_lane_u16(a,0); } +template<> EIGEN_STRONG_INLINE uint16_t pfirst(const Packet8us& a) { return vgetq_lane_u16(a,0); } +template<> EIGEN_STRONG_INLINE int32_t pfirst(const Packet2i& a) { return vget_lane_s32(a,0); } +template<> EIGEN_STRONG_INLINE int32_t pfirst(const Packet4i& a) { return vgetq_lane_s32(a,0); } +template<> EIGEN_STRONG_INLINE uint32_t pfirst(const Packet2ui& a) { return vget_lane_u32(a,0); } +template<> EIGEN_STRONG_INLINE uint32_t pfirst(const Packet4ui& a) { return vgetq_lane_u32(a,0); } +template<> EIGEN_STRONG_INLINE int64_t pfirst(const Packet2l& a) { return vgetq_lane_s64(a,0); } +template<> EIGEN_STRONG_INLINE uint64_t pfirst(const Packet2ul& a) { return vgetq_lane_u64(a,0); } + +template<> EIGEN_STRONG_INLINE Packet2f preverse(const Packet2f& a) { return vrev64_f32(a); } +template<> EIGEN_STRONG_INLINE Packet4f preverse(const Packet4f& a) +{ + const float32x4_t a_r64 = vrev64q_f32(a); + return vcombine_f32(vget_high_f32(a_r64), vget_low_f32(a_r64)); +} +template<> EIGEN_STRONG_INLINE Packet4c preverse(const Packet4c& a) +{ return vget_lane_s32(vreinterpret_s32_s8(vrev64_s8(vreinterpret_s8_s32(vdup_n_s32(a)))), 0); } +template<> EIGEN_STRONG_INLINE Packet8c preverse(const Packet8c& a) { return vrev64_s8(a); } +template<> EIGEN_STRONG_INLINE Packet16c preverse(const Packet16c& a) +{ + const int8x16_t a_r64 = vrev64q_s8(a); + return vcombine_s8(vget_high_s8(a_r64), vget_low_s8(a_r64)); +} +template<> EIGEN_STRONG_INLINE Packet4uc preverse(const Packet4uc& a) +{ return vget_lane_u32(vreinterpret_u32_u8(vrev64_u8(vreinterpret_u8_u32(vdup_n_u32(a)))), 0); } +template<> EIGEN_STRONG_INLINE Packet8uc preverse(const Packet8uc& a) { return vrev64_u8(a); } +template<> EIGEN_STRONG_INLINE Packet16uc preverse(const Packet16uc& a) +{ + const uint8x16_t a_r64 = vrev64q_u8(a); + return vcombine_u8(vget_high_u8(a_r64), vget_low_u8(a_r64)); +} +template<> EIGEN_STRONG_INLINE Packet4s preverse(const Packet4s& a) { return vrev64_s16(a); } +template<> EIGEN_STRONG_INLINE Packet8s preverse(const Packet8s& a) +{ + const int16x8_t a_r64 = vrev64q_s16(a); + return vcombine_s16(vget_high_s16(a_r64), vget_low_s16(a_r64)); +} +template<> EIGEN_STRONG_INLINE Packet4us preverse(const Packet4us& a) { return vrev64_u16(a); } +template<> EIGEN_STRONG_INLINE Packet8us preverse(const Packet8us& a) +{ + const uint16x8_t a_r64 = vrev64q_u16(a); + return vcombine_u16(vget_high_u16(a_r64), vget_low_u16(a_r64)); +} +template<> EIGEN_STRONG_INLINE Packet2i preverse(const Packet2i& a) { return vrev64_s32(a); } +template<> EIGEN_STRONG_INLINE Packet4i preverse(const Packet4i& a) +{ + const int32x4_t a_r64 = vrev64q_s32(a); + return vcombine_s32(vget_high_s32(a_r64), vget_low_s32(a_r64)); +} +template<> EIGEN_STRONG_INLINE Packet2ui preverse(const Packet2ui& a) { return vrev64_u32(a); } +template<> EIGEN_STRONG_INLINE Packet4ui preverse(const Packet4ui& a) +{ + const uint32x4_t a_r64 = vrev64q_u32(a); + return vcombine_u32(vget_high_u32(a_r64), vget_low_u32(a_r64)); +} +template<> EIGEN_STRONG_INLINE Packet2l preverse(const Packet2l& a) +{ return vcombine_s64(vget_high_s64(a), vget_low_s64(a)); } +template<> EIGEN_STRONG_INLINE Packet2ul preverse(const Packet2ul& a) +{ return vcombine_u64(vget_high_u64(a), vget_low_u64(a)); } + +template<> EIGEN_STRONG_INLINE Packet2f pabs(const Packet2f& a) { return vabs_f32(a); } +template<> EIGEN_STRONG_INLINE Packet4f pabs(const Packet4f& a) { return vabsq_f32(a); } +template<> EIGEN_STRONG_INLINE Packet4c pabs(const Packet4c& a) +{ return vget_lane_s32(vreinterpret_s32_s8(vabs_s8(vreinterpret_s8_s32(vdup_n_s32(a)))), 0); } +template<> EIGEN_STRONG_INLINE Packet8c pabs(const Packet8c& a) { return vabs_s8(a); } +template<> EIGEN_STRONG_INLINE Packet16c pabs(const Packet16c& a) { return vabsq_s8(a); } +template<> EIGEN_STRONG_INLINE Packet4uc pabs(const Packet4uc& a) { return a; } +template<> EIGEN_STRONG_INLINE Packet8uc pabs(const Packet8uc& a) { return a; } +template<> EIGEN_STRONG_INLINE Packet16uc pabs(const Packet16uc& a) { return a; } +template<> EIGEN_STRONG_INLINE Packet4s pabs(const Packet4s& a) { return vabs_s16(a); } +template<> EIGEN_STRONG_INLINE Packet8s pabs(const Packet8s& a) { return vabsq_s16(a); } +template<> EIGEN_STRONG_INLINE Packet4us pabs(const Packet4us& a) { return a; } +template<> EIGEN_STRONG_INLINE Packet8us pabs(const Packet8us& a) { return a; } +template<> EIGEN_STRONG_INLINE Packet2i pabs(const Packet2i& a) { return vabs_s32(a); } +template<> EIGEN_STRONG_INLINE Packet4i pabs(const Packet4i& a) { return vabsq_s32(a); } +template<> EIGEN_STRONG_INLINE Packet2ui pabs(const Packet2ui& a) { return a; } +template<> EIGEN_STRONG_INLINE Packet4ui pabs(const Packet4ui& a) { return a; } +template<> EIGEN_STRONG_INLINE Packet2l pabs(const Packet2l& a) { +#if EIGEN_ARCH_ARM64 + return vabsq_s64(a); +#else + return vcombine_s64( + vdup_n_s64((std::abs)(vgetq_lane_s64(a, 0))), + vdup_n_s64((std::abs)(vgetq_lane_s64(a, 1)))); +#endif +} +template<> EIGEN_STRONG_INLINE Packet2ul pabs(const Packet2ul& a) { return a; } + +template<> EIGEN_STRONG_INLINE Packet2f pfrexp(const Packet2f& a, Packet2f& exponent) +{ return pfrexp_generic(a,exponent); } +template<> EIGEN_STRONG_INLINE Packet4f pfrexp(const Packet4f& a, Packet4f& exponent) +{ return pfrexp_generic(a,exponent); } + +template<> EIGEN_STRONG_INLINE Packet2f pldexp(const Packet2f& a, const Packet2f& exponent) +{ return pldexp_generic(a,exponent); } +template<> EIGEN_STRONG_INLINE Packet4f pldexp(const Packet4f& a, const Packet4f& exponent) +{ return pldexp_generic(a,exponent); } + +template<> EIGEN_STRONG_INLINE float predux(const Packet2f& a) { return vget_lane_f32(vpadd_f32(a,a), 0); } +template<> EIGEN_STRONG_INLINE float predux(const Packet4f& a) +{ + const float32x2_t sum = vadd_f32(vget_low_f32(a), vget_high_f32(a)); + return vget_lane_f32(vpadd_f32(sum, sum), 0); +} +template<> EIGEN_STRONG_INLINE int8_t predux(const Packet4c& a) +{ + const int8x8_t a_dup = vreinterpret_s8_s32(vdup_n_s32(a)); + int8x8_t sum = vpadd_s8(a_dup, a_dup); + sum = vpadd_s8(sum, sum); + return vget_lane_s8(sum, 0); +} +template<> EIGEN_STRONG_INLINE int8_t predux(const Packet8c& a) +{ + int8x8_t sum = vpadd_s8(a,a); + sum = vpadd_s8(sum, sum); + sum = vpadd_s8(sum, sum); + return vget_lane_s8(sum, 0); +} +template<> EIGEN_STRONG_INLINE int8_t predux(const Packet16c& a) +{ + int8x8_t sum = vadd_s8(vget_low_s8(a), vget_high_s8(a)); + sum = vpadd_s8(sum, sum); + sum = vpadd_s8(sum, sum); + sum = vpadd_s8(sum, sum); + return vget_lane_s8(sum, 0); +} +template<> EIGEN_STRONG_INLINE uint8_t predux(const Packet4uc& a) +{ + const uint8x8_t a_dup = vreinterpret_u8_u32(vdup_n_u32(a)); + uint8x8_t sum = vpadd_u8(a_dup, a_dup); + sum = vpadd_u8(sum, sum); + return vget_lane_u8(sum, 0); +} +template<> EIGEN_STRONG_INLINE uint8_t predux(const Packet8uc& a) +{ + uint8x8_t sum = vpadd_u8(a,a); + sum = vpadd_u8(sum, sum); + sum = vpadd_u8(sum, sum); + return vget_lane_u8(sum, 0); +} +template<> EIGEN_STRONG_INLINE uint8_t predux(const Packet16uc& a) +{ + uint8x8_t sum = vadd_u8(vget_low_u8(a), vget_high_u8(a)); + sum = vpadd_u8(sum, sum); + sum = vpadd_u8(sum, sum); + sum = vpadd_u8(sum, sum); + return vget_lane_u8(sum, 0); +} +template<> EIGEN_STRONG_INLINE int16_t predux(const Packet4s& a) +{ + const int16x4_t sum = vpadd_s16(a,a); + return vget_lane_s16(vpadd_s16(sum, sum), 0); +} +template<> EIGEN_STRONG_INLINE int16_t predux(const Packet8s& a) +{ + int16x4_t sum = vadd_s16(vget_low_s16(a), vget_high_s16(a)); + sum = vpadd_s16(sum, sum); + sum = vpadd_s16(sum, sum); + return vget_lane_s16(sum, 0); +} +template<> EIGEN_STRONG_INLINE uint16_t predux(const Packet4us& a) +{ + const uint16x4_t sum = vpadd_u16(a,a); + return vget_lane_u16(vpadd_u16(sum, sum), 0); +} +template<> EIGEN_STRONG_INLINE uint16_t predux(const Packet8us& a) +{ + uint16x4_t sum = vadd_u16(vget_low_u16(a), vget_high_u16(a)); + sum = vpadd_u16(sum, sum); + sum = vpadd_u16(sum, sum); + return vget_lane_u16(sum, 0); +} +template<> EIGEN_STRONG_INLINE int32_t predux(const Packet2i& a) { return vget_lane_s32(vpadd_s32(a,a), 0); } +template<> EIGEN_STRONG_INLINE int32_t predux(const Packet4i& a) +{ + const int32x2_t sum = vadd_s32(vget_low_s32(a), vget_high_s32(a)); + return vget_lane_s32(vpadd_s32(sum, sum), 0); +} +template<> EIGEN_STRONG_INLINE uint32_t predux(const Packet2ui& a) { return vget_lane_u32(vpadd_u32(a,a), 0); } +template<> EIGEN_STRONG_INLINE uint32_t predux(const Packet4ui& a) +{ + const uint32x2_t sum = vadd_u32(vget_low_u32(a), vget_high_u32(a)); + return vget_lane_u32(vpadd_u32(sum, sum), 0); +} +template<> EIGEN_STRONG_INLINE int64_t predux(const Packet2l& a) +{ return vgetq_lane_s64(a, 0) + vgetq_lane_s64(a, 1); } +template<> EIGEN_STRONG_INLINE uint64_t predux(const Packet2ul& a) +{ return vgetq_lane_u64(a, 0) + vgetq_lane_u64(a, 1); } + +template<> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet4c predux_half_dowto4(const Packet8c& a) +{ + return vget_lane_s32(vreinterpret_s32_s8(vadd_s8(a, + vreinterpret_s8_s32(vrev64_s32(vreinterpret_s32_s8(a))))), 0); +} +template<> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet8c predux_half_dowto4(const Packet16c& a) +{ return vadd_s8(vget_high_s8(a), vget_low_s8(a)); } +template<> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet4uc predux_half_dowto4(const Packet8uc& a) +{ + return vget_lane_u32(vreinterpret_u32_u8(vadd_u8(a, + vreinterpret_u8_u32(vrev64_u32(vreinterpret_u32_u8(a))))), 0); +} +template<> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet8uc predux_half_dowto4(const Packet16uc& a) +{ return vadd_u8(vget_high_u8(a), vget_low_u8(a)); } +template<> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet4s predux_half_dowto4(const Packet8s& a) +{ return vadd_s16(vget_high_s16(a), vget_low_s16(a)); } +template<> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet4us predux_half_dowto4(const Packet8us& a) +{ return vadd_u16(vget_high_u16(a), vget_low_u16(a)); } + +// Other reduction functions: +// mul +template<> EIGEN_STRONG_INLINE float predux_mul(const Packet2f& a) +{ return vget_lane_f32(a, 0) * vget_lane_f32(a, 1); } +template<> EIGEN_STRONG_INLINE float predux_mul(const Packet4f& a) +{ return predux_mul(vmul_f32(vget_low_f32(a), vget_high_f32(a))); } +template<> EIGEN_STRONG_INLINE int8_t predux_mul(const Packet4c& a) +{ + int8x8_t prod = vreinterpret_s8_s32(vdup_n_s32(a)); + prod = vmul_s8(prod, vrev16_s8(prod)); + return vget_lane_s8(prod, 0) * vget_lane_s8(prod, 2); +} +template<> EIGEN_STRONG_INLINE int8_t predux_mul(const Packet8c& a) +{ + int8x8_t prod = vmul_s8(a, vrev16_s8(a)); + prod = vmul_s8(prod, vrev32_s8(prod)); + return vget_lane_s8(prod, 0) * vget_lane_s8(prod, 4); +} +template<> EIGEN_STRONG_INLINE int8_t predux_mul(const Packet16c& a) +{ return predux_mul(vmul_s8(vget_low_s8(a), vget_high_s8(a))); } +template<> EIGEN_STRONG_INLINE uint8_t predux_mul(const Packet4uc& a) +{ + uint8x8_t prod = vreinterpret_u8_u32(vdup_n_u32(a)); + prod = vmul_u8(prod, vrev16_u8(prod)); + return vget_lane_u8(prod, 0) * vget_lane_u8(prod, 2); +} +template<> EIGEN_STRONG_INLINE uint8_t predux_mul(const Packet8uc& a) +{ + uint8x8_t prod = vmul_u8(a, vrev16_u8(a)); + prod = vmul_u8(prod, vrev32_u8(prod)); + return vget_lane_u8(prod, 0) * vget_lane_u8(prod, 4); +} +template<> EIGEN_STRONG_INLINE uint8_t predux_mul(const Packet16uc& a) +{ return predux_mul(vmul_u8(vget_low_u8(a), vget_high_u8(a))); } +template<> EIGEN_STRONG_INLINE int16_t predux_mul(const Packet4s& a) +{ + const int16x4_t prod = vmul_s16(a, vrev32_s16(a)); + return vget_lane_s16(prod, 0) * vget_lane_s16(prod, 2); +} +template<> EIGEN_STRONG_INLINE int16_t predux_mul(const Packet8s& a) +{ + int16x4_t prod; + + // Get the product of a_lo * a_hi -> |a1*a5|a2*a6|a3*a7|a4*a8| + prod = vmul_s16(vget_low_s16(a), vget_high_s16(a)); + // Swap and multiply |a1*a5*a2*a6|a3*a7*a4*a8| + prod = vmul_s16(prod, vrev32_s16(prod)); + // Multiply |a1*a5*a2*a6*a3*a7*a4*a8| + return vget_lane_s16(prod, 0) * vget_lane_s16(prod, 2); +} +template<> EIGEN_STRONG_INLINE uint16_t predux_mul(const Packet4us& a) +{ + const uint16x4_t prod = vmul_u16(a, vrev32_u16(a)); + return vget_lane_u16(prod, 0) * vget_lane_u16(prod, 2); +} +template<> EIGEN_STRONG_INLINE uint16_t predux_mul(const Packet8us& a) +{ + uint16x4_t prod; + + // Get the product of a_lo * a_hi -> |a1*a5|a2*a6|a3*a7|a4*a8| + prod = vmul_u16(vget_low_u16(a), vget_high_u16(a)); + // Swap and multiply |a1*a5*a2*a6|a3*a7*a4*a8| + prod = vmul_u16(prod, vrev32_u16(prod)); + // Multiply |a1*a5*a2*a6*a3*a7*a4*a8| + return vget_lane_u16(prod, 0) * vget_lane_u16(prod, 2); +} +template<> EIGEN_STRONG_INLINE int32_t predux_mul(const Packet2i& a) +{ return vget_lane_s32(a, 0) * vget_lane_s32(a, 1); } +template<> EIGEN_STRONG_INLINE int32_t predux_mul(const Packet4i& a) +{ return predux_mul(vmul_s32(vget_low_s32(a), vget_high_s32(a))); } +template<> EIGEN_STRONG_INLINE uint32_t predux_mul(const Packet2ui& a) +{ return vget_lane_u32(a, 0) * vget_lane_u32(a, 1); } +template<> EIGEN_STRONG_INLINE uint32_t predux_mul(const Packet4ui& a) +{ return predux_mul(vmul_u32(vget_low_u32(a), vget_high_u32(a))); } +template<> EIGEN_STRONG_INLINE int64_t predux_mul(const Packet2l& a) +{ return vgetq_lane_s64(a, 0) * vgetq_lane_s64(a, 1); } +template<> EIGEN_STRONG_INLINE uint64_t predux_mul(const Packet2ul& a) +{ return vgetq_lane_u64(a, 0) * vgetq_lane_u64(a, 1); } + +// min +template<> EIGEN_STRONG_INLINE float predux_min(const Packet2f& a) +{ return vget_lane_f32(vpmin_f32(a,a), 0); } +template<> EIGEN_STRONG_INLINE float predux_min(const Packet4f& a) +{ + const float32x2_t min = vmin_f32(vget_low_f32(a), vget_high_f32(a)); + return vget_lane_f32(vpmin_f32(min, min), 0); +} +template<> EIGEN_STRONG_INLINE int8_t predux_min(const Packet4c& a) +{ + const int8x8_t a_dup = vreinterpret_s8_s32(vdup_n_s32(a)); + int8x8_t min = vpmin_s8(a_dup, a_dup); + min = vpmin_s8(min, min); + return vget_lane_s8(min, 0); +} +template<> EIGEN_STRONG_INLINE int8_t predux_min(const Packet8c& a) +{ + int8x8_t min = vpmin_s8(a,a); + min = vpmin_s8(min, min); + min = vpmin_s8(min, min); + return vget_lane_s8(min, 0); +} +template<> EIGEN_STRONG_INLINE int8_t predux_min(const Packet16c& a) +{ + int8x8_t min = vmin_s8(vget_low_s8(a), vget_high_s8(a)); + min = vpmin_s8(min, min); + min = vpmin_s8(min, min); + min = vpmin_s8(min, min); + return vget_lane_s8(min, 0); +} +template<> EIGEN_STRONG_INLINE uint8_t predux_min(const Packet4uc& a) +{ + const uint8x8_t a_dup = vreinterpret_u8_u32(vdup_n_u32(a)); + uint8x8_t min = vpmin_u8(a_dup, a_dup); + min = vpmin_u8(min, min); + return vget_lane_u8(min, 0); +} +template<> EIGEN_STRONG_INLINE uint8_t predux_min(const Packet8uc& a) +{ + uint8x8_t min = vpmin_u8(a,a); + min = vpmin_u8(min, min); + min = vpmin_u8(min, min); + return vget_lane_u8(min, 0); +} +template<> EIGEN_STRONG_INLINE uint8_t predux_min(const Packet16uc& a) +{ + uint8x8_t min = vmin_u8(vget_low_u8(a), vget_high_u8(a)); + min = vpmin_u8(min, min); + min = vpmin_u8(min, min); + min = vpmin_u8(min, min); + return vget_lane_u8(min, 0); +} +template<> EIGEN_STRONG_INLINE int16_t predux_min(const Packet4s& a) +{ + const int16x4_t min = vpmin_s16(a,a); + return vget_lane_s16(vpmin_s16(min, min), 0); +} +template<> EIGEN_STRONG_INLINE int16_t predux_min(const Packet8s& a) +{ + int16x4_t min = vmin_s16(vget_low_s16(a), vget_high_s16(a)); + min = vpmin_s16(min, min); + min = vpmin_s16(min, min); + return vget_lane_s16(min, 0); +} +template<> EIGEN_STRONG_INLINE uint16_t predux_min(const Packet4us& a) +{ + const uint16x4_t min = vpmin_u16(a,a); + return vget_lane_u16(vpmin_u16(min, min), 0); +} +template<> EIGEN_STRONG_INLINE uint16_t predux_min(const Packet8us& a) +{ + uint16x4_t min = vmin_u16(vget_low_u16(a), vget_high_u16(a)); + min = vpmin_u16(min, min); + min = vpmin_u16(min, min); + return vget_lane_u16(min, 0); +} +template<> EIGEN_STRONG_INLINE int32_t predux_min(const Packet2i& a) +{ return vget_lane_s32(vpmin_s32(a,a), 0); } +template<> EIGEN_STRONG_INLINE int32_t predux_min(const Packet4i& a) +{ + const int32x2_t min = vmin_s32(vget_low_s32(a), vget_high_s32(a)); + return vget_lane_s32(vpmin_s32(min, min), 0); +} +template<> EIGEN_STRONG_INLINE uint32_t predux_min(const Packet2ui& a) +{ return vget_lane_u32(vpmin_u32(a,a), 0); } +template<> EIGEN_STRONG_INLINE uint32_t predux_min(const Packet4ui& a) +{ + const uint32x2_t min = vmin_u32(vget_low_u32(a), vget_high_u32(a)); + return vget_lane_u32(vpmin_u32(min, min), 0); +} +template<> EIGEN_STRONG_INLINE int64_t predux_min(const Packet2l& a) +{ return (std::min)(vgetq_lane_s64(a, 0), vgetq_lane_s64(a, 1)); } +template<> EIGEN_STRONG_INLINE uint64_t predux_min(const Packet2ul& a) +{ return (std::min)(vgetq_lane_u64(a, 0), vgetq_lane_u64(a, 1)); } + +// max +template<> EIGEN_STRONG_INLINE float predux_max(const Packet2f& a) +{ return vget_lane_f32(vpmax_f32(a,a), 0); } +template<> EIGEN_STRONG_INLINE float predux_max(const Packet4f& a) +{ + const float32x2_t max = vmax_f32(vget_low_f32(a), vget_high_f32(a)); + return vget_lane_f32(vpmax_f32(max, max), 0); +} +template<> EIGEN_STRONG_INLINE int8_t predux_max(const Packet4c& a) +{ + const int8x8_t a_dup = vreinterpret_s8_s32(vdup_n_s32(a)); + int8x8_t max = vpmax_s8(a_dup, a_dup); + max = vpmax_s8(max, max); + return vget_lane_s8(max, 0); +} +template<> EIGEN_STRONG_INLINE int8_t predux_max(const Packet8c& a) +{ + int8x8_t max = vpmax_s8(a,a); + max = vpmax_s8(max, max); + max = vpmax_s8(max, max); + return vget_lane_s8(max, 0); +} +template<> EIGEN_STRONG_INLINE int8_t predux_max(const Packet16c& a) +{ + int8x8_t max = vmax_s8(vget_low_s8(a), vget_high_s8(a)); + max = vpmax_s8(max, max); + max = vpmax_s8(max, max); + max = vpmax_s8(max, max); + return vget_lane_s8(max, 0); +} +template<> EIGEN_STRONG_INLINE uint8_t predux_max(const Packet4uc& a) +{ + const uint8x8_t a_dup = vreinterpret_u8_u32(vdup_n_u32(a)); + uint8x8_t max = vpmax_u8(a_dup, a_dup); + max = vpmax_u8(max, max); + return vget_lane_u8(max, 0); +} +template<> EIGEN_STRONG_INLINE uint8_t predux_max(const Packet8uc& a) +{ + uint8x8_t max = vpmax_u8(a,a); + max = vpmax_u8(max, max); + max = vpmax_u8(max, max); + return vget_lane_u8(max, 0); +} +template<> EIGEN_STRONG_INLINE uint8_t predux_max(const Packet16uc& a) +{ + uint8x8_t max = vmax_u8(vget_low_u8(a), vget_high_u8(a)); + max = vpmax_u8(max, max); + max = vpmax_u8(max, max); + max = vpmax_u8(max, max); + return vget_lane_u8(max, 0); +} +template<> EIGEN_STRONG_INLINE int16_t predux_max(const Packet4s& a) +{ + const int16x4_t max = vpmax_s16(a,a); + return vget_lane_s16(vpmax_s16(max, max), 0); +} +template<> EIGEN_STRONG_INLINE int16_t predux_max(const Packet8s& a) +{ + int16x4_t max = vmax_s16(vget_low_s16(a), vget_high_s16(a)); + max = vpmax_s16(max, max); + max = vpmax_s16(max, max); + return vget_lane_s16(max, 0); +} +template<> EIGEN_STRONG_INLINE uint16_t predux_max(const Packet4us& a) +{ + const uint16x4_t max = vpmax_u16(a,a); + return vget_lane_u16(vpmax_u16(max, max), 0); +} +template<> EIGEN_STRONG_INLINE uint16_t predux_max(const Packet8us& a) +{ + uint16x4_t max = vmax_u16(vget_low_u16(a), vget_high_u16(a)); + max = vpmax_u16(max, max); + max = vpmax_u16(max, max); + return vget_lane_u16(max, 0); +} +template<> EIGEN_STRONG_INLINE int32_t predux_max(const Packet2i& a) +{ return vget_lane_s32(vpmax_s32(a,a), 0); } +template<> EIGEN_STRONG_INLINE int32_t predux_max(const Packet4i& a) +{ + const int32x2_t max = vmax_s32(vget_low_s32(a), vget_high_s32(a)); + return vget_lane_s32(vpmax_s32(max, max), 0); +} +template<> EIGEN_STRONG_INLINE uint32_t predux_max(const Packet2ui& a) +{ return vget_lane_u32(vpmax_u32(a,a), 0); } +template<> EIGEN_STRONG_INLINE uint32_t predux_max(const Packet4ui& a) +{ + const uint32x2_t max = vmax_u32(vget_low_u32(a), vget_high_u32(a)); + return vget_lane_u32(vpmax_u32(max, max), 0); +} +template<> EIGEN_STRONG_INLINE int64_t predux_max(const Packet2l& a) +{ return (std::max)(vgetq_lane_s64(a, 0), vgetq_lane_s64(a, 1)); } +template<> EIGEN_STRONG_INLINE uint64_t predux_max(const Packet2ul& a) +{ return (std::max)(vgetq_lane_u64(a, 0), vgetq_lane_u64(a, 1)); } + +template<> EIGEN_STRONG_INLINE bool predux_any(const Packet4f& x) +{ + uint32x2_t tmp = vorr_u32(vget_low_u32( vreinterpretq_u32_f32(x)), + vget_high_u32(vreinterpretq_u32_f32(x))); + return vget_lane_u32(vpmax_u32(tmp, tmp), 0); +} + +// Helpers for ptranspose. +namespace detail { + +template +void zip_in_place(Packet& p1, Packet& p2); + +template<> +EIGEN_ALWAYS_INLINE void zip_in_place(Packet2f& p1, Packet2f& p2) { + const float32x2x2_t tmp = vzip_f32(p1, p2); + p1 = tmp.val[0]; + p2 = tmp.val[1]; +} + +template<> +EIGEN_ALWAYS_INLINE void zip_in_place(Packet4f& p1, Packet4f& p2) { + const float32x4x2_t tmp = vzipq_f32(p1, p2); + p1 = tmp.val[0]; + p2 = tmp.val[1]; +} + +template<> +EIGEN_ALWAYS_INLINE void zip_in_place(Packet8c& p1, Packet8c& p2) { + const int8x8x2_t tmp = vzip_s8(p1, p2); + p1 = tmp.val[0]; + p2 = tmp.val[1]; +} + +template<> +EIGEN_ALWAYS_INLINE void zip_in_place(Packet16c& p1, Packet16c& p2) { + const int8x16x2_t tmp = vzipq_s8(p1, p2); + p1 = tmp.val[0]; + p2 = tmp.val[1]; +} + +template<> +EIGEN_ALWAYS_INLINE void zip_in_place(Packet8uc& p1, Packet8uc& p2) { + const uint8x8x2_t tmp = vzip_u8(p1, p2); + p1 = tmp.val[0]; + p2 = tmp.val[1]; +} + +template<> +EIGEN_ALWAYS_INLINE void zip_in_place(Packet16uc& p1, Packet16uc& p2) { + const uint8x16x2_t tmp = vzipq_u8(p1, p2); + p1 = tmp.val[0]; + p2 = tmp.val[1]; +} + +template<> +EIGEN_ALWAYS_INLINE void zip_in_place(Packet2i& p1, Packet2i& p2) { + const int32x2x2_t tmp = vzip_s32(p1, p2); + p1 = tmp.val[0]; + p2 = tmp.val[1]; +} + +template<> +EIGEN_ALWAYS_INLINE void zip_in_place(Packet4i& p1, Packet4i& p2) { + const int32x4x2_t tmp = vzipq_s32(p1, p2); + p1 = tmp.val[0]; + p2 = tmp.val[1]; +} + +template<> +EIGEN_ALWAYS_INLINE void zip_in_place(Packet2ui& p1, Packet2ui& p2) { + const uint32x2x2_t tmp = vzip_u32(p1, p2); + p1 = tmp.val[0]; + p2 = tmp.val[1]; +} + +template<> +EIGEN_ALWAYS_INLINE void zip_in_place(Packet4ui& p1, Packet4ui& p2) { + const uint32x4x2_t tmp = vzipq_u32(p1, p2); + p1 = tmp.val[0]; + p2 = tmp.val[1]; +} + +template<> +EIGEN_ALWAYS_INLINE void zip_in_place(Packet4s& p1, Packet4s& p2) { + const int16x4x2_t tmp = vzip_s16(p1, p2); + p1 = tmp.val[0]; + p2 = tmp.val[1]; +} + +template<> +EIGEN_ALWAYS_INLINE void zip_in_place(Packet8s& p1, Packet8s& p2) { + const int16x8x2_t tmp = vzipq_s16(p1, p2); + p1 = tmp.val[0]; + p2 = tmp.val[1]; +} + +template<> +EIGEN_ALWAYS_INLINE void zip_in_place(Packet4us& p1, Packet4us& p2) { + const uint16x4x2_t tmp = vzip_u16(p1, p2); + p1 = tmp.val[0]; + p2 = tmp.val[1]; +} + +template<> +EIGEN_ALWAYS_INLINE void zip_in_place(Packet8us& p1, Packet8us& p2) { + const uint16x8x2_t tmp = vzipq_u16(p1, p2); + p1 = tmp.val[0]; + p2 = tmp.val[1]; +} + +template +EIGEN_ALWAYS_INLINE void ptranspose_impl(PacketBlock& kernel) { + zip_in_place(kernel.packet[0], kernel.packet[1]); +} + +template +EIGEN_ALWAYS_INLINE void ptranspose_impl(PacketBlock& kernel) { + zip_in_place(kernel.packet[0], kernel.packet[2]); + zip_in_place(kernel.packet[1], kernel.packet[3]); + zip_in_place(kernel.packet[0], kernel.packet[1]); + zip_in_place(kernel.packet[2], kernel.packet[3]); +} + +template +EIGEN_ALWAYS_INLINE void ptranspose_impl(PacketBlock& kernel) { + zip_in_place(kernel.packet[0], kernel.packet[4]); + zip_in_place(kernel.packet[1], kernel.packet[5]); + zip_in_place(kernel.packet[2], kernel.packet[6]); + zip_in_place(kernel.packet[3], kernel.packet[7]); + + zip_in_place(kernel.packet[0], kernel.packet[2]); + zip_in_place(kernel.packet[1], kernel.packet[3]); + zip_in_place(kernel.packet[4], kernel.packet[6]); + zip_in_place(kernel.packet[5], kernel.packet[7]); + + zip_in_place(kernel.packet[0], kernel.packet[1]); + zip_in_place(kernel.packet[2], kernel.packet[3]); + zip_in_place(kernel.packet[4], kernel.packet[5]); + zip_in_place(kernel.packet[6], kernel.packet[7]); +} + +template +EIGEN_ALWAYS_INLINE void ptranspose_impl(PacketBlock& kernel) { + EIGEN_UNROLL_LOOP + for (int i=0; i<4; ++i) { + const int m = (1 << i); + EIGEN_UNROLL_LOOP + for (int j=0; j& kernel) { + detail::ptranspose_impl(kernel); +} +EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void ptranspose(PacketBlock& kernel) { + detail::ptranspose_impl(kernel); +} + +EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void ptranspose(PacketBlock& kernel) +{ + const int8x8_t a = vreinterpret_s8_s32(vset_lane_s32(kernel.packet[2], vdup_n_s32(kernel.packet[0]), 1)); + const int8x8_t b = vreinterpret_s8_s32(vset_lane_s32(kernel.packet[3], vdup_n_s32(kernel.packet[1]), 1)); + + const int8x8x2_t zip8 = vzip_s8(a,b); + const int16x4x2_t zip16 = vzip_s16(vreinterpret_s16_s8(zip8.val[0]), vreinterpret_s16_s8(zip8.val[1])); + + kernel.packet[0] = vget_lane_s32(vreinterpret_s32_s16(zip16.val[0]), 0); + kernel.packet[1] = vget_lane_s32(vreinterpret_s32_s16(zip16.val[0]), 1); + kernel.packet[2] = vget_lane_s32(vreinterpret_s32_s16(zip16.val[1]), 0); + kernel.packet[3] = vget_lane_s32(vreinterpret_s32_s16(zip16.val[1]), 1); +} +EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void ptranspose(PacketBlock& kernel) { + detail::ptranspose_impl(kernel); +} +EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void ptranspose(PacketBlock& kernel) { + detail::ptranspose_impl(kernel); +} +EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void ptranspose(PacketBlock& kernel) { + detail::ptranspose_impl(kernel); +} +EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void ptranspose(PacketBlock& kernel) { + detail::ptranspose_impl(kernel); +} +EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void ptranspose(PacketBlock& kernel) { + detail::ptranspose_impl(kernel); +} + +EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void ptranspose(PacketBlock& kernel) +{ + const uint8x8_t a = vreinterpret_u8_u32(vset_lane_u32(kernel.packet[2], vdup_n_u32(kernel.packet[0]), 1)); + const uint8x8_t b = vreinterpret_u8_u32(vset_lane_u32(kernel.packet[3], vdup_n_u32(kernel.packet[1]), 1)); + + const uint8x8x2_t zip8 = vzip_u8(a,b); + const uint16x4x2_t zip16 = vzip_u16(vreinterpret_u16_u8(zip8.val[0]), vreinterpret_u16_u8(zip8.val[1])); + + kernel.packet[0] = vget_lane_u32(vreinterpret_u32_u16(zip16.val[0]), 0); + kernel.packet[1] = vget_lane_u32(vreinterpret_u32_u16(zip16.val[0]), 1); + kernel.packet[2] = vget_lane_u32(vreinterpret_u32_u16(zip16.val[1]), 0); + kernel.packet[3] = vget_lane_u32(vreinterpret_u32_u16(zip16.val[1]), 1); +} +EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void ptranspose(PacketBlock& kernel) { + detail::ptranspose_impl(kernel); +} +EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void ptranspose(PacketBlock& kernel) { + detail::ptranspose_impl(kernel); +} +EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void ptranspose(PacketBlock& kernel) { + detail::ptranspose_impl(kernel); +} +EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void ptranspose(PacketBlock& kernel) { + detail::ptranspose_impl(kernel); +} +EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void ptranspose(PacketBlock& kernel) { + detail::ptranspose_impl(kernel); +} + +EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void ptranspose(PacketBlock& kernel) { + detail::ptranspose_impl(kernel); +} +EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void ptranspose(PacketBlock& kernel) { + detail::ptranspose_impl(kernel); +} +EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void ptranspose(PacketBlock& kernel) { + detail::ptranspose_impl(kernel); +} + +EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void ptranspose(PacketBlock& kernel) { + detail::ptranspose_impl(kernel); +} +EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void ptranspose(PacketBlock& kernel) { + detail::ptranspose_impl(kernel); +} +EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void ptranspose(PacketBlock& kernel) { + detail::ptranspose_impl(kernel); +} + +EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void ptranspose(PacketBlock& kernel) { + detail::ptranspose_impl(kernel); +} +EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void ptranspose(PacketBlock& kernel) { + detail::ptranspose_impl(kernel); +} +EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void ptranspose(PacketBlock& kernel) { + detail::zip_in_place(kernel.packet[0], kernel.packet[1]); +} +EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void ptranspose(PacketBlock& kernel) { + detail::ptranspose_impl(kernel); +} + +EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void +ptranspose(PacketBlock& kernel) +{ +#if EIGEN_ARCH_ARM64 + const int64x2_t tmp1 = vzip1q_s64(kernel.packet[0], kernel.packet[1]); + kernel.packet[1] = vzip2q_s64(kernel.packet[0], kernel.packet[1]); + kernel.packet[0] = tmp1; +#else + const int64x1_t tmp[2][2] = { + { vget_low_s64(kernel.packet[0]), vget_high_s64(kernel.packet[0]) }, + { vget_low_s64(kernel.packet[1]), vget_high_s64(kernel.packet[1]) } + }; + + kernel.packet[0] = vcombine_s64(tmp[0][0], tmp[1][0]); + kernel.packet[1] = vcombine_s64(tmp[0][1], tmp[1][1]); +#endif +} +EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void +ptranspose(PacketBlock& kernel) +{ +#if EIGEN_ARCH_ARM64 + const uint64x2_t tmp1 = vzip1q_u64(kernel.packet[0], kernel.packet[1]); + kernel.packet[1] = vzip2q_u64(kernel.packet[0], kernel.packet[1]); + kernel.packet[0] = tmp1; +#else + const uint64x1_t tmp[2][2] = { + { vget_low_u64(kernel.packet[0]), vget_high_u64(kernel.packet[0]) }, + { vget_low_u64(kernel.packet[1]), vget_high_u64(kernel.packet[1]) } + }; + + kernel.packet[0] = vcombine_u64(tmp[0][0], tmp[1][0]); + kernel.packet[1] = vcombine_u64(tmp[0][1], tmp[1][1]); +#endif +} + +template<> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet2f pselect( const Packet2f& mask, const Packet2f& a, const Packet2f& b) +{ return vbsl_f32(vreinterpret_u32_f32(mask), a, b); } +template<> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet4f pselect(const Packet4f& mask, const Packet4f& a, const Packet4f& b) +{ return vbslq_f32(vreinterpretq_u32_f32(mask), a, b); } +template<> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet8c pselect(const Packet8c& mask, const Packet8c& a, const Packet8c& b) +{ return vbsl_s8(vreinterpret_u8_s8(mask), a, b); } +template<> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet16c pselect(const Packet16c& mask, const Packet16c& a, const Packet16c& b) +{ return vbslq_s8(vreinterpretq_u8_s8(mask), a, b); } +template<> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet8uc pselect(const Packet8uc& mask, const Packet8uc& a, const Packet8uc& b) +{ return vbsl_u8(mask, a, b); } +template<> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet16uc pselect(const Packet16uc& mask, const Packet16uc& a, const Packet16uc& b) +{ return vbslq_u8(mask, a, b); } +template<> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet4s pselect(const Packet4s& mask, const Packet4s& a, const Packet4s& b) +{ return vbsl_s16(vreinterpret_u16_s16(mask), a, b); } +template<> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet8s pselect(const Packet8s& mask, const Packet8s& a, const Packet8s& b) +{ return vbslq_s16(vreinterpretq_u16_s16(mask), a, b); } +template<> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet4us pselect(const Packet4us& mask, const Packet4us& a, const Packet4us& b) +{ return vbsl_u16(mask, a, b); } +template<> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet8us pselect(const Packet8us& mask, const Packet8us& a, const Packet8us& b) +{ return vbslq_u16(mask, a, b); } +template<> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet2i pselect(const Packet2i& mask, const Packet2i& a, const Packet2i& b) +{ return vbsl_s32(vreinterpret_u32_s32(mask), a, b); } +template<> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet4i pselect(const Packet4i& mask, const Packet4i& a, const Packet4i& b) +{ return vbslq_s32(vreinterpretq_u32_s32(mask), a, b); } +template<> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet2ui pselect(const Packet2ui& mask, const Packet2ui& a, const Packet2ui& b) +{ return vbsl_u32(mask, a, b); } +template<> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet4ui pselect(const Packet4ui& mask, const Packet4ui& a, const Packet4ui& b) +{ return vbslq_u32(mask, a, b); } +template<> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet2l pselect(const Packet2l& mask, const Packet2l& a, const Packet2l& b) +{ return vbslq_s64(vreinterpretq_u64_s64(mask), a, b); } +template<> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet2ul pselect(const Packet2ul& mask, const Packet2ul& a, const Packet2ul& b) +{ return vbslq_u64(mask, a, b); } + +// Use armv8 rounding intinsics if available. +#if EIGEN_ARCH_ARMV8 +template<> EIGEN_STRONG_INLINE Packet2f print(const Packet2f& a) +{ return vrndn_f32(a); } + +template<> EIGEN_STRONG_INLINE Packet4f print(const Packet4f& a) +{ return vrndnq_f32(a); } + +template<> EIGEN_STRONG_INLINE Packet2f pfloor(const Packet2f& a) +{ return vrndm_f32(a); } + +template<> EIGEN_STRONG_INLINE Packet4f pfloor(const Packet4f& a) +{ return vrndmq_f32(a); } + +template<> EIGEN_STRONG_INLINE Packet2f pceil(const Packet2f& a) +{ return vrndp_f32(a); } + +template<> EIGEN_STRONG_INLINE Packet4f pceil(const Packet4f& a) +{ return vrndpq_f32(a); } + +#else + +template<> EIGEN_STRONG_INLINE Packet4f print(const Packet4f& a) { + // Adds and subtracts signum(a) * 2^23 to force rounding. + const Packet4f limit = pset1(static_cast(1<<23)); + const Packet4f abs_a = pabs(a); + Packet4f r = padd(abs_a, limit); + // Don't compile-away addition and subtraction. + EIGEN_OPTIMIZATION_BARRIER(r); + r = psub(r, limit); + // If greater than limit, simply return a. Otherwise, account for sign. + r = pselect(pcmp_lt(abs_a, limit), + pselect(pcmp_lt(a, pzero(a)), pnegate(r), r), a); + return r; +} + +template<> EIGEN_STRONG_INLINE Packet2f print(const Packet2f& a) { + // Adds and subtracts signum(a) * 2^23 to force rounding. + const Packet2f limit = pset1(static_cast(1<<23)); + const Packet2f abs_a = pabs(a); + Packet2f r = padd(abs_a, limit); + // Don't compile-away addition and subtraction. + EIGEN_OPTIMIZATION_BARRIER(r); + r = psub(r, limit); + // If greater than limit, simply return a. Otherwise, account for sign. + r = pselect(pcmp_lt(abs_a, limit), + pselect(pcmp_lt(a, pzero(a)), pnegate(r), r), a); + return r; +} + +template<> EIGEN_STRONG_INLINE Packet4f pfloor(const Packet4f& a) +{ + const Packet4f cst_1 = pset1(1.0f); + Packet4f tmp = print(a); + // If greater, subtract one. + Packet4f mask = pcmp_lt(a, tmp); + mask = pand(mask, cst_1); + return psub(tmp, mask); +} + +template<> EIGEN_STRONG_INLINE Packet2f pfloor(const Packet2f& a) +{ + const Packet2f cst_1 = pset1(1.0f); + Packet2f tmp = print(a); + // If greater, subtract one. + Packet2f mask = pcmp_lt(a, tmp); + mask = pand(mask, cst_1); + return psub(tmp, mask); +} + +template<> EIGEN_STRONG_INLINE Packet4f pceil(const Packet4f& a) +{ + const Packet4f cst_1 = pset1(1.0f); + Packet4f tmp = print(a); + // If smaller, add one. + Packet4f mask = pcmp_lt(tmp, a); + mask = pand(mask, cst_1); + return padd(tmp, mask); +} + +template<> EIGEN_STRONG_INLINE Packet2f pceil(const Packet2f& a) +{ + const Packet2f cst_1 = pset1(1.0); + Packet2f tmp = print(a); + // If smaller, add one. + Packet2f mask = pcmp_lt(tmp, a); + mask = pand(mask, cst_1); + return padd(tmp, mask); +} + +#endif + +/** + * Computes the integer square root + * @remarks The calculation is performed using an algorithm which iterates through each binary digit of the result + * and tests whether setting that digit to 1 would cause the square of the value to be greater than the argument + * value. The algorithm is described in detail here: http://ww1.microchip.com/downloads/en/AppNotes/91040a.pdf . + */ +template<> EIGEN_STRONG_INLINE Packet4uc psqrt(const Packet4uc& a) { + uint8x8_t x = vreinterpret_u8_u32(vdup_n_u32(a)); + uint8x8_t res = vdup_n_u8(0); + uint8x8_t add = vdup_n_u8(0x8); + for (int i = 0; i < 4; i++) + { + const uint8x8_t temp = vorr_u8(res, add); + res = vbsl_u8(vcge_u8(x, vmul_u8(temp, temp)), temp, res); + add = vshr_n_u8(add, 1); + } + return vget_lane_u32(vreinterpret_u32_u8(res), 0); +} +/// @copydoc Eigen::internal::psqrt(const Packet4uc& a) +template<> EIGEN_STRONG_INLINE Packet8uc psqrt(const Packet8uc& a) { + uint8x8_t res = vdup_n_u8(0); + uint8x8_t add = vdup_n_u8(0x8); + for (int i = 0; i < 4; i++) + { + const uint8x8_t temp = vorr_u8(res, add); + res = vbsl_u8(vcge_u8(a, vmul_u8(temp, temp)), temp, res); + add = vshr_n_u8(add, 1); + } + return res; +} +/// @copydoc Eigen::internal::psqrt(const Packet4uc& a) +template<> EIGEN_STRONG_INLINE Packet16uc psqrt(const Packet16uc& a) { + uint8x16_t res = vdupq_n_u8(0); + uint8x16_t add = vdupq_n_u8(0x8); + for (int i = 0; i < 4; i++) + { + const uint8x16_t temp = vorrq_u8(res, add); + res = vbslq_u8(vcgeq_u8(a, vmulq_u8(temp, temp)), temp, res); + add = vshrq_n_u8(add, 1); + } + return res; +} +/// @copydoc Eigen::internal::psqrt(const Packet4uc& a) +template<> EIGEN_STRONG_INLINE Packet4us psqrt(const Packet4us& a) { + uint16x4_t res = vdup_n_u16(0); + uint16x4_t add = vdup_n_u16(0x80); + for (int i = 0; i < 8; i++) + { + const uint16x4_t temp = vorr_u16(res, add); + res = vbsl_u16(vcge_u16(a, vmul_u16(temp, temp)), temp, res); + add = vshr_n_u16(add, 1); + } + return res; +} +/// @copydoc Eigen::internal::psqrt(const Packet4uc& a) +template<> EIGEN_STRONG_INLINE Packet8us psqrt(const Packet8us& a) { + uint16x8_t res = vdupq_n_u16(0); + uint16x8_t add = vdupq_n_u16(0x80); + for (int i = 0; i < 8; i++) + { + const uint16x8_t temp = vorrq_u16(res, add); + res = vbslq_u16(vcgeq_u16(a, vmulq_u16(temp, temp)), temp, res); + add = vshrq_n_u16(add, 1); + } + return res; +} +/// @copydoc Eigen::internal::psqrt(const Packet4uc& a) +template<> EIGEN_STRONG_INLINE Packet2ui psqrt(const Packet2ui& a) { + uint32x2_t res = vdup_n_u32(0); + uint32x2_t add = vdup_n_u32(0x8000); + for (int i = 0; i < 16; i++) + { + const uint32x2_t temp = vorr_u32(res, add); + res = vbsl_u32(vcge_u32(a, vmul_u32(temp, temp)), temp, res); + add = vshr_n_u32(add, 1); + } + return res; +} +/// @copydoc Eigen::internal::psqrt(const Packet4uc& a) +template<> EIGEN_STRONG_INLINE Packet4ui psqrt(const Packet4ui& a) { + uint32x4_t res = vdupq_n_u32(0); + uint32x4_t add = vdupq_n_u32(0x8000); + for (int i = 0; i < 16; i++) + { + const uint32x4_t temp = vorrq_u32(res, add); + res = vbslq_u32(vcgeq_u32(a, vmulq_u32(temp, temp)), temp, res); + add = vshrq_n_u32(add, 1); + } + return res; +} + +template<> EIGEN_STRONG_INLINE Packet4f prsqrt(const Packet4f& a) { + // Compute approximate reciprocal sqrt. + Packet4f x = vrsqrteq_f32(a); + // Do Newton iterations for 1/sqrt(x). + x = vmulq_f32(vrsqrtsq_f32(vmulq_f32(a, x), x), x); + x = vmulq_f32(vrsqrtsq_f32(vmulq_f32(a, x), x), x); + const Packet4f infinity = pset1(NumTraits::infinity()); + return pselect(pcmp_eq(a, pzero(a)), infinity, x); +} + +template<> EIGEN_STRONG_INLINE Packet2f prsqrt(const Packet2f& a) { + // Compute approximate reciprocal sqrt. + Packet2f x = vrsqrte_f32(a); + // Do Newton iterations for 1/sqrt(x). + x = vmul_f32(vrsqrts_f32(vmul_f32(a, x), x), x); + x = vmul_f32(vrsqrts_f32(vmul_f32(a, x), x), x); + const Packet2f infinity = pset1(NumTraits::infinity()); + return pselect(pcmp_eq(a, pzero(a)), infinity, x); +} + +// Unfortunately vsqrt_f32 is only available for A64. +#if EIGEN_ARCH_ARM64 +template<> EIGEN_STRONG_INLINE Packet4f psqrt(const Packet4f& _x){return vsqrtq_f32(_x);} +template<> EIGEN_STRONG_INLINE Packet2f psqrt(const Packet2f& _x){return vsqrt_f32(_x); } +#else +template<> EIGEN_STRONG_INLINE Packet4f psqrt(const Packet4f& a) { + const Packet4f infinity = pset1(NumTraits::infinity()); + const Packet4f is_zero_or_inf = por(pcmp_eq(a, pzero(a)), pcmp_eq(a, infinity)); + return pselect(is_zero_or_inf, a, pmul(a, prsqrt(a))); +} +template<> EIGEN_STRONG_INLINE Packet2f psqrt(const Packet2f& a) { + const Packet2f infinity = pset1(NumTraits::infinity()); + const Packet2f is_zero_or_inf = por(pcmp_eq(a, pzero(a)), pcmp_eq(a, infinity)); + return pselect(is_zero_or_inf, a, pmul(a, prsqrt(a))); +} +#endif + +//---------- bfloat16 ---------- +// TODO: Add support for native armv8.6-a bfloat16_t + +// TODO: Guard if we have native bfloat16 support +typedef eigen_packet_wrapper Packet4bf; + +template<> struct is_arithmetic { enum { value = true }; }; + +template<> struct packet_traits : default_packet_traits +{ + typedef Packet4bf type; + typedef Packet4bf half; + enum + { + Vectorizable = 1, + AlignedOnScalar = 1, + size = 4, + HasHalfPacket = 0, + + HasCmp = 1, + HasAdd = 1, + HasSub = 1, + HasShift = 1, + HasMul = 1, + HasNegate = 1, + HasAbs = 1, + HasArg = 0, + HasAbs2 = 1, + HasAbsDiff = 1, + HasMin = 1, + HasMax = 1, + HasConj = 1, + HasSetLinear = 0, + HasBlend = 0, + HasDiv = 1, + HasFloor = 1, + HasCeil = 1, + HasRint = 1, + + HasSin = EIGEN_FAST_MATH, + HasCos = EIGEN_FAST_MATH, + HasLog = 1, + HasExp = 1, + HasSqrt = 0, + HasTanh = EIGEN_FAST_MATH, + HasErf = EIGEN_FAST_MATH, + HasBessel = 0, // Issues with accuracy. + HasNdtri = 0 + }; +}; + +template<> struct unpacket_traits +{ + typedef bfloat16 type; + typedef Packet4bf half; + enum + { + size = 4, + alignment = Aligned16, + vectorizable = true, + masked_load_available = false, + masked_store_available = false + }; +}; + +namespace detail { +template<> +EIGEN_ALWAYS_INLINE void zip_in_place(Packet4bf& p1, Packet4bf& p2) { + const uint16x4x2_t tmp = vzip_u16(p1, p2); + p1 = tmp.val[0]; + p2 = tmp.val[1]; +} +} // namespace detail + +EIGEN_STRONG_INLINE Packet4bf F32ToBf16(const Packet4f& p) +{ + // See the scalar implemention in BFloat16.h for a comprehensible explanation + // of this fast rounding algorithm + Packet4ui input = reinterpret_cast(p); + + // lsb = (input >> 16) & 1 + Packet4ui lsb = vandq_u32(vshrq_n_u32(input, 16), vdupq_n_u32(1)); + + // rounding_bias = 0x7fff + lsb + Packet4ui rounding_bias = vaddq_u32(lsb, vdupq_n_u32(0x7fff)); + + // input += rounding_bias + input = vaddq_u32(input, rounding_bias); + + // input = input >> 16 + input = vshrq_n_u32(input, 16); + + // Replace float-nans by bfloat16-nans, that is 0x7fc0 + const Packet4ui bf16_nan = vdupq_n_u32(0x7fc0); + const Packet4ui mask = vceqq_f32(p, p); + input = vbslq_u32(mask, input, bf16_nan); + + // output = static_cast(input) + return vmovn_u32(input); +} + +EIGEN_STRONG_INLINE Packet4f Bf16ToF32(const Packet4bf& p) +{ + return reinterpret_cast(vshlq_n_u32(vmovl_u16(p), 16)); +} + +EIGEN_STRONG_INLINE Packet4bf F32MaskToBf16Mask(const Packet4f& p) { + return vmovn_u32(vreinterpretq_u32_f32(p)); +} + +template<> EIGEN_STRONG_INLINE Packet4bf pset1(const bfloat16& from) { + return pset1(from.value); +} + +template<> EIGEN_STRONG_INLINE bfloat16 pfirst(const Packet4bf& from) { + return bfloat16_impl::raw_uint16_to_bfloat16(static_cast(pfirst(from))); +} + +template<> EIGEN_STRONG_INLINE Packet4bf pload(const bfloat16* from) +{ + return pload(reinterpret_cast(from)); +} + +template<> EIGEN_STRONG_INLINE Packet4bf ploadu(const bfloat16* from) +{ + return ploadu(reinterpret_cast(from)); +} + +template<> EIGEN_STRONG_INLINE void pstore(bfloat16* to, const Packet4bf& from) +{ + EIGEN_DEBUG_ALIGNED_STORE vst1_u16(reinterpret_cast(to), from); +} + +template<> EIGEN_STRONG_INLINE void pstoreu(bfloat16* to, const Packet4bf& from) +{ + EIGEN_DEBUG_UNALIGNED_STORE vst1_u16(reinterpret_cast(to), from); +} + +template<> EIGEN_STRONG_INLINE Packet4bf ploaddup(const bfloat16* from) +{ + return ploaddup(reinterpret_cast(from)); +} + +template <> EIGEN_STRONG_INLINE Packet4bf pabs(const Packet4bf& a) { + return F32ToBf16(pabs(Bf16ToF32(a))); +} + +template <> EIGEN_STRONG_INLINE Packet4bf pmin(const Packet4bf &a, + const Packet4bf &b) +{ + return F32ToBf16(pmin(Bf16ToF32(a), Bf16ToF32(b))); +} +template <> EIGEN_STRONG_INLINE Packet4bf pmin(const Packet4bf &a, + const Packet4bf &b) +{ + return F32ToBf16(pmin(Bf16ToF32(a), Bf16ToF32(b))); +} + +template <> EIGEN_STRONG_INLINE Packet4bf pmin(const Packet4bf &a, + const Packet4bf &b) +{ + return F32ToBf16(pmin(Bf16ToF32(a), Bf16ToF32(b))); +} + +template <> EIGEN_STRONG_INLINE Packet4bf pmax(const Packet4bf &a, + const Packet4bf &b) +{ + return F32ToBf16(pmax(Bf16ToF32(a), Bf16ToF32(b))); +} +template <> EIGEN_STRONG_INLINE Packet4bf pmax(const Packet4bf &a, + const Packet4bf &b) +{ + return F32ToBf16(pmax(Bf16ToF32(a), Bf16ToF32(b))); +} + +template <> EIGEN_STRONG_INLINE Packet4bf pmax(const Packet4bf &a, + const Packet4bf &b) +{ + return F32ToBf16(pmax(Bf16ToF32(a), Bf16ToF32(b))); +} + +template<> EIGEN_STRONG_INLINE Packet4bf plset(const bfloat16& a) +{ + return F32ToBf16(plset(static_cast(a))); +} + +template<> EIGEN_STRONG_INLINE Packet4bf por(const Packet4bf& a,const Packet4bf& b) { + return por(a, b); +} + +template<> EIGEN_STRONG_INLINE Packet4bf pxor(const Packet4bf& a,const Packet4bf& b) { + return pxor(a, b); +} + +template<> EIGEN_STRONG_INLINE Packet4bf pand(const Packet4bf& a,const Packet4bf& b) { + return pand(a, b); +} + +template<> EIGEN_STRONG_INLINE Packet4bf pandnot(const Packet4bf& a,const Packet4bf& b) { + return pandnot(a, b); +} + +template<> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet4bf pselect(const Packet4bf& mask, const Packet4bf& a, + const Packet4bf& b) +{ + return pselect(mask, a, b); +} + +template<> EIGEN_STRONG_INLINE Packet4bf print(const Packet4bf& a) +{ + return F32ToBf16(print(Bf16ToF32(a))); +} + +template<> EIGEN_STRONG_INLINE Packet4bf pfloor(const Packet4bf& a) +{ + return F32ToBf16(pfloor(Bf16ToF32(a))); +} + +template<> EIGEN_STRONG_INLINE Packet4bf pceil(const Packet4bf& a) +{ + return F32ToBf16(pceil(Bf16ToF32(a))); +} + +template<> EIGEN_STRONG_INLINE Packet4bf pconj(const Packet4bf& a) { return a; } + +template<> EIGEN_STRONG_INLINE Packet4bf padd(const Packet4bf& a, const Packet4bf& b) { + return F32ToBf16(padd(Bf16ToF32(a), Bf16ToF32(b))); +} + +template<> EIGEN_STRONG_INLINE Packet4bf psub(const Packet4bf& a, const Packet4bf& b) { + return F32ToBf16(psub(Bf16ToF32(a), Bf16ToF32(b))); +} + +template<> EIGEN_STRONG_INLINE Packet4bf pmul(const Packet4bf& a, const Packet4bf& b) { + return F32ToBf16(pmul(Bf16ToF32(a), Bf16ToF32(b))); +} + +template<> EIGEN_STRONG_INLINE Packet4bf pdiv(const Packet4bf& a, const Packet4bf& b) { + return F32ToBf16(pdiv(Bf16ToF32(a), Bf16ToF32(b))); +} + +template<> +EIGEN_STRONG_INLINE Packet4bf pgather(const bfloat16* from, Index stride) +{ + return pgather(reinterpret_cast(from), stride); +} + +template<> +EIGEN_STRONG_INLINE void pscatter(bfloat16* to, const Packet4bf& from, Index stride) +{ + pscatter(reinterpret_cast(to), from, stride); +} + +template<> EIGEN_STRONG_INLINE bfloat16 predux(const Packet4bf& a) +{ + return static_cast(predux(Bf16ToF32(a))); +} + +template<> EIGEN_STRONG_INLINE bfloat16 predux_max(const Packet4bf& a) +{ + return static_cast(predux_max(Bf16ToF32(a))); +} + +template<> EIGEN_STRONG_INLINE bfloat16 predux_min(const Packet4bf& a) +{ + return static_cast(predux_min(Bf16ToF32(a))); +} + +template<> EIGEN_STRONG_INLINE bfloat16 predux_mul(const Packet4bf& a) +{ + return static_cast(predux_mul(Bf16ToF32(a))); +} + +template<> EIGEN_STRONG_INLINE Packet4bf preverse(const Packet4bf& a) +{ + return preverse(a); +} + +EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void ptranspose(PacketBlock& kernel) +{ + detail::ptranspose_impl(kernel); +} + +template<> EIGEN_STRONG_INLINE Packet4bf pabsdiff(const Packet4bf& a, const Packet4bf& b) +{ + return F32ToBf16(pabsdiff(Bf16ToF32(a), Bf16ToF32(b))); +} + +template<> EIGEN_STRONG_INLINE Packet4bf pcmp_eq(const Packet4bf& a, const Packet4bf& b) +{ + return F32MaskToBf16Mask(pcmp_eq(Bf16ToF32(a), Bf16ToF32(b))); +} + +template<> EIGEN_STRONG_INLINE Packet4bf pcmp_lt(const Packet4bf& a, const Packet4bf& b) +{ + return F32MaskToBf16Mask(pcmp_lt(Bf16ToF32(a), Bf16ToF32(b))); +} + +template<> EIGEN_STRONG_INLINE Packet4bf pcmp_lt_or_nan(const Packet4bf& a, const Packet4bf& b) +{ + return F32MaskToBf16Mask(pcmp_lt_or_nan(Bf16ToF32(a), Bf16ToF32(b))); +} + +template<> EIGEN_STRONG_INLINE Packet4bf pcmp_le(const Packet4bf& a, const Packet4bf& b) +{ + return F32MaskToBf16Mask(pcmp_le(Bf16ToF32(a), Bf16ToF32(b))); +} + +template<> EIGEN_STRONG_INLINE Packet4bf pnegate(const Packet4bf& a) +{ + return pxor(a, pset1(static_cast(0x8000))); +} + +//---------- double ---------- + +// Clang 3.5 in the iOS toolchain has an ICE triggered by NEON intrisics for double. +// Confirmed at least with __apple_build_version__ = 6000054. +#ifdef __apple_build_version__ +// Let's hope that by the time __apple_build_version__ hits the 601* range, the bug will be fixed. +// https://gist.github.com/yamaya/2924292 suggests that the 3 first digits are only updated with +// major toolchain updates. +#define EIGEN_APPLE_DOUBLE_NEON_BUG (__apple_build_version__ < 6010000) +#else +#define EIGEN_APPLE_DOUBLE_NEON_BUG 0 +#endif + +#if EIGEN_ARCH_ARM64 && !EIGEN_APPLE_DOUBLE_NEON_BUG + +// Bug 907: workaround missing declarations of the following two functions in the ADK +// Defining these functions as templates ensures that if these intrinsics are +// already defined in arm_neon.h, then our workaround doesn't cause a conflict +// and has lower priority in overload resolution. +template uint64x2_t vreinterpretq_u64_f64(T a) { return (uint64x2_t) a; } + +template float64x2_t vreinterpretq_f64_u64(T a) { return (float64x2_t) a; } + +typedef float64x2_t Packet2d; +typedef float64x1_t Packet1d; + +// fuctionally equivalent to _mm_shuffle_pd in SSE (i.e. shuffle(m, n, mask) equals _mm_shuffle_pd(m,n,mask)) +// Currently used in LU/arch/InverseSize4.h to enable a shared implementation +// for fast inversion of matrices of size 4. +EIGEN_STRONG_INLINE Packet2d shuffle(const Packet2d& m, const Packet2d& n, int mask) +{ + const double* a = reinterpret_cast(&m); + const double* b = reinterpret_cast(&n); + Packet2d res = {*(a + (mask & 1)), *(b + ((mask >> 1) & 1))}; + return res; +} + +EIGEN_STRONG_INLINE Packet2d vec2d_swizzle2(const Packet2d& a, const Packet2d& b, int mask) +{ + return shuffle(a, b, mask); +} +EIGEN_STRONG_INLINE Packet2d vec2d_unpacklo(const Packet2d& a,const Packet2d& b) +{ + return shuffle(a, b, 0); +} +EIGEN_STRONG_INLINE Packet2d vec2d_unpackhi(const Packet2d& a,const Packet2d& b) +{ + return shuffle(a, b, 3); +} +#define vec2d_duplane(a, p) \ + vdupq_laneq_f64(a, p) + +template<> struct packet_traits : default_packet_traits +{ + typedef Packet2d type; + typedef Packet2d half; + enum + { + Vectorizable = 1, + AlignedOnScalar = 1, + size = 2, + HasHalfPacket = 0, + + HasCmp = 1, + HasAdd = 1, + HasSub = 1, + HasShift = 1, + HasMul = 1, + HasNegate = 1, + HasAbs = 1, + HasArg = 0, + HasAbs2 = 1, + HasAbsDiff = 1, + HasMin = 1, + HasMax = 1, + HasConj = 1, + HasSetLinear = 0, + HasBlend = 0, + + HasDiv = 1, + HasFloor = 1, + HasCeil = 1, + HasRint = 1, + + HasSin = 0, + HasCos = 0, + HasLog = 1, + HasExp = 1, + HasSqrt = 1, + HasRsqrt = 1, + HasTanh = 0, + HasErf = 0 + }; +}; + +template<> struct unpacket_traits +{ + typedef double type; + typedef Packet2d half; + typedef Packet2l integer_packet; + enum + { + size = 2, + alignment = Aligned16, + vectorizable = true, + masked_load_available = false, + masked_store_available = false + }; +}; + +template<> EIGEN_STRONG_INLINE Packet2d pset1(const double& from) { return vdupq_n_f64(from); } + +template<> EIGEN_STRONG_INLINE Packet2d plset(const double& a) +{ + const double c[] = {0.0,1.0}; + return vaddq_f64(pset1(a), vld1q_f64(c)); +} + +template<> EIGEN_STRONG_INLINE Packet2d padd(const Packet2d& a, const Packet2d& b) { return vaddq_f64(a,b); } + +template<> EIGEN_STRONG_INLINE Packet2d psub(const Packet2d& a, const Packet2d& b) { return vsubq_f64(a,b); } + +template<> EIGEN_STRONG_INLINE Packet2d pxor(const Packet2d& , const Packet2d& ); +template<> EIGEN_STRONG_INLINE Packet2d paddsub(const Packet2d& a, const Packet2d& b){ + const Packet2d mask = {numext::bit_cast(0x8000000000000000ull),0.0}; + return padd(a, pxor(mask, b)); +} + +template<> EIGEN_STRONG_INLINE Packet2d pnegate(const Packet2d& a) { return vnegq_f64(a); } + +template<> EIGEN_STRONG_INLINE Packet2d pconj(const Packet2d& a) { return a; } + +template<> EIGEN_STRONG_INLINE Packet2d pmul(const Packet2d& a, const Packet2d& b) { return vmulq_f64(a,b); } + +template<> EIGEN_STRONG_INLINE Packet2d pdiv(const Packet2d& a, const Packet2d& b) { return vdivq_f64(a,b); } + +#ifdef __ARM_FEATURE_FMA +// See bug 936. See above comment about FMA for float. +template<> EIGEN_STRONG_INLINE Packet2d pmadd(const Packet2d& a, const Packet2d& b, const Packet2d& c) +{ return vfmaq_f64(c,a,b); } +#else +template<> EIGEN_STRONG_INLINE Packet2d pmadd(const Packet2d& a, const Packet2d& b, const Packet2d& c) +{ return vmlaq_f64(c,a,b); } +#endif + +template<> EIGEN_STRONG_INLINE Packet2d pmin(const Packet2d& a, const Packet2d& b) { return vminq_f64(a,b); } + +#ifdef __ARM_FEATURE_NUMERIC_MAXMIN +// numeric max and min are only available if ARM_FEATURE_NUMERIC_MAXMIN is defined (which can only be the case for Armv8 systems). +template<> EIGEN_STRONG_INLINE Packet2d pmin(const Packet2d& a, const Packet2d& b) { return vminnmq_f64(a, b); } +template<> EIGEN_STRONG_INLINE Packet2d pmax(const Packet2d& a, const Packet2d& b) { return vmaxnmq_f64(a, b); } + +#endif + +template<> EIGEN_STRONG_INLINE Packet2d pmin(const Packet2d& a, const Packet2d& b) { return pmin(a, b); } + +template<> EIGEN_STRONG_INLINE Packet2d pmax(const Packet2d& a, const Packet2d& b) { return vmaxq_f64(a,b); } + + +template<> EIGEN_STRONG_INLINE Packet2d pmax(const Packet2d& a, const Packet2d& b) { return pmax(a, b); } + +// Logical Operations are not supported for float, so we have to reinterpret casts using NEON intrinsics +template<> EIGEN_STRONG_INLINE Packet2d pand(const Packet2d& a, const Packet2d& b) +{ return vreinterpretq_f64_u64(vandq_u64(vreinterpretq_u64_f64(a),vreinterpretq_u64_f64(b))); } + +template<> EIGEN_STRONG_INLINE Packet2d por(const Packet2d& a, const Packet2d& b) +{ return vreinterpretq_f64_u64(vorrq_u64(vreinterpretq_u64_f64(a),vreinterpretq_u64_f64(b))); } + +template<> EIGEN_STRONG_INLINE Packet2d pxor(const Packet2d& a, const Packet2d& b) +{ return vreinterpretq_f64_u64(veorq_u64(vreinterpretq_u64_f64(a),vreinterpretq_u64_f64(b))); } + +template<> EIGEN_STRONG_INLINE Packet2d pandnot(const Packet2d& a, const Packet2d& b) +{ return vreinterpretq_f64_u64(vbicq_u64(vreinterpretq_u64_f64(a),vreinterpretq_u64_f64(b))); } + +template<> EIGEN_STRONG_INLINE Packet2d pcmp_le(const Packet2d& a, const Packet2d& b) +{ return vreinterpretq_f64_u64(vcleq_f64(a,b)); } + +template<> EIGEN_STRONG_INLINE Packet2d pcmp_lt(const Packet2d& a, const Packet2d& b) +{ return vreinterpretq_f64_u64(vcltq_f64(a,b)); } + +template<> EIGEN_STRONG_INLINE Packet2d pcmp_lt_or_nan(const Packet2d& a, const Packet2d& b) +{ return vreinterpretq_f64_u32(vmvnq_u32(vreinterpretq_u32_u64(vcgeq_f64(a,b)))); } + +template<> EIGEN_STRONG_INLINE Packet2d pcmp_eq(const Packet2d& a, const Packet2d& b) +{ return vreinterpretq_f64_u64(vceqq_f64(a,b)); } + +template<> EIGEN_STRONG_INLINE Packet2d pload(const double* from) +{ EIGEN_DEBUG_ALIGNED_LOAD return vld1q_f64(from); } + +template<> EIGEN_STRONG_INLINE Packet2d ploadu(const double* from) +{ EIGEN_DEBUG_UNALIGNED_LOAD return vld1q_f64(from); } + +template<> EIGEN_STRONG_INLINE Packet2d ploaddup(const double* from) { return vld1q_dup_f64(from); } +template<> EIGEN_STRONG_INLINE void pstore(double* to, const Packet2d& from) +{ EIGEN_DEBUG_ALIGNED_STORE vst1q_f64(to,from); } + +template<> EIGEN_STRONG_INLINE void pstoreu(double* to, const Packet2d& from) +{ EIGEN_DEBUG_UNALIGNED_STORE vst1q_f64(to,from); } + +template<> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet2d pgather(const double* from, Index stride) +{ + Packet2d res = pset1(0.0); + res = vld1q_lane_f64(from + 0*stride, res, 0); + res = vld1q_lane_f64(from + 1*stride, res, 1); + return res; +} + +template<> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void pscatter(double* to, const Packet2d& from, Index stride) +{ + vst1q_lane_f64(to + stride*0, from, 0); + vst1q_lane_f64(to + stride*1, from, 1); +} + +template<> EIGEN_STRONG_INLINE void prefetch(const double* addr) { EIGEN_ARM_PREFETCH(addr); } + +// FIXME only store the 2 first elements ? +template<> EIGEN_STRONG_INLINE double pfirst(const Packet2d& a) { return vgetq_lane_f64(a,0); } + +template<> EIGEN_STRONG_INLINE Packet2d preverse(const Packet2d& a) +{ return vcombine_f64(vget_high_f64(a), vget_low_f64(a)); } + +template<> EIGEN_STRONG_INLINE Packet2d pabs(const Packet2d& a) { return vabsq_f64(a); } + +#if EIGEN_COMP_CLANG && defined(__apple_build_version__) +// workaround ICE, see bug 907 +template<> EIGEN_STRONG_INLINE double predux(const Packet2d& a) +{ return (vget_low_f64(a) + vget_high_f64(a))[0]; } +#else +template<> EIGEN_STRONG_INLINE double predux(const Packet2d& a) +{ return vget_lane_f64(vget_low_f64(a) + vget_high_f64(a), 0); } +#endif + +// Other reduction functions: +// mul +#if EIGEN_COMP_CLANG && defined(__apple_build_version__) +template<> EIGEN_STRONG_INLINE double predux_mul(const Packet2d& a) +{ return (vget_low_f64(a) * vget_high_f64(a))[0]; } +#else +template<> EIGEN_STRONG_INLINE double predux_mul(const Packet2d& a) +{ return vget_lane_f64(vget_low_f64(a) * vget_high_f64(a), 0); } +#endif + +// min +template<> EIGEN_STRONG_INLINE double predux_min(const Packet2d& a) +{ return vgetq_lane_f64(vpminq_f64(a,a), 0); } + +// max +template<> EIGEN_STRONG_INLINE double predux_max(const Packet2d& a) +{ return vgetq_lane_f64(vpmaxq_f64(a,a), 0); } + + +EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void +ptranspose(PacketBlock& kernel) +{ + const float64x2_t tmp1 = vzip1q_f64(kernel.packet[0], kernel.packet[1]); + const float64x2_t tmp2 = vzip2q_f64(kernel.packet[0], kernel.packet[1]); + + kernel.packet[0] = tmp1; + kernel.packet[1] = tmp2; +} + +template<> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet2d pselect( const Packet2d& mask, const Packet2d& a, const Packet2d& b) +{ return vbslq_f64(vreinterpretq_u64_f64(mask), a, b); } + +template<> EIGEN_STRONG_INLINE Packet2d print(const Packet2d& a) +{ return vrndnq_f64(a); } + +template<> EIGEN_STRONG_INLINE Packet2d pfloor(const Packet2d& a) +{ return vrndmq_f64(a); } + +template<> EIGEN_STRONG_INLINE Packet2d pceil(const Packet2d& a) +{ return vrndpq_f64(a); } + +template<> EIGEN_STRONG_INLINE Packet2d pldexp(const Packet2d& a, const Packet2d& exponent) +{ return pldexp_generic(a, exponent); } + +template<> EIGEN_STRONG_INLINE Packet2d pfrexp(const Packet2d& a, Packet2d& exponent) +{ return pfrexp_generic(a,exponent); } + +template<> EIGEN_STRONG_INLINE Packet2d pset1frombits(uint64_t from) +{ return vreinterpretq_f64_u64(vdupq_n_u64(from)); } + +template<> EIGEN_STRONG_INLINE Packet2d prsqrt(const Packet2d& a) { + // Compute approximate reciprocal sqrt. + Packet2d x = vrsqrteq_f64(a); + // Do Newton iterations for 1/sqrt(x). + x = vmulq_f64(vrsqrtsq_f64(vmulq_f64(a, x), x), x); + x = vmulq_f64(vrsqrtsq_f64(vmulq_f64(a, x), x), x); + x = vmulq_f64(vrsqrtsq_f64(vmulq_f64(a, x), x), x); + const Packet2d infinity = pset1(NumTraits::infinity()); + return pselect(pcmp_eq(a, pzero(a)), infinity, x); +} + +template<> EIGEN_STRONG_INLINE Packet2d psqrt(const Packet2d& _x){ return vsqrtq_f64(_x); } + +#endif // EIGEN_ARCH_ARM64 + +// Do we have an fp16 types and supporting Neon intrinsics? +#if EIGEN_HAS_ARM64_FP16_VECTOR_ARITHMETIC +typedef float16x4_t Packet4hf; +typedef float16x8_t Packet8hf; + +template <> +struct packet_traits : default_packet_traits { + typedef Packet8hf type; + typedef Packet4hf half; + enum { + Vectorizable = 1, + AlignedOnScalar = 1, + size = 8, + HasHalfPacket = 1, + + HasCmp = 1, + HasCast = 1, + HasAdd = 1, + HasSub = 1, + HasShift = 1, + HasMul = 1, + HasNegate = 1, + HasAbs = 1, + HasArg = 0, + HasAbs2 = 1, + HasAbsDiff = 0, + HasMin = 1, + HasMax = 1, + HasConj = 1, + HasSetLinear = 0, + HasBlend = 0, + HasInsert = 1, + HasReduxp = 1, + HasDiv = 1, + HasFloor = 1, + HasCeil = 1, + HasRint = 1, + HasSin = 0, + HasCos = 0, + HasLog = 0, + HasExp = 0, + HasSqrt = 1, + HasRsqrt = 1, + HasErf = EIGEN_FAST_MATH, + HasBessel = 0, // Issues with accuracy. + HasNdtri = 0 + }; +}; + +template <> +struct unpacket_traits { + typedef Eigen::half type; + typedef Packet4hf half; + enum { + size = 4, + alignment = Aligned16, + vectorizable = true, + masked_load_available = false, + masked_store_available = false + }; +}; + +template <> +struct unpacket_traits { + typedef Eigen::half type; + typedef Packet4hf half; + enum { + size = 8, + alignment = Aligned16, + vectorizable = true, + masked_load_available = false, + masked_store_available = false + }; +}; + +template<> +EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet4hf predux_half_dowto4(const Packet8hf& a) { + return vadd_f16(vget_low_f16(a), vget_high_f16(a)); +} + +template <> +EIGEN_STRONG_INLINE Packet8hf pset1(const Eigen::half& from) { + return vdupq_n_f16(from.x); +} + +template <> +EIGEN_STRONG_INLINE Packet4hf pset1(const Eigen::half& from) { + return vdup_n_f16(from.x); +} + +template <> +EIGEN_STRONG_INLINE Packet8hf plset(const Eigen::half& a) { + const float16_t f[] = {0, 1, 2, 3, 4, 5, 6, 7}; + Packet8hf countdown = vld1q_f16(f); + return vaddq_f16(pset1(a), countdown); +} + +template <> +EIGEN_STRONG_INLINE Packet4hf plset(const Eigen::half& a) { + const float16_t f[] = {0, 1, 2, 3}; + Packet4hf countdown = vld1_f16(f); + return vadd_f16(pset1(a), countdown); +} + +template <> +EIGEN_STRONG_INLINE Packet8hf padd(const Packet8hf& a, const Packet8hf& b) { + return vaddq_f16(a, b); +} + +template <> +EIGEN_STRONG_INLINE Packet4hf padd(const Packet4hf& a, const Packet4hf& b) { + return vadd_f16(a, b); +} + +template <> +EIGEN_STRONG_INLINE Packet8hf psub(const Packet8hf& a, const Packet8hf& b) { + return vsubq_f16(a, b); +} + +template <> +EIGEN_STRONG_INLINE Packet4hf psub(const Packet4hf& a, const Packet4hf& b) { + return vsub_f16(a, b); +} + +template <> +EIGEN_STRONG_INLINE Packet8hf pnegate(const Packet8hf& a) { + return vnegq_f16(a); +} + +template <> +EIGEN_STRONG_INLINE Packet4hf pnegate(const Packet4hf& a) { + return vneg_f16(a); +} + +template <> +EIGEN_STRONG_INLINE Packet8hf pconj(const Packet8hf& a) { + return a; +} + +template <> +EIGEN_STRONG_INLINE Packet4hf pconj(const Packet4hf& a) { + return a; +} + +template <> +EIGEN_STRONG_INLINE Packet8hf pmul(const Packet8hf& a, const Packet8hf& b) { + return vmulq_f16(a, b); +} + +template <> +EIGEN_STRONG_INLINE Packet4hf pmul(const Packet4hf& a, const Packet4hf& b) { + return vmul_f16(a, b); +} + +template <> +EIGEN_STRONG_INLINE Packet8hf pdiv(const Packet8hf& a, const Packet8hf& b) { + return vdivq_f16(a, b); +} + +template <> +EIGEN_STRONG_INLINE Packet4hf pdiv(const Packet4hf& a, const Packet4hf& b) { + return vdiv_f16(a, b); +} + +template <> +EIGEN_STRONG_INLINE Packet8hf pmadd(const Packet8hf& a, const Packet8hf& b, const Packet8hf& c) { + return vfmaq_f16(c, a, b); +} + +template <> +EIGEN_STRONG_INLINE Packet4hf pmadd(const Packet4hf& a, const Packet4hf& b, const Packet4hf& c) { + return vfma_f16(c, a, b); +} + +template <> +EIGEN_STRONG_INLINE Packet8hf pmin(const Packet8hf& a, const Packet8hf& b) { + return vminq_f16(a, b); +} + +template <> +EIGEN_STRONG_INLINE Packet4hf pmin(const Packet4hf& a, const Packet4hf& b) { + return vmin_f16(a, b); +} + +#ifdef __ARM_FEATURE_NUMERIC_MAXMIN +// numeric max and min are only available if ARM_FEATURE_NUMERIC_MAXMIN is defined (which can only be the case for Armv8 systems). +template<> EIGEN_STRONG_INLINE Packet4hf pmin(const Packet4hf& a, const Packet4hf& b) { return vminnm_f16(a, b); } +template<> EIGEN_STRONG_INLINE Packet8hf pmin(const Packet8hf& a, const Packet8hf& b) { return vminnmq_f16(a, b); } +#endif + +template<> EIGEN_STRONG_INLINE Packet4hf pmin(const Packet4hf& a, const Packet4hf& b) { return pmin(a, b); } + +template<> EIGEN_STRONG_INLINE Packet8hf pmin(const Packet8hf& a, const Packet8hf& b) { return pmin(a, b); } + +template <> +EIGEN_STRONG_INLINE Packet8hf pmax(const Packet8hf& a, const Packet8hf& b) { + return vmaxq_f16(a, b); +} + +template <> +EIGEN_STRONG_INLINE Packet4hf pmax(const Packet4hf& a, const Packet4hf& b) { + return vmax_f16(a, b); +} + +#ifdef __ARM_FEATURE_NUMERIC_MAXMIN +// numeric max and min are only available if ARM_FEATURE_NUMERIC_MAXMIN is defined (which can only be the case for Armv8 systems). +template<> EIGEN_STRONG_INLINE Packet4hf pmax(const Packet4hf& a, const Packet4hf& b) { return vmaxnm_f16(a, b); } +template<> EIGEN_STRONG_INLINE Packet8hf pmax(const Packet8hf& a, const Packet8hf& b) { return vmaxnmq_f16(a, b); } +#endif + +template<> EIGEN_STRONG_INLINE Packet4hf pmax(const Packet4hf& a, const Packet4hf& b) { return pmax(a, b); } + +template<> EIGEN_STRONG_INLINE Packet8hf pmax(const Packet8hf& a, const Packet8hf& b) { return pmax(a, b); } + +#define EIGEN_MAKE_ARM_FP16_CMP_8(name) \ + template <> \ + EIGEN_STRONG_INLINE Packet8hf pcmp_##name(const Packet8hf& a, const Packet8hf& b) { \ + return vreinterpretq_f16_u16(vc##name##q_f16(a, b)); \ + } + +#define EIGEN_MAKE_ARM_FP16_CMP_4(name) \ + template <> \ + EIGEN_STRONG_INLINE Packet4hf pcmp_##name(const Packet4hf& a, const Packet4hf& b) { \ + return vreinterpret_f16_u16(vc##name##_f16(a, b)); \ + } + +EIGEN_MAKE_ARM_FP16_CMP_8(eq) +EIGEN_MAKE_ARM_FP16_CMP_8(lt) +EIGEN_MAKE_ARM_FP16_CMP_8(le) + +EIGEN_MAKE_ARM_FP16_CMP_4(eq) +EIGEN_MAKE_ARM_FP16_CMP_4(lt) +EIGEN_MAKE_ARM_FP16_CMP_4(le) + +#undef EIGEN_MAKE_ARM_FP16_CMP_8 +#undef EIGEN_MAKE_ARM_FP16_CMP_4 + +template <> +EIGEN_STRONG_INLINE Packet8hf pcmp_lt_or_nan(const Packet8hf& a, const Packet8hf& b) { + return vreinterpretq_f16_u16(vmvnq_u16(vcgeq_f16(a, b))); +} + +template <> +EIGEN_STRONG_INLINE Packet4hf pcmp_lt_or_nan(const Packet4hf& a, const Packet4hf& b) { + return vreinterpret_f16_u16(vmvn_u16(vcge_f16(a, b))); +} + +template <> +EIGEN_STRONG_INLINE Packet8hf print(const Packet8hf& a) +{ return vrndnq_f16(a); } + +template <> +EIGEN_STRONG_INLINE Packet4hf print(const Packet4hf& a) +{ return vrndn_f16(a); } + +template <> +EIGEN_STRONG_INLINE Packet8hf pfloor(const Packet8hf& a) +{ return vrndmq_f16(a); } + +template <> +EIGEN_STRONG_INLINE Packet4hf pfloor(const Packet4hf& a) +{ return vrndm_f16(a); } + +template <> +EIGEN_STRONG_INLINE Packet8hf pceil(const Packet8hf& a) +{ return vrndpq_f16(a); } + +template <> +EIGEN_STRONG_INLINE Packet4hf pceil(const Packet4hf& a) +{ return vrndp_f16(a); } + +template <> +EIGEN_STRONG_INLINE Packet8hf psqrt(const Packet8hf& a) { + return vsqrtq_f16(a); +} + +template <> +EIGEN_STRONG_INLINE Packet4hf psqrt(const Packet4hf& a) { + return vsqrt_f16(a); +} + +template <> +EIGEN_STRONG_INLINE Packet8hf pand(const Packet8hf& a, const Packet8hf& b) { + return vreinterpretq_f16_u16(vandq_u16(vreinterpretq_u16_f16(a), vreinterpretq_u16_f16(b))); +} + +template <> +EIGEN_STRONG_INLINE Packet4hf pand(const Packet4hf& a, const Packet4hf& b) { + return vreinterpret_f16_u16(vand_u16(vreinterpret_u16_f16(a), vreinterpret_u16_f16(b))); +} + +template <> +EIGEN_STRONG_INLINE Packet8hf por(const Packet8hf& a, const Packet8hf& b) { + return vreinterpretq_f16_u16(vorrq_u16(vreinterpretq_u16_f16(a), vreinterpretq_u16_f16(b))); +} + +template <> +EIGEN_STRONG_INLINE Packet4hf por(const Packet4hf& a, const Packet4hf& b) { + return vreinterpret_f16_u16(vorr_u16(vreinterpret_u16_f16(a), vreinterpret_u16_f16(b))); +} + +template <> +EIGEN_STRONG_INLINE Packet8hf pxor(const Packet8hf& a, const Packet8hf& b) { + return vreinterpretq_f16_u16(veorq_u16(vreinterpretq_u16_f16(a), vreinterpretq_u16_f16(b))); +} + +template <> +EIGEN_STRONG_INLINE Packet4hf pxor(const Packet4hf& a, const Packet4hf& b) { + return vreinterpret_f16_u16(veor_u16(vreinterpret_u16_f16(a), vreinterpret_u16_f16(b))); +} + +template <> +EIGEN_STRONG_INLINE Packet8hf pandnot(const Packet8hf& a, const Packet8hf& b) { + return vreinterpretq_f16_u16(vbicq_u16(vreinterpretq_u16_f16(a), vreinterpretq_u16_f16(b))); +} + +template <> +EIGEN_STRONG_INLINE Packet4hf pandnot(const Packet4hf& a, const Packet4hf& b) { + return vreinterpret_f16_u16(vbic_u16(vreinterpret_u16_f16(a), vreinterpret_u16_f16(b))); +} + +template <> +EIGEN_STRONG_INLINE Packet8hf pload(const Eigen::half* from) { + EIGEN_DEBUG_ALIGNED_LOAD return vld1q_f16(reinterpret_cast(from)); +} + +template <> +EIGEN_STRONG_INLINE Packet4hf pload(const Eigen::half* from) { + EIGEN_DEBUG_ALIGNED_LOAD return vld1_f16(reinterpret_cast(from)); +} + +template <> +EIGEN_STRONG_INLINE Packet8hf ploadu(const Eigen::half* from) { + EIGEN_DEBUG_UNALIGNED_LOAD return vld1q_f16(reinterpret_cast(from)); +} + +template <> +EIGEN_STRONG_INLINE Packet4hf ploadu(const Eigen::half* from) { + EIGEN_DEBUG_UNALIGNED_LOAD return vld1_f16(reinterpret_cast(from)); +} + +template <> +EIGEN_STRONG_INLINE Packet8hf ploaddup(const Eigen::half* from) { + Packet8hf packet; + packet[0] = from[0].x; + packet[1] = from[0].x; + packet[2] = from[1].x; + packet[3] = from[1].x; + packet[4] = from[2].x; + packet[5] = from[2].x; + packet[6] = from[3].x; + packet[7] = from[3].x; + return packet; +} + +template <> +EIGEN_STRONG_INLINE Packet4hf ploaddup(const Eigen::half* from) { + float16x4_t packet; + float16_t* tmp; + tmp = (float16_t*)&packet; + tmp[0] = from[0].x; + tmp[1] = from[0].x; + tmp[2] = from[1].x; + tmp[3] = from[1].x; + return packet; +} + +template <> +EIGEN_STRONG_INLINE Packet8hf ploadquad(const Eigen::half* from) { + Packet4hf lo, hi; + lo = vld1_dup_f16(reinterpret_cast(from)); + hi = vld1_dup_f16(reinterpret_cast(from+1)); + return vcombine_f16(lo, hi); +} + +EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet8hf pinsertfirst(const Packet8hf& a, Eigen::half b) { return vsetq_lane_f16(b.x, a, 0); } + +EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet4hf pinsertfirst(const Packet4hf& a, Eigen::half b) { return vset_lane_f16(b.x, a, 0); } + +template <> +EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet8hf pselect(const Packet8hf& mask, const Packet8hf& a, const Packet8hf& b) { + return vbslq_f16(vreinterpretq_u16_f16(mask), a, b); +} + +template <> +EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet4hf pselect(const Packet4hf& mask, const Packet4hf& a, const Packet4hf& b) { + return vbsl_f16(vreinterpret_u16_f16(mask), a, b); +} + +EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet8hf pinsertlast(const Packet8hf& a, Eigen::half b) { return vsetq_lane_f16(b.x, a, 7); } + +EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet4hf pinsertlast(const Packet4hf& a, Eigen::half b) { return vset_lane_f16(b.x, a, 3); } + +template <> +EIGEN_STRONG_INLINE void pstore(Eigen::half* to, const Packet8hf& from) { + EIGEN_DEBUG_ALIGNED_STORE vst1q_f16(reinterpret_cast(to), from); +} + +template <> +EIGEN_STRONG_INLINE void pstore(Eigen::half* to, const Packet4hf& from) { + EIGEN_DEBUG_ALIGNED_STORE vst1_f16(reinterpret_cast(to), from); +} + +template <> +EIGEN_STRONG_INLINE void pstoreu(Eigen::half* to, const Packet8hf& from) { + EIGEN_DEBUG_UNALIGNED_STORE vst1q_f16(reinterpret_cast(to), from); +} + +template <> +EIGEN_STRONG_INLINE void pstoreu(Eigen::half* to, const Packet4hf& from) { + EIGEN_DEBUG_UNALIGNED_STORE vst1_f16(reinterpret_cast(to), from); +} + +template <> +EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet8hf pgather(const Eigen::half* from, Index stride) { + Packet8hf res = pset1(Eigen::half(0.f)); + res = vsetq_lane_f16(from[0 * stride].x, res, 0); + res = vsetq_lane_f16(from[1 * stride].x, res, 1); + res = vsetq_lane_f16(from[2 * stride].x, res, 2); + res = vsetq_lane_f16(from[3 * stride].x, res, 3); + res = vsetq_lane_f16(from[4 * stride].x, res, 4); + res = vsetq_lane_f16(from[5 * stride].x, res, 5); + res = vsetq_lane_f16(from[6 * stride].x, res, 6); + res = vsetq_lane_f16(from[7 * stride].x, res, 7); + return res; +} + +template <> +EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet4hf pgather(const Eigen::half* from, Index stride) { + Packet4hf res = pset1(Eigen::half(0.f)); + res = vset_lane_f16(from[0 * stride].x, res, 0); + res = vset_lane_f16(from[1 * stride].x, res, 1); + res = vset_lane_f16(from[2 * stride].x, res, 2); + res = vset_lane_f16(from[3 * stride].x, res, 3); + return res; +} + +template <> +EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void pscatter(Eigen::half* to, const Packet8hf& from, Index stride) { + to[stride * 0].x = vgetq_lane_f16(from, 0); + to[stride * 1].x = vgetq_lane_f16(from, 1); + to[stride * 2].x = vgetq_lane_f16(from, 2); + to[stride * 3].x = vgetq_lane_f16(from, 3); + to[stride * 4].x = vgetq_lane_f16(from, 4); + to[stride * 5].x = vgetq_lane_f16(from, 5); + to[stride * 6].x = vgetq_lane_f16(from, 6); + to[stride * 7].x = vgetq_lane_f16(from, 7); +} + +template <> +EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void pscatter(Eigen::half* to, const Packet4hf& from, Index stride) { + to[stride * 0].x = vget_lane_f16(from, 0); + to[stride * 1].x = vget_lane_f16(from, 1); + to[stride * 2].x = vget_lane_f16(from, 2); + to[stride * 3].x = vget_lane_f16(from, 3); +} + +template <> +EIGEN_STRONG_INLINE void prefetch(const Eigen::half* addr) { + EIGEN_ARM_PREFETCH(addr); +} + +template <> +EIGEN_STRONG_INLINE Eigen::half pfirst(const Packet8hf& a) { + float16_t x[8]; + vst1q_f16(x, a); + Eigen::half h; + h.x = x[0]; + return h; +} + +template <> +EIGEN_STRONG_INLINE Eigen::half pfirst(const Packet4hf& a) { + float16_t x[4]; + vst1_f16(x, a); + Eigen::half h; + h.x = x[0]; + return h; +} + +template<> EIGEN_STRONG_INLINE Packet8hf preverse(const Packet8hf& a) { + float16x4_t a_lo, a_hi; + Packet8hf a_r64; + + a_r64 = vrev64q_f16(a); + a_lo = vget_low_f16(a_r64); + a_hi = vget_high_f16(a_r64); + return vcombine_f16(a_hi, a_lo); +} + +template <> +EIGEN_STRONG_INLINE Packet4hf preverse(const Packet4hf& a) { + return vrev64_f16(a); +} + +template <> +EIGEN_STRONG_INLINE Packet8hf pabs(const Packet8hf& a) { + return vabsq_f16(a); +} + +template <> +EIGEN_STRONG_INLINE Packet4hf pabs(const Packet4hf& a) { + return vabs_f16(a); +} + +template <> +EIGEN_STRONG_INLINE Eigen::half predux(const Packet8hf& a) { + float16x4_t a_lo, a_hi, sum; + + a_lo = vget_low_f16(a); + a_hi = vget_high_f16(a); + sum = vpadd_f16(a_lo, a_hi); + sum = vpadd_f16(sum, sum); + sum = vpadd_f16(sum, sum); + + Eigen::half h; + h.x = vget_lane_f16(sum, 0); + return h; +} + +template <> +EIGEN_STRONG_INLINE Eigen::half predux(const Packet4hf& a) { + float16x4_t sum; + + sum = vpadd_f16(a, a); + sum = vpadd_f16(sum, sum); + Eigen::half h; + h.x = vget_lane_f16(sum, 0); + return h; +} + +template <> +EIGEN_STRONG_INLINE Eigen::half predux_mul(const Packet8hf& a) { + float16x4_t a_lo, a_hi, prod; + + a_lo = vget_low_f16(a); + a_hi = vget_high_f16(a); + prod = vmul_f16(a_lo, a_hi); + prod = vmul_f16(prod, vrev64_f16(prod)); + + Eigen::half h; + h.x = vmulh_f16(vget_lane_f16(prod, 0), vget_lane_f16(prod, 1)); + return h; +} + +template <> +EIGEN_STRONG_INLINE Eigen::half predux_mul(const Packet4hf& a) { + float16x4_t prod; + prod = vmul_f16(a, vrev64_f16(a)); + Eigen::half h; + h.x = vmulh_f16(vget_lane_f16(prod, 0), vget_lane_f16(prod, 1)); + return h; +} + +template <> +EIGEN_STRONG_INLINE Eigen::half predux_min(const Packet8hf& a) { + float16x4_t a_lo, a_hi, min; + + a_lo = vget_low_f16(a); + a_hi = vget_high_f16(a); + min = vpmin_f16(a_lo, a_hi); + min = vpmin_f16(min, min); + min = vpmin_f16(min, min); + + Eigen::half h; + h.x = vget_lane_f16(min, 0); + return h; +} + +template <> +EIGEN_STRONG_INLINE Eigen::half predux_min(const Packet4hf& a) { + Packet4hf tmp; + tmp = vpmin_f16(a, a); + tmp = vpmin_f16(tmp, tmp); + Eigen::half h; + h.x = vget_lane_f16(tmp, 0); + return h; +} + +template <> +EIGEN_STRONG_INLINE Eigen::half predux_max(const Packet8hf& a) { + float16x4_t a_lo, a_hi, max; + + a_lo = vget_low_f16(a); + a_hi = vget_high_f16(a); + max = vpmax_f16(a_lo, a_hi); + max = vpmax_f16(max, max); + max = vpmax_f16(max, max); + + Eigen::half h; + h.x = vget_lane_f16(max, 0); + return h; +} + +template <> +EIGEN_STRONG_INLINE Eigen::half predux_max(const Packet4hf& a) { + Packet4hf tmp; + tmp = vpmax_f16(a, a); + tmp = vpmax_f16(tmp, tmp); + Eigen::half h; + h.x = vget_lane_f16(tmp, 0); + return h; +} + +EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void ptranspose(PacketBlock& kernel) +{ + const float16x8x2_t zip16_1 = vzipq_f16(kernel.packet[0], kernel.packet[1]); + const float16x8x2_t zip16_2 = vzipq_f16(kernel.packet[2], kernel.packet[3]); + + const float32x4x2_t zip32_1 = vzipq_f32(vreinterpretq_f32_f16(zip16_1.val[0]), vreinterpretq_f32_f16(zip16_2.val[0])); + const float32x4x2_t zip32_2 = vzipq_f32(vreinterpretq_f32_f16(zip16_1.val[1]), vreinterpretq_f32_f16(zip16_2.val[1])); + + kernel.packet[0] = vreinterpretq_f16_f32(zip32_1.val[0]); + kernel.packet[1] = vreinterpretq_f16_f32(zip32_1.val[1]); + kernel.packet[2] = vreinterpretq_f16_f32(zip32_2.val[0]); + kernel.packet[3] = vreinterpretq_f16_f32(zip32_2.val[1]); +} + +EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void ptranspose(PacketBlock& kernel) { + EIGEN_ALIGN16 float16x4x4_t tmp_x4; + float16_t* tmp = (float16_t*)&kernel; + tmp_x4 = vld4_f16(tmp); + + kernel.packet[0] = tmp_x4.val[0]; + kernel.packet[1] = tmp_x4.val[1]; + kernel.packet[2] = tmp_x4.val[2]; + kernel.packet[3] = tmp_x4.val[3]; +} + +EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void ptranspose(PacketBlock& kernel) { + float16x8x2_t T_1[4]; + + T_1[0] = vuzpq_f16(kernel.packet[0], kernel.packet[1]); + T_1[1] = vuzpq_f16(kernel.packet[2], kernel.packet[3]); + T_1[2] = vuzpq_f16(kernel.packet[4], kernel.packet[5]); + T_1[3] = vuzpq_f16(kernel.packet[6], kernel.packet[7]); + + float16x8x2_t T_2[4]; + T_2[0] = vuzpq_f16(T_1[0].val[0], T_1[1].val[0]); + T_2[1] = vuzpq_f16(T_1[0].val[1], T_1[1].val[1]); + T_2[2] = vuzpq_f16(T_1[2].val[0], T_1[3].val[0]); + T_2[3] = vuzpq_f16(T_1[2].val[1], T_1[3].val[1]); + + float16x8x2_t T_3[4]; + T_3[0] = vuzpq_f16(T_2[0].val[0], T_2[2].val[0]); + T_3[1] = vuzpq_f16(T_2[0].val[1], T_2[2].val[1]); + T_3[2] = vuzpq_f16(T_2[1].val[0], T_2[3].val[0]); + T_3[3] = vuzpq_f16(T_2[1].val[1], T_2[3].val[1]); + + kernel.packet[0] = T_3[0].val[0]; + kernel.packet[1] = T_3[2].val[0]; + kernel.packet[2] = T_3[1].val[0]; + kernel.packet[3] = T_3[3].val[0]; + kernel.packet[4] = T_3[0].val[1]; + kernel.packet[5] = T_3[2].val[1]; + kernel.packet[6] = T_3[1].val[1]; + kernel.packet[7] = T_3[3].val[1]; +} +#endif // end EIGEN_HAS_ARM64_FP16_VECTOR_ARITHMETIC + +} // end namespace internal + +} // end namespace Eigen + +#endif // EIGEN_PACKET_MATH_NEON_H diff --git a/Eigen/src/Core/arch/NEON/TypeCasting.h b/Eigen/src/Core/arch/NEON/TypeCasting.h new file mode 100644 index 0000000..54f9733 --- /dev/null +++ b/Eigen/src/Core/arch/NEON/TypeCasting.h @@ -0,0 +1,1419 @@ +// This file is part of Eigen, a lightweight C++ template library +// for linear algebra. +// +// Copyright (C) 2018 Rasmus Munk Larsen +// Copyright (C) 2020 Antonio Sanchez +// +// This Source Code Form is subject to the terms of the Mozilla +// Public License v. 2.0. If a copy of the MPL was not distributed +// with this file, You can obtain one at http://mozilla.org/MPL/2.0/. + +#ifndef EIGEN_TYPE_CASTING_NEON_H +#define EIGEN_TYPE_CASTING_NEON_H + +namespace Eigen { + +namespace internal { + +//============================================================================== +// pcast, SrcType = float +//============================================================================== +template <> +struct type_casting_traits { + enum { VectorizedCast = 1, SrcCoeffRatio = 1, TgtCoeffRatio = 1 }; +}; +template <> +EIGEN_STRONG_INLINE Packet4f pcast(const Packet4f& a) { + return a; +} +template <> +EIGEN_STRONG_INLINE Packet2f pcast(const Packet2f& a) { + return a; +} + +template <> +struct type_casting_traits { + enum { VectorizedCast = 1, SrcCoeffRatio = 1, TgtCoeffRatio = 2 }; +}; +template <> +struct type_casting_traits { + enum { VectorizedCast = 1, SrcCoeffRatio = 1, TgtCoeffRatio = 2 }; +}; +// If float64 exists, first convert to that to keep as much precision as possible. +#if EIGEN_ARCH_ARM64 +template <> +EIGEN_STRONG_INLINE Packet2l pcast(const Packet4f& a) { + // Discard second half of input. + return vcvtq_s64_f64(vcvt_f64_f32(vget_low_f32(a))); +} +template <> +EIGEN_STRONG_INLINE Packet2ul pcast(const Packet4f& a) { + // Discard second half of input. + return vcvtq_u64_f64(vcvt_f64_f32(vget_low_f32(a))); +} +#else +template <> +EIGEN_STRONG_INLINE Packet2l pcast(const Packet4f& a) { + // Discard second half of input. + return vmovl_s32(vget_low_s32(vcvtq_s32_f32(a))); +} +template <> +EIGEN_STRONG_INLINE Packet2ul pcast(const Packet4f& a) { + // Discard second half of input. + return vmovl_u32(vget_low_u32(vcvtq_u32_f32(a))); +} +#endif // EIGEN_ARCH_ARM64 + +template <> +struct type_casting_traits { + enum { VectorizedCast = 1, SrcCoeffRatio = 1, TgtCoeffRatio = 1 }; +}; +template <> +EIGEN_STRONG_INLINE Packet4i pcast(const Packet4f& a) { + return vcvtq_s32_f32(a); +} +template <> +EIGEN_STRONG_INLINE Packet2i pcast(const Packet2f& a) { + return vcvt_s32_f32(a); +} + +template <> +struct type_casting_traits { + enum { VectorizedCast = 1, SrcCoeffRatio = 1, TgtCoeffRatio = 1 }; +}; +template <> +EIGEN_STRONG_INLINE Packet4ui pcast(const Packet4f& a) { + return vcvtq_u32_f32(a); +} +template <> +EIGEN_STRONG_INLINE Packet2ui pcast(const Packet2f& a) { + return vcvt_u32_f32(a); +} + +template <> +struct type_casting_traits { + enum { VectorizedCast = 1, SrcCoeffRatio = 2, TgtCoeffRatio = 1 }; +}; +template <> +EIGEN_STRONG_INLINE Packet8s pcast(const Packet4f& a, const Packet4f& b) { + return vcombine_s16(vmovn_s32(vcvtq_s32_f32(a)), vmovn_s32(vcvtq_s32_f32(b))); +} +template <> +EIGEN_STRONG_INLINE Packet4s pcast(const Packet2f& a, const Packet2f& b) { + return vmovn_s32(vcombine_s32(vcvt_s32_f32(a), vcvt_s32_f32(b))); +} + +template <> +struct type_casting_traits { + enum { VectorizedCast = 1, SrcCoeffRatio = 2, TgtCoeffRatio = 1 }; +}; +template <> +EIGEN_STRONG_INLINE Packet8us pcast(const Packet4f& a, const Packet4f& b) { + return vcombine_u16(vmovn_u32(vcvtq_u32_f32(a)), vmovn_u32(vcvtq_u32_f32(b))); +} +template <> +EIGEN_STRONG_INLINE Packet4us pcast(const Packet2f& a, const Packet2f& b) { + return vmovn_u32(vcombine_u32(vcvt_u32_f32(a), vcvt_u32_f32(b))); +} + +template <> +struct type_casting_traits { + enum { VectorizedCast = 1, SrcCoeffRatio = 4, TgtCoeffRatio = 1 }; +}; +template <> +EIGEN_STRONG_INLINE Packet16c pcast(const Packet4f& a, const Packet4f& b, const Packet4f& c, + const Packet4f& d) { + const int16x8_t ab_s16 = pcast(a, b); + const int16x8_t cd_s16 = pcast(c, d); + return vcombine_s8(vmovn_s16(ab_s16), vmovn_s16(cd_s16)); +} +template <> +EIGEN_STRONG_INLINE Packet8c pcast(const Packet2f& a, const Packet2f& b, const Packet2f& c, + const Packet2f& d) { + const int16x4_t ab_s16 = pcast(a, b); + const int16x4_t cd_s16 = pcast(c, d); + return vmovn_s16(vcombine_s16(ab_s16, cd_s16)); +} + +template <> +struct type_casting_traits { + enum { VectorizedCast = 1, SrcCoeffRatio = 4, TgtCoeffRatio = 1 }; +}; +template <> +EIGEN_STRONG_INLINE Packet16uc pcast(const Packet4f& a, const Packet4f& b, const Packet4f& c, + const Packet4f& d) { + const uint16x8_t ab_u16 = pcast(a, b); + const uint16x8_t cd_u16 = pcast(c, d); + return vcombine_u8(vmovn_u16(ab_u16), vmovn_u16(cd_u16)); +} +template <> +EIGEN_STRONG_INLINE Packet8uc pcast(const Packet2f& a, const Packet2f& b, const Packet2f& c, + const Packet2f& d) { + const uint16x4_t ab_u16 = pcast(a, b); + const uint16x4_t cd_u16 = pcast(c, d); + return vmovn_u16(vcombine_u16(ab_u16, cd_u16)); +} + +//============================================================================== +// pcast, SrcType = int8_t +//============================================================================== +template <> +struct type_casting_traits { + enum { VectorizedCast = 1, SrcCoeffRatio = 1, TgtCoeffRatio = 4 }; +}; +template <> +EIGEN_STRONG_INLINE Packet4f pcast(const Packet16c& a) { + // Discard all but first 4 bytes. + return vcvtq_f32_s32(vmovl_s16(vget_low_s16(vmovl_s8(vget_low_s8(a))))); +} +template <> +EIGEN_STRONG_INLINE Packet2f pcast(const Packet8c& a) { + // Discard all but first 2 bytes. + return vcvt_f32_s32(vget_low_s32(vmovl_s16(vget_low_s16(vmovl_s8(a))))); +} + +template <> +struct type_casting_traits { + enum { VectorizedCast = 1, SrcCoeffRatio = 1, TgtCoeffRatio = 8 }; +}; +template <> +EIGEN_STRONG_INLINE Packet2l pcast(const Packet16c& a) { + // Discard all but first two bytes. + return vmovl_s32(vget_low_s32(vmovl_s16(vget_low_s16(vmovl_s8(vget_low_s8(a)))))); +} + +template <> +struct type_casting_traits { + enum { VectorizedCast = 1, SrcCoeffRatio = 1, TgtCoeffRatio = 8 }; +}; +template <> +EIGEN_STRONG_INLINE Packet2ul pcast(const Packet16c& a) { + return vreinterpretq_u64_s64(pcast(a)); +} + +template <> +struct type_casting_traits { + enum { VectorizedCast = 1, SrcCoeffRatio = 1, TgtCoeffRatio = 4 }; +}; +template <> +EIGEN_STRONG_INLINE Packet4i pcast(const Packet16c& a) { + // Discard all but first 4 bytes. + return vmovl_s16(vget_low_s16(vmovl_s8(vget_low_s8(a)))); +} +template <> +EIGEN_STRONG_INLINE Packet2i pcast(const Packet8c& a) { + // Discard all but first 2 bytes. + return vget_low_s32(vmovl_s16(vget_low_s16(vmovl_s8(a)))); +} + +template <> +struct type_casting_traits { + enum { VectorizedCast = 1, SrcCoeffRatio = 1, TgtCoeffRatio = 4 }; +}; +template <> +EIGEN_STRONG_INLINE Packet4ui pcast(const Packet16c& a) { + return vreinterpretq_u32_s32(pcast(a)); +} +template <> +EIGEN_STRONG_INLINE Packet2ui pcast(const Packet8c& a) { + return vreinterpret_u32_s32(pcast(a)); +} + +template <> +struct type_casting_traits { + enum { VectorizedCast = 1, SrcCoeffRatio = 1, TgtCoeffRatio = 2 }; +}; +template <> +EIGEN_STRONG_INLINE Packet8s pcast(const Packet16c& a) { + // Discard second half of input. + return vmovl_s8(vget_low_s8(a)); +} +template <> +EIGEN_STRONG_INLINE Packet4s pcast(const Packet8c& a) { + // Discard second half of input. + return vget_low_s16(vmovl_s8(a)); +} + +template <> +struct type_casting_traits { + enum { VectorizedCast = 1, SrcCoeffRatio = 1, TgtCoeffRatio = 2 }; +}; +template <> +EIGEN_STRONG_INLINE Packet8us pcast(const Packet16c& a) { + return vreinterpretq_u16_s16(pcast(a)); +} +template <> +EIGEN_STRONG_INLINE Packet4us pcast(const Packet8c& a) { + return vreinterpret_u16_s16(pcast(a)); +} + +template <> +struct type_casting_traits { + enum { VectorizedCast = 1, SrcCoeffRatio = 1, TgtCoeffRatio = 1 }; +}; +template <> +EIGEN_STRONG_INLINE Packet16c pcast(const Packet16c& a) { + return a; +} +template <> +EIGEN_STRONG_INLINE Packet8c pcast(const Packet8c& a) { + return a; +} +template <> +EIGEN_STRONG_INLINE Packet4c pcast(const Packet4c& a) { + return a; +} + +template <> +struct type_casting_traits { + enum { VectorizedCast = 1, SrcCoeffRatio = 1, TgtCoeffRatio = 1 }; +}; +template <> +EIGEN_STRONG_INLINE Packet16uc pcast(const Packet16c& a) { + return vreinterpretq_u8_s8(a); +} +template <> +EIGEN_STRONG_INLINE Packet8uc pcast(const Packet8c& a) { + return vreinterpret_u8_s8(a); +} +template <> +EIGEN_STRONG_INLINE Packet4uc pcast(const Packet4c& a) { + return static_cast(a); +} + +//============================================================================== +// pcast, SrcType = uint8_t +//============================================================================== +template <> +struct type_casting_traits { + enum { VectorizedCast = 1, SrcCoeffRatio = 1, TgtCoeffRatio = 4 }; +}; +template <> +EIGEN_STRONG_INLINE Packet4f pcast(const Packet16uc& a) { + // Discard all but first 4 bytes. + return vcvtq_f32_u32(vmovl_u16(vget_low_u16(vmovl_u8(vget_low_u8(a))))); +} +template <> +EIGEN_STRONG_INLINE Packet2f pcast(const Packet8uc& a) { + // Discard all but first 2 bytes. + return vcvt_f32_u32(vget_low_u32(vmovl_u16(vget_low_u16(vmovl_u8(a))))); +} + +template <> +struct type_casting_traits { + enum { VectorizedCast = 1, SrcCoeffRatio = 1, TgtCoeffRatio = 8 }; +}; +template <> +EIGEN_STRONG_INLINE Packet2ul pcast(const Packet16uc& a) { + // Discard all but first two bytes. + return vmovl_u32(vget_low_u32(vmovl_u16(vget_low_u16(vmovl_u8(vget_low_u8(a)))))); +} + +template <> +struct type_casting_traits { + enum { VectorizedCast = 1, SrcCoeffRatio = 1, TgtCoeffRatio = 8 }; +}; +template <> +EIGEN_STRONG_INLINE Packet2l pcast(const Packet16uc& a) { + return vreinterpretq_s64_u64(pcast(a)); +} + +template <> +struct type_casting_traits { + enum { VectorizedCast = 1, SrcCoeffRatio = 1, TgtCoeffRatio = 4 }; +}; +template <> +EIGEN_STRONG_INLINE Packet4ui pcast(const Packet16uc& a) { + // Discard all but first 4 bytes. + return vmovl_u16(vget_low_u16(vmovl_u8(vget_low_u8(a)))); +} +template <> +EIGEN_STRONG_INLINE Packet2ui pcast(const Packet8uc& a) { + // Discard all but first 2 bytes. + return vget_low_u32(vmovl_u16(vget_low_u16(vmovl_u8(a)))); +} + +template <> +struct type_casting_traits { + enum { VectorizedCast = 1, SrcCoeffRatio = 1, TgtCoeffRatio = 4 }; +}; +template <> +EIGEN_STRONG_INLINE Packet4i pcast(const Packet16uc& a) { + return vreinterpretq_s32_u32(pcast(a)); +} +template <> +EIGEN_STRONG_INLINE Packet2i pcast(const Packet8uc& a) { + return vreinterpret_s32_u32(pcast(a)); +} + +template <> +struct type_casting_traits { + enum { VectorizedCast = 1, SrcCoeffRatio = 1, TgtCoeffRatio = 2 }; +}; +template <> +EIGEN_STRONG_INLINE Packet8us pcast(const Packet16uc& a) { + // Discard second half of input. + return vmovl_u8(vget_low_u8(a)); +} +template <> +EIGEN_STRONG_INLINE Packet4us pcast(const Packet8uc& a) { + // Discard second half of input. + return vget_low_u16(vmovl_u8(a)); +} + +template <> +struct type_casting_traits { + enum { VectorizedCast = 1, SrcCoeffRatio = 1, TgtCoeffRatio = 2 }; +}; +template <> +EIGEN_STRONG_INLINE Packet8s pcast(const Packet16uc& a) { + return vreinterpretq_s16_u16(pcast(a)); +} +template <> +EIGEN_STRONG_INLINE Packet4s pcast(const Packet8uc& a) { + return vreinterpret_s16_u16(pcast(a)); +} + +template <> +struct type_casting_traits { + enum { VectorizedCast = 1, SrcCoeffRatio = 1, TgtCoeffRatio = 1 }; +}; +template <> +EIGEN_STRONG_INLINE Packet16uc pcast(const Packet16uc& a) { + return a; +} +template <> +EIGEN_STRONG_INLINE Packet8uc pcast(const Packet8uc& a) { + return a; +} +template <> +EIGEN_STRONG_INLINE Packet4uc pcast(const Packet4uc& a) { + return a; +} + +template <> +struct type_casting_traits { + enum { VectorizedCast = 1, SrcCoeffRatio = 1, TgtCoeffRatio = 1 }; +}; +template <> +EIGEN_STRONG_INLINE Packet16c pcast(const Packet16uc& a) { + return vreinterpretq_s8_u8(a); +} +template <> +EIGEN_STRONG_INLINE Packet8c pcast(const Packet8uc& a) { + return vreinterpret_s8_u8(a); +} +template <> +EIGEN_STRONG_INLINE Packet4c pcast(const Packet4uc& a) { + return static_cast(a); +} + +//============================================================================== +// pcast, SrcType = int16_t +//============================================================================== +template <> +struct type_casting_traits { + enum { VectorizedCast = 1, SrcCoeffRatio = 1, TgtCoeffRatio = 2 }; +}; +template <> +EIGEN_STRONG_INLINE Packet4f pcast(const Packet8s& a) { + // Discard second half of input. + return vcvtq_f32_s32(vmovl_s16(vget_low_s16(a))); +} +template <> +EIGEN_STRONG_INLINE Packet2f pcast(const Packet4s& a) { + // Discard second half of input. + return vcvt_f32_s32(vget_low_s32(vmovl_s16(a))); +} + +template <> +struct type_casting_traits { + enum { VectorizedCast = 1, SrcCoeffRatio = 1, TgtCoeffRatio = 4 }; +}; +template <> +EIGEN_STRONG_INLINE Packet2l pcast(const Packet8s& a) { + // Discard all but first two values. + return vmovl_s32(vget_low_s32(vmovl_s16(vget_low_s16(a)))); +} + +template <> +struct type_casting_traits { + enum { VectorizedCast = 1, SrcCoeffRatio = 1, TgtCoeffRatio = 4 }; +}; +template <> +EIGEN_STRONG_INLINE Packet2ul pcast(const Packet8s& a) { + return vreinterpretq_u64_s64(pcast(a)); +} + +template <> +struct type_casting_traits { + enum { VectorizedCast = 1, SrcCoeffRatio = 1, TgtCoeffRatio = 2 }; +}; +template <> +EIGEN_STRONG_INLINE Packet4i pcast(const Packet8s& a) { + // Discard second half of input. + return vmovl_s16(vget_low_s16(a)); +} +template <> +EIGEN_STRONG_INLINE Packet2i pcast(const Packet4s& a) { + // Discard second half of input. + return vget_low_s32(vmovl_s16(a)); +} + +template <> +struct type_casting_traits { + enum { VectorizedCast = 1, SrcCoeffRatio = 1, TgtCoeffRatio = 2 }; +}; +template <> +EIGEN_STRONG_INLINE Packet4ui pcast(const Packet8s& a) { + return vreinterpretq_u32_s32(pcast(a)); +} +template <> +EIGEN_STRONG_INLINE Packet2ui pcast(const Packet4s& a) { + return vreinterpret_u32_s32(pcast(a)); +} + +template <> +struct type_casting_traits { + enum { VectorizedCast = 1, SrcCoeffRatio = 1, TgtCoeffRatio = 1 }; +}; +template <> +EIGEN_STRONG_INLINE Packet8s pcast(const Packet8s& a) { + return a; +} +template <> +EIGEN_STRONG_INLINE Packet4s pcast(const Packet4s& a) { + return a; +} + +template <> +struct type_casting_traits { + enum { VectorizedCast = 1, SrcCoeffRatio = 1, TgtCoeffRatio = 1 }; +}; +template <> +EIGEN_STRONG_INLINE Packet8us pcast(const Packet8s& a) { + return vreinterpretq_u16_s16(a); +} +template <> +EIGEN_STRONG_INLINE Packet4us pcast(const Packet4s& a) { + return vreinterpret_u16_s16(a); +} + +template <> +struct type_casting_traits { + enum { VectorizedCast = 1, SrcCoeffRatio = 2, TgtCoeffRatio = 1 }; +}; +template <> +EIGEN_STRONG_INLINE Packet16c pcast(const Packet8s& a, const Packet8s& b) { + return vcombine_s8(vmovn_s16(a), vmovn_s16(b)); +} +template <> +EIGEN_STRONG_INLINE Packet8c pcast(const Packet4s& a, const Packet4s& b) { + return vmovn_s16(vcombine_s16(a, b)); +} + +template <> +struct type_casting_traits { + enum { VectorizedCast = 1, SrcCoeffRatio = 2, TgtCoeffRatio = 1 }; +}; +template <> +EIGEN_STRONG_INLINE Packet16uc pcast(const Packet8s& a, const Packet8s& b) { + return vcombine_u8(vmovn_u16(vreinterpretq_u16_s16(a)), vmovn_u16(vreinterpretq_u16_s16(b))); +} +template <> +EIGEN_STRONG_INLINE Packet8uc pcast(const Packet4s& a, const Packet4s& b) { + return vmovn_u16(vcombine_u16(vreinterpret_u16_s16(a), vreinterpret_u16_s16(b))); +} + +//============================================================================== +// pcast, SrcType = uint16_t +//============================================================================== +template <> +struct type_casting_traits { + enum { VectorizedCast = 1, SrcCoeffRatio = 1, TgtCoeffRatio = 2 }; +}; +template <> +EIGEN_STRONG_INLINE Packet4f pcast(const Packet8us& a) { + // Discard second half of input. + return vcvtq_f32_u32(vmovl_u16(vget_low_u16(a))); +} +template <> +EIGEN_STRONG_INLINE Packet2f pcast(const Packet4us& a) { + // Discard second half of input. + return vcvt_f32_u32(vget_low_u32(vmovl_u16(a))); +} + +template <> +struct type_casting_traits { + enum { VectorizedCast = 1, SrcCoeffRatio = 1, TgtCoeffRatio = 4 }; +}; +template <> +EIGEN_STRONG_INLINE Packet2ul pcast(const Packet8us& a) { + // Discard all but first two values. + return vmovl_u32(vget_low_u32(vmovl_u16(vget_low_u16(a)))); +} + +template <> +struct type_casting_traits { + enum { VectorizedCast = 1, SrcCoeffRatio = 1, TgtCoeffRatio = 4 }; +}; +template <> +EIGEN_STRONG_INLINE Packet2l pcast(const Packet8us& a) { + return vreinterpretq_s64_u64(pcast(a)); +} + +template <> +struct type_casting_traits { + enum { VectorizedCast = 1, SrcCoeffRatio = 1, TgtCoeffRatio = 2 }; +}; +template <> +EIGEN_STRONG_INLINE Packet4ui pcast(const Packet8us& a) { + // Discard second half of input. + return vmovl_u16(vget_low_u16(a)); +} +template <> +EIGEN_STRONG_INLINE Packet2ui pcast(const Packet4us& a) { + // Discard second half of input. + return vget_low_u32(vmovl_u16(a)); +} + +template <> +struct type_casting_traits { + enum { VectorizedCast = 1, SrcCoeffRatio = 1, TgtCoeffRatio = 2 }; +}; +template <> +EIGEN_STRONG_INLINE Packet4i pcast(const Packet8us& a) { + return vreinterpretq_s32_u32(pcast(a)); +} +template <> +EIGEN_STRONG_INLINE Packet2i pcast(const Packet4us& a) { + return vreinterpret_s32_u32(pcast(a)); +} + +template <> +struct type_casting_traits { + enum { VectorizedCast = 1, SrcCoeffRatio = 1, TgtCoeffRatio = 1 }; +}; +template <> +EIGEN_STRONG_INLINE Packet8us pcast(const Packet8us& a) { + return a; +} +template <> +EIGEN_STRONG_INLINE Packet4us pcast(const Packet4us& a) { + return a; +} + +template <> +struct type_casting_traits { + enum { VectorizedCast = 1, SrcCoeffRatio = 1, TgtCoeffRatio = 1 }; +}; +template <> +EIGEN_STRONG_INLINE Packet8s pcast(const Packet8us& a) { + return vreinterpretq_s16_u16(a); +} +template <> +EIGEN_STRONG_INLINE Packet4s pcast(const Packet4us& a) { + return vreinterpret_s16_u16(a); +} + +template <> +struct type_casting_traits { + enum { VectorizedCast = 1, SrcCoeffRatio = 2, TgtCoeffRatio = 1 }; +}; +template <> +EIGEN_STRONG_INLINE Packet16uc pcast(const Packet8us& a, const Packet8us& b) { + return vcombine_u8(vmovn_u16(a), vmovn_u16(b)); +} +template <> +EIGEN_STRONG_INLINE Packet8uc pcast(const Packet4us& a, const Packet4us& b) { + return vmovn_u16(vcombine_u16(a, b)); +} + +template <> +struct type_casting_traits { + enum { VectorizedCast = 1, SrcCoeffRatio = 2, TgtCoeffRatio = 1 }; +}; +template <> +EIGEN_STRONG_INLINE Packet16c pcast(const Packet8us& a, const Packet8us& b) { + return vreinterpretq_s8_u8(pcast(a, b)); +} +template <> +EIGEN_STRONG_INLINE Packet8c pcast(const Packet4us& a, const Packet4us& b) { + return vreinterpret_s8_u8(pcast(a, b)); +} + +//============================================================================== +// pcast, SrcType = int32_t +//============================================================================== +template <> +struct type_casting_traits { + enum { VectorizedCast = 1, SrcCoeffRatio = 1, TgtCoeffRatio = 1 }; +}; +template <> +EIGEN_STRONG_INLINE Packet4f pcast(const Packet4i& a) { + return vcvtq_f32_s32(a); +} +template <> +EIGEN_STRONG_INLINE Packet2f pcast(const Packet2i& a) { + return vcvt_f32_s32(a); +} + +template <> +struct type_casting_traits { + enum { VectorizedCast = 1, SrcCoeffRatio = 1, TgtCoeffRatio = 2 }; +}; +template <> +EIGEN_STRONG_INLINE Packet2l pcast(const Packet4i& a) { + // Discard second half of input. + return vmovl_s32(vget_low_s32(a)); +} + +template <> +struct type_casting_traits { + enum { VectorizedCast = 1, SrcCoeffRatio = 1, TgtCoeffRatio = 2 }; +}; +template <> +EIGEN_STRONG_INLINE Packet2ul pcast(const Packet4i& a) { + return vreinterpretq_u64_s64(pcast(a)); +} + +template <> +struct type_casting_traits { + enum { VectorizedCast = 1, SrcCoeffRatio = 1, TgtCoeffRatio = 1 }; +}; +template <> +EIGEN_STRONG_INLINE Packet4i pcast(const Packet4i& a) { + return a; +} +template <> +EIGEN_STRONG_INLINE Packet2i pcast(const Packet2i& a) { + return a; +} + +template <> +struct type_casting_traits { + enum { VectorizedCast = 1, SrcCoeffRatio = 1, TgtCoeffRatio = 1 }; +}; +template <> +EIGEN_STRONG_INLINE Packet4ui pcast(const Packet4i& a) { + return vreinterpretq_u32_s32(a); +} +template <> +EIGEN_STRONG_INLINE Packet2ui pcast(const Packet2i& a) { + return vreinterpret_u32_s32(a); +} + +template <> +struct type_casting_traits { + enum { VectorizedCast = 1, SrcCoeffRatio = 2, TgtCoeffRatio = 1 }; +}; +template <> +EIGEN_STRONG_INLINE Packet8s pcast(const Packet4i& a, const Packet4i& b) { + return vcombine_s16(vmovn_s32(a), vmovn_s32(b)); +} +template <> +EIGEN_STRONG_INLINE Packet4s pcast(const Packet2i& a, const Packet2i& b) { + return vmovn_s32(vcombine_s32(a, b)); +} + +template <> +struct type_casting_traits { + enum { VectorizedCast = 1, SrcCoeffRatio = 2, TgtCoeffRatio = 1 }; +}; +template <> +EIGEN_STRONG_INLINE Packet8us pcast(const Packet4i& a, const Packet4i& b) { + return vcombine_u16(vmovn_u32(vreinterpretq_u32_s32(a)), vmovn_u32(vreinterpretq_u32_s32(b))); +} +template <> +EIGEN_STRONG_INLINE Packet4us pcast(const Packet2i& a, const Packet2i& b) { + return vmovn_u32(vreinterpretq_u32_s32(vcombine_s32(a, b))); +} + +template <> +struct type_casting_traits { + enum { VectorizedCast = 1, SrcCoeffRatio = 4, TgtCoeffRatio = 1 }; +}; +template <> +EIGEN_STRONG_INLINE Packet16c pcast(const Packet4i& a, const Packet4i& b, const Packet4i& c, + const Packet4i& d) { + const int16x8_t ab_s16 = pcast(a, b); + const int16x8_t cd_s16 = pcast(c, d); + return vcombine_s8(vmovn_s16(ab_s16), vmovn_s16(cd_s16)); +} +template <> +EIGEN_STRONG_INLINE Packet8c pcast(const Packet2i& a, const Packet2i& b, const Packet2i& c, + const Packet2i& d) { + const int16x4_t ab_s16 = vmovn_s32(vcombine_s32(a, b)); + const int16x4_t cd_s16 = vmovn_s32(vcombine_s32(c, d)); + return vmovn_s16(vcombine_s16(ab_s16, cd_s16)); +} + +template <> +struct type_casting_traits { + enum { VectorizedCast = 1, SrcCoeffRatio = 4, TgtCoeffRatio = 1 }; +}; +template <> +EIGEN_STRONG_INLINE Packet16uc pcast(const Packet4i& a, const Packet4i& b, const Packet4i& c, + const Packet4i& d) { + const uint16x8_t ab_u16 = pcast(a, b); + const uint16x8_t cd_u16 = pcast(c, d); + return vcombine_u8(vmovn_u16(ab_u16), vmovn_u16(cd_u16)); +} +template <> +EIGEN_STRONG_INLINE Packet8uc pcast(const Packet2i& a, const Packet2i& b, const Packet2i& c, + const Packet2i& d) { + const uint16x4_t ab_u16 = pcast(a, b); + const uint16x4_t cd_u16 = pcast(c, d); + return vmovn_u16(vcombine_u16(ab_u16, cd_u16)); +} + +//============================================================================== +// pcast, SrcType = uint32_t +//============================================================================== +template <> +struct type_casting_traits { + enum { VectorizedCast = 1, SrcCoeffRatio = 1, TgtCoeffRatio = 1 }; +}; +template <> +EIGEN_STRONG_INLINE Packet4f pcast(const Packet4ui& a) { + return vcvtq_f32_u32(a); +} +template <> +EIGEN_STRONG_INLINE Packet2f pcast(const Packet2ui& a) { + return vcvt_f32_u32(a); +} + +template <> +struct type_casting_traits { + enum { VectorizedCast = 1, SrcCoeffRatio = 1, TgtCoeffRatio = 2 }; +}; +template <> +EIGEN_STRONG_INLINE Packet2ul pcast(const Packet4ui& a) { + // Discard second half of input. + return vmovl_u32(vget_low_u32(a)); +} + +template <> +struct type_casting_traits { + enum { VectorizedCast = 1, SrcCoeffRatio = 1, TgtCoeffRatio = 2 }; +}; +template <> +EIGEN_STRONG_INLINE Packet2l pcast(const Packet4ui& a) { + return vreinterpretq_s64_u64(pcast(a)); +} + +template <> +struct type_casting_traits { + enum { VectorizedCast = 1, SrcCoeffRatio = 1, TgtCoeffRatio = 1 }; +}; +template <> +EIGEN_STRONG_INLINE Packet4ui pcast(const Packet4ui& a) { + return a; +} +template <> +EIGEN_STRONG_INLINE Packet2ui pcast(const Packet2ui& a) { + return a; +} + +template <> +struct type_casting_traits { + enum { VectorizedCast = 1, SrcCoeffRatio = 1, TgtCoeffRatio = 1 }; +}; +template <> +EIGEN_STRONG_INLINE Packet4i pcast(const Packet4ui& a) { + return vreinterpretq_s32_u32(a); +} +template <> +EIGEN_STRONG_INLINE Packet2i pcast(const Packet2ui& a) { + return vreinterpret_s32_u32(a); +} + +template <> +struct type_casting_traits { + enum { VectorizedCast = 1, SrcCoeffRatio = 2, TgtCoeffRatio = 1 }; +}; +template <> +EIGEN_STRONG_INLINE Packet8us pcast(const Packet4ui& a, const Packet4ui& b) { + return vcombine_u16(vmovn_u32(a), vmovn_u32(b)); +} +template <> +EIGEN_STRONG_INLINE Packet4us pcast(const Packet2ui& a, const Packet2ui& b) { + return vmovn_u32(vcombine_u32(a, b)); +} + +template <> +struct type_casting_traits { + enum { VectorizedCast = 1, SrcCoeffRatio = 2, TgtCoeffRatio = 1 }; +}; +template <> +EIGEN_STRONG_INLINE Packet8s pcast(const Packet4ui& a, const Packet4ui& b) { + return vreinterpretq_s16_u16(pcast(a, b)); +} +template <> +EIGEN_STRONG_INLINE Packet4s pcast(const Packet2ui& a, const Packet2ui& b) { + return vreinterpret_s16_u16(pcast(a, b)); +} + +template <> +struct type_casting_traits { + enum { VectorizedCast = 1, SrcCoeffRatio = 4, TgtCoeffRatio = 1 }; +}; +template <> +EIGEN_STRONG_INLINE Packet16uc pcast(const Packet4ui& a, const Packet4ui& b, const Packet4ui& c, + const Packet4ui& d) { + const uint16x8_t ab_u16 = vcombine_u16(vmovn_u32(a), vmovn_u32(b)); + const uint16x8_t cd_u16 = vcombine_u16(vmovn_u32(c), vmovn_u32(d)); + return vcombine_u8(vmovn_u16(ab_u16), vmovn_u16(cd_u16)); +} +template <> +EIGEN_STRONG_INLINE Packet8uc pcast(const Packet2ui& a, const Packet2ui& b, const Packet2ui& c, + const Packet2ui& d) { + const uint16x4_t ab_u16 = vmovn_u32(vcombine_u32(a, b)); + const uint16x4_t cd_u16 = vmovn_u32(vcombine_u32(c, d)); + return vmovn_u16(vcombine_u16(ab_u16, cd_u16)); +} + +template <> +struct type_casting_traits { + enum { VectorizedCast = 1, SrcCoeffRatio = 4, TgtCoeffRatio = 1 }; +}; +template <> +EIGEN_STRONG_INLINE Packet16c pcast(const Packet4ui& a, const Packet4ui& b, const Packet4ui& c, + const Packet4ui& d) { + return vreinterpretq_s8_u8(pcast(a, b, c, d)); +} +template <> +EIGEN_STRONG_INLINE Packet8c pcast(const Packet2ui& a, const Packet2ui& b, const Packet2ui& c, + const Packet2ui& d) { + return vreinterpret_s8_u8(pcast(a, b, c, d)); +} + +//============================================================================== +// pcast, SrcType = int64_t +//============================================================================== +template <> +struct type_casting_traits { + enum { VectorizedCast = 1, SrcCoeffRatio = 2, TgtCoeffRatio = 1 }; +}; +template <> +EIGEN_STRONG_INLINE Packet4f pcast(const Packet2l& a, const Packet2l& b) { + return vcvtq_f32_s32(vcombine_s32(vmovn_s64(a), vmovn_s64(b))); +} + +template <> +struct type_casting_traits { + enum { VectorizedCast = 1, SrcCoeffRatio = 1, TgtCoeffRatio = 1 }; +}; +template <> +EIGEN_STRONG_INLINE Packet2l pcast(const Packet2l& a) { + return a; +} + +template <> +struct type_casting_traits { + enum { VectorizedCast = 1, SrcCoeffRatio = 1, TgtCoeffRatio = 1 }; +}; +template <> +EIGEN_STRONG_INLINE Packet2ul pcast(const Packet2l& a) { + return vreinterpretq_u64_s64(a); +} + +template <> +struct type_casting_traits { + enum { VectorizedCast = 1, SrcCoeffRatio = 2, TgtCoeffRatio = 1 }; +}; +template <> +EIGEN_STRONG_INLINE Packet4i pcast(const Packet2l& a, const Packet2l& b) { + return vcombine_s32(vmovn_s64(a), vmovn_s64(b)); +} + +template <> +struct type_casting_traits { + enum { VectorizedCast = 1, SrcCoeffRatio = 2, TgtCoeffRatio = 1 }; +}; +template <> +EIGEN_STRONG_INLINE Packet4ui pcast(const Packet2l& a, const Packet2l& b) { + return vcombine_u32(vmovn_u64(vreinterpretq_u64_s64(a)), vmovn_u64(vreinterpretq_u64_s64(b))); +} + +template <> +struct type_casting_traits { + enum { VectorizedCast = 1, SrcCoeffRatio = 4, TgtCoeffRatio = 1 }; +}; +template <> +EIGEN_STRONG_INLINE Packet8s pcast(const Packet2l& a, const Packet2l& b, const Packet2l& c, + const Packet2l& d) { + const int32x4_t ab_s32 = pcast(a, b); + const int32x4_t cd_s32 = pcast(c, d); + return vcombine_s16(vmovn_s32(ab_s32), vmovn_s32(cd_s32)); +} + +template <> +struct type_casting_traits { + enum { VectorizedCast = 1, SrcCoeffRatio = 4, TgtCoeffRatio = 1 }; +}; +template <> +EIGEN_STRONG_INLINE Packet8us pcast(const Packet2l& a, const Packet2l& b, const Packet2l& c, + const Packet2l& d) { + const uint32x4_t ab_u32 = pcast(a, b); + const uint32x4_t cd_u32 = pcast(c, d); + return vcombine_u16(vmovn_u32(ab_u32), vmovn_u32(cd_u32)); +} + +template <> +struct type_casting_traits { + enum { VectorizedCast = 1, SrcCoeffRatio = 8, TgtCoeffRatio = 1 }; +}; +template <> +EIGEN_STRONG_INLINE Packet16c pcast(const Packet2l& a, const Packet2l& b, const Packet2l& c, + const Packet2l& d, const Packet2l& e, const Packet2l& f, + const Packet2l& g, const Packet2l& h) { + const int16x8_t abcd_s16 = pcast(a, b, c, d); + const int16x8_t efgh_s16 = pcast(e, f, g, h); + return vcombine_s8(vmovn_s16(abcd_s16), vmovn_s16(efgh_s16)); +} + +template <> +struct type_casting_traits { + enum { VectorizedCast = 1, SrcCoeffRatio = 8, TgtCoeffRatio = 1 }; +}; +template <> +EIGEN_STRONG_INLINE Packet16uc pcast(const Packet2l& a, const Packet2l& b, const Packet2l& c, + const Packet2l& d, const Packet2l& e, const Packet2l& f, + const Packet2l& g, const Packet2l& h) { + const uint16x8_t abcd_u16 = pcast(a, b, c, d); + const uint16x8_t efgh_u16 = pcast(e, f, g, h); + return vcombine_u8(vmovn_u16(abcd_u16), vmovn_u16(efgh_u16)); +} + +//============================================================================== +// pcast, SrcType = uint64_t +//============================================================================== +template <> +struct type_casting_traits { + enum { VectorizedCast = 1, SrcCoeffRatio = 2, TgtCoeffRatio = 1 }; +}; +template <> +EIGEN_STRONG_INLINE Packet4f pcast(const Packet2ul& a, const Packet2ul& b) { + return vcvtq_f32_u32(vcombine_u32(vmovn_u64(a), vmovn_u64(b))); +} + +template <> +struct type_casting_traits { + enum { VectorizedCast = 1, SrcCoeffRatio = 1, TgtCoeffRatio = 1 }; +}; +template <> +EIGEN_STRONG_INLINE Packet2ul pcast(const Packet2ul& a) { + return a; +} + +template <> +struct type_casting_traits { + enum { VectorizedCast = 1, SrcCoeffRatio = 1, TgtCoeffRatio = 1 }; +}; +template <> +EIGEN_STRONG_INLINE Packet2l pcast(const Packet2ul& a) { + return vreinterpretq_s64_u64(a); +} + +template <> +struct type_casting_traits { + enum { VectorizedCast = 1, SrcCoeffRatio = 2, TgtCoeffRatio = 1 }; +}; +template <> +EIGEN_STRONG_INLINE Packet4ui pcast(const Packet2ul& a, const Packet2ul& b) { + return vcombine_u32(vmovn_u64(a), vmovn_u64(b)); +} + +template <> +struct type_casting_traits { + enum { VectorizedCast = 1, SrcCoeffRatio = 2, TgtCoeffRatio = 1 }; +}; +template <> +EIGEN_STRONG_INLINE Packet4i pcast(const Packet2ul& a, const Packet2ul& b) { + return vreinterpretq_s32_u32(pcast(a, b)); +} + +template <> +struct type_casting_traits { + enum { VectorizedCast = 1, SrcCoeffRatio = 4, TgtCoeffRatio = 1 }; +}; +template <> +EIGEN_STRONG_INLINE Packet8us pcast(const Packet2ul& a, const Packet2ul& b, const Packet2ul& c, + const Packet2ul& d) { + const uint16x4_t ab_u16 = vmovn_u32(vcombine_u32(vmovn_u64(a), vmovn_u64(b))); + const uint16x4_t cd_u16 = vmovn_u32(vcombine_u32(vmovn_u64(c), vmovn_u64(d))); + return vcombine_u16(ab_u16, cd_u16); +} + +template <> +struct type_casting_traits { + enum { VectorizedCast = 1, SrcCoeffRatio = 4, TgtCoeffRatio = 1 }; +}; +template <> +EIGEN_STRONG_INLINE Packet8s pcast(const Packet2ul& a, const Packet2ul& b, const Packet2ul& c, + const Packet2ul& d) { + return vreinterpretq_s16_u16(pcast(a, b, c, d)); +} + +template <> +struct type_casting_traits { + enum { VectorizedCast = 1, SrcCoeffRatio = 8, TgtCoeffRatio = 1 }; +}; +template <> +EIGEN_STRONG_INLINE Packet16uc pcast(const Packet2ul& a, const Packet2ul& b, const Packet2ul& c, + const Packet2ul& d, const Packet2ul& e, const Packet2ul& f, + const Packet2ul& g, const Packet2ul& h) { + const uint16x8_t abcd_u16 = pcast(a, b, c, d); + const uint16x8_t efgh_u16 = pcast(e, f, g, h); + return vcombine_u8(vmovn_u16(abcd_u16), vmovn_u16(efgh_u16)); +} + +template <> +struct type_casting_traits { + enum { VectorizedCast = 1, SrcCoeffRatio = 8, TgtCoeffRatio = 1 }; +}; +template <> +EIGEN_STRONG_INLINE Packet16c pcast(const Packet2ul& a, const Packet2ul& b, const Packet2ul& c, + const Packet2ul& d, const Packet2ul& e, const Packet2ul& f, + const Packet2ul& g, const Packet2ul& h) { + return vreinterpretq_s8_u8(pcast(a, b, c, d, e, f, g, h)); +} + +//============================================================================== +// preinterpret +//============================================================================== +template <> +EIGEN_STRONG_INLINE Packet2f preinterpret(const Packet2i& a) { + return vreinterpret_f32_s32(a); +} +template <> +EIGEN_STRONG_INLINE Packet2f preinterpret(const Packet2ui& a) { + return vreinterpret_f32_u32(a); +} +template <> +EIGEN_STRONG_INLINE Packet4f preinterpret(const Packet4i& a) { + return vreinterpretq_f32_s32(a); +} +template <> +EIGEN_STRONG_INLINE Packet4f preinterpret(const Packet4ui& a) { + return vreinterpretq_f32_u32(a); +} + +template <> +EIGEN_STRONG_INLINE Packet4c preinterpret(const Packet4uc& a) { + return static_cast(a); +} +template <> +EIGEN_STRONG_INLINE Packet8c preinterpret(const Packet8uc& a) { + return vreinterpret_s8_u8(a); +} +template <> +EIGEN_STRONG_INLINE Packet16c preinterpret(const Packet16uc& a) { + return vreinterpretq_s8_u8(a); +} + +template <> +EIGEN_STRONG_INLINE Packet4uc preinterpret(const Packet4c& a) { + return static_cast(a); +} +template <> +EIGEN_STRONG_INLINE Packet8uc preinterpret(const Packet8c& a) { + return vreinterpret_u8_s8(a); +} +template <> +EIGEN_STRONG_INLINE Packet16uc preinterpret(const Packet16c& a) { + return vreinterpretq_u8_s8(a); +} + +template <> +EIGEN_STRONG_INLINE Packet4s preinterpret(const Packet4us& a) { + return vreinterpret_s16_u16(a); +} +template <> +EIGEN_STRONG_INLINE Packet8s preinterpret(const Packet8us& a) { + return vreinterpretq_s16_u16(a); +} + +template <> +EIGEN_STRONG_INLINE Packet4us preinterpret(const Packet4s& a) { + return vreinterpret_u16_s16(a); +} +template <> +EIGEN_STRONG_INLINE Packet8us preinterpret(const Packet8s& a) { + return vreinterpretq_u16_s16(a); +} + +template <> +EIGEN_STRONG_INLINE Packet2i preinterpret(const Packet2f& a) { + return vreinterpret_s32_f32(a); +} +template <> +EIGEN_STRONG_INLINE Packet2i preinterpret(const Packet2ui& a) { + return vreinterpret_s32_u32(a); +} +template <> +EIGEN_STRONG_INLINE Packet4i preinterpret(const Packet4f& a) { + return vreinterpretq_s32_f32(a); +} +template <> +EIGEN_STRONG_INLINE Packet4i preinterpret(const Packet4ui& a) { + return vreinterpretq_s32_u32(a); +} + +template <> +EIGEN_STRONG_INLINE Packet2ui preinterpret(const Packet2f& a) { + return vreinterpret_u32_f32(a); +} +template <> +EIGEN_STRONG_INLINE Packet2ui preinterpret(const Packet2i& a) { + return vreinterpret_u32_s32(a); +} +template <> +EIGEN_STRONG_INLINE Packet4ui preinterpret(const Packet4f& a) { + return vreinterpretq_u32_f32(a); +} +template <> +EIGEN_STRONG_INLINE Packet4ui preinterpret(const Packet4i& a) { + return vreinterpretq_u32_s32(a); +} + +template <> +EIGEN_STRONG_INLINE Packet2l preinterpret(const Packet2ul& a) { + return vreinterpretq_s64_u64(a); +} +template <> +EIGEN_STRONG_INLINE Packet2ul preinterpret(const Packet2l& a) { + return vreinterpretq_u64_s64(a); +} + +#if EIGEN_ARCH_ARM64 + +//============================================================================== +// pcast/preinterpret, Double +//============================================================================== + +template <> +struct type_casting_traits { + enum { VectorizedCast = 1, SrcCoeffRatio = 1, TgtCoeffRatio = 1 }; +}; +template <> +EIGEN_STRONG_INLINE Packet2d pcast(const Packet2d& a) { + return a; +} + +template <> +struct type_casting_traits { + enum { VectorizedCast = 1, SrcCoeffRatio = 2, TgtCoeffRatio = 1 }; +}; +template <> +EIGEN_STRONG_INLINE Packet4f pcast(const Packet2d& a, const Packet2d& b) { + return vcombine_f32(vcvt_f32_f64(a), vcvt_f32_f64(b)); +} + +template <> +struct type_casting_traits { + enum { VectorizedCast = 1, SrcCoeffRatio = 1, TgtCoeffRatio = 1 }; +}; +template <> +EIGEN_STRONG_INLINE Packet2l pcast(const Packet2d& a) { + return vcvtq_s64_f64(a); +} + +template <> +struct type_casting_traits { + enum { VectorizedCast = 1, SrcCoeffRatio = 1, TgtCoeffRatio = 1 }; +}; +template <> +EIGEN_STRONG_INLINE Packet2ul pcast(const Packet2d& a) { + return vcvtq_u64_f64(a); +} + +template <> +struct type_casting_traits { + enum { VectorizedCast = 1, SrcCoeffRatio = 2, TgtCoeffRatio = 1 }; +}; +template <> +EIGEN_STRONG_INLINE Packet4i pcast(const Packet2d& a, const Packet2d& b) { + return vcombine_s32(vmovn_s64(vcvtq_s64_f64(a)), vmovn_s64(vcvtq_s64_f64(b))); +} + +template <> +struct type_casting_traits { + enum { VectorizedCast = 1, SrcCoeffRatio = 2, TgtCoeffRatio = 1 }; +}; +template <> +EIGEN_STRONG_INLINE Packet4ui pcast(const Packet2d& a, const Packet2d& b) { + return vcombine_u32(vmovn_u64(vcvtq_u64_f64(a)), vmovn_u64(vcvtq_u64_f64(b))); +} + +template <> +struct type_casting_traits { + enum { VectorizedCast = 1, SrcCoeffRatio = 4, TgtCoeffRatio = 1 }; +}; +template <> +EIGEN_STRONG_INLINE Packet8s pcast(const Packet2d& a, const Packet2d& b, const Packet2d& c, + const Packet2d& d) { + const int32x4_t ab_s32 = pcast(a, b); + const int32x4_t cd_s32 = pcast(c, d); + return vcombine_s16(vmovn_s32(ab_s32), vmovn_s32(cd_s32)); +} + +template <> +struct type_casting_traits { + enum { VectorizedCast = 1, SrcCoeffRatio = 4, TgtCoeffRatio = 1 }; +}; +template <> +EIGEN_STRONG_INLINE Packet8us pcast(const Packet2d& a, const Packet2d& b, const Packet2d& c, + const Packet2d& d) { + const uint32x4_t ab_u32 = pcast(a, b); + const uint32x4_t cd_u32 = pcast(c, d); + return vcombine_u16(vmovn_u32(ab_u32), vmovn_u32(cd_u32)); +} + +template <> +struct type_casting_traits { + enum { VectorizedCast = 1, SrcCoeffRatio = 8, TgtCoeffRatio = 1 }; +}; +template <> +EIGEN_STRONG_INLINE Packet16c pcast(const Packet2d& a, const Packet2d& b, const Packet2d& c, + const Packet2d& d, const Packet2d& e, const Packet2d& f, + const Packet2d& g, const Packet2d& h) { + const int16x8_t abcd_s16 = pcast(a, b, c, d); + const int16x8_t efgh_s16 = pcast(e, f, g, h); + return vcombine_s8(vmovn_s16(abcd_s16), vmovn_s16(efgh_s16)); +} + +template <> +struct type_casting_traits { + enum { VectorizedCast = 1, SrcCoeffRatio = 8, TgtCoeffRatio = 1 }; +}; +template <> +EIGEN_STRONG_INLINE Packet16uc pcast(const Packet2d& a, const Packet2d& b, const Packet2d& c, + const Packet2d& d, const Packet2d& e, const Packet2d& f, + const Packet2d& g, const Packet2d& h) { + const uint16x8_t abcd_u16 = pcast(a, b, c, d); + const uint16x8_t efgh_u16 = pcast(e, f, g, h); + return vcombine_u8(vmovn_u16(abcd_u16), vmovn_u16(efgh_u16)); +} + +template <> +struct type_casting_traits { + enum { VectorizedCast = 1, SrcCoeffRatio = 1, TgtCoeffRatio = 2 }; +}; +template <> +EIGEN_STRONG_INLINE Packet2d pcast(const Packet4f& a) { + // Discard second-half of input. + return vcvt_f64_f32(vget_low_f32(a)); +} + +template <> +struct type_casting_traits { + enum { VectorizedCast = 1, SrcCoeffRatio = 1, TgtCoeffRatio = 8 }; +}; +template <> +EIGEN_STRONG_INLINE Packet2d pcast(const Packet16c& a) { + // Discard all but first two values. + return vcvt_f64_f32(pcast(vget_low_s8(a))); +} + +template <> +struct type_casting_traits { + enum { VectorizedCast = 1, SrcCoeffRatio = 1, TgtCoeffRatio = 8 }; +}; +template <> +EIGEN_STRONG_INLINE Packet2d pcast(const Packet16uc& a) { + // Discard all but first two values. + return vcvt_f64_f32(pcast(vget_low_u8(a))); +} + +template <> +struct type_casting_traits { + enum { VectorizedCast = 1, SrcCoeffRatio = 1, TgtCoeffRatio = 4 }; +}; +template <> +EIGEN_STRONG_INLINE Packet2d pcast(const Packet8s& a) { + // Discard all but first two values. + return vcvt_f64_f32(pcast(vget_low_s16(a))); +} + +template <> +struct type_casting_traits { + enum { VectorizedCast = 1, SrcCoeffRatio = 1, TgtCoeffRatio = 4 }; +}; +template <> +EIGEN_STRONG_INLINE Packet2d pcast(const Packet8us& a) { + // Discard all but first two values. + return vcvt_f64_f32(pcast(vget_low_u16(a))); +} + +template <> +struct type_casting_traits { + enum { VectorizedCast = 1, SrcCoeffRatio = 1, TgtCoeffRatio = 2 }; +}; +template <> +EIGEN_STRONG_INLINE Packet2d pcast(const Packet4i& a) { + // Discard second half of input. + return vcvtq_f64_s64(vmovl_s32(vget_low_s32(a))); +} + +template <> +struct type_casting_traits { + enum { VectorizedCast = 1, SrcCoeffRatio = 1, TgtCoeffRatio = 2 }; +}; +template <> +EIGEN_STRONG_INLINE Packet2d pcast(const Packet4ui& a) { + // Discard second half of input. + return vcvtq_f64_u64(vmovl_u32(vget_low_u32(a))); +} + +template <> +struct type_casting_traits { + enum { VectorizedCast = 1, SrcCoeffRatio = 1, TgtCoeffRatio = 1 }; +}; +template <> +EIGEN_STRONG_INLINE Packet2d pcast(const Packet2l& a) { + return vcvtq_f64_s64(a); +} + +template <> +struct type_casting_traits { + enum { VectorizedCast = 1, SrcCoeffRatio = 1, TgtCoeffRatio = 1 }; +}; +template <> +EIGEN_STRONG_INLINE Packet2d pcast(const Packet2ul& a) { + return vcvtq_f64_u64(a); +} + +template <> +EIGEN_STRONG_INLINE Packet2d preinterpret(const Packet2l& a) { + return vreinterpretq_f64_s64(a); +} +template <> +EIGEN_STRONG_INLINE Packet2d preinterpret(const Packet2ul& a) { + return vreinterpretq_f64_u64(a); +} +template <> +EIGEN_STRONG_INLINE Packet2l preinterpret(const Packet2d& a) { + return vreinterpretq_s64_f64(a); +} +template <> +EIGEN_STRONG_INLINE Packet2ul preinterpret(const Packet2d& a) { + return vreinterpretq_u64_f64(a); +} +template <> +EIGEN_STRONG_INLINE Packet2d preinterpret(const Packet4i& a) { + return vreinterpretq_f64_s32(a); +} +template <> +EIGEN_STRONG_INLINE Packet4i preinterpret(const Packet2d& a) { + return vreinterpretq_s32_f64(a); +} + +#endif // EIGEN_ARCH_ARM64 + +} // end namespace internal + +} // end namespace Eigen + +#endif // EIGEN_TYPE_CASTING_NEON_H diff --git a/Eigen/src/Core/arch/SSE/Complex.h b/Eigen/src/Core/arch/SSE/Complex.h new file mode 100644 index 0000000..8fe22da --- /dev/null +++ b/Eigen/src/Core/arch/SSE/Complex.h @@ -0,0 +1,351 @@ +// This file is part of Eigen, a lightweight C++ template library +// for linear algebra. +// +// Copyright (C) 2010 Gael Guennebaud +// +// This Source Code Form is subject to the terms of the Mozilla +// Public License v. 2.0. If a copy of the MPL was not distributed +// with this file, You can obtain one at http://mozilla.org/MPL/2.0/. + +#ifndef EIGEN_COMPLEX_SSE_H +#define EIGEN_COMPLEX_SSE_H + +namespace Eigen { + +namespace internal { + +//---------- float ---------- +struct Packet2cf +{ + EIGEN_STRONG_INLINE Packet2cf() {} + EIGEN_STRONG_INLINE explicit Packet2cf(const __m128& a) : v(a) {} + Packet4f v; +}; + +// Use the packet_traits defined in AVX/PacketMath.h instead if we're going +// to leverage AVX instructions. +#ifndef EIGEN_VECTORIZE_AVX +template<> struct packet_traits > : default_packet_traits +{ + typedef Packet2cf type; + typedef Packet2cf half; + enum { + Vectorizable = 1, + AlignedOnScalar = 1, + size = 2, + HasHalfPacket = 0, + + HasAdd = 1, + HasSub = 1, + HasMul = 1, + HasDiv = 1, + HasNegate = 1, + HasSqrt = 1, + HasAbs = 0, + HasAbs2 = 0, + HasMin = 0, + HasMax = 0, + HasSetLinear = 0, + HasBlend = 1 + }; +}; +#endif + +template<> struct unpacket_traits { + typedef std::complex type; + typedef Packet2cf half; + typedef Packet4f as_real; + enum { + size=2, + alignment=Aligned16, + vectorizable=true, + masked_load_available=false, + masked_store_available=false + }; +}; + +template<> EIGEN_STRONG_INLINE Packet2cf padd(const Packet2cf& a, const Packet2cf& b) { return Packet2cf(_mm_add_ps(a.v,b.v)); } +template<> EIGEN_STRONG_INLINE Packet2cf psub(const Packet2cf& a, const Packet2cf& b) { return Packet2cf(_mm_sub_ps(a.v,b.v)); } + +template<> EIGEN_STRONG_INLINE Packet2cf pnegate(const Packet2cf& a) +{ + const __m128 mask = _mm_castsi128_ps(_mm_setr_epi32(0x80000000,0x80000000,0x80000000,0x80000000)); + return Packet2cf(_mm_xor_ps(a.v,mask)); +} +template<> EIGEN_STRONG_INLINE Packet2cf pconj(const Packet2cf& a) +{ + const __m128 mask = _mm_castsi128_ps(_mm_setr_epi32(0x00000000,0x80000000,0x00000000,0x80000000)); + return Packet2cf(_mm_xor_ps(a.v,mask)); +} + +template<> EIGEN_STRONG_INLINE Packet2cf pmul(const Packet2cf& a, const Packet2cf& b) +{ + #ifdef EIGEN_VECTORIZE_SSE3 + return Packet2cf(_mm_addsub_ps(_mm_mul_ps(_mm_moveldup_ps(a.v), b.v), + _mm_mul_ps(_mm_movehdup_ps(a.v), + vec4f_swizzle1(b.v, 1, 0, 3, 2)))); +// return Packet2cf(_mm_addsub_ps(_mm_mul_ps(vec4f_swizzle1(a.v, 0, 0, 2, 2), b.v), +// _mm_mul_ps(vec4f_swizzle1(a.v, 1, 1, 3, 3), +// vec4f_swizzle1(b.v, 1, 0, 3, 2)))); + #else + const __m128 mask = _mm_castsi128_ps(_mm_setr_epi32(0x80000000,0x00000000,0x80000000,0x00000000)); + return Packet2cf(_mm_add_ps(_mm_mul_ps(vec4f_swizzle1(a.v, 0, 0, 2, 2), b.v), + _mm_xor_ps(_mm_mul_ps(vec4f_swizzle1(a.v, 1, 1, 3, 3), + vec4f_swizzle1(b.v, 1, 0, 3, 2)), mask))); + #endif +} + +template<> EIGEN_STRONG_INLINE Packet2cf ptrue (const Packet2cf& a) { return Packet2cf(ptrue(Packet4f(a.v))); } +template<> EIGEN_STRONG_INLINE Packet2cf pand (const Packet2cf& a, const Packet2cf& b) { return Packet2cf(_mm_and_ps(a.v,b.v)); } +template<> EIGEN_STRONG_INLINE Packet2cf por (const Packet2cf& a, const Packet2cf& b) { return Packet2cf(_mm_or_ps(a.v,b.v)); } +template<> EIGEN_STRONG_INLINE Packet2cf pxor (const Packet2cf& a, const Packet2cf& b) { return Packet2cf(_mm_xor_ps(a.v,b.v)); } +template<> EIGEN_STRONG_INLINE Packet2cf pandnot(const Packet2cf& a, const Packet2cf& b) { return Packet2cf(_mm_andnot_ps(b.v,a.v)); } + +template<> EIGEN_STRONG_INLINE Packet2cf pload (const std::complex* from) { EIGEN_DEBUG_ALIGNED_LOAD return Packet2cf(pload(&numext::real_ref(*from))); } +template<> EIGEN_STRONG_INLINE Packet2cf ploadu(const std::complex* from) { EIGEN_DEBUG_UNALIGNED_LOAD return Packet2cf(ploadu(&numext::real_ref(*from))); } + +template<> EIGEN_STRONG_INLINE Packet2cf pset1(const std::complex& from) +{ + Packet2cf res; +#ifdef EIGEN_VECTORIZE_SSE3 + res.v = _mm_castpd_ps(_mm_loaddup_pd(reinterpret_cast(&from))); +#else + res.v = _mm_castpd_ps(_mm_load_sd(reinterpret_cast(&from))); + res.v = _mm_movelh_ps(res.v, res.v); +#endif + return res; +} + +template<> EIGEN_STRONG_INLINE Packet2cf ploaddup(const std::complex* from) { return pset1(*from); } + +template<> EIGEN_STRONG_INLINE void pstore >(std::complex * to, const Packet2cf& from) { EIGEN_DEBUG_ALIGNED_STORE pstore(&numext::real_ref(*to), Packet4f(from.v)); } +template<> EIGEN_STRONG_INLINE void pstoreu >(std::complex * to, const Packet2cf& from) { EIGEN_DEBUG_UNALIGNED_STORE pstoreu(&numext::real_ref(*to), Packet4f(from.v)); } + + +template<> EIGEN_DEVICE_FUNC inline Packet2cf pgather, Packet2cf>(const std::complex* from, Index stride) +{ + return Packet2cf(_mm_set_ps(std::imag(from[1*stride]), std::real(from[1*stride]), + std::imag(from[0*stride]), std::real(from[0*stride]))); +} + +template<> EIGEN_DEVICE_FUNC inline void pscatter, Packet2cf>(std::complex* to, const Packet2cf& from, Index stride) +{ + to[stride*0] = std::complex(_mm_cvtss_f32(_mm_shuffle_ps(from.v, from.v, 0)), + _mm_cvtss_f32(_mm_shuffle_ps(from.v, from.v, 1))); + to[stride*1] = std::complex(_mm_cvtss_f32(_mm_shuffle_ps(from.v, from.v, 2)), + _mm_cvtss_f32(_mm_shuffle_ps(from.v, from.v, 3))); +} + +template<> EIGEN_STRONG_INLINE void prefetch >(const std::complex * addr) { _mm_prefetch((SsePrefetchPtrType)(addr), _MM_HINT_T0); } + +template<> EIGEN_STRONG_INLINE std::complex pfirst(const Packet2cf& a) +{ + #if EIGEN_GNUC_AT_MOST(4,3) + // Workaround gcc 4.2 ICE - this is not performance wise ideal, but who cares... + // This workaround also fix invalid code generation with gcc 4.3 + EIGEN_ALIGN16 std::complex res[2]; + _mm_store_ps((float*)res, a.v); + return res[0]; + #else + std::complex res; + _mm_storel_pi((__m64*)&res, a.v); + return res; + #endif +} + +template<> EIGEN_STRONG_INLINE Packet2cf preverse(const Packet2cf& a) { return Packet2cf(_mm_castpd_ps(preverse(Packet2d(_mm_castps_pd(a.v))))); } + +template<> EIGEN_STRONG_INLINE std::complex predux(const Packet2cf& a) +{ + return pfirst(Packet2cf(_mm_add_ps(a.v, _mm_movehl_ps(a.v,a.v)))); +} + +template<> EIGEN_STRONG_INLINE std::complex predux_mul(const Packet2cf& a) +{ + return pfirst(pmul(a, Packet2cf(_mm_movehl_ps(a.v,a.v)))); +} + +EIGEN_STRONG_INLINE Packet2cf pcplxflip/* */(const Packet2cf& x) +{ + return Packet2cf(vec4f_swizzle1(x.v, 1, 0, 3, 2)); +} + +EIGEN_MAKE_CONJ_HELPER_CPLX_REAL(Packet2cf,Packet4f) + +template<> EIGEN_STRONG_INLINE Packet2cf pdiv(const Packet2cf& a, const Packet2cf& b) +{ + // TODO optimize it for SSE3 and 4 + Packet2cf res = pmul(a, pconj(b)); + __m128 s = _mm_mul_ps(b.v,b.v); + return Packet2cf(_mm_div_ps(res.v,_mm_add_ps(s,vec4f_swizzle1(s, 1, 0, 3, 2)))); +} + + + +//---------- double ---------- +struct Packet1cd +{ + EIGEN_STRONG_INLINE Packet1cd() {} + EIGEN_STRONG_INLINE explicit Packet1cd(const __m128d& a) : v(a) {} + Packet2d v; +}; + +// Use the packet_traits defined in AVX/PacketMath.h instead if we're going +// to leverage AVX instructions. +#ifndef EIGEN_VECTORIZE_AVX +template<> struct packet_traits > : default_packet_traits +{ + typedef Packet1cd type; + typedef Packet1cd half; + enum { + Vectorizable = 1, + AlignedOnScalar = 0, + size = 1, + HasHalfPacket = 0, + + HasAdd = 1, + HasSub = 1, + HasMul = 1, + HasDiv = 1, + HasNegate = 1, + HasSqrt = 1, + HasAbs = 0, + HasAbs2 = 0, + HasMin = 0, + HasMax = 0, + HasSetLinear = 0 + }; +}; +#endif + +template<> struct unpacket_traits { + typedef std::complex type; + typedef Packet1cd half; + typedef Packet2d as_real; + enum { + size=1, + alignment=Aligned16, + vectorizable=true, + masked_load_available=false, + masked_store_available=false + }; +}; + +template<> EIGEN_STRONG_INLINE Packet1cd padd(const Packet1cd& a, const Packet1cd& b) { return Packet1cd(_mm_add_pd(a.v,b.v)); } +template<> EIGEN_STRONG_INLINE Packet1cd psub(const Packet1cd& a, const Packet1cd& b) { return Packet1cd(_mm_sub_pd(a.v,b.v)); } +template<> EIGEN_STRONG_INLINE Packet1cd pnegate(const Packet1cd& a) { return Packet1cd(pnegate(Packet2d(a.v))); } +template<> EIGEN_STRONG_INLINE Packet1cd pconj(const Packet1cd& a) +{ + const __m128d mask = _mm_castsi128_pd(_mm_set_epi32(0x80000000,0x0,0x0,0x0)); + return Packet1cd(_mm_xor_pd(a.v,mask)); +} + +template<> EIGEN_STRONG_INLINE Packet1cd pmul(const Packet1cd& a, const Packet1cd& b) +{ + #ifdef EIGEN_VECTORIZE_SSE3 + return Packet1cd(_mm_addsub_pd(_mm_mul_pd(_mm_movedup_pd(a.v), b.v), + _mm_mul_pd(vec2d_swizzle1(a.v, 1, 1), + vec2d_swizzle1(b.v, 1, 0)))); + #else + const __m128d mask = _mm_castsi128_pd(_mm_set_epi32(0x0,0x0,0x80000000,0x0)); + return Packet1cd(_mm_add_pd(_mm_mul_pd(vec2d_swizzle1(a.v, 0, 0), b.v), + _mm_xor_pd(_mm_mul_pd(vec2d_swizzle1(a.v, 1, 1), + vec2d_swizzle1(b.v, 1, 0)), mask))); + #endif +} + +template<> EIGEN_STRONG_INLINE Packet1cd ptrue (const Packet1cd& a) { return Packet1cd(ptrue(Packet2d(a.v))); } +template<> EIGEN_STRONG_INLINE Packet1cd pand (const Packet1cd& a, const Packet1cd& b) { return Packet1cd(_mm_and_pd(a.v,b.v)); } +template<> EIGEN_STRONG_INLINE Packet1cd por (const Packet1cd& a, const Packet1cd& b) { return Packet1cd(_mm_or_pd(a.v,b.v)); } +template<> EIGEN_STRONG_INLINE Packet1cd pxor (const Packet1cd& a, const Packet1cd& b) { return Packet1cd(_mm_xor_pd(a.v,b.v)); } +template<> EIGEN_STRONG_INLINE Packet1cd pandnot(const Packet1cd& a, const Packet1cd& b) { return Packet1cd(_mm_andnot_pd(b.v,a.v)); } + +// FIXME force unaligned load, this is a temporary fix +template<> EIGEN_STRONG_INLINE Packet1cd pload (const std::complex* from) +{ EIGEN_DEBUG_ALIGNED_LOAD return Packet1cd(pload((const double*)from)); } +template<> EIGEN_STRONG_INLINE Packet1cd ploadu(const std::complex* from) +{ EIGEN_DEBUG_UNALIGNED_LOAD return Packet1cd(ploadu((const double*)from)); } +template<> EIGEN_STRONG_INLINE Packet1cd pset1(const std::complex& from) +{ /* here we really have to use unaligned loads :( */ return ploadu(&from); } + +template<> EIGEN_STRONG_INLINE Packet1cd ploaddup(const std::complex* from) { return pset1(*from); } + +// FIXME force unaligned store, this is a temporary fix +template<> EIGEN_STRONG_INLINE void pstore >(std::complex * to, const Packet1cd& from) { EIGEN_DEBUG_ALIGNED_STORE pstore((double*)to, Packet2d(from.v)); } +template<> EIGEN_STRONG_INLINE void pstoreu >(std::complex * to, const Packet1cd& from) { EIGEN_DEBUG_UNALIGNED_STORE pstoreu((double*)to, Packet2d(from.v)); } + +template<> EIGEN_STRONG_INLINE void prefetch >(const std::complex * addr) { _mm_prefetch((SsePrefetchPtrType)(addr), _MM_HINT_T0); } + +template<> EIGEN_STRONG_INLINE std::complex pfirst(const Packet1cd& a) +{ + EIGEN_ALIGN16 double res[2]; + _mm_store_pd(res, a.v); + return std::complex(res[0],res[1]); +} + +template<> EIGEN_STRONG_INLINE Packet1cd preverse(const Packet1cd& a) { return a; } + +template<> EIGEN_STRONG_INLINE std::complex predux(const Packet1cd& a) +{ + return pfirst(a); +} + +template<> EIGEN_STRONG_INLINE std::complex predux_mul(const Packet1cd& a) +{ + return pfirst(a); +} + +EIGEN_MAKE_CONJ_HELPER_CPLX_REAL(Packet1cd,Packet2d) + +template<> EIGEN_STRONG_INLINE Packet1cd pdiv(const Packet1cd& a, const Packet1cd& b) +{ + // TODO optimize it for SSE3 and 4 + Packet1cd res = pmul(a,pconj(b)); + __m128d s = _mm_mul_pd(b.v,b.v); + return Packet1cd(_mm_div_pd(res.v, _mm_add_pd(s,_mm_shuffle_pd(s, s, 0x1)))); +} + +EIGEN_STRONG_INLINE Packet1cd pcplxflip/* */(const Packet1cd& x) +{ + return Packet1cd(preverse(Packet2d(x.v))); +} + +EIGEN_DEVICE_FUNC inline void +ptranspose(PacketBlock& kernel) { + __m128d w1 = _mm_castps_pd(kernel.packet[0].v); + __m128d w2 = _mm_castps_pd(kernel.packet[1].v); + + __m128 tmp = _mm_castpd_ps(_mm_unpackhi_pd(w1, w2)); + kernel.packet[0].v = _mm_castpd_ps(_mm_unpacklo_pd(w1, w2)); + kernel.packet[1].v = tmp; +} + +template<> EIGEN_STRONG_INLINE Packet2cf pcmp_eq(const Packet2cf& a, const Packet2cf& b) +{ + __m128 eq = _mm_cmpeq_ps(a.v, b.v); + return Packet2cf(pand(eq, vec4f_swizzle1(eq, 1, 0, 3, 2))); +} + +template<> EIGEN_STRONG_INLINE Packet1cd pcmp_eq(const Packet1cd& a, const Packet1cd& b) +{ + __m128d eq = _mm_cmpeq_pd(a.v, b.v); + return Packet1cd(pand(eq, vec2d_swizzle1(eq, 1, 0))); +} + +template<> EIGEN_STRONG_INLINE Packet2cf pblend(const Selector<2>& ifPacket, const Packet2cf& thenPacket, const Packet2cf& elsePacket) { + __m128d result = pblend(ifPacket, _mm_castps_pd(thenPacket.v), _mm_castps_pd(elsePacket.v)); + return Packet2cf(_mm_castpd_ps(result)); +} + +template<> EIGEN_STRONG_INLINE Packet1cd psqrt(const Packet1cd& a) { + return psqrt_complex(a); +} + +template<> EIGEN_STRONG_INLINE Packet2cf psqrt(const Packet2cf& a) { + return psqrt_complex(a); +} + +} // end namespace internal +} // end namespace Eigen + +#endif // EIGEN_COMPLEX_SSE_H diff --git a/Eigen/src/Core/arch/SSE/MathFunctions.h b/Eigen/src/Core/arch/SSE/MathFunctions.h new file mode 100644 index 0000000..8736d0d --- /dev/null +++ b/Eigen/src/Core/arch/SSE/MathFunctions.h @@ -0,0 +1,199 @@ +// This file is part of Eigen, a lightweight C++ template library +// for linear algebra. +// +// Copyright (C) 2007 Julien Pommier +// Copyright (C) 2009 Gael Guennebaud +// +// This Source Code Form is subject to the terms of the Mozilla +// Public License v. 2.0. If a copy of the MPL was not distributed +// with this file, You can obtain one at http://mozilla.org/MPL/2.0/. + +/* The sin and cos and functions of this file come from + * Julien Pommier's sse math library: http://gruntthepeon.free.fr/ssemath/ + */ + +#ifndef EIGEN_MATH_FUNCTIONS_SSE_H +#define EIGEN_MATH_FUNCTIONS_SSE_H + +namespace Eigen { + +namespace internal { + +template<> EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS EIGEN_UNUSED +Packet4f plog(const Packet4f& _x) { + return plog_float(_x); +} + +template<> EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS EIGEN_UNUSED +Packet2d plog(const Packet2d& _x) { + return plog_double(_x); +} + +template<> EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS EIGEN_UNUSED +Packet4f plog2(const Packet4f& _x) { + return plog2_float(_x); +} + +template<> EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS EIGEN_UNUSED +Packet2d plog2(const Packet2d& _x) { + return plog2_double(_x); +} + +template<> EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS EIGEN_UNUSED +Packet4f plog1p(const Packet4f& _x) { + return generic_plog1p(_x); +} + +template<> EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS EIGEN_UNUSED +Packet4f pexpm1(const Packet4f& _x) { + return generic_expm1(_x); +} + +template<> EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS EIGEN_UNUSED +Packet4f pexp(const Packet4f& _x) +{ + return pexp_float(_x); +} + +template<> EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS EIGEN_UNUSED +Packet2d pexp(const Packet2d& x) +{ + return pexp_double(x); +} + +template<> EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS EIGEN_UNUSED +Packet4f psin(const Packet4f& _x) +{ + return psin_float(_x); +} + +template<> EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS EIGEN_UNUSED +Packet4f pcos(const Packet4f& _x) +{ + return pcos_float(_x); +} + +#if EIGEN_FAST_MATH + +// Functions for sqrt. +// The EIGEN_FAST_MATH version uses the _mm_rsqrt_ps approximation and one step +// of Newton's method, at a cost of 1-2 bits of precision as opposed to the +// exact solution. It does not handle +inf, or denormalized numbers correctly. +// The main advantage of this approach is not just speed, but also the fact that +// it can be inlined and pipelined with other computations, further reducing its +// effective latency. This is similar to Quake3's fast inverse square root. +// For detail see here: http://www.beyond3d.com/content/articles/8/ +template<> EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS EIGEN_UNUSED +Packet4f psqrt(const Packet4f& _x) +{ + Packet4f minus_half_x = pmul(_x, pset1(-0.5f)); + Packet4f denormal_mask = pandnot( + pcmp_lt(_x, pset1((std::numeric_limits::min)())), + pcmp_lt(_x, pzero(_x))); + + // Compute approximate reciprocal sqrt. + Packet4f x = _mm_rsqrt_ps(_x); + // Do a single step of Newton's iteration. + x = pmul(x, pmadd(minus_half_x, pmul(x,x), pset1(1.5f))); + // Flush results for denormals to zero. + return pandnot(pmul(_x,x), denormal_mask); +} + +#else + +template<>EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS EIGEN_UNUSED +Packet4f psqrt(const Packet4f& x) { return _mm_sqrt_ps(x); } + +#endif + +template<> EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS EIGEN_UNUSED +Packet2d psqrt(const Packet2d& x) { return _mm_sqrt_pd(x); } + +template<> EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS EIGEN_UNUSED +Packet16b psqrt(const Packet16b& x) { return x; } + +#if EIGEN_FAST_MATH + +template<> EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS EIGEN_UNUSED +Packet4f prsqrt(const Packet4f& _x) { + _EIGEN_DECLARE_CONST_Packet4f(one_point_five, 1.5f); + _EIGEN_DECLARE_CONST_Packet4f(minus_half, -0.5f); + _EIGEN_DECLARE_CONST_Packet4f_FROM_INT(inf, 0x7f800000u); + _EIGEN_DECLARE_CONST_Packet4f_FROM_INT(flt_min, 0x00800000u); + + Packet4f neg_half = pmul(_x, p4f_minus_half); + + // Identity infinite, zero, negative and denormal arguments. + Packet4f lt_min_mask = _mm_cmplt_ps(_x, p4f_flt_min); + Packet4f inf_mask = _mm_cmpeq_ps(_x, p4f_inf); + Packet4f not_normal_finite_mask = _mm_or_ps(lt_min_mask, inf_mask); + + // Compute an approximate result using the rsqrt intrinsic. + Packet4f y_approx = _mm_rsqrt_ps(_x); + + // Do a single step of Newton-Raphson iteration to improve the approximation. + // This uses the formula y_{n+1} = y_n * (1.5 - y_n * (0.5 * x) * y_n). + // It is essential to evaluate the inner term like this because forming + // y_n^2 may over- or underflow. + Packet4f y_newton = pmul( + y_approx, pmadd(y_approx, pmul(neg_half, y_approx), p4f_one_point_five)); + + // Select the result of the Newton-Raphson step for positive normal arguments. + // For other arguments, choose the output of the intrinsic. This will + // return rsqrt(+inf) = 0, rsqrt(x) = NaN if x < 0, and rsqrt(x) = +inf if + // x is zero or a positive denormalized float (equivalent to flushing positive + // denormalized inputs to zero). + return pselect(not_normal_finite_mask, y_approx, y_newton); +} + +#else + +template<> EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS EIGEN_UNUSED +Packet4f prsqrt(const Packet4f& x) { + // Unfortunately we can't use the much faster mm_rsqrt_ps since it only provides an approximation. + return _mm_div_ps(pset1(1.0f), _mm_sqrt_ps(x)); +} + +#endif + +template<> EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS EIGEN_UNUSED +Packet2d prsqrt(const Packet2d& x) { + return _mm_div_pd(pset1(1.0), _mm_sqrt_pd(x)); +} + +// Hyperbolic Tangent function. +template <> +EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS EIGEN_UNUSED Packet4f +ptanh(const Packet4f& x) { + return internal::generic_fast_tanh_float(x); +} + +} // end namespace internal + +namespace numext { + +template<> +EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE +float sqrt(const float &x) +{ + return internal::pfirst(internal::Packet4f(_mm_sqrt_ss(_mm_set_ss(x)))); +} + +template<> +EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE +double sqrt(const double &x) +{ +#if EIGEN_COMP_GNUC_STRICT + // This works around a GCC bug generating poor code for _mm_sqrt_pd + // See https://gitlab.com/libeigen/eigen/commit/8dca9f97e38970 + return internal::pfirst(internal::Packet2d(__builtin_ia32_sqrtsd(_mm_set_sd(x)))); +#else + return internal::pfirst(internal::Packet2d(_mm_sqrt_pd(_mm_set_sd(x)))); +#endif +} + +} // end namespace numex + +} // end namespace Eigen + +#endif // EIGEN_MATH_FUNCTIONS_SSE_H diff --git a/Eigen/src/Core/arch/SSE/PacketMath.h b/Eigen/src/Core/arch/SSE/PacketMath.h new file mode 100755 index 0000000..db102c7 --- /dev/null +++ b/Eigen/src/Core/arch/SSE/PacketMath.h @@ -0,0 +1,1505 @@ +// This file is part of Eigen, a lightweight C++ template library +// for linear algebra. +// +// Copyright (C) 2008-2009 Gael Guennebaud +// +// This Source Code Form is subject to the terms of the Mozilla +// Public License v. 2.0. If a copy of the MPL was not distributed +// with this file, You can obtain one at http://mozilla.org/MPL/2.0/. + +#ifndef EIGEN_PACKET_MATH_SSE_H +#define EIGEN_PACKET_MATH_SSE_H + +namespace Eigen { + +namespace internal { + +#ifndef EIGEN_CACHEFRIENDLY_PRODUCT_THRESHOLD +#define EIGEN_CACHEFRIENDLY_PRODUCT_THRESHOLD 8 +#endif + +#if !defined(EIGEN_VECTORIZE_AVX) && !defined(EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS) +// 32 bits => 8 registers +// 64 bits => 16 registers +#define EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS (2*sizeof(void*)) +#endif + +#ifdef EIGEN_VECTORIZE_FMA +#ifndef EIGEN_HAS_SINGLE_INSTRUCTION_MADD +#define EIGEN_HAS_SINGLE_INSTRUCTION_MADD +#endif +#endif + +#if ((defined EIGEN_VECTORIZE_AVX) && (EIGEN_COMP_GNUC_STRICT || EIGEN_COMP_MINGW) && (__GXX_ABI_VERSION < 1004)) || EIGEN_OS_QNX +// With GCC's default ABI version, a __m128 or __m256 are the same types and therefore we cannot +// have overloads for both types without linking error. +// One solution is to increase ABI version using -fabi-version=4 (or greater). +// Otherwise, we workaround this inconvenience by wrapping 128bit types into the following helper +// structure: +typedef eigen_packet_wrapper<__m128> Packet4f; +typedef eigen_packet_wrapper<__m128d> Packet2d; +#else +typedef __m128 Packet4f; +typedef __m128d Packet2d; +#endif + +typedef eigen_packet_wrapper<__m128i, 0> Packet4i; +typedef eigen_packet_wrapper<__m128i, 1> Packet16b; + +template<> struct is_arithmetic<__m128> { enum { value = true }; }; +template<> struct is_arithmetic<__m128i> { enum { value = true }; }; +template<> struct is_arithmetic<__m128d> { enum { value = true }; }; +template<> struct is_arithmetic { enum { value = true }; }; +template<> struct is_arithmetic { enum { value = true }; }; + +template +struct shuffle_mask{ + enum { mask = (s)<<6|(r)<<4|(q)<<2|(p) }; +}; + +// TODO: change the implementation of all swizzle* ops from macro to template, +#define vec4f_swizzle1(v,p,q,r,s) \ + Packet4f(_mm_castsi128_ps(_mm_shuffle_epi32( _mm_castps_si128(v), (shuffle_mask::mask)))) + +#define vec4i_swizzle1(v,p,q,r,s) \ + Packet4i(_mm_shuffle_epi32( v, (shuffle_mask::mask))) + +#define vec2d_swizzle1(v,p,q) \ + Packet2d(_mm_castsi128_pd(_mm_shuffle_epi32( _mm_castpd_si128(v), (shuffle_mask<2*p,2*p+1,2*q,2*q+1>::mask)))) + +#define vec4f_swizzle2(a,b,p,q,r,s) \ + Packet4f(_mm_shuffle_ps( (a), (b), (shuffle_mask::mask))) + +#define vec4i_swizzle2(a,b,p,q,r,s) \ + Packet4i(_mm_castps_si128( (_mm_shuffle_ps( _mm_castsi128_ps(a), _mm_castsi128_ps(b), (shuffle_mask::mask))))) + +EIGEN_STRONG_INLINE Packet4f vec4f_movelh(const Packet4f& a, const Packet4f& b) +{ + return Packet4f(_mm_movelh_ps(a,b)); +} +EIGEN_STRONG_INLINE Packet4f vec4f_movehl(const Packet4f& a, const Packet4f& b) +{ + return Packet4f(_mm_movehl_ps(a,b)); +} +EIGEN_STRONG_INLINE Packet4f vec4f_unpacklo(const Packet4f& a, const Packet4f& b) +{ + return Packet4f(_mm_unpacklo_ps(a,b)); +} +EIGEN_STRONG_INLINE Packet4f vec4f_unpackhi(const Packet4f& a, const Packet4f& b) +{ + return Packet4f(_mm_unpackhi_ps(a,b)); +} +#define vec4f_duplane(a,p) \ + vec4f_swizzle2(a,a,p,p,p,p) + +#define vec2d_swizzle2(a,b,mask) \ + Packet2d(_mm_shuffle_pd(a,b,mask)) + +EIGEN_STRONG_INLINE Packet2d vec2d_unpacklo(const Packet2d& a, const Packet2d& b) +{ + return Packet2d(_mm_unpacklo_pd(a,b)); +} +EIGEN_STRONG_INLINE Packet2d vec2d_unpackhi(const Packet2d& a, const Packet2d& b) +{ + return Packet2d(_mm_unpackhi_pd(a,b)); +} +#define vec2d_duplane(a,p) \ + vec2d_swizzle2(a,a,(p<<1)|p) + +#define _EIGEN_DECLARE_CONST_Packet4f(NAME,X) \ + const Packet4f p4f_##NAME = pset1(X) + +#define _EIGEN_DECLARE_CONST_Packet2d(NAME,X) \ + const Packet2d p2d_##NAME = pset1(X) + +#define _EIGEN_DECLARE_CONST_Packet4f_FROM_INT(NAME,X) \ + const Packet4f p4f_##NAME = pset1frombits(X) + +#define _EIGEN_DECLARE_CONST_Packet4i(NAME,X) \ + const Packet4i p4i_##NAME = pset1(X) + + +// Use the packet_traits defined in AVX/PacketMath.h instead if we're going +// to leverage AVX instructions. +#ifndef EIGEN_VECTORIZE_AVX +template <> +struct packet_traits : default_packet_traits { + typedef Packet4f type; + typedef Packet4f half; + enum { + Vectorizable = 1, + AlignedOnScalar = 1, + size = 4, + HasHalfPacket = 0, + + HasCmp = 1, + HasDiv = 1, + HasSin = EIGEN_FAST_MATH, + HasCos = EIGEN_FAST_MATH, + HasLog = 1, + HasLog1p = 1, + HasExpm1 = 1, + HasNdtri = 1, + HasExp = 1, + HasBessel = 1, + HasSqrt = 1, + HasRsqrt = 1, + HasTanh = EIGEN_FAST_MATH, + HasErf = EIGEN_FAST_MATH, + HasBlend = 1, + HasCeil = 1, + HasFloor = 1, +#ifdef EIGEN_VECTORIZE_SSE4_1 + HasRound = 1, +#endif + HasRint = 1 + }; +}; +template <> +struct packet_traits : default_packet_traits { + typedef Packet2d type; + typedef Packet2d half; + enum { + Vectorizable = 1, + AlignedOnScalar = 1, + size=2, + HasHalfPacket = 0, + + HasCmp = 1, + HasDiv = 1, + HasLog = 1, + HasExp = 1, + HasSqrt = 1, + HasRsqrt = 1, + HasBlend = 1, + HasFloor = 1, + HasCeil = 1, +#ifdef EIGEN_VECTORIZE_SSE4_1 + HasRound = 1, +#endif + HasRint = 1 + }; +}; +#endif +template<> struct packet_traits : default_packet_traits +{ + typedef Packet4i type; + typedef Packet4i half; + enum { + Vectorizable = 1, + AlignedOnScalar = 1, + size=4, + + HasShift = 1, + HasBlend = 1 + }; +}; + +template<> struct packet_traits : default_packet_traits +{ + typedef Packet16b type; + typedef Packet16b half; + enum { + Vectorizable = 1, + AlignedOnScalar = 1, + HasHalfPacket = 0, + size=16, + + HasAdd = 1, + HasSub = 1, + HasShift = 0, + HasMul = 1, + HasNegate = 1, + HasAbs = 0, + HasAbs2 = 0, + HasMin = 0, + HasMax = 0, + HasConj = 0, + HasSqrt = 1 + }; +}; + +template<> struct unpacket_traits { + typedef float type; + typedef Packet4f half; + typedef Packet4i integer_packet; + enum {size=4, alignment=Aligned16, vectorizable=true, masked_load_available=false, masked_store_available=false}; +}; +template<> struct unpacket_traits { + typedef double type; + typedef Packet2d half; + enum {size=2, alignment=Aligned16, vectorizable=true, masked_load_available=false, masked_store_available=false}; +}; +template<> struct unpacket_traits { + typedef int type; + typedef Packet4i half; + enum {size=4, alignment=Aligned16, vectorizable=false, masked_load_available=false, masked_store_available=false}; +}; +template<> struct unpacket_traits { + typedef bool type; + typedef Packet16b half; + enum {size=16, alignment=Aligned16, vectorizable=true, masked_load_available=false, masked_store_available=false}; +}; + +#ifndef EIGEN_VECTORIZE_AVX +template<> struct scalar_div_cost { enum { value = 7 }; }; +template<> struct scalar_div_cost { enum { value = 8 }; }; +#endif + +#if EIGEN_COMP_MSVC==1500 +// Workaround MSVC 9 internal compiler error. +// TODO: It has been detected with win64 builds (amd64), so let's check whether it also happens in 32bits+SSE mode +// TODO: let's check whether there does not exist a better fix, like adding a pset0() function. (it crashed on pset1(0)). +template<> EIGEN_STRONG_INLINE Packet4f pset1(const float& from) { return _mm_set_ps(from,from,from,from); } +template<> EIGEN_STRONG_INLINE Packet2d pset1(const double& from) { return _mm_set_pd(from,from); } +template<> EIGEN_STRONG_INLINE Packet4i pset1(const int& from) { return _mm_set_epi32(from,from,from,from); } +#else +template<> EIGEN_STRONG_INLINE Packet4f pset1(const float& from) { return _mm_set_ps1(from); } +template<> EIGEN_STRONG_INLINE Packet2d pset1(const double& from) { return _mm_set1_pd(from); } +template<> EIGEN_STRONG_INLINE Packet4i pset1(const int& from) { return _mm_set1_epi32(from); } +#endif +template<> EIGEN_STRONG_INLINE Packet16b pset1(const bool& from) { return _mm_set1_epi8(static_cast(from)); } + +template<> EIGEN_STRONG_INLINE Packet4f pset1frombits(unsigned int from) { return _mm_castsi128_ps(pset1(from)); } +template<> EIGEN_STRONG_INLINE Packet2d pset1frombits(uint64_t from) { return _mm_castsi128_pd(_mm_set1_epi64x(from)); } + +template<> EIGEN_STRONG_INLINE Packet4f peven_mask(const Packet4f& /*a*/) { return _mm_castsi128_ps(_mm_set_epi32(0, -1, 0, -1)); } +template<> EIGEN_STRONG_INLINE Packet4i peven_mask(const Packet4i& /*a*/) { return _mm_set_epi32(0, -1, 0, -1); } +template<> EIGEN_STRONG_INLINE Packet2d peven_mask(const Packet2d& /*a*/) { return _mm_castsi128_pd(_mm_set_epi32(0, 0, -1, -1)); } + +template<> EIGEN_STRONG_INLINE Packet4f pzero(const Packet4f& /*a*/) { return _mm_setzero_ps(); } +template<> EIGEN_STRONG_INLINE Packet2d pzero(const Packet2d& /*a*/) { return _mm_setzero_pd(); } +template<> EIGEN_STRONG_INLINE Packet4i pzero(const Packet4i& /*a*/) { return _mm_setzero_si128(); } + +// GCC generates a shufps instruction for _mm_set1_ps/_mm_load1_ps instead of the more efficient pshufd instruction. +// However, using inrinsics for pset1 makes gcc to generate crappy code in some cases (see bug 203) +// Using inline assembly is also not an option because then gcc fails to reorder properly the instructions. +// Therefore, we introduced the pload1 functions to be used in product kernels for which bug 203 does not apply. +// Also note that with AVX, we want it to generate a vbroadcastss. +#if EIGEN_COMP_GNUC_STRICT && (!defined __AVX__) +template<> EIGEN_STRONG_INLINE Packet4f pload1(const float *from) { + return vec4f_swizzle1(_mm_load_ss(from),0,0,0,0); +} +#endif + +template<> EIGEN_STRONG_INLINE Packet4f plset(const float& a) { return _mm_add_ps(pset1(a), _mm_set_ps(3,2,1,0)); } +template<> EIGEN_STRONG_INLINE Packet2d plset(const double& a) { return _mm_add_pd(pset1(a),_mm_set_pd(1,0)); } +template<> EIGEN_STRONG_INLINE Packet4i plset(const int& a) { return _mm_add_epi32(pset1(a),_mm_set_epi32(3,2,1,0)); } + +template<> EIGEN_STRONG_INLINE Packet4f padd(const Packet4f& a, const Packet4f& b) { return _mm_add_ps(a,b); } +template<> EIGEN_STRONG_INLINE Packet2d padd(const Packet2d& a, const Packet2d& b) { return _mm_add_pd(a,b); } +template<> EIGEN_STRONG_INLINE Packet4i padd(const Packet4i& a, const Packet4i& b) { return _mm_add_epi32(a,b); } + +template<> EIGEN_STRONG_INLINE Packet16b padd(const Packet16b& a, const Packet16b& b) { return _mm_or_si128(a,b); } + +template<> EIGEN_STRONG_INLINE Packet4f psub(const Packet4f& a, const Packet4f& b) { return _mm_sub_ps(a,b); } +template<> EIGEN_STRONG_INLINE Packet2d psub(const Packet2d& a, const Packet2d& b) { return _mm_sub_pd(a,b); } +template<> EIGEN_STRONG_INLINE Packet4i psub(const Packet4i& a, const Packet4i& b) { return _mm_sub_epi32(a,b); } +template<> EIGEN_STRONG_INLINE Packet16b psub(const Packet16b& a, const Packet16b& b) { return _mm_xor_si128(a,b); } + +template<> EIGEN_STRONG_INLINE Packet4f pxor(const Packet4f& a, const Packet4f& b); +template<> EIGEN_STRONG_INLINE Packet4f paddsub(const Packet4f& a, const Packet4f& b) +{ +#ifdef EIGEN_VECTORIZE_SSE3 + return _mm_addsub_ps(a,b); +#else + const Packet4f mask = _mm_castsi128_ps(_mm_setr_epi32(0x80000000,0x0,0x80000000,0x0)); + return padd(a, pxor(mask, b)); +#endif +} + +template<> EIGEN_STRONG_INLINE Packet2d pxor(const Packet2d& , const Packet2d& ); +template<> EIGEN_STRONG_INLINE Packet2d paddsub(const Packet2d& a, const Packet2d& b) +{ +#ifdef EIGEN_VECTORIZE_SSE3 + return _mm_addsub_pd(a,b); +#else + const Packet2d mask = _mm_castsi128_pd(_mm_setr_epi32(0x0,0x80000000,0x0,0x0)); + return padd(a, pxor(mask, b)); +#endif +} + +template<> EIGEN_STRONG_INLINE Packet4f pnegate(const Packet4f& a) +{ + const Packet4f mask = _mm_castsi128_ps(_mm_setr_epi32(0x80000000,0x80000000,0x80000000,0x80000000)); + return _mm_xor_ps(a,mask); +} +template<> EIGEN_STRONG_INLINE Packet2d pnegate(const Packet2d& a) +{ + const Packet2d mask = _mm_castsi128_pd(_mm_setr_epi32(0x0,0x80000000,0x0,0x80000000)); + return _mm_xor_pd(a,mask); +} +template<> EIGEN_STRONG_INLINE Packet4i pnegate(const Packet4i& a) +{ + return psub(Packet4i(_mm_setr_epi32(0,0,0,0)), a); +} + +template<> EIGEN_STRONG_INLINE Packet16b pnegate(const Packet16b& a) +{ + return psub(pset1(false), a); +} + +template<> EIGEN_STRONG_INLINE Packet4f pconj(const Packet4f& a) { return a; } +template<> EIGEN_STRONG_INLINE Packet2d pconj(const Packet2d& a) { return a; } +template<> EIGEN_STRONG_INLINE Packet4i pconj(const Packet4i& a) { return a; } + +template<> EIGEN_STRONG_INLINE Packet4f pmul(const Packet4f& a, const Packet4f& b) { return _mm_mul_ps(a,b); } +template<> EIGEN_STRONG_INLINE Packet2d pmul(const Packet2d& a, const Packet2d& b) { return _mm_mul_pd(a,b); } +template<> EIGEN_STRONG_INLINE Packet4i pmul(const Packet4i& a, const Packet4i& b) +{ +#ifdef EIGEN_VECTORIZE_SSE4_1 + return _mm_mullo_epi32(a,b); +#else + // this version is slightly faster than 4 scalar products + return vec4i_swizzle1( + vec4i_swizzle2( + _mm_mul_epu32(a,b), + _mm_mul_epu32(vec4i_swizzle1(a,1,0,3,2), + vec4i_swizzle1(b,1,0,3,2)), + 0,2,0,2), + 0,2,1,3); +#endif +} + +template<> EIGEN_STRONG_INLINE Packet16b pmul(const Packet16b& a, const Packet16b& b) { return _mm_and_si128(a,b); } + +template<> EIGEN_STRONG_INLINE Packet4f pdiv(const Packet4f& a, const Packet4f& b) { return _mm_div_ps(a,b); } +template<> EIGEN_STRONG_INLINE Packet2d pdiv(const Packet2d& a, const Packet2d& b) { return _mm_div_pd(a,b); } + +// for some weird raisons, it has to be overloaded for packet of integers +template<> EIGEN_STRONG_INLINE Packet4i pmadd(const Packet4i& a, const Packet4i& b, const Packet4i& c) { return padd(pmul(a,b), c); } +#ifdef EIGEN_VECTORIZE_FMA +template<> EIGEN_STRONG_INLINE Packet4f pmadd(const Packet4f& a, const Packet4f& b, const Packet4f& c) { return _mm_fmadd_ps(a,b,c); } +template<> EIGEN_STRONG_INLINE Packet2d pmadd(const Packet2d& a, const Packet2d& b, const Packet2d& c) { return _mm_fmadd_pd(a,b,c); } +#endif + +#ifdef EIGEN_VECTORIZE_SSE4_1 +template<> EIGEN_DEVICE_FUNC inline Packet4f pselect(const Packet4f& mask, const Packet4f& a, const Packet4f& b) { + return _mm_blendv_ps(b,a,mask); +} + +template<> EIGEN_DEVICE_FUNC inline Packet4i pselect(const Packet4i& mask, const Packet4i& a, const Packet4i& b) { + return _mm_castps_si128(_mm_blendv_ps(_mm_castsi128_ps(b),_mm_castsi128_ps(a),_mm_castsi128_ps(mask))); +} + +template<> EIGEN_DEVICE_FUNC inline Packet2d pselect(const Packet2d& mask, const Packet2d& a, const Packet2d& b) { return _mm_blendv_pd(b,a,mask); } + +template<> EIGEN_DEVICE_FUNC inline Packet16b pselect(const Packet16b& mask, const Packet16b& a, const Packet16b& b) { + return _mm_blendv_epi8(b,a,mask); +} +#else +template<> EIGEN_DEVICE_FUNC inline Packet16b pselect(const Packet16b& mask, const Packet16b& a, const Packet16b& b) { + Packet16b a_part = _mm_and_si128(mask, a); + Packet16b b_part = _mm_andnot_si128(mask, b); + return _mm_or_si128(a_part, b_part); +} +#endif + +template<> EIGEN_STRONG_INLINE Packet4i ptrue(const Packet4i& a) { return _mm_cmpeq_epi32(a, a); } +template<> EIGEN_STRONG_INLINE Packet16b ptrue(const Packet16b& a) { return _mm_cmpeq_epi8(a, a); } +template<> EIGEN_STRONG_INLINE Packet4f +ptrue(const Packet4f& a) { + Packet4i b = _mm_castps_si128(a); + return _mm_castsi128_ps(_mm_cmpeq_epi32(b, b)); +} +template<> EIGEN_STRONG_INLINE Packet2d +ptrue(const Packet2d& a) { + Packet4i b = _mm_castpd_si128(a); + return _mm_castsi128_pd(_mm_cmpeq_epi32(b, b)); +} + + +template<> EIGEN_STRONG_INLINE Packet4f pand(const Packet4f& a, const Packet4f& b) { return _mm_and_ps(a,b); } +template<> EIGEN_STRONG_INLINE Packet2d pand(const Packet2d& a, const Packet2d& b) { return _mm_and_pd(a,b); } +template<> EIGEN_STRONG_INLINE Packet4i pand(const Packet4i& a, const Packet4i& b) { return _mm_and_si128(a,b); } +template<> EIGEN_STRONG_INLINE Packet16b pand(const Packet16b& a, const Packet16b& b) { return _mm_and_si128(a,b); } + +template<> EIGEN_STRONG_INLINE Packet4f por(const Packet4f& a, const Packet4f& b) { return _mm_or_ps(a,b); } +template<> EIGEN_STRONG_INLINE Packet2d por(const Packet2d& a, const Packet2d& b) { return _mm_or_pd(a,b); } +template<> EIGEN_STRONG_INLINE Packet4i por(const Packet4i& a, const Packet4i& b) { return _mm_or_si128(a,b); } +template<> EIGEN_STRONG_INLINE Packet16b por(const Packet16b& a, const Packet16b& b) { return _mm_or_si128(a,b); } + +template<> EIGEN_STRONG_INLINE Packet4f pxor(const Packet4f& a, const Packet4f& b) { return _mm_xor_ps(a,b); } +template<> EIGEN_STRONG_INLINE Packet2d pxor(const Packet2d& a, const Packet2d& b) { return _mm_xor_pd(a,b); } +template<> EIGEN_STRONG_INLINE Packet4i pxor(const Packet4i& a, const Packet4i& b) { return _mm_xor_si128(a,b); } +template<> EIGEN_STRONG_INLINE Packet16b pxor(const Packet16b& a, const Packet16b& b) { return _mm_xor_si128(a,b); } + +template<> EIGEN_STRONG_INLINE Packet4f pandnot(const Packet4f& a, const Packet4f& b) { return _mm_andnot_ps(b,a); } +template<> EIGEN_STRONG_INLINE Packet2d pandnot(const Packet2d& a, const Packet2d& b) { return _mm_andnot_pd(b,a); } +template<> EIGEN_STRONG_INLINE Packet4i pandnot(const Packet4i& a, const Packet4i& b) { return _mm_andnot_si128(b,a); } + +template<> EIGEN_STRONG_INLINE Packet4f pcmp_le(const Packet4f& a, const Packet4f& b) { return _mm_cmple_ps(a,b); } +template<> EIGEN_STRONG_INLINE Packet4f pcmp_lt(const Packet4f& a, const Packet4f& b) { return _mm_cmplt_ps(a,b); } +template<> EIGEN_STRONG_INLINE Packet4f pcmp_lt_or_nan(const Packet4f& a, const Packet4f& b) { return _mm_cmpnge_ps(a,b); } +template<> EIGEN_STRONG_INLINE Packet4f pcmp_eq(const Packet4f& a, const Packet4f& b) { return _mm_cmpeq_ps(a,b); } + +template<> EIGEN_STRONG_INLINE Packet2d pcmp_le(const Packet2d& a, const Packet2d& b) { return _mm_cmple_pd(a,b); } +template<> EIGEN_STRONG_INLINE Packet2d pcmp_lt(const Packet2d& a, const Packet2d& b) { return _mm_cmplt_pd(a,b); } +template<> EIGEN_STRONG_INLINE Packet2d pcmp_lt_or_nan(const Packet2d& a, const Packet2d& b) { return _mm_cmpnge_pd(a,b); } +template<> EIGEN_STRONG_INLINE Packet2d pcmp_eq(const Packet2d& a, const Packet2d& b) { return _mm_cmpeq_pd(a,b); } + +template<> EIGEN_STRONG_INLINE Packet4i pcmp_lt(const Packet4i& a, const Packet4i& b) { return _mm_cmplt_epi32(a,b); } +template<> EIGEN_STRONG_INLINE Packet4i pcmp_eq(const Packet4i& a, const Packet4i& b) { return _mm_cmpeq_epi32(a,b); } +template<> EIGEN_STRONG_INLINE Packet16b pcmp_eq(const Packet16b& a, const Packet16b& b) { return _mm_cmpeq_epi8(a,b); } +template<> EIGEN_STRONG_INLINE Packet4i pcmp_le(const Packet4i& a, const Packet4i& b) { return por(pcmp_lt(a,b), pcmp_eq(a,b)); } + +template<> EIGEN_STRONG_INLINE Packet4f pmin(const Packet4f& a, const Packet4f& b) { +#if EIGEN_COMP_GNUC && EIGEN_COMP_GNUC < 63 + // There appears to be a bug in GCC, by which the optimizer may + // flip the argument order in calls to _mm_min_ps, so we have to + // resort to inline ASM here. This is supposed to be fixed in gcc6.3, + // see also: https://gcc.gnu.org/bugzilla/show_bug.cgi?id=72867 + #ifdef EIGEN_VECTORIZE_AVX + Packet4f res; + asm("vminps %[a], %[b], %[res]" : [res] "=x" (res) : [a] "x" (a), [b] "x" (b)); + #else + Packet4f res = b; + asm("minps %[a], %[res]" : [res] "+x" (res) : [a] "x" (a)); + #endif + return res; +#else + // Arguments are reversed to match NaN propagation behavior of std::min. + return _mm_min_ps(b, a); +#endif +} +template<> EIGEN_STRONG_INLINE Packet2d pmin(const Packet2d& a, const Packet2d& b) { +#if EIGEN_COMP_GNUC && EIGEN_COMP_GNUC < 63 + // There appears to be a bug in GCC, by which the optimizer may + // flip the argument order in calls to _mm_min_pd, so we have to + // resort to inline ASM here. This is supposed to be fixed in gcc6.3, + // see also: https://gcc.gnu.org/bugzilla/show_bug.cgi?id=72867 + #ifdef EIGEN_VECTORIZE_AVX + Packet2d res; + asm("vminpd %[a], %[b], %[res]" : [res] "=x" (res) : [a] "x" (a), [b] "x" (b)); + #else + Packet2d res = b; + asm("minpd %[a], %[res]" : [res] "+x" (res) : [a] "x" (a)); + #endif + return res; +#else + // Arguments are reversed to match NaN propagation behavior of std::min. + return _mm_min_pd(b, a); +#endif +} +template<> EIGEN_STRONG_INLINE Packet4i pmin(const Packet4i& a, const Packet4i& b) +{ +#ifdef EIGEN_VECTORIZE_SSE4_1 + return _mm_min_epi32(a,b); +#else + // after some bench, this version *is* faster than a scalar implementation + Packet4i mask = _mm_cmplt_epi32(a,b); + return _mm_or_si128(_mm_and_si128(mask,a),_mm_andnot_si128(mask,b)); +#endif +} + + +template<> EIGEN_STRONG_INLINE Packet4f pmax(const Packet4f& a, const Packet4f& b) { +#if EIGEN_COMP_GNUC && EIGEN_COMP_GNUC < 63 + // There appears to be a bug in GCC, by which the optimizer may + // flip the argument order in calls to _mm_max_ps, so we have to + // resort to inline ASM here. This is supposed to be fixed in gcc6.3, + // see also: https://gcc.gnu.org/bugzilla/show_bug.cgi?id=72867 + #ifdef EIGEN_VECTORIZE_AVX + Packet4f res; + asm("vmaxps %[a], %[b], %[res]" : [res] "=x" (res) : [a] "x" (a), [b] "x" (b)); + #else + Packet4f res = b; + asm("maxps %[a], %[res]" : [res] "+x" (res) : [a] "x" (a)); + #endif + return res; +#else + // Arguments are reversed to match NaN propagation behavior of std::max. + return _mm_max_ps(b, a); +#endif +} +template<> EIGEN_STRONG_INLINE Packet2d pmax(const Packet2d& a, const Packet2d& b) { +#if EIGEN_COMP_GNUC && EIGEN_COMP_GNUC < 63 + // There appears to be a bug in GCC, by which the optimizer may + // flip the argument order in calls to _mm_max_pd, so we have to + // resort to inline ASM here. This is supposed to be fixed in gcc6.3, + // see also: https://gcc.gnu.org/bugzilla/show_bug.cgi?id=72867 + #ifdef EIGEN_VECTORIZE_AVX + Packet2d res; + asm("vmaxpd %[a], %[b], %[res]" : [res] "=x" (res) : [a] "x" (a), [b] "x" (b)); + #else + Packet2d res = b; + asm("maxpd %[a], %[res]" : [res] "+x" (res) : [a] "x" (a)); + #endif + return res; +#else + // Arguments are reversed to match NaN propagation behavior of std::max. + return _mm_max_pd(b, a); +#endif +} +template<> EIGEN_STRONG_INLINE Packet4i pmax(const Packet4i& a, const Packet4i& b) +{ +#ifdef EIGEN_VECTORIZE_SSE4_1 + return _mm_max_epi32(a,b); +#else + // after some bench, this version *is* faster than a scalar implementation + Packet4i mask = _mm_cmpgt_epi32(a,b); + return _mm_or_si128(_mm_and_si128(mask,a),_mm_andnot_si128(mask,b)); +#endif +} + +template +EIGEN_STRONG_INLINE Packet pminmax_propagate_numbers(const Packet& a, const Packet& b, Op op) { + // In this implementation, we take advantage of the fact that pmin/pmax for SSE + // always return a if either a or b is NaN. + Packet not_nan_mask_a = pcmp_eq(a, a); + Packet m = op(a, b); + return pselect(not_nan_mask_a, m, b); +} + +template +EIGEN_STRONG_INLINE Packet pminmax_propagate_nan(const Packet& a, const Packet& b, Op op) { + // In this implementation, we take advantage of the fact that pmin/pmax for SSE + // always return a if either a or b is NaN. + Packet not_nan_mask_a = pcmp_eq(a, a); + Packet m = op(b, a); + return pselect(not_nan_mask_a, m, a); +} + +// Add specializations for min/max with prescribed NaN progation. +template<> +EIGEN_STRONG_INLINE Packet4f pmin(const Packet4f& a, const Packet4f& b) { + return pminmax_propagate_numbers(a, b, pmin); +} +template<> +EIGEN_STRONG_INLINE Packet2d pmin(const Packet2d& a, const Packet2d& b) { + return pminmax_propagate_numbers(a, b, pmin); +} +template<> +EIGEN_STRONG_INLINE Packet4f pmax(const Packet4f& a, const Packet4f& b) { + return pminmax_propagate_numbers(a, b, pmax); +} +template<> +EIGEN_STRONG_INLINE Packet2d pmax(const Packet2d& a, const Packet2d& b) { + return pminmax_propagate_numbers(a, b, pmax); +} +template<> +EIGEN_STRONG_INLINE Packet4f pmin(const Packet4f& a, const Packet4f& b) { + return pminmax_propagate_nan(a, b, pmin); +} +template<> +EIGEN_STRONG_INLINE Packet2d pmin(const Packet2d& a, const Packet2d& b) { + return pminmax_propagate_nan(a, b, pmin); +} +template<> +EIGEN_STRONG_INLINE Packet4f pmax(const Packet4f& a, const Packet4f& b) { + return pminmax_propagate_nan(a, b, pmax); +} +template<> +EIGEN_STRONG_INLINE Packet2d pmax(const Packet2d& a, const Packet2d& b) { + return pminmax_propagate_nan(a, b, pmax); +} + +template EIGEN_STRONG_INLINE Packet4i parithmetic_shift_right(const Packet4i& a) { return _mm_srai_epi32(a,N); } +template EIGEN_STRONG_INLINE Packet4i plogical_shift_right (const Packet4i& a) { return _mm_srli_epi32(a,N); } +template EIGEN_STRONG_INLINE Packet4i plogical_shift_left (const Packet4i& a) { return _mm_slli_epi32(a,N); } + +template<> EIGEN_STRONG_INLINE Packet4f pabs(const Packet4f& a) +{ + const Packet4f mask = _mm_castsi128_ps(_mm_setr_epi32(0x7FFFFFFF,0x7FFFFFFF,0x7FFFFFFF,0x7FFFFFFF)); + return _mm_and_ps(a,mask); +} +template<> EIGEN_STRONG_INLINE Packet2d pabs(const Packet2d& a) +{ + const Packet2d mask = _mm_castsi128_pd(_mm_setr_epi32(0xFFFFFFFF,0x7FFFFFFF,0xFFFFFFFF,0x7FFFFFFF)); + return _mm_and_pd(a,mask); +} +template<> EIGEN_STRONG_INLINE Packet4i pabs(const Packet4i& a) +{ + #ifdef EIGEN_VECTORIZE_SSSE3 + return _mm_abs_epi32(a); + #else + Packet4i aux = _mm_srai_epi32(a,31); + return _mm_sub_epi32(_mm_xor_si128(a,aux),aux); + #endif +} + +#ifdef EIGEN_VECTORIZE_SSE4_1 +template<> EIGEN_STRONG_INLINE Packet4f pround(const Packet4f& a) +{ + // Unfortunatly _mm_round_ps doesn't have a rounding mode to implement numext::round. + const Packet4f mask = pset1frombits(0x80000000u); + const Packet4f prev0dot5 = pset1frombits(0x3EFFFFFFu); + return _mm_round_ps(padd(por(pand(a, mask), prev0dot5), a), _MM_FROUND_TO_ZERO); +} + +template<> EIGEN_STRONG_INLINE Packet2d pround(const Packet2d& a) +{ + const Packet2d mask = _mm_castsi128_pd(_mm_set_epi64x(0x8000000000000000ull, 0x8000000000000000ull)); + const Packet2d prev0dot5 = _mm_castsi128_pd(_mm_set_epi64x(0x3FDFFFFFFFFFFFFFull, 0x3FDFFFFFFFFFFFFFull)); + return _mm_round_pd(padd(por(pand(a, mask), prev0dot5), a), _MM_FROUND_TO_ZERO); +} + +template<> EIGEN_STRONG_INLINE Packet4f print(const Packet4f& a) { return _mm_round_ps(a, _MM_FROUND_CUR_DIRECTION); } +template<> EIGEN_STRONG_INLINE Packet2d print(const Packet2d& a) { return _mm_round_pd(a, _MM_FROUND_CUR_DIRECTION); } + +template<> EIGEN_STRONG_INLINE Packet4f pceil(const Packet4f& a) { return _mm_ceil_ps(a); } +template<> EIGEN_STRONG_INLINE Packet2d pceil(const Packet2d& a) { return _mm_ceil_pd(a); } + +template<> EIGEN_STRONG_INLINE Packet4f pfloor(const Packet4f& a) { return _mm_floor_ps(a); } +template<> EIGEN_STRONG_INLINE Packet2d pfloor(const Packet2d& a) { return _mm_floor_pd(a); } +#else +template<> EIGEN_STRONG_INLINE Packet4f print(const Packet4f& a) { + // Adds and subtracts signum(a) * 2^23 to force rounding. + const Packet4f limit = pset1(static_cast(1<<23)); + const Packet4f abs_a = pabs(a); + Packet4f r = padd(abs_a, limit); + // Don't compile-away addition and subtraction. + EIGEN_OPTIMIZATION_BARRIER(r); + r = psub(r, limit); + // If greater than limit, simply return a. Otherwise, account for sign. + r = pselect(pcmp_lt(abs_a, limit), + pselect(pcmp_lt(a, pzero(a)), pnegate(r), r), a); + return r; +} + +template<> EIGEN_STRONG_INLINE Packet2d print(const Packet2d& a) { + // Adds and subtracts signum(a) * 2^52 to force rounding. + const Packet2d limit = pset1(static_cast(1ull<<52)); + const Packet2d abs_a = pabs(a); + Packet2d r = padd(abs_a, limit); + // Don't compile-away addition and subtraction. + EIGEN_OPTIMIZATION_BARRIER(r); + r = psub(r, limit); + // If greater than limit, simply return a. Otherwise, account for sign. + r = pselect(pcmp_lt(abs_a, limit), + pselect(pcmp_lt(a, pzero(a)), pnegate(r), r), a); + return r; +} + +template<> EIGEN_STRONG_INLINE Packet4f pfloor(const Packet4f& a) +{ + const Packet4f cst_1 = pset1(1.0f); + Packet4f tmp = print(a); + // If greater, subtract one. + Packet4f mask = _mm_cmpgt_ps(tmp, a); + mask = pand(mask, cst_1); + return psub(tmp, mask); +} + +template<> EIGEN_STRONG_INLINE Packet2d pfloor(const Packet2d& a) +{ + const Packet2d cst_1 = pset1(1.0); + Packet2d tmp = print(a); + // If greater, subtract one. + Packet2d mask = _mm_cmpgt_pd(tmp, a); + mask = pand(mask, cst_1); + return psub(tmp, mask); +} + +template<> EIGEN_STRONG_INLINE Packet4f pceil(const Packet4f& a) +{ + const Packet4f cst_1 = pset1(1.0f); + Packet4f tmp = print(a); + // If smaller, add one. + Packet4f mask = _mm_cmplt_ps(tmp, a); + mask = pand(mask, cst_1); + return padd(tmp, mask); +} + +template<> EIGEN_STRONG_INLINE Packet2d pceil(const Packet2d& a) +{ + const Packet2d cst_1 = pset1(1.0); + Packet2d tmp = print(a); + // If smaller, add one. + Packet2d mask = _mm_cmplt_pd(tmp, a); + mask = pand(mask, cst_1); + return padd(tmp, mask); +} +#endif + +template<> EIGEN_STRONG_INLINE Packet4f pload(const float* from) { EIGEN_DEBUG_ALIGNED_LOAD return _mm_load_ps(from); } +template<> EIGEN_STRONG_INLINE Packet2d pload(const double* from) { EIGEN_DEBUG_ALIGNED_LOAD return _mm_load_pd(from); } +template<> EIGEN_STRONG_INLINE Packet4i pload(const int* from) { EIGEN_DEBUG_ALIGNED_LOAD return _mm_load_si128(reinterpret_cast(from)); } +template<> EIGEN_STRONG_INLINE Packet16b pload(const bool* from) { EIGEN_DEBUG_ALIGNED_LOAD return _mm_load_si128(reinterpret_cast(from)); } + +#if EIGEN_COMP_MSVC + template<> EIGEN_STRONG_INLINE Packet4f ploadu(const float* from) { + EIGEN_DEBUG_UNALIGNED_LOAD + #if (EIGEN_COMP_MSVC==1600) + // NOTE Some version of MSVC10 generates bad code when using _mm_loadu_ps + // (i.e., it does not generate an unaligned load!! + __m128 res = _mm_loadl_pi(_mm_set1_ps(0.0f), (const __m64*)(from)); + res = _mm_loadh_pi(res, (const __m64*)(from+2)); + return res; + #else + return _mm_loadu_ps(from); + #endif + } +#else +// NOTE: with the code below, MSVC's compiler crashes! + +template<> EIGEN_STRONG_INLINE Packet4f ploadu(const float* from) +{ + EIGEN_DEBUG_UNALIGNED_LOAD + return _mm_loadu_ps(from); +} +#endif + +template<> EIGEN_STRONG_INLINE Packet2d ploadu(const double* from) +{ + EIGEN_DEBUG_UNALIGNED_LOAD + return _mm_loadu_pd(from); +} +template<> EIGEN_STRONG_INLINE Packet4i ploadu(const int* from) +{ + EIGEN_DEBUG_UNALIGNED_LOAD + return _mm_loadu_si128(reinterpret_cast(from)); +} +template<> EIGEN_STRONG_INLINE Packet16b ploadu(const bool* from) { + EIGEN_DEBUG_UNALIGNED_LOAD + return _mm_loadu_si128(reinterpret_cast(from)); +} + + +template<> EIGEN_STRONG_INLINE Packet4f ploaddup(const float* from) +{ + return vec4f_swizzle1(_mm_castpd_ps(_mm_load_sd(reinterpret_cast(from))), 0, 0, 1, 1); +} +template<> EIGEN_STRONG_INLINE Packet2d ploaddup(const double* from) +{ return pset1(from[0]); } +template<> EIGEN_STRONG_INLINE Packet4i ploaddup(const int* from) +{ + Packet4i tmp; + tmp = _mm_loadl_epi64(reinterpret_cast(from)); + return vec4i_swizzle1(tmp, 0, 0, 1, 1); +} + +// Loads 8 bools from memory and returns the packet +// {b0, b0, b1, b1, b2, b2, b3, b3, b4, b4, b5, b5, b6, b6, b7, b7} +template<> EIGEN_STRONG_INLINE Packet16b ploaddup(const bool* from) +{ + __m128i tmp = _mm_castpd_si128(pload1(reinterpret_cast(from))); + return _mm_unpacklo_epi8(tmp, tmp); +} + +// Loads 4 bools from memory and returns the packet +// {b0, b0 b0, b0, b1, b1, b1, b1, b2, b2, b2, b2, b3, b3, b3, b3} +template<> EIGEN_STRONG_INLINE Packet16b +ploadquad(const bool* from) { + __m128i tmp = _mm_castps_si128(pload1(reinterpret_cast(from))); + tmp = _mm_unpacklo_epi8(tmp, tmp); + return _mm_unpacklo_epi16(tmp, tmp); +} + +template<> EIGEN_STRONG_INLINE void pstore(float* to, const Packet4f& from) { EIGEN_DEBUG_ALIGNED_STORE _mm_store_ps(to, from); } +template<> EIGEN_STRONG_INLINE void pstore(double* to, const Packet2d& from) { EIGEN_DEBUG_ALIGNED_STORE _mm_store_pd(to, from); } +template<> EIGEN_STRONG_INLINE void pstore(int* to, const Packet4i& from) { EIGEN_DEBUG_ALIGNED_STORE _mm_store_si128(reinterpret_cast<__m128i*>(to), from); } +template<> EIGEN_STRONG_INLINE void pstore(bool* to, const Packet16b& from) { EIGEN_DEBUG_ALIGNED_STORE _mm_store_si128(reinterpret_cast<__m128i*>(to), from); } + +template<> EIGEN_STRONG_INLINE void pstoreu(double* to, const Packet2d& from) { EIGEN_DEBUG_UNALIGNED_STORE _mm_storeu_pd(to, from); } +template<> EIGEN_STRONG_INLINE void pstoreu(float* to, const Packet4f& from) { EIGEN_DEBUG_UNALIGNED_STORE _mm_storeu_ps(to, from); } +template<> EIGEN_STRONG_INLINE void pstoreu(int* to, const Packet4i& from) { EIGEN_DEBUG_UNALIGNED_STORE _mm_storeu_si128(reinterpret_cast<__m128i*>(to), from); } +template<> EIGEN_STRONG_INLINE void pstoreu(bool* to, const Packet16b& from) { EIGEN_DEBUG_ALIGNED_STORE _mm_storeu_si128(reinterpret_cast<__m128i*>(to), from); } + +template<> EIGEN_DEVICE_FUNC inline Packet4f pgather(const float* from, Index stride) +{ + return _mm_set_ps(from[3*stride], from[2*stride], from[1*stride], from[0*stride]); +} +template<> EIGEN_DEVICE_FUNC inline Packet2d pgather(const double* from, Index stride) +{ + return _mm_set_pd(from[1*stride], from[0*stride]); +} +template<> EIGEN_DEVICE_FUNC inline Packet4i pgather(const int* from, Index stride) +{ + return _mm_set_epi32(from[3*stride], from[2*stride], from[1*stride], from[0*stride]); +} + +template<> EIGEN_DEVICE_FUNC inline Packet16b pgather(const bool* from, Index stride) +{ + return _mm_set_epi8(from[15*stride], from[14*stride], from[13*stride], from[12*stride], + from[11*stride], from[10*stride], from[9*stride], from[8*stride], + from[7*stride], from[6*stride], from[5*stride], from[4*stride], + from[3*stride], from[2*stride], from[1*stride], from[0*stride]); +} + +template<> EIGEN_DEVICE_FUNC inline void pscatter(float* to, const Packet4f& from, Index stride) +{ + to[stride*0] = _mm_cvtss_f32(from); + to[stride*1] = _mm_cvtss_f32(_mm_shuffle_ps(from, from, 1)); + to[stride*2] = _mm_cvtss_f32(_mm_shuffle_ps(from, from, 2)); + to[stride*3] = _mm_cvtss_f32(_mm_shuffle_ps(from, from, 3)); +} +template<> EIGEN_DEVICE_FUNC inline void pscatter(double* to, const Packet2d& from, Index stride) +{ + to[stride*0] = _mm_cvtsd_f64(from); + to[stride*1] = _mm_cvtsd_f64(_mm_shuffle_pd(from, from, 1)); +} +template<> EIGEN_DEVICE_FUNC inline void pscatter(int* to, const Packet4i& from, Index stride) +{ + to[stride*0] = _mm_cvtsi128_si32(from); + to[stride*1] = _mm_cvtsi128_si32(_mm_shuffle_epi32(from, 1)); + to[stride*2] = _mm_cvtsi128_si32(_mm_shuffle_epi32(from, 2)); + to[stride*3] = _mm_cvtsi128_si32(_mm_shuffle_epi32(from, 3)); +} +template<> EIGEN_DEVICE_FUNC inline void pscatter(bool* to, const Packet16b& from, Index stride) +{ + to[4*stride*0] = _mm_cvtsi128_si32(from); + to[4*stride*1] = _mm_cvtsi128_si32(_mm_shuffle_epi32(from, 1)); + to[4*stride*2] = _mm_cvtsi128_si32(_mm_shuffle_epi32(from, 2)); + to[4*stride*3] = _mm_cvtsi128_si32(_mm_shuffle_epi32(from, 3)); +} + + +// some compilers might be tempted to perform multiple moves instead of using a vector path. +template<> EIGEN_STRONG_INLINE void pstore1(float* to, const float& a) +{ + Packet4f pa = _mm_set_ss(a); + pstore(to, Packet4f(vec4f_swizzle1(pa,0,0,0,0))); +} +// some compilers might be tempted to perform multiple moves instead of using a vector path. +template<> EIGEN_STRONG_INLINE void pstore1(double* to, const double& a) +{ + Packet2d pa = _mm_set_sd(a); + pstore(to, Packet2d(vec2d_swizzle1(pa,0,0))); +} + +#if EIGEN_COMP_PGI && EIGEN_COMP_PGI < 1900 +typedef const void * SsePrefetchPtrType; +#else +typedef const char * SsePrefetchPtrType; +#endif + +#ifndef EIGEN_VECTORIZE_AVX +template<> EIGEN_STRONG_INLINE void prefetch(const float* addr) { _mm_prefetch((SsePrefetchPtrType)(addr), _MM_HINT_T0); } +template<> EIGEN_STRONG_INLINE void prefetch(const double* addr) { _mm_prefetch((SsePrefetchPtrType)(addr), _MM_HINT_T0); } +template<> EIGEN_STRONG_INLINE void prefetch(const int* addr) { _mm_prefetch((SsePrefetchPtrType)(addr), _MM_HINT_T0); } +#endif + +#if EIGEN_COMP_MSVC_STRICT && EIGEN_OS_WIN64 +// The temporary variable fixes an internal compilation error in vs <= 2008 and a wrong-result bug in vs 2010 +// Direct of the struct members fixed bug #62. +template<> EIGEN_STRONG_INLINE float pfirst(const Packet4f& a) { return a.m128_f32[0]; } +template<> EIGEN_STRONG_INLINE double pfirst(const Packet2d& a) { return a.m128d_f64[0]; } +template<> EIGEN_STRONG_INLINE int pfirst(const Packet4i& a) { int x = _mm_cvtsi128_si32(a); return x; } +#elif EIGEN_COMP_MSVC_STRICT +// The temporary variable fixes an internal compilation error in vs <= 2008 and a wrong-result bug in vs 2010 +template<> EIGEN_STRONG_INLINE float pfirst(const Packet4f& a) { float x = _mm_cvtss_f32(a); return x; } +template<> EIGEN_STRONG_INLINE double pfirst(const Packet2d& a) { double x = _mm_cvtsd_f64(a); return x; } +template<> EIGEN_STRONG_INLINE int pfirst(const Packet4i& a) { int x = _mm_cvtsi128_si32(a); return x; } +#else +template<> EIGEN_STRONG_INLINE float pfirst(const Packet4f& a) { return _mm_cvtss_f32(a); } +template<> EIGEN_STRONG_INLINE double pfirst(const Packet2d& a) { return _mm_cvtsd_f64(a); } +template<> EIGEN_STRONG_INLINE int pfirst(const Packet4i& a) { return _mm_cvtsi128_si32(a); } +#endif +template<> EIGEN_STRONG_INLINE bool pfirst(const Packet16b& a) { int x = _mm_cvtsi128_si32(a); return static_cast(x & 1); } + +template<> EIGEN_STRONG_INLINE Packet4f preverse(const Packet4f& a) { return _mm_shuffle_ps(a,a,0x1B); } +template<> EIGEN_STRONG_INLINE Packet2d preverse(const Packet2d& a) { return _mm_shuffle_pd(a,a,0x1); } +template<> EIGEN_STRONG_INLINE Packet4i preverse(const Packet4i& a) { return _mm_shuffle_epi32(a,0x1B); } +template<> EIGEN_STRONG_INLINE Packet16b preverse(const Packet16b& a) { +#ifdef EIGEN_VECTORIZE_SSSE3 + __m128i mask = _mm_set_epi8(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15); + return _mm_shuffle_epi8(a, mask); +#else + Packet16b tmp = _mm_shuffle_epi32(a, _MM_SHUFFLE(0, 1, 2, 3)); + tmp = _mm_shufflehi_epi16(_mm_shufflelo_epi16(tmp, _MM_SHUFFLE(2, 3, 0, 1)), _MM_SHUFFLE(2, 3, 0, 1)); + return _mm_or_si128(_mm_slli_epi16(tmp, 8), _mm_srli_epi16(tmp, 8)); +#endif +} + +template<> EIGEN_STRONG_INLINE Packet4f pfrexp(const Packet4f& a, Packet4f& exponent) { + return pfrexp_generic(a,exponent); +} + +// Extract exponent without existence of Packet2l. +template<> +EIGEN_STRONG_INLINE +Packet2d pfrexp_generic_get_biased_exponent(const Packet2d& a) { + const Packet2d cst_exp_mask = pset1frombits(static_cast(0x7ff0000000000000ull)); + __m128i a_expo = _mm_srli_epi64(_mm_castpd_si128(pand(a, cst_exp_mask)), 52); + return _mm_cvtepi32_pd(vec4i_swizzle1(a_expo, 0, 2, 1, 3)); +} + +template<> EIGEN_STRONG_INLINE Packet2d pfrexp(const Packet2d& a, Packet2d& exponent) { + return pfrexp_generic(a, exponent); +} + +template<> EIGEN_STRONG_INLINE Packet4f pldexp(const Packet4f& a, const Packet4f& exponent) { + return pldexp_generic(a,exponent); +} + +// We specialize pldexp here, since the generic implementation uses Packet2l, which is not well +// supported by SSE, and has more range than is needed for exponents. +template<> EIGEN_STRONG_INLINE Packet2d pldexp(const Packet2d& a, const Packet2d& exponent) { + // Clamp exponent to [-2099, 2099] + const Packet2d max_exponent = pset1(2099.0); + const Packet2d e = pmin(pmax(exponent, pnegate(max_exponent)), max_exponent); + + // Convert e to integer and swizzle to low-order bits. + const Packet4i ei = vec4i_swizzle1(_mm_cvtpd_epi32(e), 0, 3, 1, 3); + + // Split 2^e into four factors and multiply: + const Packet4i bias = _mm_set_epi32(0, 1023, 0, 1023); + Packet4i b = parithmetic_shift_right<2>(ei); // floor(e/4) + Packet2d c = _mm_castsi128_pd(_mm_slli_epi64(padd(b, bias), 52)); // 2^b + Packet2d out = pmul(pmul(pmul(a, c), c), c); // a * 2^(3b) + b = psub(psub(psub(ei, b), b), b); // e - 3b + c = _mm_castsi128_pd(_mm_slli_epi64(padd(b, bias), 52)); // 2^(e - 3b) + out = pmul(out, c); // a * 2^e + return out; +} + +// with AVX, the default implementations based on pload1 are faster +#ifndef __AVX__ +template<> EIGEN_STRONG_INLINE void +pbroadcast4(const float *a, + Packet4f& a0, Packet4f& a1, Packet4f& a2, Packet4f& a3) +{ + a3 = pload(a); + a0 = vec4f_swizzle1(a3, 0,0,0,0); + a1 = vec4f_swizzle1(a3, 1,1,1,1); + a2 = vec4f_swizzle1(a3, 2,2,2,2); + a3 = vec4f_swizzle1(a3, 3,3,3,3); +} +template<> EIGEN_STRONG_INLINE void +pbroadcast4(const double *a, + Packet2d& a0, Packet2d& a1, Packet2d& a2, Packet2d& a3) +{ +#ifdef EIGEN_VECTORIZE_SSE3 + a0 = _mm_loaddup_pd(a+0); + a1 = _mm_loaddup_pd(a+1); + a2 = _mm_loaddup_pd(a+2); + a3 = _mm_loaddup_pd(a+3); +#else + a1 = pload(a); + a0 = vec2d_swizzle1(a1, 0,0); + a1 = vec2d_swizzle1(a1, 1,1); + a3 = pload(a+2); + a2 = vec2d_swizzle1(a3, 0,0); + a3 = vec2d_swizzle1(a3, 1,1); +#endif +} +#endif + +EIGEN_STRONG_INLINE void punpackp(Packet4f* vecs) +{ + vecs[1] = _mm_castsi128_ps(_mm_shuffle_epi32(_mm_castps_si128(vecs[0]), 0x55)); + vecs[2] = _mm_castsi128_ps(_mm_shuffle_epi32(_mm_castps_si128(vecs[0]), 0xAA)); + vecs[3] = _mm_castsi128_ps(_mm_shuffle_epi32(_mm_castps_si128(vecs[0]), 0xFF)); + vecs[0] = _mm_castsi128_ps(_mm_shuffle_epi32(_mm_castps_si128(vecs[0]), 0x00)); +} + +template<> EIGEN_STRONG_INLINE float predux(const Packet4f& a) +{ + // Disable SSE3 _mm_hadd_pd that is extremely slow on all existing Intel's architectures + // (from Nehalem to Haswell) +// #ifdef EIGEN_VECTORIZE_SSE3 +// Packet4f tmp = _mm_add_ps(a, vec4f_swizzle1(a,2,3,2,3)); +// return pfirst(_mm_hadd_ps(tmp, tmp)); +// #else + Packet4f tmp = _mm_add_ps(a, _mm_movehl_ps(a,a)); + return pfirst(_mm_add_ss(tmp, _mm_shuffle_ps(tmp,tmp, 1))); +// #endif +} + +template<> EIGEN_STRONG_INLINE double predux(const Packet2d& a) +{ + // Disable SSE3 _mm_hadd_pd that is extremely slow on all existing Intel's architectures + // (from Nehalem to Haswell) +// #ifdef EIGEN_VECTORIZE_SSE3 +// return pfirst(_mm_hadd_pd(a, a)); +// #else + return pfirst(_mm_add_sd(a, _mm_unpackhi_pd(a,a))); +// #endif +} + +#ifdef EIGEN_VECTORIZE_SSSE3 +template<> EIGEN_STRONG_INLINE int predux(const Packet4i& a) +{ + Packet4i tmp0 = _mm_hadd_epi32(a,a); + return pfirst(_mm_hadd_epi32(tmp0,tmp0)); +} + +#else +template<> EIGEN_STRONG_INLINE int predux(const Packet4i& a) +{ + Packet4i tmp = _mm_add_epi32(a, _mm_unpackhi_epi64(a,a)); + return pfirst(tmp) + pfirst(_mm_shuffle_epi32(tmp, 1)); +} +#endif + +template<> EIGEN_STRONG_INLINE bool predux(const Packet16b& a) { + Packet4i tmp = _mm_or_si128(a, _mm_unpackhi_epi64(a,a)); + return (pfirst(tmp) != 0) || (pfirst(_mm_shuffle_epi32(tmp, 1)) != 0); +} + +// Other reduction functions: + + +// mul +template<> EIGEN_STRONG_INLINE float predux_mul(const Packet4f& a) +{ + Packet4f tmp = _mm_mul_ps(a, _mm_movehl_ps(a,a)); + return pfirst(_mm_mul_ss(tmp, _mm_shuffle_ps(tmp,tmp, 1))); +} +template<> EIGEN_STRONG_INLINE double predux_mul(const Packet2d& a) +{ + return pfirst(_mm_mul_sd(a, _mm_unpackhi_pd(a,a))); +} +template<> EIGEN_STRONG_INLINE int predux_mul(const Packet4i& a) +{ + // after some experiments, it is seems this is the fastest way to implement it + // for GCC (eg., reusing pmul is very slow !) + // TODO try to call _mm_mul_epu32 directly + EIGEN_ALIGN16 int aux[4]; + pstore(aux, a); + return (aux[0] * aux[1]) * (aux[2] * aux[3]); +} + +template<> EIGEN_STRONG_INLINE bool predux_mul(const Packet16b& a) { + Packet4i tmp = _mm_and_si128(a, _mm_unpackhi_epi64(a,a)); + return ((pfirst(tmp) == 0x01010101) && + (pfirst(_mm_shuffle_epi32(tmp, 1)) == 0x01010101)); +} + +// min +template<> EIGEN_STRONG_INLINE float predux_min(const Packet4f& a) +{ + Packet4f tmp = _mm_min_ps(a, _mm_movehl_ps(a,a)); + return pfirst(_mm_min_ss(tmp, _mm_shuffle_ps(tmp,tmp, 1))); +} +template<> EIGEN_STRONG_INLINE double predux_min(const Packet2d& a) +{ + return pfirst(_mm_min_sd(a, _mm_unpackhi_pd(a,a))); +} +template<> EIGEN_STRONG_INLINE int predux_min(const Packet4i& a) +{ +#ifdef EIGEN_VECTORIZE_SSE4_1 + Packet4i tmp = _mm_min_epi32(a, _mm_shuffle_epi32(a, _MM_SHUFFLE(0,0,3,2))); + return pfirst(_mm_min_epi32(tmp,_mm_shuffle_epi32(tmp, 1))); +#else + // after some experiments, it is seems this is the fastest way to implement it + // for GCC (eg., it does not like using std::min after the pstore !!) + EIGEN_ALIGN16 int aux[4]; + pstore(aux, a); + int aux0 = aux[0] EIGEN_STRONG_INLINE float predux_max(const Packet4f& a) +{ + Packet4f tmp = _mm_max_ps(a, _mm_movehl_ps(a,a)); + return pfirst(_mm_max_ss(tmp, _mm_shuffle_ps(tmp,tmp, 1))); +} +template<> EIGEN_STRONG_INLINE double predux_max(const Packet2d& a) +{ + return pfirst(_mm_max_sd(a, _mm_unpackhi_pd(a,a))); +} +template<> EIGEN_STRONG_INLINE int predux_max(const Packet4i& a) +{ +#ifdef EIGEN_VECTORIZE_SSE4_1 + Packet4i tmp = _mm_max_epi32(a, _mm_shuffle_epi32(a, _MM_SHUFFLE(0,0,3,2))); + return pfirst(_mm_max_epi32(tmp,_mm_shuffle_epi32(tmp, 1))); +#else + // after some experiments, it is seems this is the fastest way to implement it + // for GCC (eg., it does not like using std::min after the pstore !!) + EIGEN_ALIGN16 int aux[4]; + pstore(aux, a); + int aux0 = aux[0]>aux[1] ? aux[0] : aux[1]; + int aux2 = aux[2]>aux[3] ? aux[2] : aux[3]; + return aux0>aux2 ? aux0 : aux2; +#endif // EIGEN_VECTORIZE_SSE4_1 +} + +// not needed yet +// template<> EIGEN_STRONG_INLINE bool predux_all(const Packet4f& x) +// { +// return _mm_movemask_ps(x) == 0xF; +// } + +template<> EIGEN_STRONG_INLINE bool predux_any(const Packet4f& x) +{ + return _mm_movemask_ps(x) != 0x0; +} + +EIGEN_DEVICE_FUNC inline void +ptranspose(PacketBlock& kernel) { + _MM_TRANSPOSE4_PS(kernel.packet[0], kernel.packet[1], kernel.packet[2], kernel.packet[3]); +} + +EIGEN_DEVICE_FUNC inline void +ptranspose(PacketBlock& kernel) { + __m128d tmp = _mm_unpackhi_pd(kernel.packet[0], kernel.packet[1]); + kernel.packet[0] = _mm_unpacklo_pd(kernel.packet[0], kernel.packet[1]); + kernel.packet[1] = tmp; +} + +EIGEN_DEVICE_FUNC inline void +ptranspose(PacketBlock& kernel) { + __m128i T0 = _mm_unpacklo_epi32(kernel.packet[0], kernel.packet[1]); + __m128i T1 = _mm_unpacklo_epi32(kernel.packet[2], kernel.packet[3]); + __m128i T2 = _mm_unpackhi_epi32(kernel.packet[0], kernel.packet[1]); + __m128i T3 = _mm_unpackhi_epi32(kernel.packet[2], kernel.packet[3]); + + kernel.packet[0] = _mm_unpacklo_epi64(T0, T1); + kernel.packet[1] = _mm_unpackhi_epi64(T0, T1); + kernel.packet[2] = _mm_unpacklo_epi64(T2, T3); + kernel.packet[3] = _mm_unpackhi_epi64(T2, T3); +} + +EIGEN_DEVICE_FUNC inline void +ptranspose(PacketBlock& kernel) { + __m128i T0 = _mm_unpacklo_epi8(kernel.packet[0], kernel.packet[1]); + __m128i T1 = _mm_unpackhi_epi8(kernel.packet[0], kernel.packet[1]); + __m128i T2 = _mm_unpacklo_epi8(kernel.packet[2], kernel.packet[3]); + __m128i T3 = _mm_unpackhi_epi8(kernel.packet[2], kernel.packet[3]); + kernel.packet[0] = _mm_unpacklo_epi16(T0, T2); + kernel.packet[1] = _mm_unpackhi_epi16(T0, T2); + kernel.packet[2] = _mm_unpacklo_epi16(T1, T3); + kernel.packet[3] = _mm_unpackhi_epi16(T1, T3); +} + +EIGEN_DEVICE_FUNC inline void +ptranspose(PacketBlock& kernel) { + // If we number the elements in the input thus: + // kernel.packet[ 0] = {00, 01, 02, 03, 04, 05, 06, 07, 08, 09, 0a, 0b, 0c, 0d, 0e, 0f} + // kernel.packet[ 1] = {10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 1a, 1b, 1c, 1d, 1e, 1f} + // ... + // kernel.packet[15] = {f0, f1, f2, f3, f4, f5, f6, f7, f8, f9, fa, fb, fc, fd, fe, ff}, + // + // the desired output is: + // kernel.packet[ 0] = {00, 10, 20, 30, 40, 50, 60, 70, 80, 90, a0, b0, c0, d0, e0, f0} + // kernel.packet[ 1] = {01, 11, 21, 31, 41, 51, 61, 71, 81, 91, a1, b1, c1, d1, e1, f1} + // ... + // kernel.packet[15] = {0f, 1f, 2f, 3f, 4f, 5f, 6f, 7f, 8f, 9f, af, bf, cf, df, ef, ff}, + __m128i t0 = _mm_unpacklo_epi8(kernel.packet[0], kernel.packet[1]); // 00 10 01 11 02 12 03 13 04 14 05 15 06 16 07 17 + __m128i t1 = _mm_unpackhi_epi8(kernel.packet[0], kernel.packet[1]); // 08 18 09 19 0a 1a 0b 1b 0c 1c 0d 1d 0e 1e 0f 1f + __m128i t2 = _mm_unpacklo_epi8(kernel.packet[2], kernel.packet[3]); // 20 30 21 31 22 32 ... 27 37 + __m128i t3 = _mm_unpackhi_epi8(kernel.packet[2], kernel.packet[3]); // 28 38 29 39 2a 3a ... 2f 3f + __m128i t4 = _mm_unpacklo_epi8(kernel.packet[4], kernel.packet[5]); // 40 50 41 51 42 52 47 57 + __m128i t5 = _mm_unpackhi_epi8(kernel.packet[4], kernel.packet[5]); // 48 58 49 59 4a 5a + __m128i t6 = _mm_unpacklo_epi8(kernel.packet[6], kernel.packet[7]); + __m128i t7 = _mm_unpackhi_epi8(kernel.packet[6], kernel.packet[7]); + __m128i t8 = _mm_unpacklo_epi8(kernel.packet[8], kernel.packet[9]); + __m128i t9 = _mm_unpackhi_epi8(kernel.packet[8], kernel.packet[9]); + __m128i ta = _mm_unpacklo_epi8(kernel.packet[10], kernel.packet[11]); + __m128i tb = _mm_unpackhi_epi8(kernel.packet[10], kernel.packet[11]); + __m128i tc = _mm_unpacklo_epi8(kernel.packet[12], kernel.packet[13]); + __m128i td = _mm_unpackhi_epi8(kernel.packet[12], kernel.packet[13]); + __m128i te = _mm_unpacklo_epi8(kernel.packet[14], kernel.packet[15]); + __m128i tf = _mm_unpackhi_epi8(kernel.packet[14], kernel.packet[15]); + + __m128i s0 = _mm_unpacklo_epi16(t0, t2); // 00 10 20 30 01 11 21 31 02 12 22 32 03 13 23 33 + __m128i s1 = _mm_unpackhi_epi16(t0, t2); // 04 14 24 34 + __m128i s2 = _mm_unpacklo_epi16(t1, t3); // 08 18 28 38 ... + __m128i s3 = _mm_unpackhi_epi16(t1, t3); // 0c 1c 2c 3c ... + __m128i s4 = _mm_unpacklo_epi16(t4, t6); // 40 50 60 70 41 51 61 71 42 52 62 72 43 53 63 73 + __m128i s5 = _mm_unpackhi_epi16(t4, t6); // 44 54 64 74 ... + __m128i s6 = _mm_unpacklo_epi16(t5, t7); + __m128i s7 = _mm_unpackhi_epi16(t5, t7); + __m128i s8 = _mm_unpacklo_epi16(t8, ta); + __m128i s9 = _mm_unpackhi_epi16(t8, ta); + __m128i sa = _mm_unpacklo_epi16(t9, tb); + __m128i sb = _mm_unpackhi_epi16(t9, tb); + __m128i sc = _mm_unpacklo_epi16(tc, te); + __m128i sd = _mm_unpackhi_epi16(tc, te); + __m128i se = _mm_unpacklo_epi16(td, tf); + __m128i sf = _mm_unpackhi_epi16(td, tf); + + __m128i u0 = _mm_unpacklo_epi32(s0, s4); // 00 10 20 30 40 50 60 70 01 11 21 31 41 51 61 71 + __m128i u1 = _mm_unpackhi_epi32(s0, s4); // 02 12 22 32 42 52 62 72 03 13 23 33 43 53 63 73 + __m128i u2 = _mm_unpacklo_epi32(s1, s5); + __m128i u3 = _mm_unpackhi_epi32(s1, s5); + __m128i u4 = _mm_unpacklo_epi32(s2, s6); + __m128i u5 = _mm_unpackhi_epi32(s2, s6); + __m128i u6 = _mm_unpacklo_epi32(s3, s7); + __m128i u7 = _mm_unpackhi_epi32(s3, s7); + __m128i u8 = _mm_unpacklo_epi32(s8, sc); + __m128i u9 = _mm_unpackhi_epi32(s8, sc); + __m128i ua = _mm_unpacklo_epi32(s9, sd); + __m128i ub = _mm_unpackhi_epi32(s9, sd); + __m128i uc = _mm_unpacklo_epi32(sa, se); + __m128i ud = _mm_unpackhi_epi32(sa, se); + __m128i ue = _mm_unpacklo_epi32(sb, sf); + __m128i uf = _mm_unpackhi_epi32(sb, sf); + + kernel.packet[0] = _mm_unpacklo_epi64(u0, u8); + kernel.packet[1] = _mm_unpackhi_epi64(u0, u8); + kernel.packet[2] = _mm_unpacklo_epi64(u1, u9); + kernel.packet[3] = _mm_unpackhi_epi64(u1, u9); + kernel.packet[4] = _mm_unpacklo_epi64(u2, ua); + kernel.packet[5] = _mm_unpackhi_epi64(u2, ua); + kernel.packet[6] = _mm_unpacklo_epi64(u3, ub); + kernel.packet[7] = _mm_unpackhi_epi64(u3, ub); + kernel.packet[8] = _mm_unpacklo_epi64(u4, uc); + kernel.packet[9] = _mm_unpackhi_epi64(u4, uc); + kernel.packet[10] = _mm_unpacklo_epi64(u5, ud); + kernel.packet[11] = _mm_unpackhi_epi64(u5, ud); + kernel.packet[12] = _mm_unpacklo_epi64(u6, ue); + kernel.packet[13] = _mm_unpackhi_epi64(u6, ue); + kernel.packet[14] = _mm_unpacklo_epi64(u7, uf); + kernel.packet[15] = _mm_unpackhi_epi64(u7, uf); +} + +template<> EIGEN_STRONG_INLINE Packet4i pblend(const Selector<4>& ifPacket, const Packet4i& thenPacket, const Packet4i& elsePacket) { + const __m128i zero = _mm_setzero_si128(); + const __m128i select = _mm_set_epi32(ifPacket.select[3], ifPacket.select[2], ifPacket.select[1], ifPacket.select[0]); + __m128i false_mask = _mm_cmpeq_epi32(select, zero); +#ifdef EIGEN_VECTORIZE_SSE4_1 + return _mm_blendv_epi8(thenPacket, elsePacket, false_mask); +#else + return _mm_or_si128(_mm_andnot_si128(false_mask, thenPacket), _mm_and_si128(false_mask, elsePacket)); +#endif +} +template<> EIGEN_STRONG_INLINE Packet4f pblend(const Selector<4>& ifPacket, const Packet4f& thenPacket, const Packet4f& elsePacket) { + const __m128 zero = _mm_setzero_ps(); + const __m128 select = _mm_set_ps(ifPacket.select[3], ifPacket.select[2], ifPacket.select[1], ifPacket.select[0]); + __m128 false_mask = _mm_cmpeq_ps(select, zero); +#ifdef EIGEN_VECTORIZE_SSE4_1 + return _mm_blendv_ps(thenPacket, elsePacket, false_mask); +#else + return _mm_or_ps(_mm_andnot_ps(false_mask, thenPacket), _mm_and_ps(false_mask, elsePacket)); +#endif +} +template<> EIGEN_STRONG_INLINE Packet2d pblend(const Selector<2>& ifPacket, const Packet2d& thenPacket, const Packet2d& elsePacket) { + const __m128d zero = _mm_setzero_pd(); + const __m128d select = _mm_set_pd(ifPacket.select[1], ifPacket.select[0]); + __m128d false_mask = _mm_cmpeq_pd(select, zero); +#ifdef EIGEN_VECTORIZE_SSE4_1 + return _mm_blendv_pd(thenPacket, elsePacket, false_mask); +#else + return _mm_or_pd(_mm_andnot_pd(false_mask, thenPacket), _mm_and_pd(false_mask, elsePacket)); +#endif +} + +// Scalar path for pmadd with FMA to ensure consistency with vectorized path. +#ifdef EIGEN_VECTORIZE_FMA +template<> EIGEN_STRONG_INLINE float pmadd(const float& a, const float& b, const float& c) { + return ::fmaf(a,b,c); +} +template<> EIGEN_STRONG_INLINE double pmadd(const double& a, const double& b, const double& c) { + return ::fma(a,b,c); +} +#endif + + +// Packet math for Eigen::half +// Disable the following code since it's broken on too many platforms / compilers. +//#elif defined(EIGEN_VECTORIZE_SSE) && (!EIGEN_ARCH_x86_64) && (!EIGEN_COMP_MSVC) +#if 0 + +typedef struct { + __m64 x; +} Packet4h; + + +template<> struct is_arithmetic { enum { value = true }; }; + +template <> +struct packet_traits : default_packet_traits { + typedef Packet4h type; + // There is no half-size packet for Packet4h. + typedef Packet4h half; + enum { + Vectorizable = 1, + AlignedOnScalar = 1, + size = 4, + HasHalfPacket = 0, + HasAdd = 1, + HasSub = 1, + HasMul = 1, + HasDiv = 1, + HasNegate = 0, + HasAbs = 0, + HasAbs2 = 0, + HasMin = 0, + HasMax = 0, + HasConj = 0, + HasSetLinear = 0, + HasSqrt = 0, + HasRsqrt = 0, + HasExp = 0, + HasLog = 0, + HasBlend = 0 + }; +}; + + +template<> struct unpacket_traits { typedef Eigen::half type; enum {size=4, alignment=Aligned16, vectorizable=true, masked_load_available=false, masked_store_available=false}; typedef Packet4h half; }; + +template<> EIGEN_STRONG_INLINE Packet4h pset1(const Eigen::half& from) { + Packet4h result; + result.x = _mm_set1_pi16(from.x); + return result; +} + +template<> EIGEN_STRONG_INLINE Eigen::half pfirst(const Packet4h& from) { + return half_impl::raw_uint16_to_half(static_cast(_mm_cvtsi64_si32(from.x))); +} + +template<> EIGEN_STRONG_INLINE Packet4h pconj(const Packet4h& a) { return a; } + +template<> EIGEN_STRONG_INLINE Packet4h padd(const Packet4h& a, const Packet4h& b) { + __int64_t a64 = _mm_cvtm64_si64(a.x); + __int64_t b64 = _mm_cvtm64_si64(b.x); + + Eigen::half h[4]; + + Eigen::half ha = half_impl::raw_uint16_to_half(static_cast(a64)); + Eigen::half hb = half_impl::raw_uint16_to_half(static_cast(b64)); + h[0] = ha + hb; + ha = half_impl::raw_uint16_to_half(static_cast(a64 >> 16)); + hb = half_impl::raw_uint16_to_half(static_cast(b64 >> 16)); + h[1] = ha + hb; + ha = half_impl::raw_uint16_to_half(static_cast(a64 >> 32)); + hb = half_impl::raw_uint16_to_half(static_cast(b64 >> 32)); + h[2] = ha + hb; + ha = half_impl::raw_uint16_to_half(static_cast(a64 >> 48)); + hb = half_impl::raw_uint16_to_half(static_cast(b64 >> 48)); + h[3] = ha + hb; + Packet4h result; + result.x = _mm_set_pi16(h[3].x, h[2].x, h[1].x, h[0].x); + return result; +} + +template<> EIGEN_STRONG_INLINE Packet4h psub(const Packet4h& a, const Packet4h& b) { + __int64_t a64 = _mm_cvtm64_si64(a.x); + __int64_t b64 = _mm_cvtm64_si64(b.x); + + Eigen::half h[4]; + + Eigen::half ha = half_impl::raw_uint16_to_half(static_cast(a64)); + Eigen::half hb = half_impl::raw_uint16_to_half(static_cast(b64)); + h[0] = ha - hb; + ha = half_impl::raw_uint16_to_half(static_cast(a64 >> 16)); + hb = half_impl::raw_uint16_to_half(static_cast(b64 >> 16)); + h[1] = ha - hb; + ha = half_impl::raw_uint16_to_half(static_cast(a64 >> 32)); + hb = half_impl::raw_uint16_to_half(static_cast(b64 >> 32)); + h[2] = ha - hb; + ha = half_impl::raw_uint16_to_half(static_cast(a64 >> 48)); + hb = half_impl::raw_uint16_to_half(static_cast(b64 >> 48)); + h[3] = ha - hb; + Packet4h result; + result.x = _mm_set_pi16(h[3].x, h[2].x, h[1].x, h[0].x); + return result; +} + +template<> EIGEN_STRONG_INLINE Packet4h pmul(const Packet4h& a, const Packet4h& b) { + __int64_t a64 = _mm_cvtm64_si64(a.x); + __int64_t b64 = _mm_cvtm64_si64(b.x); + + Eigen::half h[4]; + + Eigen::half ha = half_impl::raw_uint16_to_half(static_cast(a64)); + Eigen::half hb = half_impl::raw_uint16_to_half(static_cast(b64)); + h[0] = ha * hb; + ha = half_impl::raw_uint16_to_half(static_cast(a64 >> 16)); + hb = half_impl::raw_uint16_to_half(static_cast(b64 >> 16)); + h[1] = ha * hb; + ha = half_impl::raw_uint16_to_half(static_cast(a64 >> 32)); + hb = half_impl::raw_uint16_to_half(static_cast(b64 >> 32)); + h[2] = ha * hb; + ha = half_impl::raw_uint16_to_half(static_cast(a64 >> 48)); + hb = half_impl::raw_uint16_to_half(static_cast(b64 >> 48)); + h[3] = ha * hb; + Packet4h result; + result.x = _mm_set_pi16(h[3].x, h[2].x, h[1].x, h[0].x); + return result; +} + +template<> EIGEN_STRONG_INLINE Packet4h pdiv(const Packet4h& a, const Packet4h& b) { + __int64_t a64 = _mm_cvtm64_si64(a.x); + __int64_t b64 = _mm_cvtm64_si64(b.x); + + Eigen::half h[4]; + + Eigen::half ha = half_impl::raw_uint16_to_half(static_cast(a64)); + Eigen::half hb = half_impl::raw_uint16_to_half(static_cast(b64)); + h[0] = ha / hb; + ha = half_impl::raw_uint16_to_half(static_cast(a64 >> 16)); + hb = half_impl::raw_uint16_to_half(static_cast(b64 >> 16)); + h[1] = ha / hb; + ha = half_impl::raw_uint16_to_half(static_cast(a64 >> 32)); + hb = half_impl::raw_uint16_to_half(static_cast(b64 >> 32)); + h[2] = ha / hb; + ha = half_impl::raw_uint16_to_half(static_cast(a64 >> 48)); + hb = half_impl::raw_uint16_to_half(static_cast(b64 >> 48)); + h[3] = ha / hb; + Packet4h result; + result.x = _mm_set_pi16(h[3].x, h[2].x, h[1].x, h[0].x); + return result; +} + +template<> EIGEN_STRONG_INLINE Packet4h pload(const Eigen::half* from) { + Packet4h result; + result.x = _mm_cvtsi64_m64(*reinterpret_cast(from)); + return result; +} + +template<> EIGEN_STRONG_INLINE Packet4h ploadu(const Eigen::half* from) { + Packet4h result; + result.x = _mm_cvtsi64_m64(*reinterpret_cast(from)); + return result; +} + +template<> EIGEN_STRONG_INLINE void pstore(Eigen::half* to, const Packet4h& from) { + __int64_t r = _mm_cvtm64_si64(from.x); + *(reinterpret_cast<__int64_t*>(to)) = r; +} + +template<> EIGEN_STRONG_INLINE void pstoreu(Eigen::half* to, const Packet4h& from) { + __int64_t r = _mm_cvtm64_si64(from.x); + *(reinterpret_cast<__int64_t*>(to)) = r; +} + +template<> EIGEN_STRONG_INLINE Packet4h +ploadquad(const Eigen::half* from) { + return pset1(*from); +} + +template<> EIGEN_STRONG_INLINE Packet4h pgather(const Eigen::half* from, Index stride) +{ + Packet4h result; + result.x = _mm_set_pi16(from[3*stride].x, from[2*stride].x, from[1*stride].x, from[0*stride].x); + return result; +} + +template<> EIGEN_STRONG_INLINE void pscatter(Eigen::half* to, const Packet4h& from, Index stride) +{ + __int64_t a = _mm_cvtm64_si64(from.x); + to[stride*0].x = static_cast(a); + to[stride*1].x = static_cast(a >> 16); + to[stride*2].x = static_cast(a >> 32); + to[stride*3].x = static_cast(a >> 48); +} + +EIGEN_STRONG_INLINE void +ptranspose(PacketBlock& kernel) { + __m64 T0 = _mm_unpacklo_pi16(kernel.packet[0].x, kernel.packet[1].x); + __m64 T1 = _mm_unpacklo_pi16(kernel.packet[2].x, kernel.packet[3].x); + __m64 T2 = _mm_unpackhi_pi16(kernel.packet[0].x, kernel.packet[1].x); + __m64 T3 = _mm_unpackhi_pi16(kernel.packet[2].x, kernel.packet[3].x); + + kernel.packet[0].x = _mm_unpacklo_pi32(T0, T1); + kernel.packet[1].x = _mm_unpackhi_pi32(T0, T1); + kernel.packet[2].x = _mm_unpacklo_pi32(T2, T3); + kernel.packet[3].x = _mm_unpackhi_pi32(T2, T3); +} + +#endif + + +} // end namespace internal + +} // end namespace Eigen + +#if EIGEN_COMP_PGI && EIGEN_COMP_PGI < 1900 +// PGI++ does not define the following intrinsics in C++ mode. +static inline __m128 _mm_castpd_ps (__m128d x) { return reinterpret_cast<__m128&>(x); } +static inline __m128i _mm_castpd_si128(__m128d x) { return reinterpret_cast<__m128i&>(x); } +static inline __m128d _mm_castps_pd (__m128 x) { return reinterpret_cast<__m128d&>(x); } +static inline __m128i _mm_castps_si128(__m128 x) { return reinterpret_cast<__m128i&>(x); } +static inline __m128 _mm_castsi128_ps(__m128i x) { return reinterpret_cast<__m128&>(x); } +static inline __m128d _mm_castsi128_pd(__m128i x) { return reinterpret_cast<__m128d&>(x); } +#endif + +#endif // EIGEN_PACKET_MATH_SSE_H diff --git a/Eigen/src/Core/arch/SSE/TypeCasting.h b/Eigen/src/Core/arch/SSE/TypeCasting.h new file mode 100644 index 0000000..d2a0037 --- /dev/null +++ b/Eigen/src/Core/arch/SSE/TypeCasting.h @@ -0,0 +1,142 @@ +// This file is part of Eigen, a lightweight C++ template library +// for linear algebra. +// +// Copyright (C) 2015 Benoit Steiner +// +// This Source Code Form is subject to the terms of the Mozilla +// Public License v. 2.0. If a copy of the MPL was not distributed +// with this file, You can obtain one at http://mozilla.org/MPL/2.0/. + +#ifndef EIGEN_TYPE_CASTING_SSE_H +#define EIGEN_TYPE_CASTING_SSE_H + +namespace Eigen { + +namespace internal { + +#ifndef EIGEN_VECTORIZE_AVX +template <> +struct type_casting_traits { + enum { + VectorizedCast = 1, + SrcCoeffRatio = 1, + TgtCoeffRatio = 1 + }; +}; + +template <> +struct type_casting_traits { + enum { + VectorizedCast = 1, + SrcCoeffRatio = 1, + TgtCoeffRatio = 1 + }; +}; + +template <> +struct type_casting_traits { + enum { + VectorizedCast = 1, + SrcCoeffRatio = 2, + TgtCoeffRatio = 1 + }; +}; + +template <> +struct type_casting_traits { + enum { + VectorizedCast = 1, + SrcCoeffRatio = 1, + TgtCoeffRatio = 2 + }; +}; +#endif + +template<> EIGEN_STRONG_INLINE Packet4i pcast(const Packet4f& a) { + return _mm_cvttps_epi32(a); +} + +template<> EIGEN_STRONG_INLINE Packet4f pcast(const Packet4i& a) { + return _mm_cvtepi32_ps(a); +} + +template<> EIGEN_STRONG_INLINE Packet4f pcast(const Packet2d& a, const Packet2d& b) { + return _mm_shuffle_ps(_mm_cvtpd_ps(a), _mm_cvtpd_ps(b), (1 << 2) | (1 << 6)); +} + +template<> EIGEN_STRONG_INLINE Packet2d pcast(const Packet4f& a) { + // Simply discard the second half of the input + return _mm_cvtps_pd(a); +} + +template<> EIGEN_STRONG_INLINE Packet4i preinterpret(const Packet4f& a) { + return _mm_castps_si128(a); +} + +template<> EIGEN_STRONG_INLINE Packet4f preinterpret(const Packet4i& a) { + return _mm_castsi128_ps(a); +} + +template<> EIGEN_STRONG_INLINE Packet2d preinterpret(const Packet4i& a) { + return _mm_castsi128_pd(a); +} + +template<> EIGEN_STRONG_INLINE Packet4i preinterpret(const Packet2d& a) { + return _mm_castpd_si128(a); +} + +// Disable the following code since it's broken on too many platforms / compilers. +//#elif defined(EIGEN_VECTORIZE_SSE) && (!EIGEN_ARCH_x86_64) && (!EIGEN_COMP_MSVC) +#if 0 + +template <> +struct type_casting_traits { + enum { + VectorizedCast = 1, + SrcCoeffRatio = 1, + TgtCoeffRatio = 1 + }; +}; + +template<> EIGEN_STRONG_INLINE Packet4f pcast(const Packet4h& a) { + __int64_t a64 = _mm_cvtm64_si64(a.x); + Eigen::half h = raw_uint16_to_half(static_cast(a64)); + float f1 = static_cast(h); + h = raw_uint16_to_half(static_cast(a64 >> 16)); + float f2 = static_cast(h); + h = raw_uint16_to_half(static_cast(a64 >> 32)); + float f3 = static_cast(h); + h = raw_uint16_to_half(static_cast(a64 >> 48)); + float f4 = static_cast(h); + return _mm_set_ps(f4, f3, f2, f1); +} + +template <> +struct type_casting_traits { + enum { + VectorizedCast = 1, + SrcCoeffRatio = 1, + TgtCoeffRatio = 1 + }; +}; + +template<> EIGEN_STRONG_INLINE Packet4h pcast(const Packet4f& a) { + EIGEN_ALIGN16 float aux[4]; + pstore(aux, a); + Eigen::half h0(aux[0]); + Eigen::half h1(aux[1]); + Eigen::half h2(aux[2]); + Eigen::half h3(aux[3]); + + Packet4h result; + result.x = _mm_set_pi16(h3.x, h2.x, h1.x, h0.x); + return result; +} + +#endif + +} // end namespace internal + +} // end namespace Eigen + +#endif // EIGEN_TYPE_CASTING_SSE_H diff --git a/Eigen/src/Core/arch/SVE/MathFunctions.h b/Eigen/src/Core/arch/SVE/MathFunctions.h new file mode 100644 index 0000000..b139ea2 --- /dev/null +++ b/Eigen/src/Core/arch/SVE/MathFunctions.h @@ -0,0 +1,44 @@ +// This file is part of Eigen, a lightweight C++ template library +// for linear algebra. +// +// Copyright (C) 2020, Arm Limited and Contributors +// +// This Source Code Form is subject to the terms of the Mozilla +// Public License v. 2.0. If a copy of the MPL was not distributed +// with this file, You can obtain one at http://mozilla.org/MPL/2.0/. + +#ifndef EIGEN_MATH_FUNCTIONS_SVE_H +#define EIGEN_MATH_FUNCTIONS_SVE_H + +namespace Eigen { +namespace internal { + +template <> +EIGEN_STRONG_INLINE EIGEN_UNUSED PacketXf pexp(const PacketXf& x) { + return pexp_float(x); +} + +template <> +EIGEN_STRONG_INLINE EIGEN_UNUSED PacketXf plog(const PacketXf& x) { + return plog_float(x); +} + +template <> +EIGEN_STRONG_INLINE EIGEN_UNUSED PacketXf psin(const PacketXf& x) { + return psin_float(x); +} + +template <> +EIGEN_STRONG_INLINE EIGEN_UNUSED PacketXf pcos(const PacketXf& x) { + return pcos_float(x); +} + +// Hyperbolic Tangent function. +template <> +EIGEN_STRONG_INLINE EIGEN_UNUSED PacketXf ptanh(const PacketXf& x) { + return internal::generic_fast_tanh_float(x); +} +} // end namespace internal +} // end namespace Eigen + +#endif // EIGEN_MATH_FUNCTIONS_SVE_H diff --git a/Eigen/src/Core/arch/SVE/PacketMath.h b/Eigen/src/Core/arch/SVE/PacketMath.h new file mode 100644 index 0000000..9060b37 --- /dev/null +++ b/Eigen/src/Core/arch/SVE/PacketMath.h @@ -0,0 +1,752 @@ +// This file is part of Eigen, a lightweight C++ template library +// for linear algebra. +// +// Copyright (C) 2020, Arm Limited and Contributors +// +// This Source Code Form is subject to the terms of the Mozilla +// Public License v. 2.0. If a copy of the MPL was not distributed +// with this file, You can obtain one at http://mozilla.org/MPL/2.0/. + +#ifndef EIGEN_PACKET_MATH_SVE_H +#define EIGEN_PACKET_MATH_SVE_H + +namespace Eigen +{ +namespace internal +{ +#ifndef EIGEN_CACHEFRIENDLY_PRODUCT_THRESHOLD +#define EIGEN_CACHEFRIENDLY_PRODUCT_THRESHOLD 8 +#endif + +#ifndef EIGEN_HAS_SINGLE_INSTRUCTION_MADD +#define EIGEN_HAS_SINGLE_INSTRUCTION_MADD +#endif + +#define EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS 32 + +template +struct sve_packet_size_selector { + enum { size = SVEVectorLength / (sizeof(Scalar) * CHAR_BIT) }; +}; + +/********************************* int32 **************************************/ +typedef svint32_t PacketXi __attribute__((arm_sve_vector_bits(EIGEN_ARM64_SVE_VL))); + +template <> +struct packet_traits : default_packet_traits { + typedef PacketXi type; + typedef PacketXi half; // Half not implemented yet + enum { + Vectorizable = 1, + AlignedOnScalar = 1, + size = sve_packet_size_selector::size, + HasHalfPacket = 0, + + HasAdd = 1, + HasSub = 1, + HasShift = 1, + HasMul = 1, + HasNegate = 1, + HasAbs = 1, + HasArg = 0, + HasAbs2 = 1, + HasMin = 1, + HasMax = 1, + HasConj = 1, + HasSetLinear = 0, + HasBlend = 0, + HasReduxp = 0 // Not implemented in SVE + }; +}; + +template <> +struct unpacket_traits { + typedef numext::int32_t type; + typedef PacketXi half; // Half not yet implemented + enum { + size = sve_packet_size_selector::size, + alignment = Aligned64, + vectorizable = true, + masked_load_available = false, + masked_store_available = false + }; +}; + +template <> +EIGEN_STRONG_INLINE void prefetch(const numext::int32_t* addr) +{ + svprfw(svptrue_b32(), addr, SV_PLDL1KEEP); +} + +template <> +EIGEN_STRONG_INLINE PacketXi pset1(const numext::int32_t& from) +{ + return svdup_n_s32(from); +} + +template <> +EIGEN_STRONG_INLINE PacketXi plset(const numext::int32_t& a) +{ + numext::int32_t c[packet_traits::size]; + for (int i = 0; i < packet_traits::size; i++) c[i] = i; + return svadd_s32_z(svptrue_b32(), pset1(a), svld1_s32(svptrue_b32(), c)); +} + +template <> +EIGEN_STRONG_INLINE PacketXi padd(const PacketXi& a, const PacketXi& b) +{ + return svadd_s32_z(svptrue_b32(), a, b); +} + +template <> +EIGEN_STRONG_INLINE PacketXi psub(const PacketXi& a, const PacketXi& b) +{ + return svsub_s32_z(svptrue_b32(), a, b); +} + +template <> +EIGEN_STRONG_INLINE PacketXi pnegate(const PacketXi& a) +{ + return svneg_s32_z(svptrue_b32(), a); +} + +template <> +EIGEN_STRONG_INLINE PacketXi pconj(const PacketXi& a) +{ + return a; +} + +template <> +EIGEN_STRONG_INLINE PacketXi pmul(const PacketXi& a, const PacketXi& b) +{ + return svmul_s32_z(svptrue_b32(), a, b); +} + +template <> +EIGEN_STRONG_INLINE PacketXi pdiv(const PacketXi& a, const PacketXi& b) +{ + return svdiv_s32_z(svptrue_b32(), a, b); +} + +template <> +EIGEN_STRONG_INLINE PacketXi pmadd(const PacketXi& a, const PacketXi& b, const PacketXi& c) +{ + return svmla_s32_z(svptrue_b32(), c, a, b); +} + +template <> +EIGEN_STRONG_INLINE PacketXi pmin(const PacketXi& a, const PacketXi& b) +{ + return svmin_s32_z(svptrue_b32(), a, b); +} + +template <> +EIGEN_STRONG_INLINE PacketXi pmax(const PacketXi& a, const PacketXi& b) +{ + return svmax_s32_z(svptrue_b32(), a, b); +} + +template <> +EIGEN_STRONG_INLINE PacketXi pcmp_le(const PacketXi& a, const PacketXi& b) +{ + return svdup_n_s32_z(svcmplt_s32(svptrue_b32(), a, b), 0xffffffffu); +} + +template <> +EIGEN_STRONG_INLINE PacketXi pcmp_lt(const PacketXi& a, const PacketXi& b) +{ + return svdup_n_s32_z(svcmplt_s32(svptrue_b32(), a, b), 0xffffffffu); +} + +template <> +EIGEN_STRONG_INLINE PacketXi pcmp_eq(const PacketXi& a, const PacketXi& b) +{ + return svdup_n_s32_z(svcmpeq_s32(svptrue_b32(), a, b), 0xffffffffu); +} + +template <> +EIGEN_STRONG_INLINE PacketXi ptrue(const PacketXi& /*a*/) +{ + return svdup_n_s32_z(svptrue_b32(), 0xffffffffu); +} + +template <> +EIGEN_STRONG_INLINE PacketXi pzero(const PacketXi& /*a*/) +{ + return svdup_n_s32_z(svptrue_b32(), 0); +} + +template <> +EIGEN_STRONG_INLINE PacketXi pand(const PacketXi& a, const PacketXi& b) +{ + return svand_s32_z(svptrue_b32(), a, b); +} + +template <> +EIGEN_STRONG_INLINE PacketXi por(const PacketXi& a, const PacketXi& b) +{ + return svorr_s32_z(svptrue_b32(), a, b); +} + +template <> +EIGEN_STRONG_INLINE PacketXi pxor(const PacketXi& a, const PacketXi& b) +{ + return sveor_s32_z(svptrue_b32(), a, b); +} + +template <> +EIGEN_STRONG_INLINE PacketXi pandnot(const PacketXi& a, const PacketXi& b) +{ + return svbic_s32_z(svptrue_b32(), a, b); +} + +template +EIGEN_STRONG_INLINE PacketXi parithmetic_shift_right(PacketXi a) +{ + return svasrd_n_s32_z(svptrue_b32(), a, N); +} + +template +EIGEN_STRONG_INLINE PacketXi plogical_shift_right(PacketXi a) +{ + return svreinterpret_s32_u32(svlsr_u32_z(svptrue_b32(), svreinterpret_u32_s32(a), svdup_n_u32_z(svptrue_b32(), N))); +} + +template +EIGEN_STRONG_INLINE PacketXi plogical_shift_left(PacketXi a) +{ + return svlsl_s32_z(svptrue_b32(), a, svdup_n_u32_z(svptrue_b32(), N)); +} + +template <> +EIGEN_STRONG_INLINE PacketXi pload(const numext::int32_t* from) +{ + EIGEN_DEBUG_ALIGNED_LOAD return svld1_s32(svptrue_b32(), from); +} + +template <> +EIGEN_STRONG_INLINE PacketXi ploadu(const numext::int32_t* from) +{ + EIGEN_DEBUG_UNALIGNED_LOAD return svld1_s32(svptrue_b32(), from); +} + +template <> +EIGEN_STRONG_INLINE PacketXi ploaddup(const numext::int32_t* from) +{ + svuint32_t indices = svindex_u32(0, 1); // index {base=0, base+step=1, base+step*2, ...} + indices = svzip1_u32(indices, indices); // index in the format {a0, a0, a1, a1, a2, a2, ...} + return svld1_gather_u32index_s32(svptrue_b32(), from, indices); +} + +template <> +EIGEN_STRONG_INLINE PacketXi ploadquad(const numext::int32_t* from) +{ + svuint32_t indices = svindex_u32(0, 1); // index {base=0, base+step=1, base+step*2, ...} + indices = svzip1_u32(indices, indices); // index in the format {a0, a0, a1, a1, a2, a2, ...} + indices = svzip1_u32(indices, indices); // index in the format {a0, a0, a0, a0, a1, a1, a1, a1, ...} + return svld1_gather_u32index_s32(svptrue_b32(), from, indices); +} + +template <> +EIGEN_STRONG_INLINE void pstore(numext::int32_t* to, const PacketXi& from) +{ + EIGEN_DEBUG_ALIGNED_STORE svst1_s32(svptrue_b32(), to, from); +} + +template <> +EIGEN_STRONG_INLINE void pstoreu(numext::int32_t* to, const PacketXi& from) +{ + EIGEN_DEBUG_UNALIGNED_STORE svst1_s32(svptrue_b32(), to, from); +} + +template <> +EIGEN_DEVICE_FUNC inline PacketXi pgather(const numext::int32_t* from, Index stride) +{ + // Indice format: {base=0, base+stride, base+stride*2, base+stride*3, ...} + svint32_t indices = svindex_s32(0, stride); + return svld1_gather_s32index_s32(svptrue_b32(), from, indices); +} + +template <> +EIGEN_DEVICE_FUNC inline void pscatter(numext::int32_t* to, const PacketXi& from, Index stride) +{ + // Indice format: {base=0, base+stride, base+stride*2, base+stride*3, ...} + svint32_t indices = svindex_s32(0, stride); + svst1_scatter_s32index_s32(svptrue_b32(), to, indices, from); +} + +template <> +EIGEN_STRONG_INLINE numext::int32_t pfirst(const PacketXi& a) +{ + // svlasta returns the first element if all predicate bits are 0 + return svlasta_s32(svpfalse_b(), a); +} + +template <> +EIGEN_STRONG_INLINE PacketXi preverse(const PacketXi& a) +{ + return svrev_s32(a); +} + +template <> +EIGEN_STRONG_INLINE PacketXi pabs(const PacketXi& a) +{ + return svabs_s32_z(svptrue_b32(), a); +} + +template <> +EIGEN_STRONG_INLINE numext::int32_t predux(const PacketXi& a) +{ + return static_cast(svaddv_s32(svptrue_b32(), a)); +} + +template <> +EIGEN_STRONG_INLINE numext::int32_t predux_mul(const PacketXi& a) +{ + EIGEN_STATIC_ASSERT((EIGEN_ARM64_SVE_VL % 128 == 0), + EIGEN_INTERNAL_ERROR_PLEASE_FILE_A_BUG_REPORT); + + // Multiply the vector by its reverse + svint32_t prod = svmul_s32_z(svptrue_b32(), a, svrev_s32(a)); + svint32_t half_prod; + + // Extract the high half of the vector. Depending on the VL more reductions need to be done + if (EIGEN_ARM64_SVE_VL >= 2048) { + half_prod = svtbl_s32(prod, svindex_u32(32, 1)); + prod = svmul_s32_z(svptrue_b32(), prod, half_prod); + } + if (EIGEN_ARM64_SVE_VL >= 1024) { + half_prod = svtbl_s32(prod, svindex_u32(16, 1)); + prod = svmul_s32_z(svptrue_b32(), prod, half_prod); + } + if (EIGEN_ARM64_SVE_VL >= 512) { + half_prod = svtbl_s32(prod, svindex_u32(8, 1)); + prod = svmul_s32_z(svptrue_b32(), prod, half_prod); + } + if (EIGEN_ARM64_SVE_VL >= 256) { + half_prod = svtbl_s32(prod, svindex_u32(4, 1)); + prod = svmul_s32_z(svptrue_b32(), prod, half_prod); + } + // Last reduction + half_prod = svtbl_s32(prod, svindex_u32(2, 1)); + prod = svmul_s32_z(svptrue_b32(), prod, half_prod); + + // The reduction is done to the first element. + return pfirst(prod); +} + +template <> +EIGEN_STRONG_INLINE numext::int32_t predux_min(const PacketXi& a) +{ + return svminv_s32(svptrue_b32(), a); +} + +template <> +EIGEN_STRONG_INLINE numext::int32_t predux_max(const PacketXi& a) +{ + return svmaxv_s32(svptrue_b32(), a); +} + +template +EIGEN_DEVICE_FUNC inline void ptranspose(PacketBlock& kernel) { + int buffer[packet_traits::size * N] = {0}; + int i = 0; + + PacketXi stride_index = svindex_s32(0, N); + + for (i = 0; i < N; i++) { + svst1_scatter_s32index_s32(svptrue_b32(), buffer + i, stride_index, kernel.packet[i]); + } + for (i = 0; i < N; i++) { + kernel.packet[i] = svld1_s32(svptrue_b32(), buffer + i * packet_traits::size); + } +} + +/********************************* float32 ************************************/ + +typedef svfloat32_t PacketXf __attribute__((arm_sve_vector_bits(EIGEN_ARM64_SVE_VL))); + +template <> +struct packet_traits : default_packet_traits { + typedef PacketXf type; + typedef PacketXf half; + + enum { + Vectorizable = 1, + AlignedOnScalar = 1, + size = sve_packet_size_selector::size, + HasHalfPacket = 0, + + HasAdd = 1, + HasSub = 1, + HasShift = 1, + HasMul = 1, + HasNegate = 1, + HasAbs = 1, + HasArg = 0, + HasAbs2 = 1, + HasMin = 1, + HasMax = 1, + HasConj = 1, + HasSetLinear = 0, + HasBlend = 0, + HasReduxp = 0, // Not implemented in SVE + + HasDiv = 1, + HasFloor = 1, + + HasSin = EIGEN_FAST_MATH, + HasCos = EIGEN_FAST_MATH, + HasLog = 1, + HasExp = 1, + HasSqrt = 0, + HasTanh = EIGEN_FAST_MATH, + HasErf = EIGEN_FAST_MATH + }; +}; + +template <> +struct unpacket_traits { + typedef float type; + typedef PacketXf half; // Half not yet implemented + typedef PacketXi integer_packet; + + enum { + size = sve_packet_size_selector::size, + alignment = Aligned64, + vectorizable = true, + masked_load_available = false, + masked_store_available = false + }; +}; + +template <> +EIGEN_STRONG_INLINE PacketXf pset1(const float& from) +{ + return svdup_n_f32(from); +} + +template <> +EIGEN_STRONG_INLINE PacketXf pset1frombits(numext::uint32_t from) +{ + return svreinterpret_f32_u32(svdup_n_u32_z(svptrue_b32(), from)); +} + +template <> +EIGEN_STRONG_INLINE PacketXf plset(const float& a) +{ + float c[packet_traits::size]; + for (int i = 0; i < packet_traits::size; i++) c[i] = i; + return svadd_f32_z(svptrue_b32(), pset1(a), svld1_f32(svptrue_b32(), c)); +} + +template <> +EIGEN_STRONG_INLINE PacketXf padd(const PacketXf& a, const PacketXf& b) +{ + return svadd_f32_z(svptrue_b32(), a, b); +} + +template <> +EIGEN_STRONG_INLINE PacketXf psub(const PacketXf& a, const PacketXf& b) +{ + return svsub_f32_z(svptrue_b32(), a, b); +} + +template <> +EIGEN_STRONG_INLINE PacketXf pnegate(const PacketXf& a) +{ + return svneg_f32_z(svptrue_b32(), a); +} + +template <> +EIGEN_STRONG_INLINE PacketXf pconj(const PacketXf& a) +{ + return a; +} + +template <> +EIGEN_STRONG_INLINE PacketXf pmul(const PacketXf& a, const PacketXf& b) +{ + return svmul_f32_z(svptrue_b32(), a, b); +} + +template <> +EIGEN_STRONG_INLINE PacketXf pdiv(const PacketXf& a, const PacketXf& b) +{ + return svdiv_f32_z(svptrue_b32(), a, b); +} + +template <> +EIGEN_STRONG_INLINE PacketXf pmadd(const PacketXf& a, const PacketXf& b, const PacketXf& c) +{ + return svmla_f32_z(svptrue_b32(), c, a, b); +} + +template <> +EIGEN_STRONG_INLINE PacketXf pmin(const PacketXf& a, const PacketXf& b) +{ + return svmin_f32_z(svptrue_b32(), a, b); +} + +template <> +EIGEN_STRONG_INLINE PacketXf pmin(const PacketXf& a, const PacketXf& b) +{ + return pmin(a, b); +} + +template <> +EIGEN_STRONG_INLINE PacketXf pmin(const PacketXf& a, const PacketXf& b) +{ + return svminnm_f32_z(svptrue_b32(), a, b); +} + +template <> +EIGEN_STRONG_INLINE PacketXf pmax(const PacketXf& a, const PacketXf& b) +{ + return svmax_f32_z(svptrue_b32(), a, b); +} + +template <> +EIGEN_STRONG_INLINE PacketXf pmax(const PacketXf& a, const PacketXf& b) +{ + return pmax(a, b); +} + +template <> +EIGEN_STRONG_INLINE PacketXf pmax(const PacketXf& a, const PacketXf& b) +{ + return svmaxnm_f32_z(svptrue_b32(), a, b); +} + +// Float comparisons in SVE return svbool (predicate). Use svdup to set active +// lanes to 1 (0xffffffffu) and inactive lanes to 0. +template <> +EIGEN_STRONG_INLINE PacketXf pcmp_le(const PacketXf& a, const PacketXf& b) +{ + return svreinterpret_f32_u32(svdup_n_u32_z(svcmplt_f32(svptrue_b32(), a, b), 0xffffffffu)); +} + +template <> +EIGEN_STRONG_INLINE PacketXf pcmp_lt(const PacketXf& a, const PacketXf& b) +{ + return svreinterpret_f32_u32(svdup_n_u32_z(svcmplt_f32(svptrue_b32(), a, b), 0xffffffffu)); +} + +template <> +EIGEN_STRONG_INLINE PacketXf pcmp_eq(const PacketXf& a, const PacketXf& b) +{ + return svreinterpret_f32_u32(svdup_n_u32_z(svcmpeq_f32(svptrue_b32(), a, b), 0xffffffffu)); +} + +// Do a predicate inverse (svnot_b_z) on the predicate resulted from the +// greater/equal comparison (svcmpge_f32). Then fill a float vector with the +// active elements. +template <> +EIGEN_STRONG_INLINE PacketXf pcmp_lt_or_nan(const PacketXf& a, const PacketXf& b) +{ + return svreinterpret_f32_u32(svdup_n_u32_z(svnot_b_z(svptrue_b32(), svcmpge_f32(svptrue_b32(), a, b)), 0xffffffffu)); +} + +template <> +EIGEN_STRONG_INLINE PacketXf pfloor(const PacketXf& a) +{ + return svrintm_f32_z(svptrue_b32(), a); +} + +template <> +EIGEN_STRONG_INLINE PacketXf ptrue(const PacketXf& /*a*/) +{ + return svreinterpret_f32_u32(svdup_n_u32_z(svptrue_b32(), 0xffffffffu)); +} + +// Logical Operations are not supported for float, so reinterpret casts +template <> +EIGEN_STRONG_INLINE PacketXf pand(const PacketXf& a, const PacketXf& b) +{ + return svreinterpret_f32_u32(svand_u32_z(svptrue_b32(), svreinterpret_u32_f32(a), svreinterpret_u32_f32(b))); +} + +template <> +EIGEN_STRONG_INLINE PacketXf por(const PacketXf& a, const PacketXf& b) +{ + return svreinterpret_f32_u32(svorr_u32_z(svptrue_b32(), svreinterpret_u32_f32(a), svreinterpret_u32_f32(b))); +} + +template <> +EIGEN_STRONG_INLINE PacketXf pxor(const PacketXf& a, const PacketXf& b) +{ + return svreinterpret_f32_u32(sveor_u32_z(svptrue_b32(), svreinterpret_u32_f32(a), svreinterpret_u32_f32(b))); +} + +template <> +EIGEN_STRONG_INLINE PacketXf pandnot(const PacketXf& a, const PacketXf& b) +{ + return svreinterpret_f32_u32(svbic_u32_z(svptrue_b32(), svreinterpret_u32_f32(a), svreinterpret_u32_f32(b))); +} + +template <> +EIGEN_STRONG_INLINE PacketXf pload(const float* from) +{ + EIGEN_DEBUG_ALIGNED_LOAD return svld1_f32(svptrue_b32(), from); +} + +template <> +EIGEN_STRONG_INLINE PacketXf ploadu(const float* from) +{ + EIGEN_DEBUG_UNALIGNED_LOAD return svld1_f32(svptrue_b32(), from); +} + +template <> +EIGEN_STRONG_INLINE PacketXf ploaddup(const float* from) +{ + svuint32_t indices = svindex_u32(0, 1); // index {base=0, base+step=1, base+step*2, ...} + indices = svzip1_u32(indices, indices); // index in the format {a0, a0, a1, a1, a2, a2, ...} + return svld1_gather_u32index_f32(svptrue_b32(), from, indices); +} + +template <> +EIGEN_STRONG_INLINE PacketXf ploadquad(const float* from) +{ + svuint32_t indices = svindex_u32(0, 1); // index {base=0, base+step=1, base+step*2, ...} + indices = svzip1_u32(indices, indices); // index in the format {a0, a0, a1, a1, a2, a2, ...} + indices = svzip1_u32(indices, indices); // index in the format {a0, a0, a0, a0, a1, a1, a1, a1, ...} + return svld1_gather_u32index_f32(svptrue_b32(), from, indices); +} + +template <> +EIGEN_STRONG_INLINE void pstore(float* to, const PacketXf& from) +{ + EIGEN_DEBUG_ALIGNED_STORE svst1_f32(svptrue_b32(), to, from); +} + +template <> +EIGEN_STRONG_INLINE void pstoreu(float* to, const PacketXf& from) +{ + EIGEN_DEBUG_UNALIGNED_STORE svst1_f32(svptrue_b32(), to, from); +} + +template <> +EIGEN_DEVICE_FUNC inline PacketXf pgather(const float* from, Index stride) +{ + // Indice format: {base=0, base+stride, base+stride*2, base+stride*3, ...} + svint32_t indices = svindex_s32(0, stride); + return svld1_gather_s32index_f32(svptrue_b32(), from, indices); +} + +template <> +EIGEN_DEVICE_FUNC inline void pscatter(float* to, const PacketXf& from, Index stride) +{ + // Indice format: {base=0, base+stride, base+stride*2, base+stride*3, ...} + svint32_t indices = svindex_s32(0, stride); + svst1_scatter_s32index_f32(svptrue_b32(), to, indices, from); +} + +template <> +EIGEN_STRONG_INLINE float pfirst(const PacketXf& a) +{ + // svlasta returns the first element if all predicate bits are 0 + return svlasta_f32(svpfalse_b(), a); +} + +template <> +EIGEN_STRONG_INLINE PacketXf preverse(const PacketXf& a) +{ + return svrev_f32(a); +} + +template <> +EIGEN_STRONG_INLINE PacketXf pabs(const PacketXf& a) +{ + return svabs_f32_z(svptrue_b32(), a); +} + +// TODO(tellenbach): Should this go into MathFunctions.h? If so, change for +// all vector extensions and the generic version. +template <> +EIGEN_STRONG_INLINE PacketXf pfrexp(const PacketXf& a, PacketXf& exponent) +{ + return pfrexp_generic(a, exponent); +} + +template <> +EIGEN_STRONG_INLINE float predux(const PacketXf& a) +{ + return svaddv_f32(svptrue_b32(), a); +} + +// Other reduction functions: +// mul +// Only works for SVE Vls multiple of 128 +template <> +EIGEN_STRONG_INLINE float predux_mul(const PacketXf& a) +{ + EIGEN_STATIC_ASSERT((EIGEN_ARM64_SVE_VL % 128 == 0), + EIGEN_INTERNAL_ERROR_PLEASE_FILE_A_BUG_REPORT); + // Multiply the vector by its reverse + svfloat32_t prod = svmul_f32_z(svptrue_b32(), a, svrev_f32(a)); + svfloat32_t half_prod; + + // Extract the high half of the vector. Depending on the VL more reductions need to be done + if (EIGEN_ARM64_SVE_VL >= 2048) { + half_prod = svtbl_f32(prod, svindex_u32(32, 1)); + prod = svmul_f32_z(svptrue_b32(), prod, half_prod); + } + if (EIGEN_ARM64_SVE_VL >= 1024) { + half_prod = svtbl_f32(prod, svindex_u32(16, 1)); + prod = svmul_f32_z(svptrue_b32(), prod, half_prod); + } + if (EIGEN_ARM64_SVE_VL >= 512) { + half_prod = svtbl_f32(prod, svindex_u32(8, 1)); + prod = svmul_f32_z(svptrue_b32(), prod, half_prod); + } + if (EIGEN_ARM64_SVE_VL >= 256) { + half_prod = svtbl_f32(prod, svindex_u32(4, 1)); + prod = svmul_f32_z(svptrue_b32(), prod, half_prod); + } + // Last reduction + half_prod = svtbl_f32(prod, svindex_u32(2, 1)); + prod = svmul_f32_z(svptrue_b32(), prod, half_prod); + + // The reduction is done to the first element. + return pfirst(prod); +} + +template <> +EIGEN_STRONG_INLINE float predux_min(const PacketXf& a) +{ + return svminv_f32(svptrue_b32(), a); +} + +template <> +EIGEN_STRONG_INLINE float predux_max(const PacketXf& a) +{ + return svmaxv_f32(svptrue_b32(), a); +} + +template +EIGEN_DEVICE_FUNC inline void ptranspose(PacketBlock& kernel) +{ + float buffer[packet_traits::size * N] = {0}; + int i = 0; + + PacketXi stride_index = svindex_s32(0, N); + + for (i = 0; i < N; i++) { + svst1_scatter_s32index_f32(svptrue_b32(), buffer + i, stride_index, kernel.packet[i]); + } + + for (i = 0; i < N; i++) { + kernel.packet[i] = svld1_f32(svptrue_b32(), buffer + i * packet_traits::size); + } +} + +template<> +EIGEN_STRONG_INLINE PacketXf pldexp(const PacketXf& a, const PacketXf& exponent) +{ + return pldexp_generic(a, exponent); +} + +} // namespace internal +} // namespace Eigen + +#endif // EIGEN_PACKET_MATH_SVE_H diff --git a/Eigen/src/Core/arch/SVE/TypeCasting.h b/Eigen/src/Core/arch/SVE/TypeCasting.h new file mode 100644 index 0000000..7ba5d9c --- /dev/null +++ b/Eigen/src/Core/arch/SVE/TypeCasting.h @@ -0,0 +1,49 @@ +// This file is part of Eigen, a lightweight C++ template library +// for linear algebra. +// +// Copyright (C) 2020, Arm Limited and Contributors +// +// This Source Code Form is subject to the terms of the Mozilla +// Public License v. 2.0. If a copy of the MPL was not distributed +// with this file, You can obtain one at http://mozilla.org/MPL/2.0/. + +#ifndef EIGEN_TYPE_CASTING_SVE_H +#define EIGEN_TYPE_CASTING_SVE_H + +namespace Eigen { +namespace internal { + +template <> +struct type_casting_traits { + enum { VectorizedCast = 1, SrcCoeffRatio = 1, TgtCoeffRatio = 1 }; +}; + +template <> +struct type_casting_traits { + enum { VectorizedCast = 1, SrcCoeffRatio = 1, TgtCoeffRatio = 1 }; +}; + +template <> +EIGEN_STRONG_INLINE PacketXf pcast(const PacketXi& a) { + return svcvt_f32_s32_z(svptrue_b32(), a); +} + +template <> +EIGEN_STRONG_INLINE PacketXi pcast(const PacketXf& a) { + return svcvt_s32_f32_z(svptrue_b32(), a); +} + +template <> +EIGEN_STRONG_INLINE PacketXf preinterpret(const PacketXi& a) { + return svreinterpret_f32_s32(a); +} + +template <> +EIGEN_STRONG_INLINE PacketXi preinterpret(const PacketXf& a) { + return svreinterpret_s32_f32(a); +} + +} // namespace internal +} // namespace Eigen + +#endif // EIGEN_TYPE_CASTING_SVE_H diff --git a/Eigen/src/Core/arch/SYCL/InteropHeaders.h b/Eigen/src/Core/arch/SYCL/InteropHeaders.h new file mode 100644 index 0000000..10856ff --- /dev/null +++ b/Eigen/src/Core/arch/SYCL/InteropHeaders.h @@ -0,0 +1,232 @@ +// This file is part of Eigen, a lightweight C++ template library +// for linear algebra. +// +// Mehdi Goli Codeplay Software Ltd. +// Ralph Potter Codeplay Software Ltd. +// Luke Iwanski Codeplay Software Ltd. +// Contact: +// +// This Source Code Form is subject to the terms of the Mozilla +// Public License v. 2.0. If a copy of the MPL was not distributed +// with this file, You can obtain one at http://mozilla.org/MPL/2.0/. + +/***************************************************************** + * InteropHeaders.h + * + * \brief: + * InteropHeaders + * + *****************************************************************/ + +#ifndef EIGEN_INTEROP_HEADERS_SYCL_H +#define EIGEN_INTEROP_HEADERS_SYCL_H + +namespace Eigen { + +#if !defined(EIGEN_DONT_VECTORIZE_SYCL) + +namespace internal { + +template +struct sycl_packet_traits : default_packet_traits { + enum { + Vectorizable = 1, + AlignedOnScalar = 1, + size = lengths, + HasHalfPacket = 0, + HasDiv = 1, + HasLog = 1, + HasExp = 1, + HasSqrt = 1, + HasRsqrt = 1, + HasSin = 1, + HasCos = 1, + HasTan = 1, + HasASin = 1, + HasACos = 1, + HasATan = 1, + HasSinh = 1, + HasCosh = 1, + HasTanh = 1, + HasLGamma = 0, + HasDiGamma = 0, + HasZeta = 0, + HasPolygamma = 0, + HasErf = 0, + HasErfc = 0, + HasNdtri = 0, + HasIGamma = 0, + HasIGammac = 0, + HasBetaInc = 0, + HasBlend = has_blend, + // This flag is used to indicate whether packet comparison is supported. + // pcmp_eq, pcmp_lt and pcmp_le should be defined for it to be true. + HasCmp = 1, + HasMax = 1, + HasMin = 1, + HasMul = 1, + HasAdd = 1, + HasFloor = 1, + HasRound = 1, + HasRint = 1, + HasLog1p = 1, + HasExpm1 = 1, + HasCeil = 1, + }; +}; + +#ifdef SYCL_DEVICE_ONLY +#define SYCL_PACKET_TRAITS(packet_type, has_blend, unpacket_type, lengths) \ + template <> \ + struct packet_traits \ + : sycl_packet_traits { \ + typedef packet_type type; \ + typedef packet_type half; \ + }; + +SYCL_PACKET_TRAITS(cl::sycl::cl_float4, 1, float, 4) +SYCL_PACKET_TRAITS(cl::sycl::cl_float4, 1, const float, 4) +SYCL_PACKET_TRAITS(cl::sycl::cl_double2, 0, double, 2) +SYCL_PACKET_TRAITS(cl::sycl::cl_double2, 0, const double, 2) +#undef SYCL_PACKET_TRAITS + +// Make sure this is only available when targeting a GPU: we don't want to +// introduce conflicts between these packet_traits definitions and the ones +// we'll use on the host side (SSE, AVX, ...) +#define SYCL_ARITHMETIC(packet_type) \ + template <> \ + struct is_arithmetic { \ + enum { value = true }; \ + }; +SYCL_ARITHMETIC(cl::sycl::cl_float4) +SYCL_ARITHMETIC(cl::sycl::cl_double2) +#undef SYCL_ARITHMETIC + +#define SYCL_UNPACKET_TRAITS(packet_type, unpacket_type, lengths) \ + template <> \ + struct unpacket_traits { \ + typedef unpacket_type type; \ + enum { size = lengths, vectorizable = true, alignment = Aligned16 }; \ + typedef packet_type half; \ + }; +SYCL_UNPACKET_TRAITS(cl::sycl::cl_float4, float, 4) +SYCL_UNPACKET_TRAITS(cl::sycl::cl_double2, double, 2) + +#undef SYCL_UNPACKET_TRAITS +#endif + +} // end namespace internal + +#endif + +namespace TensorSycl { +namespace internal { + +template +struct PacketWrapper; +// This function should never get called on the device +#ifndef SYCL_DEVICE_ONLY +template +struct PacketWrapper { + typedef typename ::Eigen::internal::unpacket_traits::type + Scalar; + template + EIGEN_DEVICE_FUNC static Scalar scalarize(Index, PacketReturnType &) { + eigen_assert(false && "THERE IS NO PACKETIZE VERSION FOR THE CHOSEN TYPE"); + abort(); + } + EIGEN_DEVICE_FUNC static PacketReturnType convert_to_packet_type(Scalar in, + Scalar) { + return ::Eigen::internal::template plset(in); + } + EIGEN_DEVICE_FUNC static void set_packet(PacketReturnType, Scalar *) { + eigen_assert(false && "THERE IS NO PACKETIZE VERSION FOR THE CHOSEN TYPE"); + abort(); + } +}; + +#elif defined(SYCL_DEVICE_ONLY) +template +struct PacketWrapper { + typedef typename ::Eigen::internal::unpacket_traits::type + Scalar; + template + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE static Scalar scalarize(Index index, PacketReturnType &in) { + switch (index) { + case 0: + return in.x(); + case 1: + return in.y(); + case 2: + return in.z(); + case 3: + return in.w(); + default: + //INDEX MUST BE BETWEEN 0 and 3.There is no abort function in SYCL kernel. so we cannot use abort here. + // The code will never reach here + __builtin_unreachable(); + } + __builtin_unreachable(); + } + + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE static PacketReturnType convert_to_packet_type( + Scalar in, Scalar other) { + return PacketReturnType(in, other, other, other); + } + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE static void set_packet(PacketReturnType &lhs, Scalar *rhs) { + lhs = PacketReturnType(rhs[0], rhs[1], rhs[2], rhs[3]); + } +}; + +template +struct PacketWrapper { + typedef typename ::Eigen::internal::unpacket_traits::type + Scalar; + template + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE static Scalar scalarize(Index, PacketReturnType &in) { + return in; + } + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE static PacketReturnType convert_to_packet_type(Scalar in, + Scalar) { + return PacketReturnType(in); + } + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE static void set_packet(PacketReturnType &lhs, Scalar *rhs) { + lhs = rhs[0]; + } +}; + +template +struct PacketWrapper { + typedef typename ::Eigen::internal::unpacket_traits::type + Scalar; + template + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE static Scalar scalarize(Index index, PacketReturnType &in) { + switch (index) { + case 0: + return in.x(); + case 1: + return in.y(); + default: + //INDEX MUST BE BETWEEN 0 and 1.There is no abort function in SYCL kernel. so we cannot use abort here. + // The code will never reach here + __builtin_unreachable(); + } + __builtin_unreachable(); + } + + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE static PacketReturnType convert_to_packet_type( + Scalar in, Scalar other) { + return PacketReturnType(in, other); + } + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE static void set_packet(PacketReturnType &lhs, Scalar *rhs) { + lhs = PacketReturnType(rhs[0], rhs[1]); + } +}; + +#endif + +} // end namespace internal +} // end namespace TensorSycl +} // end namespace Eigen + +#endif // EIGEN_INTEROP_HEADERS_SYCL_H diff --git a/Eigen/src/Core/arch/SYCL/MathFunctions.h b/Eigen/src/Core/arch/SYCL/MathFunctions.h new file mode 100644 index 0000000..2ab0f2a --- /dev/null +++ b/Eigen/src/Core/arch/SYCL/MathFunctions.h @@ -0,0 +1,301 @@ +// This file is part of Eigen, a lightweight C++ template library +// for linear algebra. +// +// Mehdi Goli Codeplay Software Ltd. +// Ralph Potter Codeplay Software Ltd. +// Luke Iwanski Codeplay Software Ltd. +// Contact: +// +// This Source Code Form is subject to the terms of the Mozilla +// Public License v. 2.0. If a copy of the MPL was not distributed +// with this file, You can obtain one at http://mozilla.org/MPL/2.0/. + +/***************************************************************** + * MathFunctions.h + * + * \brief: + * MathFunctions + * + *****************************************************************/ + +#ifndef EIGEN_MATH_FUNCTIONS_SYCL_H +#define EIGEN_MATH_FUNCTIONS_SYCL_H +namespace Eigen { + +namespace internal { + +// Make sure this is only available when targeting a GPU: we don't want to +// introduce conflicts between these packet_traits definitions and the ones +// we'll use on the host side (SSE, AVX, ...) +#if defined(SYCL_DEVICE_ONLY) +#define SYCL_PLOG(packet_type) \ + template <> \ + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE packet_type plog( \ + const packet_type& a) { \ + return cl::sycl::log(a); \ + } + +SYCL_PLOG(cl::sycl::cl_float4) +SYCL_PLOG(cl::sycl::cl_double2) +#undef SYCL_PLOG + +#define SYCL_PLOG1P(packet_type) \ + template <> \ + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE packet_type plog1p( \ + const packet_type& a) { \ + return cl::sycl::log1p(a); \ + } + +SYCL_PLOG1P(cl::sycl::cl_float4) +SYCL_PLOG1P(cl::sycl::cl_double2) +#undef SYCL_PLOG1P + +#define SYCL_PLOG10(packet_type) \ + template <> \ + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE packet_type plog10( \ + const packet_type& a) { \ + return cl::sycl::log10(a); \ + } + +SYCL_PLOG10(cl::sycl::cl_float4) +SYCL_PLOG10(cl::sycl::cl_double2) +#undef SYCL_PLOG10 + +#define SYCL_PEXP(packet_type) \ + template <> \ + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE packet_type pexp( \ + const packet_type& a) { \ + return cl::sycl::exp(a); \ + } + +SYCL_PEXP(cl::sycl::cl_float4) +SYCL_PEXP(cl::sycl::cl_float) +SYCL_PEXP(cl::sycl::cl_double2) +#undef SYCL_PEXP + +#define SYCL_PEXPM1(packet_type) \ + template <> \ + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE packet_type pexpm1( \ + const packet_type& a) { \ + return cl::sycl::expm1(a); \ + } + +SYCL_PEXPM1(cl::sycl::cl_float4) +SYCL_PEXPM1(cl::sycl::cl_double2) +#undef SYCL_PEXPM1 + +#define SYCL_PSQRT(packet_type) \ + template <> \ + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE packet_type psqrt( \ + const packet_type& a) { \ + return cl::sycl::sqrt(a); \ + } + +SYCL_PSQRT(cl::sycl::cl_float4) +SYCL_PSQRT(cl::sycl::cl_double2) +#undef SYCL_PSQRT + +#define SYCL_PRSQRT(packet_type) \ + template <> \ + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE packet_type prsqrt( \ + const packet_type& a) { \ + return cl::sycl::rsqrt(a); \ + } + +SYCL_PRSQRT(cl::sycl::cl_float4) +SYCL_PRSQRT(cl::sycl::cl_double2) +#undef SYCL_PRSQRT + +/** \internal \returns the hyperbolic sine of \a a (coeff-wise) */ +#define SYCL_PSIN(packet_type) \ + template <> \ + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE packet_type psin( \ + const packet_type& a) { \ + return cl::sycl::sin(a); \ + } + +SYCL_PSIN(cl::sycl::cl_float4) +SYCL_PSIN(cl::sycl::cl_double2) +#undef SYCL_PSIN + +/** \internal \returns the hyperbolic cosine of \a a (coeff-wise) */ +#define SYCL_PCOS(packet_type) \ + template <> \ + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE packet_type pcos( \ + const packet_type& a) { \ + return cl::sycl::cos(a); \ + } + +SYCL_PCOS(cl::sycl::cl_float4) +SYCL_PCOS(cl::sycl::cl_double2) +#undef SYCL_PCOS + +/** \internal \returns the hyperbolic tan of \a a (coeff-wise) */ +#define SYCL_PTAN(packet_type) \ + template <> \ + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE packet_type ptan( \ + const packet_type& a) { \ + return cl::sycl::tan(a); \ + } + +SYCL_PTAN(cl::sycl::cl_float4) +SYCL_PTAN(cl::sycl::cl_double2) +#undef SYCL_PTAN + +/** \internal \returns the hyperbolic sine of \a a (coeff-wise) */ +#define SYCL_PASIN(packet_type) \ + template <> \ + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE packet_type pasin( \ + const packet_type& a) { \ + return cl::sycl::asin(a); \ + } + +SYCL_PASIN(cl::sycl::cl_float4) +SYCL_PASIN(cl::sycl::cl_double2) +#undef SYCL_PASIN + +/** \internal \returns the hyperbolic cosine of \a a (coeff-wise) */ +#define SYCL_PACOS(packet_type) \ + template <> \ + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE packet_type pacos( \ + const packet_type& a) { \ + return cl::sycl::acos(a); \ + } + +SYCL_PACOS(cl::sycl::cl_float4) +SYCL_PACOS(cl::sycl::cl_double2) +#undef SYCL_PACOS + +/** \internal \returns the hyperbolic tan of \a a (coeff-wise) */ +#define SYCL_PATAN(packet_type) \ + template <> \ + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE packet_type patan( \ + const packet_type& a) { \ + return cl::sycl::atan(a); \ + } + +SYCL_PATAN(cl::sycl::cl_float4) +SYCL_PATAN(cl::sycl::cl_double2) +#undef SYCL_PATAN + +/** \internal \returns the hyperbolic sine of \a a (coeff-wise) */ +#define SYCL_PSINH(packet_type) \ + template <> \ + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE packet_type psinh( \ + const packet_type& a) { \ + return cl::sycl::sinh(a); \ + } + +SYCL_PSINH(cl::sycl::cl_float4) +SYCL_PSINH(cl::sycl::cl_double2) +#undef SYCL_PSINH + +/** \internal \returns the hyperbolic cosine of \a a (coeff-wise) */ +#define SYCL_PCOSH(packet_type) \ + template <> \ + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE packet_type pcosh( \ + const packet_type& a) { \ + return cl::sycl::cosh(a); \ + } + +SYCL_PCOSH(cl::sycl::cl_float4) +SYCL_PCOSH(cl::sycl::cl_double2) +#undef SYCL_PCOSH + +/** \internal \returns the hyperbolic tan of \a a (coeff-wise) */ +#define SYCL_PTANH(packet_type) \ + template <> \ + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE packet_type ptanh( \ + const packet_type& a) { \ + return cl::sycl::tanh(a); \ + } + +SYCL_PTANH(cl::sycl::cl_float4) +SYCL_PTANH(cl::sycl::cl_double2) +#undef SYCL_PTANH + +#define SYCL_PCEIL(packet_type) \ + template <> \ + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE packet_type pceil( \ + const packet_type& a) { \ + return cl::sycl::ceil(a); \ + } + +SYCL_PCEIL(cl::sycl::cl_float4) +SYCL_PCEIL(cl::sycl::cl_double2) +#undef SYCL_PCEIL + +#define SYCL_PROUND(packet_type) \ + template <> \ + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE packet_type pround( \ + const packet_type& a) { \ + return cl::sycl::round(a); \ + } + +SYCL_PROUND(cl::sycl::cl_float4) +SYCL_PROUND(cl::sycl::cl_double2) +#undef SYCL_PROUND + +#define SYCL_PRINT(packet_type) \ + template <> \ + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE packet_type print( \ + const packet_type& a) { \ + return cl::sycl::rint(a); \ + } + +SYCL_PRINT(cl::sycl::cl_float4) +SYCL_PRINT(cl::sycl::cl_double2) +#undef SYCL_PRINT + +#define SYCL_FLOOR(packet_type) \ + template <> \ + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE packet_type pfloor( \ + const packet_type& a) { \ + return cl::sycl::floor(a); \ + } + +SYCL_FLOOR(cl::sycl::cl_float4) +SYCL_FLOOR(cl::sycl::cl_double2) +#undef SYCL_FLOOR + +#define SYCL_PMIN(packet_type, expr) \ + template <> \ + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE packet_type pmin( \ + const packet_type& a, const packet_type& b) { \ + return expr; \ + } + +SYCL_PMIN(cl::sycl::cl_float4, cl::sycl::fmin(a, b)) +SYCL_PMIN(cl::sycl::cl_double2, cl::sycl::fmin(a, b)) +#undef SYCL_PMIN + +#define SYCL_PMAX(packet_type, expr) \ + template <> \ + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE packet_type pmax( \ + const packet_type& a, const packet_type& b) { \ + return expr; \ + } + +SYCL_PMAX(cl::sycl::cl_float4, cl::sycl::fmax(a, b)) +SYCL_PMAX(cl::sycl::cl_double2, cl::sycl::fmax(a, b)) +#undef SYCL_PMAX + +#define SYCL_PLDEXP(packet_type) \ + template <> \ + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE packet_type pldexp( \ + const packet_type& a, const packet_type& exponent) { \ + return cl::sycl::ldexp( \ + a, exponent.template convert()); \ + } + +SYCL_PLDEXP(cl::sycl::cl_float4) +SYCL_PLDEXP(cl::sycl::cl_double2) +#undef SYCL_PLDEXP + +#endif +} // end namespace internal + +} // end namespace Eigen + +#endif // EIGEN_MATH_FUNCTIONS_SYCL_H diff --git a/Eigen/src/Core/arch/SYCL/PacketMath.h b/Eigen/src/Core/arch/SYCL/PacketMath.h new file mode 100644 index 0000000..87badc0 --- /dev/null +++ b/Eigen/src/Core/arch/SYCL/PacketMath.h @@ -0,0 +1,670 @@ +// This file is part of Eigen, a lightweight C++ template library +// for linear algebra. +// +// Mehdi Goli Codeplay Software Ltd. +// Ralph Potter Codeplay Software Ltd. +// Luke Iwanski Codeplay Software Ltd. +// Contact: +// +// This Source Code Form is subject to the terms of the Mozilla +// Public License v. 2.0. If a copy of the MPL was not distributed +// with this file, You can obtain one at http://mozilla.org/MPL/2.0/. + +/***************************************************************** + * PacketMath.h + * + * \brief: + * PacketMath + * + *****************************************************************/ + +#ifndef EIGEN_PACKET_MATH_SYCL_H +#define EIGEN_PACKET_MATH_SYCL_H +#include +namespace Eigen { + +namespace internal { +#ifdef SYCL_DEVICE_ONLY + +#define SYCL_PLOADT_RO(address_space_target) \ + template \ + EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE packet_type ploadt_ro( \ + typename cl::sycl::multi_ptr< \ + const typename unpacket_traits::type, \ + cl::sycl::access::address_space::address_space_target>::pointer_t \ + from) { \ + typedef typename unpacket_traits::type scalar; \ + typedef cl::sycl::multi_ptr< \ + scalar, cl::sycl::access::address_space::address_space_target> \ + multi_ptr; \ + auto res = packet_type( \ + static_cast::type>(0)); \ + res.load(0, multi_ptr(const_cast(from))); \ + return res; \ + } + +SYCL_PLOADT_RO(global_space) +SYCL_PLOADT_RO(local_space) +#undef SYCL_PLOADT_RO +#endif + +template +EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE packet_type +ploadt_ro(const Eigen::TensorSycl::internal::RangeAccess< + cl::sycl::access::mode::read_write, T>& from) { + return ploadt_ro(from.get_pointer()); +} + +#ifdef SYCL_DEVICE_ONLY +#define SYCL_PLOAD(address_space_target, Alignment, AlignedType) \ + template \ + EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE packet_type pload##AlignedType( \ + typename cl::sycl::multi_ptr< \ + const typename unpacket_traits::type, \ + cl::sycl::access::address_space::address_space_target>::pointer_t \ + from) { \ + return ploadt_ro(from); \ + } + +// global space +SYCL_PLOAD(global_space, Unaligned, u) +SYCL_PLOAD(global_space, Aligned, ) +// local space +SYCL_PLOAD(local_space, Unaligned, u) +SYCL_PLOAD(local_space, Aligned, ) + +#undef SYCL_PLOAD +#endif + +#define SYCL_PLOAD(Alignment, AlignedType) \ + template \ + EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE packet_type pload##AlignedType( \ + const Eigen::TensorSycl::internal::RangeAccess< \ + cl::sycl::access::mode::read_write, \ + typename unpacket_traits::type> \ + from) { \ + return ploadt_ro(from); \ + } +SYCL_PLOAD(Unaligned, u) +SYCL_PLOAD(Aligned, ) +#undef SYCL_PLOAD + +#ifdef SYCL_DEVICE_ONLY +/** \internal \returns a packet version of \a *from. + * The pointer \a from must be aligned on a \a Alignment bytes boundary. */ +#define SYCL_PLOADT(address_space_target) \ + template \ + EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE packet_type ploadt( \ + typename cl::sycl::multi_ptr< \ + const typename unpacket_traits::type, \ + cl::sycl::access::address_space::address_space_target>::pointer_t \ + from) { \ + if (Alignment >= unpacket_traits::alignment) \ + return pload(from); \ + else \ + return ploadu(from); \ + } + +// global space +SYCL_PLOADT(global_space) +// local space +SYCL_PLOADT(local_space) +#undef SYCL_PLOADT +#endif + +template +EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE packet_type +ploadt(const Eigen::TensorSycl::internal::RangeAccess< + cl::sycl::access::mode::read_write, + typename unpacket_traits::type>& from) { + return ploadt(from.get_pointer()); +} +#ifdef SYCL_DEVICE_ONLY + +// private_space +#define SYCL_PLOADT_RO_SPECIAL(packet_type, Alignment) \ + template <> \ + EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE packet_type \ + ploadt_ro( \ + const typename unpacket_traits::type* from) { \ + typedef typename unpacket_traits::type scalar; \ + auto res = packet_type(static_cast(0)); \ + res.template load( \ + 0, const_cast(from)); \ + return res; \ + } + +SYCL_PLOADT_RO_SPECIAL(cl::sycl::cl_float4, Aligned) +SYCL_PLOADT_RO_SPECIAL(cl::sycl::cl_double2, Aligned) +SYCL_PLOADT_RO_SPECIAL(cl::sycl::cl_float4, Unaligned) +SYCL_PLOADT_RO_SPECIAL(cl::sycl::cl_double2, Unaligned) + +#define SYCL_PLOAD_SPECIAL(packet_type, alignment_type) \ + template <> \ + EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE packet_type pload##alignment_type( \ + const typename unpacket_traits::type* from) { \ + typedef typename unpacket_traits::type scalar; \ + auto res = packet_type(static_cast(0)); \ + res.template load( \ + 0, const_cast(from)); \ + return res; \ + } +SYCL_PLOAD_SPECIAL(cl::sycl::cl_float4, ) +SYCL_PLOAD_SPECIAL(cl::sycl::cl_double2, ) +SYCL_PLOAD_SPECIAL(cl::sycl::cl_float4, u) +SYCL_PLOAD_SPECIAL(cl::sycl::cl_double2, u) + +#undef SYCL_PLOAD_SPECIAL + +#define SYCL_PSTORE(scalar, packet_type, address_space_target, alignment) \ + template <> \ + EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE void pstore##alignment( \ + typename cl::sycl::multi_ptr< \ + scalar, \ + cl::sycl::access::address_space::address_space_target>::pointer_t \ + to, \ + const packet_type& from) { \ + typedef cl::sycl::multi_ptr< \ + scalar, cl::sycl::access::address_space::address_space_target> \ + multi_ptr; \ + from.store(0, multi_ptr(to)); \ + } + +// global space +SYCL_PSTORE(float, cl::sycl::cl_float4, global_space, ) +SYCL_PSTORE(float, cl::sycl::cl_float4, global_space, u) +SYCL_PSTORE(double, cl::sycl::cl_double2, global_space, ) +SYCL_PSTORE(double, cl::sycl::cl_double2, global_space, u) +SYCL_PSTORE(float, cl::sycl::cl_float4, local_space, ) +SYCL_PSTORE(float, cl::sycl::cl_float4, local_space, u) +SYCL_PSTORE(double, cl::sycl::cl_double2, local_space, ) +SYCL_PSTORE(double, cl::sycl::cl_double2, local_space, u) + +SYCL_PSTORE(float, cl::sycl::cl_float4, private_space, ) +SYCL_PSTORE(float, cl::sycl::cl_float4, private_space, u) +SYCL_PSTORE(double, cl::sycl::cl_double2, private_space, ) +SYCL_PSTORE(double, cl::sycl::cl_double2, private_space, u) +#undef SYCL_PSTORE + +#define SYCL_PSTORE_T(address_space_target) \ + template \ + EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE void pstoret( \ + typename cl::sycl::multi_ptr< \ + scalar, \ + cl::sycl::access::address_space::address_space_target>::pointer_t \ + to, \ + const packet_type& from) { \ + if (Alignment) \ + pstore(to, from); \ + else \ + pstoreu(to, from); \ + } + +SYCL_PSTORE_T(global_space) + +SYCL_PSTORE_T(local_space) + +#undef SYCL_PSTORE_T + +#define SYCL_PSET1(packet_type) \ + template <> \ + EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE packet_type pset1( \ + const typename unpacket_traits::type& from) { \ + return packet_type(from); \ + } + +// global space +SYCL_PSET1(cl::sycl::cl_float4) +SYCL_PSET1(cl::sycl::cl_double2) + +#undef SYCL_PSET1 + +template +struct get_base_packet { + template + static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE packet_type + get_ploaddup(sycl_multi_pointer) {} + + template + static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE packet_type + get_pgather(sycl_multi_pointer, Index) {} +}; + +template <> +struct get_base_packet { + template + static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE cl::sycl::cl_float4 get_ploaddup( + sycl_multi_pointer from) { + return cl::sycl::cl_float4(from[0], from[0], from[1], from[1]); + } + template + static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE cl::sycl::cl_float4 get_pgather( + sycl_multi_pointer from, Index stride) { + return cl::sycl::cl_float4(from[0 * stride], from[1 * stride], + from[2 * stride], from[3 * stride]); + } + + template + static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void set_pscatter( + sycl_multi_pointer to, const cl::sycl::cl_float4& from, Index stride) { + auto tmp = stride; + to[0] = from.x(); + to[tmp] = from.y(); + to[tmp += stride] = from.z(); + to[tmp += stride] = from.w(); + } + static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE cl::sycl::cl_float4 set_plset( + const float& a) { + return cl::sycl::cl_float4(static_cast(a), static_cast(a + 1), + static_cast(a + 2), + static_cast(a + 3)); + } +}; + +template <> +struct get_base_packet { + template + static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE cl::sycl::cl_double2 + get_ploaddup(const sycl_multi_pointer from) { + return cl::sycl::cl_double2(from[0], from[0]); + } + + template + static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE cl::sycl::cl_double2 get_pgather( + const sycl_multi_pointer from, Index stride) { + return cl::sycl::cl_double2(from[0 * stride], from[1 * stride]); + } + + template + static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void set_pscatter( + sycl_multi_pointer to, const cl::sycl::cl_double2& from, Index stride) { + to[0] = from.x(); + to[stride] = from.y(); + } + + static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE cl::sycl::cl_double2 set_plset( + const double& a) { + return cl::sycl::cl_double2(static_cast(a), + static_cast(a + 1)); + } +}; + +#define SYCL_PLOAD_DUP(address_space_target) \ + template \ + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE packet_type ploaddup( \ + typename cl::sycl::multi_ptr< \ + const typename unpacket_traits::type, \ + cl::sycl::access::address_space::address_space_target>::pointer_t \ + from) { \ + return get_base_packet::get_ploaddup(from); \ + } + +// global space +SYCL_PLOAD_DUP(global_space) +// local_space +SYCL_PLOAD_DUP(local_space) +#undef SYCL_PLOAD_DUP + +#define SYCL_PLOAD_DUP_SPECILIZE(packet_type) \ + template <> \ + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE packet_type ploaddup( \ + const typename unpacket_traits::type* from) { \ + return get_base_packet::get_ploaddup(from); \ + } + +SYCL_PLOAD_DUP_SPECILIZE(cl::sycl::cl_float4) +SYCL_PLOAD_DUP_SPECILIZE(cl::sycl::cl_double2) + +#undef SYCL_PLOAD_DUP_SPECILIZE + +#define SYCL_PLSET(packet_type) \ + template <> \ + EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE packet_type plset( \ + const typename unpacket_traits::type& a) { \ + return get_base_packet::set_plset(a); \ + } + +SYCL_PLSET(cl::sycl::cl_float4) +SYCL_PLSET(cl::sycl::cl_double2) + +#undef SYCL_PLSET + +#define SYCL_PGATHER(address_space_target) \ + template \ + EIGEN_DEVICE_FUNC inline packet_type pgather( \ + typename cl::sycl::multi_ptr< \ + const typename unpacket_traits::type, \ + cl::sycl::access::address_space::address_space_target>::pointer_t \ + from, \ + Index stride) { \ + return get_base_packet::get_pgather(from, stride); \ + } + +// global space +SYCL_PGATHER(global_space) +// local space +SYCL_PGATHER(local_space) + +#undef SYCL_PGATHER + +#define SYCL_PGATHER_SPECILIZE(scalar, packet_type) \ + template <> \ + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE packet_type \ + pgather( \ + const typename unpacket_traits::type* from, Index stride) { \ + return get_base_packet::get_pgather(from, stride); \ + } + +SYCL_PGATHER_SPECILIZE(float, cl::sycl::cl_float4) +SYCL_PGATHER_SPECILIZE(double, cl::sycl::cl_double2) + +#undef SYCL_PGATHER_SPECILIZE + +#define SYCL_PSCATTER(address_space_target) \ + template \ + EIGEN_DEVICE_FUNC inline void pscatter( \ + typename cl::sycl::multi_ptr< \ + typename unpacket_traits::type, \ + cl::sycl::access::address_space::address_space_target>::pointer_t \ + to, \ + const packet_type& from, Index stride) { \ + get_base_packet::set_pscatter(to, from, stride); \ + } + +// global space +SYCL_PSCATTER(global_space) +// local space +SYCL_PSCATTER(local_space) + +#undef SYCL_PSCATTER + +#define SYCL_PSCATTER_SPECILIZE(scalar, packet_type) \ + template <> \ + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void pscatter( \ + typename unpacket_traits::type * to, \ + const packet_type& from, Index stride) { \ + get_base_packet::set_pscatter(to, from, stride); \ + } + +SYCL_PSCATTER_SPECILIZE(float, cl::sycl::cl_float4) +SYCL_PSCATTER_SPECILIZE(double, cl::sycl::cl_double2) + +#undef SYCL_PSCATTER_SPECILIZE + +#define SYCL_PMAD(packet_type) \ + template <> \ + EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE packet_type pmadd( \ + const packet_type& a, const packet_type& b, const packet_type& c) { \ + return cl::sycl::mad(a, b, c); \ + } + +SYCL_PMAD(cl::sycl::cl_float4) +SYCL_PMAD(cl::sycl::cl_double2) +#undef SYCL_PMAD + +template <> +EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE float pfirst( + const cl::sycl::cl_float4& a) { + return a.x(); +} +template <> +EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE double pfirst( + const cl::sycl::cl_double2& a) { + return a.x(); +} + +template <> +EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE float predux( + const cl::sycl::cl_float4& a) { + return a.x() + a.y() + a.z() + a.w(); +} + +template <> +EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE double predux( + const cl::sycl::cl_double2& a) { + return a.x() + a.y(); +} + +template <> +EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE float predux_max( + const cl::sycl::cl_float4& a) { + return cl::sycl::fmax(cl::sycl::fmax(a.x(), a.y()), + cl::sycl::fmax(a.z(), a.w())); +} +template <> +EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE double predux_max( + const cl::sycl::cl_double2& a) { + return cl::sycl::fmax(a.x(), a.y()); +} + +template <> +EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE float predux_min( + const cl::sycl::cl_float4& a) { + return cl::sycl::fmin(cl::sycl::fmin(a.x(), a.y()), + cl::sycl::fmin(a.z(), a.w())); +} +template <> +EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE double predux_min( + const cl::sycl::cl_double2& a) { + return cl::sycl::fmin(a.x(), a.y()); +} + +template <> +EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE float predux_mul( + const cl::sycl::cl_float4& a) { + return a.x() * a.y() * a.z() * a.w(); +} +template <> +EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE double predux_mul( + const cl::sycl::cl_double2& a) { + return a.x() * a.y(); +} + +template <> +EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE cl::sycl::cl_float4 +pabs(const cl::sycl::cl_float4& a) { + return cl::sycl::cl_float4(cl::sycl::fabs(a.x()), cl::sycl::fabs(a.y()), + cl::sycl::fabs(a.z()), cl::sycl::fabs(a.w())); +} +template <> +EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE cl::sycl::cl_double2 +pabs(const cl::sycl::cl_double2& a) { + return cl::sycl::cl_double2(cl::sycl::fabs(a.x()), cl::sycl::fabs(a.y())); +} + +template +EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Packet sycl_pcmp_le(const Packet &a, + const Packet &b) { + return ((a <= b) + .template convert::type, + cl::sycl::rounding_mode::automatic>()); +} + +template +EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Packet sycl_pcmp_lt(const Packet &a, + const Packet &b) { + return ((a < b) + .template convert::type, + cl::sycl::rounding_mode::automatic>()); +} + +template +EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Packet sycl_pcmp_eq(const Packet &a, + const Packet &b) { + return ((a == b) + .template convert::type, + cl::sycl::rounding_mode::automatic>()); +} + +#define SYCL_PCMP(OP, TYPE) \ + template <> \ + EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE TYPE pcmp_##OP(const TYPE &a, \ + const TYPE &b) { \ + return sycl_pcmp_##OP(a, b); \ + } + +SYCL_PCMP(le, cl::sycl::cl_float4) +SYCL_PCMP(lt, cl::sycl::cl_float4) +SYCL_PCMP(eq, cl::sycl::cl_float4) +SYCL_PCMP(le, cl::sycl::cl_double2) +SYCL_PCMP(lt, cl::sycl::cl_double2) +SYCL_PCMP(eq, cl::sycl::cl_double2) +#undef SYCL_PCMP + +template struct convert_to_integer; + +template <> struct convert_to_integer { + using type = std::int32_t; + using packet_type = cl::sycl::cl_int4; +}; +template <> struct convert_to_integer { + using type = std::int64_t; + using packet_type = cl::sycl::cl_long2; +}; + +template +EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE typename convert_to_integer< + typename unpacket_traits::type>::packet_type +vector_as_int(const PacketIn &p) { + return ( + p.template convert::type>::type, + cl::sycl::rounding_mode::automatic>()); +} + +template +EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE packetOut +convert_vector(const PacketIn &p) { + return (p.template convert::type, + cl::sycl::rounding_mode::automatic>()); +} + +#define SYCL_PAND(TYPE) \ + template <> \ + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TYPE pand(const TYPE &a, \ + const TYPE &b) { \ + return convert_vector(vector_as_int(a) & vector_as_int(b)); \ + } +SYCL_PAND(cl::sycl::cl_float4) +SYCL_PAND(cl::sycl::cl_double2) +#undef SYCL_PAND + +#define SYCL_POR(TYPE) \ + template <> \ + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TYPE por(const TYPE &a, \ + const TYPE &b) { \ + return convert_vector(vector_as_int(a) | vector_as_int(b)); \ + } + +SYCL_POR(cl::sycl::cl_float4) +SYCL_POR(cl::sycl::cl_double2) +#undef SYCL_POR + +#define SYCL_PXOR(TYPE) \ + template <> \ + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TYPE pxor(const TYPE &a, \ + const TYPE &b) { \ + return convert_vector(vector_as_int(a) ^ vector_as_int(b)); \ + } + +SYCL_PXOR(cl::sycl::cl_float4) +SYCL_PXOR(cl::sycl::cl_double2) +#undef SYCL_PXOR + +#define SYCL_PANDNOT(TYPE) \ + template <> \ + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TYPE pandnot(const TYPE &a, \ + const TYPE &b) { \ + return convert_vector(vector_as_int(a) & (~vector_as_int(b))); \ + } +SYCL_PANDNOT(cl::sycl::cl_float4) +SYCL_PANDNOT(cl::sycl::cl_double2) +#undef SYCL_PANDNOT + +EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE void ptranspose( + PacketBlock& kernel) { + float tmp = kernel.packet[0].y(); + kernel.packet[0].y() = kernel.packet[1].x(); + kernel.packet[1].x() = tmp; + + tmp = kernel.packet[0].z(); + kernel.packet[0].z() = kernel.packet[2].x(); + kernel.packet[2].x() = tmp; + + tmp = kernel.packet[0].w(); + kernel.packet[0].w() = kernel.packet[3].x(); + kernel.packet[3].x() = tmp; + + tmp = kernel.packet[1].z(); + kernel.packet[1].z() = kernel.packet[2].y(); + kernel.packet[2].y() = tmp; + + tmp = kernel.packet[1].w(); + kernel.packet[1].w() = kernel.packet[3].y(); + kernel.packet[3].y() = tmp; + + tmp = kernel.packet[2].w(); + kernel.packet[2].w() = kernel.packet[3].z(); + kernel.packet[3].z() = tmp; +} + +EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE void ptranspose( + PacketBlock& kernel) { + double tmp = kernel.packet[0].y(); + kernel.packet[0].y() = kernel.packet[1].x(); + kernel.packet[1].x() = tmp; +} + +template <> +EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE cl::sycl::cl_float4 pblend( + const Selector::size>& ifPacket, + const cl::sycl::cl_float4& thenPacket, + const cl::sycl::cl_float4& elsePacket) { + cl::sycl::cl_int4 condition( + ifPacket.select[0] ? 0 : -1, ifPacket.select[1] ? 0 : -1, + ifPacket.select[2] ? 0 : -1, ifPacket.select[3] ? 0 : -1); + return cl::sycl::select(thenPacket, elsePacket, condition); +} + +template <> +inline cl::sycl::cl_double2 pblend( + const Selector::size>& ifPacket, + const cl::sycl::cl_double2& thenPacket, + const cl::sycl::cl_double2& elsePacket) { + cl::sycl::cl_long2 condition(ifPacket.select[0] ? 0 : -1, + ifPacket.select[1] ? 0 : -1); + return cl::sycl::select(thenPacket, elsePacket, condition); +} +#endif // SYCL_DEVICE_ONLY + +#define SYCL_PSTORE(alignment) \ + template \ + EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE void pstore##alignment( \ + const Eigen::TensorSycl::internal::RangeAccess< \ + cl::sycl::access::mode::read_write, \ + typename unpacket_traits::type>& to, \ + const packet_type& from) { \ + pstore##alignment(to.get_pointer(), from); \ + } + +// global space +SYCL_PSTORE() +SYCL_PSTORE(u) + +#undef SYCL_PSTORE + +template +EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE void pstoret( + Eigen::TensorSycl::internal::RangeAccess< + cl::sycl::access::mode::read_write, + typename unpacket_traits::type> + to, + const packet_type& from) { + pstoret(to.get_pointer(), from); +} + +} // end namespace internal + +} // end namespace Eigen + +#endif // EIGEN_PACKET_MATH_SYCL_H diff --git a/Eigen/src/Core/arch/SYCL/SyclMemoryModel.h b/Eigen/src/Core/arch/SYCL/SyclMemoryModel.h new file mode 100644 index 0000000..f81e59d --- /dev/null +++ b/Eigen/src/Core/arch/SYCL/SyclMemoryModel.h @@ -0,0 +1,694 @@ +/*************************************************************************** + * Copyright (C) 2017 Codeplay Software Limited + * This Source Code Form is subject to the terms of the Mozilla + * Public License v. 2.0. If a copy of the MPL was not distributed + * with this file, You can obtain one at http://mozilla.org/MPL/2.0/. + * + * + * SyclMemoryModel.h + * + * Description: + * Interface for SYCL buffers to behave as a non-dereferenceable pointer + * Interface for Placeholder accessor to behave as a pointer on both host + * and device + * + * Authors: + * + * Ruyman Reyes Codeplay Software Ltd. + * Mehdi Goli Codeplay Software Ltd. + * Vanya Yaneva Codeplay Software Ltd. + * + **************************************************************************/ + +#if defined(EIGEN_USE_SYCL) && \ + !defined(EIGEN_CXX11_TENSOR_TENSOR_SYCL_STORAGE_MEMORY_H) +#define EIGEN_CXX11_TENSOR_TENSOR_SYCL_STORAGE_MEMORY_H + +#include +#ifdef EIGEN_EXCEPTIONS +#include +#endif +#include +#include +#include +#include + +namespace Eigen { +namespace TensorSycl { +namespace internal { + +using sycl_acc_target = cl::sycl::access::target; +using sycl_acc_mode = cl::sycl::access::mode; + +/** + * Default values for template arguments + */ +using buffer_data_type_t = uint8_t; +const sycl_acc_target default_acc_target = sycl_acc_target::global_buffer; +const sycl_acc_mode default_acc_mode = sycl_acc_mode::read_write; + +/** + * PointerMapper + * Associates fake pointers with buffers. + * + */ +class PointerMapper { + public: + using base_ptr_t = std::intptr_t; + + /* Structure of a virtual pointer + * + * |================================================| + * | POINTER ADDRESS | + * |================================================| + */ + struct virtual_pointer_t { + /* Type for the pointers + */ + base_ptr_t m_contents; + + /** Conversions from virtual_pointer_t to + * void * should just reinterpret_cast the integer number + */ + operator void *() const { return reinterpret_cast(m_contents); } + + /** + * Convert back to the integer number. + */ + operator base_ptr_t() const { return m_contents; } + + /** + * Add a certain value to the pointer to create a + * new pointer to that offset + */ + virtual_pointer_t operator+(size_t off) { return m_contents + off; } + + /* Numerical order for sorting pointers in containers. */ + bool operator<(virtual_pointer_t rhs) const { + return (static_cast(m_contents) < + static_cast(rhs.m_contents)); + } + + bool operator>(virtual_pointer_t rhs) const { + return (static_cast(m_contents) > + static_cast(rhs.m_contents)); + } + + /** + * Numerical order for sorting pointers in containers + */ + bool operator==(virtual_pointer_t rhs) const { + return (static_cast(m_contents) == + static_cast(rhs.m_contents)); + } + + /** + * Simple forward to the equality overload. + */ + bool operator!=(virtual_pointer_t rhs) const { + return !(this->operator==(rhs)); + } + + /** + * Converts a void * into a virtual pointer structure. + * Note that this will only work if the void * was + * already a virtual_pointer_t, but we have no way of + * checking + */ + virtual_pointer_t(const void *ptr) + : m_contents(reinterpret_cast(ptr)){}; + + /** + * Creates a virtual_pointer_t from the given integer + * number + */ + virtual_pointer_t(base_ptr_t u) : m_contents(u){}; + }; + + /* Definition of a null pointer + */ + const virtual_pointer_t null_virtual_ptr = nullptr; + + /** + * Whether if a pointer is null or not. + * A pointer is nullptr if the value is of null_virtual_ptr + */ + static inline bool is_nullptr(virtual_pointer_t ptr) { + return (static_cast(ptr) == nullptr); + } + + /* basic type for all buffers + */ + using buffer_t = cl::sycl::buffer_mem; + + /** + * Node that stores information about a device allocation. + * Nodes are sorted by size to organise a free list of nodes + * that can be recovered. + */ + struct pMapNode_t { + buffer_t m_buffer; + size_t m_size; + bool m_free; + + pMapNode_t(buffer_t b, size_t size, bool f) + : m_buffer{b}, m_size{size}, m_free{f} { + m_buffer.set_final_data(nullptr); + } + + bool operator<=(const pMapNode_t &rhs) { return (m_size <= rhs.m_size); } + }; + + /** Storage of the pointer / buffer tree + */ + using pointerMap_t = std::map; + + /** + * Obtain the insertion point in the pointer map for + * a pointer of the given size. + * \param requiredSize Size attemted to reclaim + */ + typename pointerMap_t::iterator get_insertion_point(size_t requiredSize) { + typename pointerMap_t::iterator retVal; + bool reuse = false; + if (!m_freeList.empty()) { + // try to re-use an existing block + for (auto freeElem : m_freeList) { + if (freeElem->second.m_size >= requiredSize) { + retVal = freeElem; + reuse = true; + // Element is not going to be free anymore + m_freeList.erase(freeElem); + break; + } + } + } + if (!reuse) { + retVal = std::prev(m_pointerMap.end()); + } + return retVal; + } + + /** + * Returns an iterator to the node that stores the information + * of the given virtual pointer from the given pointer map structure. + * If pointer is not found, throws std::out_of_range. + * If the pointer map structure is empty, throws std::out_of_range + * + * \param pMap the pointerMap_t structure storing all the pointers + * \param virtual_pointer_ptr The virtual pointer to obtain the node of + * \throws std::out:of_range if the pointer is not found or pMap is empty + */ + typename pointerMap_t::iterator get_node(const virtual_pointer_t ptr) { + if (this->count() == 0) { + m_pointerMap.clear(); + EIGEN_THROW_X(std::out_of_range("There are no pointers allocated\n")); + + } + if (is_nullptr(ptr)) { + m_pointerMap.clear(); + EIGEN_THROW_X(std::out_of_range("Cannot access null pointer\n")); + } + // The previous element to the lower bound is the node that + // holds this memory address + auto node = m_pointerMap.lower_bound(ptr); + // If the value of the pointer is not the one of the node + // then we return the previous one + if (node == std::end(m_pointerMap)) { + --node; + } else if (node->first != ptr) { + if (node == std::begin(m_pointerMap)) { + m_pointerMap.clear(); + EIGEN_THROW_X( + std::out_of_range("The pointer is not registered in the map\n")); + + } + --node; + } + + return node; + } + + /* get_buffer. + * Returns a buffer from the map using the pointer address + */ + template + cl::sycl::buffer get_buffer( + const virtual_pointer_t ptr) { + using sycl_buffer_t = cl::sycl::buffer; + + // get_node() returns a `buffer_mem`, so we need to cast it to a `buffer<>`. + // We can do this without the `buffer_mem` being a pointer, as we + // only declare member variables in the base class (`buffer_mem`) and not in + // the child class (`buffer<>). + auto node = get_node(ptr); + eigen_assert(node->first == ptr || node->first < ptr); + eigen_assert(ptr < static_cast(node->second.m_size + + node->first)); + return *(static_cast(&node->second.m_buffer)); + } + + /** + * @brief Returns an accessor to the buffer of the given virtual pointer + * @param accessMode + * @param accessTarget + * @param ptr The virtual pointer + */ + template + cl::sycl::accessor + get_access(const virtual_pointer_t ptr) { + auto buf = get_buffer(ptr); + return buf.template get_access(); + } + + /** + * @brief Returns an accessor to the buffer of the given virtual pointer + * in the given command group scope + * @param accessMode + * @param accessTarget + * @param ptr The virtual pointer + * @param cgh Reference to the command group scope + */ + template + cl::sycl::accessor + get_access(const virtual_pointer_t ptr, cl::sycl::handler &cgh) { + auto buf = get_buffer(ptr); + return buf.template get_access(cgh); + } + + /* + * Returns the offset from the base address of this pointer. + */ + inline std::ptrdiff_t get_offset(const virtual_pointer_t ptr) { + // The previous element to the lower bound is the node that + // holds this memory address + auto node = get_node(ptr); + auto start = node->first; + eigen_assert(start == ptr || start < ptr); + eigen_assert(ptr < start + node->second.m_size); + return (ptr - start); + } + + /* + * Returns the number of elements by which the given pointer is offset from + * the base address. + */ + template + inline size_t get_element_offset(const virtual_pointer_t ptr) { + return get_offset(ptr) / sizeof(buffer_data_type); + } + + /** + * Constructs the PointerMapper structure. + */ + PointerMapper(base_ptr_t baseAddress = 4096) + : m_pointerMap{}, m_freeList{}, m_baseAddress{baseAddress} { + if (m_baseAddress == 0) { + EIGEN_THROW_X(std::invalid_argument("Base address cannot be zero\n")); + } + }; + + /** + * PointerMapper cannot be copied or moved + */ + PointerMapper(const PointerMapper &) = delete; + + /** + * Empty the pointer list + */ + inline void clear() { + m_freeList.clear(); + m_pointerMap.clear(); + } + + /* add_pointer. + * Adds an existing pointer to the map and returns the virtual pointer id. + */ + inline virtual_pointer_t add_pointer(const buffer_t &b) { + return add_pointer_impl(b); + } + + /* add_pointer. + * Adds a pointer to the map and returns the virtual pointer id. + */ + inline virtual_pointer_t add_pointer(buffer_t &&b) { + return add_pointer_impl(b); + } + + /** + * @brief Fuses the given node with the previous nodes in the + * pointer map if they are free + * + * @param node A reference to the free node to be fused + */ + void fuse_forward(typename pointerMap_t::iterator &node) { + while (node != std::prev(m_pointerMap.end())) { + // if following node is free + // remove it and extend the current node with its size + auto fwd_node = std::next(node); + if (!fwd_node->second.m_free) { + break; + } + auto fwd_size = fwd_node->second.m_size; + m_freeList.erase(fwd_node); + m_pointerMap.erase(fwd_node); + + node->second.m_size += fwd_size; + } + } + + /** + * @brief Fuses the given node with the following nodes in the + * pointer map if they are free + * + * @param node A reference to the free node to be fused + */ + void fuse_backward(typename pointerMap_t::iterator &node) { + while (node != m_pointerMap.begin()) { + // if previous node is free, extend it + // with the size of the current one + auto prev_node = std::prev(node); + if (!prev_node->second.m_free) { + break; + } + prev_node->second.m_size += node->second.m_size; + + // remove the current node + m_freeList.erase(node); + m_pointerMap.erase(node); + + // point to the previous node + node = prev_node; + } + } + + /* remove_pointer. + * Removes the given pointer from the map. + * The pointer is allowed to be reused only if ReUse if true. + */ + template + void remove_pointer(const virtual_pointer_t ptr) { + if (is_nullptr(ptr)) { + return; + } + auto node = this->get_node(ptr); + + node->second.m_free = true; + m_freeList.emplace(node); + + // Fuse the node + // with free nodes before and after it + fuse_forward(node); + fuse_backward(node); + + // If after fusing the node is the last one + // simply remove it (since it is free) + if (node == std::prev(m_pointerMap.end())) { + m_freeList.erase(node); + m_pointerMap.erase(node); + } + } + + /* count. + * Return the number of active pointers (i.e, pointers that + * have been malloc but not freed). + */ + size_t count() const { return (m_pointerMap.size() - m_freeList.size()); } + + private: + /* add_pointer_impl. + * Adds a pointer to the map and returns the virtual pointer id. + * BufferT is either a const buffer_t& or a buffer_t&&. + */ + template + virtual_pointer_t add_pointer_impl(BufferT b) { + virtual_pointer_t retVal = nullptr; + size_t bufSize = b.get_count(); + pMapNode_t p{b, bufSize, false}; + // If this is the first pointer: + if (m_pointerMap.empty()) { + virtual_pointer_t initialVal{m_baseAddress}; + m_pointerMap.emplace(initialVal, p); + return initialVal; + } + + auto lastElemIter = get_insertion_point(bufSize); + // We are recovering an existing free node + if (lastElemIter->second.m_free) { + lastElemIter->second.m_buffer = b; + lastElemIter->second.m_free = false; + + // If the recovered node is bigger than the inserted one + // add a new free node with the remaining space + if (lastElemIter->second.m_size > bufSize) { + // create a new node with the remaining space + auto remainingSize = lastElemIter->second.m_size - bufSize; + pMapNode_t p2{b, remainingSize, true}; + + // update size of the current node + lastElemIter->second.m_size = bufSize; + + // add the new free node + auto newFreePtr = lastElemIter->first + bufSize; + auto freeNode = m_pointerMap.emplace(newFreePtr, p2).first; + m_freeList.emplace(freeNode); + } + + retVal = lastElemIter->first; + } else { + size_t lastSize = lastElemIter->second.m_size; + retVal = lastElemIter->first + lastSize; + m_pointerMap.emplace(retVal, p); + } + return retVal; + } + + /** + * Compare two iterators to pointer map entries according to + * the size of the allocation on the device. + */ + struct SortBySize { + bool operator()(typename pointerMap_t::iterator a, + typename pointerMap_t::iterator b) const { + return ((a->first < b->first) && (a->second <= b->second)) || + ((a->first < b->first) && (b->second <= a->second)); + } + }; + + /* Maps the pointer addresses to buffer and size pairs. + */ + pointerMap_t m_pointerMap; + + /* List of free nodes available for re-using + */ + std::set m_freeList; + + /* Base address used when issuing the first virtual pointer, allows users + * to specify alignment. Cannot be zero. */ + std::intptr_t m_baseAddress; +}; + +/* remove_pointer. + * Removes the given pointer from the map. + * The pointer is allowed to be reused only if ReUse if true. + */ +template <> +inline void PointerMapper::remove_pointer(const virtual_pointer_t ptr) { + if (is_nullptr(ptr)) { + return; + } + m_pointerMap.erase(this->get_node(ptr)); +} + +/** + * Malloc-like interface to the pointer-mapper. + * Given a size, creates a byte-typed buffer and returns a + * fake pointer to keep track of it. + * \param size Size in bytes of the desired allocation + * \throw cl::sycl::exception if error while creating the buffer + */ +inline void *SYCLmalloc(size_t size, PointerMapper &pMap) { + if (size == 0) { + return nullptr; + } + // Create a generic buffer of the given size + using buffer_t = cl::sycl::buffer; + auto thePointer = pMap.add_pointer(buffer_t(cl::sycl::range<1>{size})); + // Store the buffer on the global list + return static_cast(thePointer); +} + +/** + * Free-like interface to the pointer mapper. + * Given a fake-pointer created with the virtual-pointer malloc, + * destroys the buffer and remove it from the list. + * If ReUse is false, the pointer is not added to the freeList, + * it should be false only for sub-buffers. + */ +template +inline void SYCLfree(void *ptr, PointerMapper &pMap) { + pMap.template remove_pointer(ptr); +} + +/** + * Clear all the memory allocated by SYCL. + */ +template +inline void SYCLfreeAll(PointerMapper &pMap) { + pMap.clear(); +} + +template +struct RangeAccess { + static const auto global_access = cl::sycl::access::target::global_buffer; + static const auto is_place_holder = cl::sycl::access::placeholder::true_t; + typedef T scalar_t; + typedef scalar_t &ref_t; + typedef typename cl::sycl::global_ptr::pointer_t ptr_t; + + // the accessor type does not necessarily the same as T + typedef cl::sycl::accessor + accessor; + + typedef RangeAccess self_t; + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE RangeAccess(accessor access, + size_t offset, + std::intptr_t virtual_ptr) + : access_(access), offset_(offset), virtual_ptr_(virtual_ptr) {} + + RangeAccess(cl::sycl::buffer buff = + cl::sycl::buffer(cl::sycl::range<1>(1))) + : access_{accessor{buff}}, offset_(0), virtual_ptr_(-1) {} + + // This should be only used for null constructor on the host side + RangeAccess(std::nullptr_t) : RangeAccess() {} + // This template parameter must be removed and scalar_t should be replaced + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE ptr_t get_pointer() const { + return (access_.get_pointer().get() + offset_); + } + template + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE self_t &operator+=(Index offset) { + offset_ += (offset); + return *this; + } + template + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE self_t operator+(Index offset) const { + return self_t(access_, offset_ + offset, virtual_ptr_); + } + template + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE self_t operator-(Index offset) const { + return self_t(access_, offset_ - offset, virtual_ptr_); + } + template + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE self_t &operator-=(Index offset) { + offset_ -= offset; + return *this; + } + + // THIS IS FOR NULL COMPARISON ONLY + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE friend bool operator==( + const RangeAccess &lhs, std::nullptr_t) { + return ((lhs.virtual_ptr_ == -1)); + } + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE friend bool operator!=( + const RangeAccess &lhs, std::nullptr_t i) { + return !(lhs == i); + } + + // THIS IS FOR NULL COMPARISON ONLY + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE friend bool operator==( + std::nullptr_t, const RangeAccess &rhs) { + return ((rhs.virtual_ptr_ == -1)); + } + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE friend bool operator!=( + std::nullptr_t i, const RangeAccess &rhs) { + return !(i == rhs); + } + // Prefix operator (Increment and return value) + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE self_t &operator++() { + offset_++; + return (*this); + } + + // Postfix operator (Return value and increment) + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE self_t operator++(int i) { + EIGEN_UNUSED_VARIABLE(i); + self_t temp_iterator(*this); + offset_++; + return temp_iterator; + } + + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE std::ptrdiff_t get_size() const { + return (access_.get_count() - offset_); + } + + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE std::ptrdiff_t get_offset() const { + return offset_; + } + + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void set_offset(std::ptrdiff_t offset) { + offset_ = offset; + } + + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE ref_t operator*() const { + return *get_pointer(); + } + + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE ref_t operator*() { + return *get_pointer(); + } + + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE ptr_t operator->() = delete; + + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE ref_t operator[](int x) { + return *(get_pointer() + x); + } + + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE ref_t operator[](int x) const { + return *(get_pointer() + x); + } + + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE scalar_t *get_virtual_pointer() const { + return reinterpret_cast(virtual_ptr_ + + (offset_ * sizeof(scalar_t))); + } + + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE explicit operator bool() const { + return (virtual_ptr_ != -1); + } + + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE operator RangeAccess() { + return RangeAccess(access_, offset_, virtual_ptr_); + } + + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE + operator RangeAccess() const { + return RangeAccess(access_, offset_, virtual_ptr_); + } + // binding placeholder accessors to a command group handler for SYCL + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void bind( + cl::sycl::handler &cgh) const { + cgh.require(access_); + } + + private: + accessor access_; + size_t offset_; + std::intptr_t virtual_ptr_; // the location of the buffer in the map +}; + +template +struct RangeAccess : RangeAccess { + typedef RangeAccess Base; + using Base::Base; +}; + +} // namespace internal +} // namespace TensorSycl +} // namespace Eigen + +#endif // EIGEN_CXX11_TENSOR_TENSOR_SYCL_STORAGE_MEMORY_H diff --git a/Eigen/src/Core/arch/SYCL/TypeCasting.h b/Eigen/src/Core/arch/SYCL/TypeCasting.h new file mode 100644 index 0000000..9208ab2 --- /dev/null +++ b/Eigen/src/Core/arch/SYCL/TypeCasting.h @@ -0,0 +1,85 @@ +// This file is part of Eigen, a lightweight C++ template library +// for linear algebra. +// +// Mehdi Goli Codeplay Software Ltd. +// Ralph Potter Codeplay Software Ltd. +// Luke Iwanski Codeplay Software Ltd. +// Contact: +// +// This Source Code Form is subject to the terms of the Mozilla +// Public License v. 2.0. If a copy of the MPL was not distributed +// with this file, You can obtain one at http://mozilla.org/MPL/2.0/. + +/***************************************************************** + * TypeCasting.h + * + * \brief: + * TypeCasting + * + *****************************************************************/ + +#ifndef EIGEN_TYPE_CASTING_SYCL_H +#define EIGEN_TYPE_CASTING_SYCL_H + +namespace Eigen { + +namespace internal { +#ifdef SYCL_DEVICE_ONLY +template <> +struct type_casting_traits { + enum { VectorizedCast = 1, SrcCoeffRatio = 1, TgtCoeffRatio = 1 }; +}; + +template <> +EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE cl::sycl::cl_int4 +pcast(const cl::sycl::cl_float4& a) { + return a + .template convert(); +} + +template <> +struct type_casting_traits { + enum { VectorizedCast = 1, SrcCoeffRatio = 1, TgtCoeffRatio = 1 }; +}; + +template <> +EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE cl::sycl::cl_float4 +pcast(const cl::sycl::cl_int4& a) { + return a.template convert(); +} + +template <> +struct type_casting_traits { + enum { VectorizedCast = 1, SrcCoeffRatio = 2, TgtCoeffRatio = 1 }; +}; + +template <> +EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE cl::sycl::cl_float4 +pcast( + const cl::sycl::cl_double2& a, const cl::sycl::cl_double2& b) { + auto a1 = a.template convert(); + auto b1 = b.template convert(); + return cl::sycl::float4(a1.x(), a1.y(), b1.x(), b1.y()); +} + +template <> +struct type_casting_traits { + enum { VectorizedCast = 1, SrcCoeffRatio = 1, TgtCoeffRatio = 2 }; +}; + +template <> +EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE cl::sycl::cl_double2 +pcast(const cl::sycl::cl_float4& a) { + // Simply discard the second half of the input + return cl::sycl::cl_double2(a.x(), a.y()); +} + +#endif +} // end namespace internal + +} // end namespace Eigen + +#endif // EIGEN_TYPE_CASTING_SYCL_H diff --git a/Eigen/src/Core/arch/ZVector/Complex.h b/Eigen/src/Core/arch/ZVector/Complex.h new file mode 100644 index 0000000..0b9b33d --- /dev/null +++ b/Eigen/src/Core/arch/ZVector/Complex.h @@ -0,0 +1,426 @@ +// This file is part of Eigen, a lightweight C++ template library +// for linear algebra. +// +// Copyright (C) 2010 Gael Guennebaud +// Copyright (C) 2016 Konstantinos Margaritis +// +// This Source Code Form is subject to the terms of the Mozilla +// Public License v. 2.0. If a copy of the MPL was not distributed +// with this file, You can obtain one at http://mozilla.org/MPL/2.0/. + +#ifndef EIGEN_COMPLEX32_ALTIVEC_H +#define EIGEN_COMPLEX32_ALTIVEC_H + +namespace Eigen { + +namespace internal { + +#if !defined(__ARCH__) || (defined(__ARCH__) && __ARCH__ >= 12) +static Packet4ui p4ui_CONJ_XOR = { 0x00000000, 0x80000000, 0x00000000, 0x80000000 }; //vec_mergeh((Packet4ui)p4i_ZERO, (Packet4ui)p4f_MZERO); +#endif + +static Packet2ul p2ul_CONJ_XOR1 = (Packet2ul) vec_sld((Packet4ui) p2d_ZERO_, (Packet4ui) p2l_ZERO, 8);//{ 0x8000000000000000, 0x0000000000000000 }; +static Packet2ul p2ul_CONJ_XOR2 = (Packet2ul) vec_sld((Packet4ui) p2l_ZERO, (Packet4ui) p2d_ZERO_, 8);//{ 0x8000000000000000, 0x0000000000000000 }; + +struct Packet1cd +{ + EIGEN_STRONG_INLINE Packet1cd() {} + EIGEN_STRONG_INLINE explicit Packet1cd(const Packet2d& a) : v(a) {} + Packet2d v; +}; + +struct Packet2cf +{ + EIGEN_STRONG_INLINE Packet2cf() {} + EIGEN_STRONG_INLINE explicit Packet2cf(const Packet4f& a) : v(a) {} +#if !defined(__ARCH__) || (defined(__ARCH__) && __ARCH__ < 12) + union { + Packet4f v; + Packet1cd cd[2]; + }; +#else + Packet4f v; +#endif +}; + +template<> struct packet_traits > : default_packet_traits +{ + typedef Packet2cf type; + typedef Packet2cf half; + enum { + Vectorizable = 1, + AlignedOnScalar = 1, + size = 2, + HasHalfPacket = 0, + + HasAdd = 1, + HasSub = 1, + HasMul = 1, + HasDiv = 1, + HasNegate = 1, + HasAbs = 0, + HasAbs2 = 0, + HasMin = 0, + HasMax = 0, + HasBlend = 1, + HasSetLinear = 0 + }; +}; + + +template<> struct packet_traits > : default_packet_traits +{ + typedef Packet1cd type; + typedef Packet1cd half; + enum { + Vectorizable = 1, + AlignedOnScalar = 1, + size = 1, + HasHalfPacket = 0, + + HasAdd = 1, + HasSub = 1, + HasMul = 1, + HasDiv = 1, + HasNegate = 1, + HasAbs = 0, + HasAbs2 = 0, + HasMin = 0, + HasMax = 0, + HasSetLinear = 0 + }; +}; + +template<> struct unpacket_traits { typedef std::complex type; enum {size=2, alignment=Aligned16, vectorizable=true, masked_load_available=false, masked_store_available=false}; typedef Packet2cf half; }; +template<> struct unpacket_traits { typedef std::complex type; enum {size=1, alignment=Aligned16, vectorizable=true, masked_load_available=false, masked_store_available=false}; typedef Packet1cd half; }; + +/* Forward declaration */ +EIGEN_STRONG_INLINE void ptranspose(PacketBlock& kernel); + +/* complex first */ +template<> EIGEN_STRONG_INLINE Packet1cd pload (const std::complex* from) { EIGEN_DEBUG_ALIGNED_LOAD return Packet1cd(pload((const double*)from)); } +template<> EIGEN_STRONG_INLINE Packet1cd ploadu(const std::complex* from) { EIGEN_DEBUG_UNALIGNED_LOAD return Packet1cd(ploadu((const double*)from)); } +template<> EIGEN_STRONG_INLINE void pstore >(std::complex * to, const Packet1cd& from) { EIGEN_DEBUG_ALIGNED_STORE pstore((double*)to, from.v); } +template<> EIGEN_STRONG_INLINE void pstoreu >(std::complex * to, const Packet1cd& from) { EIGEN_DEBUG_UNALIGNED_STORE pstoreu((double*)to, from.v); } + +template<> EIGEN_STRONG_INLINE Packet1cd pset1(const std::complex& from) +{ /* here we really have to use unaligned loads :( */ return ploadu(&from); } + +template<> EIGEN_DEVICE_FUNC inline Packet1cd pgather, Packet1cd>(const std::complex* from, Index stride EIGEN_UNUSED) +{ + return pload(from); +} +template<> EIGEN_DEVICE_FUNC inline void pscatter, Packet1cd>(std::complex* to, const Packet1cd& from, Index stride EIGEN_UNUSED) +{ + pstore >(to, from); +} +template<> EIGEN_STRONG_INLINE Packet1cd padd(const Packet1cd& a, const Packet1cd& b) { return Packet1cd(a.v + b.v); } +template<> EIGEN_STRONG_INLINE Packet1cd psub(const Packet1cd& a, const Packet1cd& b) { return Packet1cd(a.v - b.v); } +template<> EIGEN_STRONG_INLINE Packet1cd pnegate(const Packet1cd& a) { return Packet1cd(pnegate(Packet2d(a.v))); } +template<> EIGEN_STRONG_INLINE Packet1cd pconj(const Packet1cd& a) { return Packet1cd((Packet2d)vec_xor((Packet2d)a.v, (Packet2d)p2ul_CONJ_XOR2)); } +template<> EIGEN_STRONG_INLINE Packet1cd pmul(const Packet1cd& a, const Packet1cd& b) +{ + Packet2d a_re, a_im, v1, v2; + + // Permute and multiply the real parts of a and b + a_re = vec_perm(a.v, a.v, p16uc_PSET64_HI); + // Get the imaginary parts of a + a_im = vec_perm(a.v, a.v, p16uc_PSET64_LO); + // multiply a_re * b + v1 = vec_madd(a_re, b.v, p2d_ZERO); + // multiply a_im * b and get the conjugate result + v2 = vec_madd(a_im, b.v, p2d_ZERO); + v2 = (Packet2d) vec_sld((Packet4ui)v2, (Packet4ui)v2, 8); + v2 = (Packet2d) vec_xor((Packet2d)v2, (Packet2d) p2ul_CONJ_XOR1); + + return Packet1cd(v1 + v2); +} +template<> EIGEN_STRONG_INLINE Packet1cd pand (const Packet1cd& a, const Packet1cd& b) { return Packet1cd(vec_and(a.v,b.v)); } +template<> EIGEN_STRONG_INLINE Packet1cd por (const Packet1cd& a, const Packet1cd& b) { return Packet1cd(vec_or(a.v,b.v)); } +template<> EIGEN_STRONG_INLINE Packet1cd pxor (const Packet1cd& a, const Packet1cd& b) { return Packet1cd(vec_xor(a.v,b.v)); } +template<> EIGEN_STRONG_INLINE Packet1cd pandnot (const Packet1cd& a, const Packet1cd& b) { return Packet1cd(vec_and(a.v, vec_nor(b.v,b.v))); } +template<> EIGEN_STRONG_INLINE Packet1cd ploaddup(const std::complex* from) { return pset1(*from); } +template<> EIGEN_STRONG_INLINE Packet1cd pcmp_eq(const Packet1cd& a, const Packet1cd& b) { + Packet2d eq = vec_cmpeq (a.v, b.v); + Packet2d tmp = { eq[1], eq[0] }; + return (Packet1cd)pand(eq, tmp); +} + +template<> EIGEN_STRONG_INLINE void prefetch >(const std::complex * addr) { EIGEN_ZVECTOR_PREFETCH(addr); } + +template<> EIGEN_STRONG_INLINE std::complex pfirst(const Packet1cd& a) +{ + std::complex EIGEN_ALIGN16 res; + pstore >(&res, a); + + return res; +} + +template<> EIGEN_STRONG_INLINE Packet1cd preverse(const Packet1cd& a) { return a; } +template<> EIGEN_STRONG_INLINE std::complex predux(const Packet1cd& a) +{ + return pfirst(a); +} +template<> EIGEN_STRONG_INLINE std::complex predux_mul(const Packet1cd& a) +{ + return pfirst(a); +} +EIGEN_MAKE_CONJ_HELPER_CPLX_REAL(Packet1cd,Packet2d) + +template<> EIGEN_STRONG_INLINE Packet1cd pdiv(const Packet1cd& a, const Packet1cd& b) +{ + // TODO optimize it for AltiVec + Packet1cd res = pmul(a,pconj(b)); + Packet2d s = vec_madd(b.v, b.v, p2d_ZERO_); + return Packet1cd(pdiv(res.v, s + vec_perm(s, s, p16uc_REVERSE64))); +} + +EIGEN_STRONG_INLINE Packet1cd pcplxflip/**/(const Packet1cd& x) +{ + return Packet1cd(preverse(Packet2d(x.v))); +} + +EIGEN_STRONG_INLINE void ptranspose(PacketBlock& kernel) +{ + Packet2d tmp = vec_perm(kernel.packet[0].v, kernel.packet[1].v, p16uc_TRANSPOSE64_HI); + kernel.packet[1].v = vec_perm(kernel.packet[0].v, kernel.packet[1].v, p16uc_TRANSPOSE64_LO); + kernel.packet[0].v = tmp; +} + +/* complex follows */ +template<> EIGEN_STRONG_INLINE Packet2cf pload (const std::complex* from) { EIGEN_DEBUG_ALIGNED_LOAD return Packet2cf(pload((const float*)from)); } +template<> EIGEN_STRONG_INLINE Packet2cf ploadu(const std::complex* from) { EIGEN_DEBUG_UNALIGNED_LOAD return Packet2cf(ploadu((const float*)from)); } +template<> EIGEN_STRONG_INLINE void pstore >(std::complex * to, const Packet2cf& from) { EIGEN_DEBUG_ALIGNED_STORE pstore((float*)to, from.v); } +template<> EIGEN_STRONG_INLINE void pstoreu >(std::complex * to, const Packet2cf& from) { EIGEN_DEBUG_UNALIGNED_STORE pstoreu((float*)to, from.v); } + +template<> EIGEN_STRONG_INLINE std::complex pfirst(const Packet2cf& a) +{ + std::complex EIGEN_ALIGN16 res[2]; + pstore >(res, a); + + return res[0]; +} + + +#if !defined(__ARCH__) || (defined(__ARCH__) && __ARCH__ < 12) +template<> EIGEN_STRONG_INLINE Packet2cf pset1(const std::complex& from) +{ + Packet2cf res; + res.cd[0] = Packet1cd(vec_ld2f((const float *)&from)); + res.cd[1] = res.cd[0]; + return res; +} +#else +template<> EIGEN_STRONG_INLINE Packet2cf pset1(const std::complex& from) +{ + Packet2cf res; + if((std::ptrdiff_t(&from) % 16) == 0) + res.v = pload((const float *)&from); + else + res.v = ploadu((const float *)&from); + res.v = vec_perm(res.v, res.v, p16uc_PSET64_HI); + return res; +} +#endif + +template<> EIGEN_DEVICE_FUNC inline Packet2cf pgather, Packet2cf>(const std::complex* from, Index stride) +{ + std::complex EIGEN_ALIGN16 af[2]; + af[0] = from[0*stride]; + af[1] = from[1*stride]; + return pload(af); +} +template<> EIGEN_DEVICE_FUNC inline void pscatter, Packet2cf>(std::complex* to, const Packet2cf& from, Index stride) +{ + std::complex EIGEN_ALIGN16 af[2]; + pstore >((std::complex *) af, from); + to[0*stride] = af[0]; + to[1*stride] = af[1]; +} + +template<> EIGEN_STRONG_INLINE Packet2cf padd(const Packet2cf& a, const Packet2cf& b) { return Packet2cf(padd(a.v, b.v)); } +template<> EIGEN_STRONG_INLINE Packet2cf psub(const Packet2cf& a, const Packet2cf& b) { return Packet2cf(psub(a.v, b.v)); } +template<> EIGEN_STRONG_INLINE Packet2cf pnegate(const Packet2cf& a) { return Packet2cf(pnegate(Packet4f(a.v))); } + +template<> EIGEN_STRONG_INLINE Packet2cf pand (const Packet2cf& a, const Packet2cf& b) { return Packet2cf(pand(a.v,b.v)); } +template<> EIGEN_STRONG_INLINE Packet2cf por (const Packet2cf& a, const Packet2cf& b) { return Packet2cf(por(a.v,b.v)); } +template<> EIGEN_STRONG_INLINE Packet2cf pxor (const Packet2cf& a, const Packet2cf& b) { return Packet2cf(pxor(a.v,b.v)); } +template<> EIGEN_STRONG_INLINE Packet2cf pandnot(const Packet2cf& a, const Packet2cf& b) { return Packet2cf(pandnot(a.v,b.v)); } + +template<> EIGEN_STRONG_INLINE Packet2cf ploaddup(const std::complex* from) { return pset1(*from); } + +template<> EIGEN_STRONG_INLINE void prefetch >(const std::complex * addr) { EIGEN_ZVECTOR_PREFETCH(addr); } + + +#if !defined(__ARCH__) || (defined(__ARCH__) && __ARCH__ < 12) + +template<> EIGEN_STRONG_INLINE Packet2cf pcmp_eq(const Packet2cf& a, const Packet2cf& b) { + Packet4f eq = pcmp_eq (a.v, b.v); + Packet2cf res; + Packet2d tmp1 = { eq.v4f[0][1], eq.v4f[0][0] }; + Packet2d tmp2 = { eq.v4f[1][1], eq.v4f[1][0] }; + res.v.v4f[0] = pand(eq.v4f[0], tmp1); + res.v.v4f[1] = pand(eq.v4f[1], tmp2); + return res; +} + +template<> EIGEN_STRONG_INLINE Packet2cf pconj(const Packet2cf& a) +{ + Packet2cf res; + res.v.v4f[0] = pconj(Packet1cd(reinterpret_cast(a.v.v4f[0]))).v; + res.v.v4f[1] = pconj(Packet1cd(reinterpret_cast(a.v.v4f[1]))).v; + return res; +} + +template<> EIGEN_STRONG_INLINE Packet2cf pmul(const Packet2cf& a, const Packet2cf& b) +{ + Packet2cf res; + res.v.v4f[0] = pmul(Packet1cd(reinterpret_cast(a.v.v4f[0])), Packet1cd(reinterpret_cast(b.v.v4f[0]))).v; + res.v.v4f[1] = pmul(Packet1cd(reinterpret_cast(a.v.v4f[1])), Packet1cd(reinterpret_cast(b.v.v4f[1]))).v; + return res; +} + +template<> EIGEN_STRONG_INLINE Packet2cf preverse(const Packet2cf& a) +{ + Packet2cf res; + res.cd[0] = a.cd[1]; + res.cd[1] = a.cd[0]; + return res; +} + +template<> EIGEN_STRONG_INLINE std::complex predux(const Packet2cf& a) +{ + std::complex res; + Packet1cd b = padd(a.cd[0], a.cd[1]); + vec_st2f(b.v, (float*)&res); + return res; +} + +template<> EIGEN_STRONG_INLINE std::complex predux_mul(const Packet2cf& a) +{ + std::complex res; + Packet1cd b = pmul(a.cd[0], a.cd[1]); + vec_st2f(b.v, (float*)&res); + return res; +} + +EIGEN_MAKE_CONJ_HELPER_CPLX_REAL(Packet2cf,Packet4f) + +template<> EIGEN_STRONG_INLINE Packet2cf pdiv(const Packet2cf& a, const Packet2cf& b) +{ + // TODO optimize it for AltiVec + Packet2cf res; + res.cd[0] = pdiv(a.cd[0], b.cd[0]); + res.cd[1] = pdiv(a.cd[1], b.cd[1]); + return res; +} + +EIGEN_STRONG_INLINE Packet2cf pcplxflip/**/(const Packet2cf& x) +{ + Packet2cf res; + res.cd[0] = pcplxflip(x.cd[0]); + res.cd[1] = pcplxflip(x.cd[1]); + return res; +} + +EIGEN_STRONG_INLINE void ptranspose(PacketBlock& kernel) +{ + Packet1cd tmp = kernel.packet[0].cd[1]; + kernel.packet[0].cd[1] = kernel.packet[1].cd[0]; + kernel.packet[1].cd[0] = tmp; +} + +template<> EIGEN_STRONG_INLINE Packet2cf pblend(const Selector<2>& ifPacket, const Packet2cf& thenPacket, const Packet2cf& elsePacket) { + Packet2cf result; + const Selector<4> ifPacket4 = { ifPacket.select[0], ifPacket.select[0], ifPacket.select[1], ifPacket.select[1] }; + result.v = pblend(ifPacket4, thenPacket.v, elsePacket.v); + return result; +} +#else +template<> EIGEN_STRONG_INLINE Packet2cf pcmp_eq(const Packet2cf& a, const Packet2cf& b) { + Packet4f eq = vec_cmpeq (a.v, b.v); + Packet4f tmp = { eq[1], eq[0], eq[3], eq[2] }; + return (Packet2cf)pand(eq, tmp); +} +template<> EIGEN_STRONG_INLINE Packet2cf pconj(const Packet2cf& a) { return Packet2cf(pxor(a.v, reinterpret_cast(p4ui_CONJ_XOR))); } +template<> EIGEN_STRONG_INLINE Packet2cf pmul(const Packet2cf& a, const Packet2cf& b) +{ + Packet4f a_re, a_im, prod, prod_im; + + // Permute and multiply the real parts of a and b + a_re = vec_perm(a.v, a.v, p16uc_PSET32_WODD); + + // Get the imaginary parts of a + a_im = vec_perm(a.v, a.v, p16uc_PSET32_WEVEN); + + // multiply a_im * b and get the conjugate result + prod_im = a_im * b.v; + prod_im = pxor(prod_im, reinterpret_cast(p4ui_CONJ_XOR)); + // permute back to a proper order + prod_im = vec_perm(prod_im, prod_im, p16uc_COMPLEX32_REV); + + // multiply a_re * b, add prod_im + prod = pmadd(a_re, b.v, prod_im); + + return Packet2cf(prod); +} + +template<> EIGEN_STRONG_INLINE Packet2cf preverse(const Packet2cf& a) +{ + Packet4f rev_a; + rev_a = vec_perm(a.v, a.v, p16uc_COMPLEX32_REV2); + return Packet2cf(rev_a); +} + +template<> EIGEN_STRONG_INLINE std::complex predux(const Packet2cf& a) +{ + Packet4f b; + b = vec_sld(a.v, a.v, 8); + b = padd(a.v, b); + return pfirst(Packet2cf(b)); +} + +template<> EIGEN_STRONG_INLINE std::complex predux_mul(const Packet2cf& a) +{ + Packet4f b; + Packet2cf prod; + b = vec_sld(a.v, a.v, 8); + prod = pmul(a, Packet2cf(b)); + + return pfirst(prod); +} + +EIGEN_MAKE_CONJ_HELPER_CPLX_REAL(Packet2cf,Packet4f) + +template<> EIGEN_STRONG_INLINE Packet2cf pdiv(const Packet2cf& a, const Packet2cf& b) +{ + // TODO optimize it for AltiVec + Packet2cf res = pmul(a, pconj(b)); + Packet4f s = pmul(b.v, b.v); + return Packet2cf(pdiv(res.v, padd(s, vec_perm(s, s, p16uc_COMPLEX32_REV)))); +} + +template<> EIGEN_STRONG_INLINE Packet2cf pcplxflip(const Packet2cf& x) +{ + return Packet2cf(vec_perm(x.v, x.v, p16uc_COMPLEX32_REV)); +} + +EIGEN_STRONG_INLINE void ptranspose(PacketBlock& kernel) +{ + Packet4f tmp = vec_perm(kernel.packet[0].v, kernel.packet[1].v, p16uc_TRANSPOSE64_HI); + kernel.packet[1].v = vec_perm(kernel.packet[0].v, kernel.packet[1].v, p16uc_TRANSPOSE64_LO); + kernel.packet[0].v = tmp; +} + +template<> EIGEN_STRONG_INLINE Packet2cf pblend(const Selector<2>& ifPacket, const Packet2cf& thenPacket, const Packet2cf& elsePacket) { + Packet2cf result; + result.v = reinterpret_cast(pblend(ifPacket, reinterpret_cast(thenPacket.v), reinterpret_cast(elsePacket.v))); + return result; +} +#endif + +} // end namespace internal + +} // end namespace Eigen + +#endif // EIGEN_COMPLEX32_ALTIVEC_H diff --git a/Eigen/src/Core/arch/ZVector/MathFunctions.h b/Eigen/src/Core/arch/ZVector/MathFunctions.h new file mode 100644 index 0000000..1635e12 --- /dev/null +++ b/Eigen/src/Core/arch/ZVector/MathFunctions.h @@ -0,0 +1,233 @@ +// This file is part of Eigen, a lightweight C++ template library +// for linear algebra. +// +// Copyright (C) 2007 Julien Pommier +// Copyright (C) 2009 Gael Guennebaud +// Copyright (C) 2016 Konstantinos Margaritis +// +// This Source Code Form is subject to the terms of the Mozilla +// Public License v. 2.0. If a copy of the MPL was not distributed +// with this file, You can obtain one at http://mozilla.org/MPL/2.0/. + +/* The sin, cos, exp, and log functions of this file come from + * Julien Pommier's sse math library: http://gruntthepeon.free.fr/ssemath/ + */ + +#ifndef EIGEN_MATH_FUNCTIONS_ALTIVEC_H +#define EIGEN_MATH_FUNCTIONS_ALTIVEC_H + +namespace Eigen { + +namespace internal { + +#if !defined(__ARCH__) || (defined(__ARCH__) && __ARCH__ >= 12) +static _EIGEN_DECLARE_CONST_Packet4f(1 , 1.0f); +static _EIGEN_DECLARE_CONST_Packet4f(half, 0.5f); +static _EIGEN_DECLARE_CONST_Packet4i(0x7f, 0x7f); +static _EIGEN_DECLARE_CONST_Packet4i(23, 23); + +static _EIGEN_DECLARE_CONST_Packet4f_FROM_INT(inv_mant_mask, ~0x7f800000); + +/* the smallest non denormalized float number */ +static _EIGEN_DECLARE_CONST_Packet4f_FROM_INT(min_norm_pos, 0x00800000); +static _EIGEN_DECLARE_CONST_Packet4f_FROM_INT(minus_inf, 0xff800000); // -1.f/0.f +static _EIGEN_DECLARE_CONST_Packet4f_FROM_INT(minus_nan, 0xffffffff); + +/* natural logarithm computed for 4 simultaneous float + return NaN for x <= 0 +*/ +static _EIGEN_DECLARE_CONST_Packet4f(cephes_SQRTHF, 0.707106781186547524f); +static _EIGEN_DECLARE_CONST_Packet4f(cephes_log_p0, 7.0376836292E-2f); +static _EIGEN_DECLARE_CONST_Packet4f(cephes_log_p1, - 1.1514610310E-1f); +static _EIGEN_DECLARE_CONST_Packet4f(cephes_log_p2, 1.1676998740E-1f); +static _EIGEN_DECLARE_CONST_Packet4f(cephes_log_p3, - 1.2420140846E-1f); +static _EIGEN_DECLARE_CONST_Packet4f(cephes_log_p4, + 1.4249322787E-1f); +static _EIGEN_DECLARE_CONST_Packet4f(cephes_log_p5, - 1.6668057665E-1f); +static _EIGEN_DECLARE_CONST_Packet4f(cephes_log_p6, + 2.0000714765E-1f); +static _EIGEN_DECLARE_CONST_Packet4f(cephes_log_p7, - 2.4999993993E-1f); +static _EIGEN_DECLARE_CONST_Packet4f(cephes_log_p8, + 3.3333331174E-1f); +static _EIGEN_DECLARE_CONST_Packet4f(cephes_log_q1, -2.12194440e-4f); +static _EIGEN_DECLARE_CONST_Packet4f(cephes_log_q2, 0.693359375f); + +static _EIGEN_DECLARE_CONST_Packet4f(exp_hi, 88.3762626647950f); +static _EIGEN_DECLARE_CONST_Packet4f(exp_lo, -88.3762626647949f); + +static _EIGEN_DECLARE_CONST_Packet4f(cephes_LOG2EF, 1.44269504088896341f); +static _EIGEN_DECLARE_CONST_Packet4f(cephes_exp_C1, 0.693359375f); +static _EIGEN_DECLARE_CONST_Packet4f(cephes_exp_C2, -2.12194440e-4f); + +static _EIGEN_DECLARE_CONST_Packet4f(cephes_exp_p0, 1.9875691500E-4f); +static _EIGEN_DECLARE_CONST_Packet4f(cephes_exp_p1, 1.3981999507E-3f); +static _EIGEN_DECLARE_CONST_Packet4f(cephes_exp_p2, 8.3334519073E-3f); +static _EIGEN_DECLARE_CONST_Packet4f(cephes_exp_p3, 4.1665795894E-2f); +static _EIGEN_DECLARE_CONST_Packet4f(cephes_exp_p4, 1.6666665459E-1f); +static _EIGEN_DECLARE_CONST_Packet4f(cephes_exp_p5, 5.0000001201E-1f); +#endif + +static _EIGEN_DECLARE_CONST_Packet2d(1 , 1.0); +static _EIGEN_DECLARE_CONST_Packet2d(2 , 2.0); +static _EIGEN_DECLARE_CONST_Packet2d(half, 0.5); + +static _EIGEN_DECLARE_CONST_Packet2d(exp_hi, 709.437); +static _EIGEN_DECLARE_CONST_Packet2d(exp_lo, -709.436139303); + +static _EIGEN_DECLARE_CONST_Packet2d(cephes_LOG2EF, 1.4426950408889634073599); + +static _EIGEN_DECLARE_CONST_Packet2d(cephes_exp_p0, 1.26177193074810590878e-4); +static _EIGEN_DECLARE_CONST_Packet2d(cephes_exp_p1, 3.02994407707441961300e-2); +static _EIGEN_DECLARE_CONST_Packet2d(cephes_exp_p2, 9.99999999999999999910e-1); + +static _EIGEN_DECLARE_CONST_Packet2d(cephes_exp_q0, 3.00198505138664455042e-6); +static _EIGEN_DECLARE_CONST_Packet2d(cephes_exp_q1, 2.52448340349684104192e-3); +static _EIGEN_DECLARE_CONST_Packet2d(cephes_exp_q2, 2.27265548208155028766e-1); +static _EIGEN_DECLARE_CONST_Packet2d(cephes_exp_q3, 2.00000000000000000009e0); + +static _EIGEN_DECLARE_CONST_Packet2d(cephes_exp_C1, 0.693145751953125); +static _EIGEN_DECLARE_CONST_Packet2d(cephes_exp_C2, 1.42860682030941723212e-6); + +template<> EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS EIGEN_UNUSED +Packet2d pexp(const Packet2d& _x) +{ + Packet2d x = _x; + + Packet2d tmp, fx; + Packet2l emm0; + + // clamp x + x = pmax(pmin(x, p2d_exp_hi), p2d_exp_lo); + /* express exp(x) as exp(g + n*log(2)) */ + fx = pmadd(p2d_cephes_LOG2EF, x, p2d_half); + + fx = vec_floor(fx); + + tmp = pmul(fx, p2d_cephes_exp_C1); + Packet2d z = pmul(fx, p2d_cephes_exp_C2); + x = psub(x, tmp); + x = psub(x, z); + + Packet2d x2 = pmul(x,x); + + Packet2d px = p2d_cephes_exp_p0; + px = pmadd(px, x2, p2d_cephes_exp_p1); + px = pmadd(px, x2, p2d_cephes_exp_p2); + px = pmul (px, x); + + Packet2d qx = p2d_cephes_exp_q0; + qx = pmadd(qx, x2, p2d_cephes_exp_q1); + qx = pmadd(qx, x2, p2d_cephes_exp_q2); + qx = pmadd(qx, x2, p2d_cephes_exp_q3); + + x = pdiv(px,psub(qx,px)); + x = pmadd(p2d_2,x,p2d_1); + + // build 2^n + emm0 = vec_ctsl(fx, 0); + + static const Packet2l p2l_1023 = { 1023, 1023 }; + static const Packet2ul p2ul_52 = { 52, 52 }; + + emm0 = emm0 + p2l_1023; + emm0 = emm0 << reinterpret_cast(p2ul_52); + + // Altivec's max & min operators just drop silent NaNs. Check NaNs in + // inputs and return them unmodified. + Packet2ul isnumber_mask = reinterpret_cast(vec_cmpeq(_x, _x)); + return vec_sel(_x, pmax(pmul(x, reinterpret_cast(emm0)), _x), + isnumber_mask); +} + +template<> EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS EIGEN_UNUSED +Packet4f pexp(const Packet4f& _x) +{ +#if !defined(__ARCH__) || (defined(__ARCH__) && __ARCH__ >= 12) + Packet4f x = _x; + + Packet4f tmp, fx; + Packet4i emm0; + + // clamp x + x = pmax(pmin(x, p4f_exp_hi), p4f_exp_lo); + + // express exp(x) as exp(g + n*log(2)) + fx = pmadd(x, p4f_cephes_LOG2EF, p4f_half); + + fx = pfloor(fx); + + tmp = pmul(fx, p4f_cephes_exp_C1); + Packet4f z = pmul(fx, p4f_cephes_exp_C2); + x = psub(x, tmp); + x = psub(x, z); + + z = pmul(x,x); + + Packet4f y = p4f_cephes_exp_p0; + y = pmadd(y, x, p4f_cephes_exp_p1); + y = pmadd(y, x, p4f_cephes_exp_p2); + y = pmadd(y, x, p4f_cephes_exp_p3); + y = pmadd(y, x, p4f_cephes_exp_p4); + y = pmadd(y, x, p4f_cephes_exp_p5); + y = pmadd(y, z, x); + y = padd(y, p4f_1); + + // build 2^n + emm0 = (Packet4i){ (int)fx[0], (int)fx[1], (int)fx[2], (int)fx[3] }; + emm0 = emm0 + p4i_0x7f; + emm0 = emm0 << reinterpret_cast(p4i_23); + + return pmax(pmul(y, reinterpret_cast(emm0)), _x); +#else + Packet4f res; + res.v4f[0] = pexp(_x.v4f[0]); + res.v4f[1] = pexp(_x.v4f[1]); + return res; +#endif +} + +template<> EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS EIGEN_UNUSED +Packet2d psqrt(const Packet2d& x) +{ + return vec_sqrt(x); +} + +template<> EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS EIGEN_UNUSED +Packet4f psqrt(const Packet4f& x) +{ + Packet4f res; +#if !defined(__ARCH__) || (defined(__ARCH__) && __ARCH__ >= 12) + res = vec_sqrt(x); +#else + res.v4f[0] = psqrt(x.v4f[0]); + res.v4f[1] = psqrt(x.v4f[1]); +#endif + return res; +} + +template<> EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS EIGEN_UNUSED +Packet2d prsqrt(const Packet2d& x) { + return pset1(1.0) / psqrt(x); +} + +template<> EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS EIGEN_UNUSED +Packet4f prsqrt(const Packet4f& x) { + Packet4f res; +#if !defined(__ARCH__) || (defined(__ARCH__) && __ARCH__ >= 12) + res = pset1(1.0) / psqrt(x); +#else + res.v4f[0] = prsqrt(x.v4f[0]); + res.v4f[1] = prsqrt(x.v4f[1]); +#endif + return res; +} + +// Hyperbolic Tangent function. +template <> +EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS EIGEN_UNUSED Packet4f +ptanh(const Packet4f& x) { + return internal::generic_fast_tanh_float(x); +} + +} // end namespace internal + +} // end namespace Eigen + +#endif // EIGEN_MATH_FUNCTIONS_ALTIVEC_H diff --git a/Eigen/src/Core/arch/ZVector/PacketMath.h b/Eigen/src/Core/arch/ZVector/PacketMath.h new file mode 100755 index 0000000..1f55a90 --- /dev/null +++ b/Eigen/src/Core/arch/ZVector/PacketMath.h @@ -0,0 +1,1060 @@ +// This file is part of Eigen, a lightweight C++ template library +// for linear algebra. +// +// Copyright (C) 2016 Konstantinos Margaritis +// +// This Source Code Form is subject to the terms of the Mozilla +// Public License v. 2.0. If a copy of the MPL was not distributed +// with this file, You can obtain one at http://mozilla.org/MPL/2.0/. + +#ifndef EIGEN_PACKET_MATH_ZVECTOR_H +#define EIGEN_PACKET_MATH_ZVECTOR_H + +namespace Eigen { + +namespace internal { + +#ifndef EIGEN_CACHEFRIENDLY_PRODUCT_THRESHOLD +#define EIGEN_CACHEFRIENDLY_PRODUCT_THRESHOLD 16 +#endif + +#ifndef EIGEN_HAS_SINGLE_INSTRUCTION_MADD +#define EIGEN_HAS_SINGLE_INSTRUCTION_MADD +#endif + +#ifndef EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS +#define EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS 32 +#endif + +typedef __vector int Packet4i; +typedef __vector unsigned int Packet4ui; +typedef __vector __bool int Packet4bi; +typedef __vector short int Packet8i; +typedef __vector unsigned char Packet16uc; +typedef __vector double Packet2d; +typedef __vector unsigned long long Packet2ul; +typedef __vector long long Packet2l; + +// Z14 has builtin support for float vectors +#if !defined(__ARCH__) || (defined(__ARCH__) && __ARCH__ >= 12) +typedef __vector float Packet4f; +#else +typedef struct { + Packet2d v4f[2]; +} Packet4f; +#endif + +typedef union { + numext::int32_t i[4]; + numext::uint32_t ui[4]; + numext::int64_t l[2]; + numext::uint64_t ul[2]; + double d[2]; + float f[4]; + Packet4i v4i; + Packet4ui v4ui; + Packet2l v2l; + Packet2ul v2ul; + Packet2d v2d; +#if !defined(__ARCH__) || (defined(__ARCH__) && __ARCH__ >= 12) + Packet4f v4f; +#endif +} Packet; + +// We don't want to write the same code all the time, but we need to reuse the constants +// and it doesn't really work to declare them global, so we define macros instead + +#define _EIGEN_DECLARE_CONST_FAST_Packet4i(NAME,X) \ + Packet4i p4i_##NAME = reinterpret_cast(vec_splat_s32(X)) + +#define _EIGEN_DECLARE_CONST_FAST_Packet2d(NAME,X) \ + Packet2d p2d_##NAME = reinterpret_cast(vec_splat_s64(X)) + +#define _EIGEN_DECLARE_CONST_FAST_Packet2l(NAME,X) \ + Packet2l p2l_##NAME = reinterpret_cast(vec_splat_s64(X)) + +#define _EIGEN_DECLARE_CONST_Packet4i(NAME,X) \ + Packet4i p4i_##NAME = pset1(X) + +#define _EIGEN_DECLARE_CONST_Packet2d(NAME,X) \ + Packet2d p2d_##NAME = pset1(X) + +#define _EIGEN_DECLARE_CONST_Packet2l(NAME,X) \ + Packet2l p2l_##NAME = pset1(X) + +// These constants are endian-agnostic +static _EIGEN_DECLARE_CONST_FAST_Packet4i(ZERO, 0); //{ 0, 0, 0, 0,} +static _EIGEN_DECLARE_CONST_FAST_Packet4i(ONE, 1); //{ 1, 1, 1, 1} + +static _EIGEN_DECLARE_CONST_FAST_Packet2d(ZERO, 0); +static _EIGEN_DECLARE_CONST_FAST_Packet2l(ZERO, 0); +static _EIGEN_DECLARE_CONST_FAST_Packet2l(ONE, 1); + +static Packet2d p2d_ONE = { 1.0, 1.0 }; +static Packet2d p2d_ZERO_ = { numext::bit_cast0x8000000000000000ull), + numext::bit_cast0x8000000000000000ull) }; + +#if !defined(__ARCH__) || (defined(__ARCH__) && __ARCH__ >= 12) +#define _EIGEN_DECLARE_CONST_FAST_Packet4f(NAME,X) \ + Packet4f p4f_##NAME = reinterpret_cast(vec_splat_s32(X)) + +#define _EIGEN_DECLARE_CONST_Packet4f(NAME,X) \ + Packet4f p4f_##NAME = pset1(X) + +#define _EIGEN_DECLARE_CONST_Packet4f_FROM_INT(NAME,X) \ + const Packet4f p4f_##NAME = reinterpret_cast(pset1(X)) + +static _EIGEN_DECLARE_CONST_FAST_Packet4f(ZERO, 0); //{ 0.0, 0.0, 0.0, 0.0} +static _EIGEN_DECLARE_CONST_FAST_Packet4i(MINUS1,-1); //{ -1, -1, -1, -1} +static Packet4f p4f_MZERO = { 0x80000000, 0x80000000, 0x80000000, 0x80000000}; +#endif + +static Packet4i p4i_COUNTDOWN = { 0, 1, 2, 3 }; +static Packet4f p4f_COUNTDOWN = { 0.0, 1.0, 2.0, 3.0 }; +static Packet2d p2d_COUNTDOWN = reinterpret_cast(vec_sld(reinterpret_cast(p2d_ZERO), reinterpret_cast(p2d_ONE), 8)); + +static Packet16uc p16uc_PSET64_HI = { 0,1,2,3, 4,5,6,7, 0,1,2,3, 4,5,6,7 }; +static Packet16uc p16uc_DUPLICATE32_HI = { 0,1,2,3, 0,1,2,3, 4,5,6,7, 4,5,6,7 }; + +// Mask alignment +#define _EIGEN_MASK_ALIGNMENT 0xfffffffffffffff0 + +#define _EIGEN_ALIGNED_PTR(x) ((std::ptrdiff_t)(x) & _EIGEN_MASK_ALIGNMENT) + +// Handle endianness properly while loading constants +// Define global static constants: + +static Packet16uc p16uc_FORWARD = { 0,1,2,3, 4,5,6,7, 8,9,10,11, 12,13,14,15 }; +static Packet16uc p16uc_REVERSE32 = { 12,13,14,15, 8,9,10,11, 4,5,6,7, 0,1,2,3 }; +static Packet16uc p16uc_REVERSE64 = { 8,9,10,11, 12,13,14,15, 0,1,2,3, 4,5,6,7 }; + +static Packet16uc p16uc_PSET32_WODD = vec_sld((Packet16uc) vec_splat((Packet4ui)p16uc_FORWARD, 0), (Packet16uc) vec_splat((Packet4ui)p16uc_FORWARD, 2), 8);//{ 0,1,2,3, 0,1,2,3, 8,9,10,11, 8,9,10,11 }; +static Packet16uc p16uc_PSET32_WEVEN = vec_sld(p16uc_DUPLICATE32_HI, (Packet16uc) vec_splat((Packet4ui)p16uc_FORWARD, 3), 8);//{ 4,5,6,7, 4,5,6,7, 12,13,14,15, 12,13,14,15 }; +/*static Packet16uc p16uc_HALF64_0_16 = vec_sld((Packet16uc)p4i_ZERO, vec_splat((Packet16uc) vec_abs(p4i_MINUS16), 3), 8); //{ 0,0,0,0, 0,0,0,0, 16,16,16,16, 16,16,16,16}; + +static Packet16uc p16uc_PSET64_HI = (Packet16uc) vec_mergeh((Packet4ui)p16uc_PSET32_WODD, (Packet4ui)p16uc_PSET32_WEVEN); //{ 0,1,2,3, 4,5,6,7, 0,1,2,3, 4,5,6,7 };*/ +static Packet16uc p16uc_PSET64_LO = (Packet16uc) vec_mergel((Packet4ui)p16uc_PSET32_WODD, (Packet4ui)p16uc_PSET32_WEVEN); //{ 8,9,10,11, 12,13,14,15, 8,9,10,11, 12,13,14,15 }; +/*static Packet16uc p16uc_TRANSPOSE64_HI = vec_add(p16uc_PSET64_HI, p16uc_HALF64_0_16); //{ 0,1,2,3, 4,5,6,7, 16,17,18,19, 20,21,22,23}; +static Packet16uc p16uc_TRANSPOSE64_LO = vec_add(p16uc_PSET64_LO, p16uc_HALF64_0_16); //{ 8,9,10,11, 12,13,14,15, 24,25,26,27, 28,29,30,31};*/ +static Packet16uc p16uc_TRANSPOSE64_HI = { 0,1,2,3, 4,5,6,7, 16,17,18,19, 20,21,22,23}; +static Packet16uc p16uc_TRANSPOSE64_LO = { 8,9,10,11, 12,13,14,15, 24,25,26,27, 28,29,30,31}; + +static Packet16uc p16uc_COMPLEX32_REV = vec_sld(p16uc_REVERSE32, p16uc_REVERSE32, 8); //{ 4,5,6,7, 0,1,2,3, 12,13,14,15, 8,9,10,11 }; + +static Packet16uc p16uc_COMPLEX32_REV2 = vec_sld(p16uc_FORWARD, p16uc_FORWARD, 8); //{ 8,9,10,11, 12,13,14,15, 0,1,2,3, 4,5,6,7 }; + + +#if EIGEN_HAS_BUILTIN(__builtin_prefetch) || EIGEN_COMP_GNUC + #define EIGEN_ZVECTOR_PREFETCH(ADDR) __builtin_prefetch(ADDR); +#else + #define EIGEN_ZVECTOR_PREFETCH(ADDR) asm( " pfd [%[addr]]\n" :: [addr] "r" (ADDR) : "cc" ); +#endif + +template<> struct packet_traits : default_packet_traits +{ + typedef Packet4i type; + typedef Packet4i half; + enum { + Vectorizable = 1, + AlignedOnScalar = 1, + size = 4, + HasHalfPacket = 0, + + HasAdd = 1, + HasSub = 1, + HasMul = 1, + HasDiv = 1, + HasBlend = 1 + }; +}; + +template <> +struct packet_traits : default_packet_traits { + typedef Packet4f type; + typedef Packet4f half; + enum { + Vectorizable = 1, + AlignedOnScalar = 1, + size = 4, + HasHalfPacket = 0, + + HasAdd = 1, + HasSub = 1, + HasMul = 1, + HasDiv = 1, + HasMin = 1, + HasMax = 1, + HasAbs = 1, + HasSin = 0, + HasCos = 0, + HasLog = 0, + HasExp = 1, + HasSqrt = 1, + HasRsqrt = 1, + HasTanh = 1, + HasErf = 1, + HasRound = 1, + HasFloor = 1, + HasCeil = 1, + HasNegate = 1, + HasBlend = 1 + }; +}; + +template<> struct packet_traits : default_packet_traits +{ + typedef Packet2d type; + typedef Packet2d half; + enum { + Vectorizable = 1, + AlignedOnScalar = 1, + size=2, + HasHalfPacket = 1, + + HasAdd = 1, + HasSub = 1, + HasMul = 1, + HasDiv = 1, + HasMin = 1, + HasMax = 1, + HasAbs = 1, + HasSin = 0, + HasCos = 0, + HasLog = 0, + HasExp = 1, + HasSqrt = 1, + HasRsqrt = 1, + HasRound = 1, + HasFloor = 1, + HasCeil = 1, + HasNegate = 1, + HasBlend = 1 + }; +}; + +template<> struct unpacket_traits { typedef int type; enum {size=4, alignment=Aligned16, vectorizable=true, masked_load_available=false, masked_store_available=false}; typedef Packet4i half; }; +template<> struct unpacket_traits { typedef float type; enum {size=4, alignment=Aligned16, vectorizable=true, masked_load_available=false, masked_store_available=false}; typedef Packet4f half; }; +template<> struct unpacket_traits { typedef double type; enum {size=2, alignment=Aligned16, vectorizable=true, masked_load_available=false, masked_store_available=false}; typedef Packet2d half; }; + +/* Forward declaration */ +EIGEN_DEVICE_FUNC inline void ptranspose(PacketBlock& kernel); + +inline std::ostream & operator <<(std::ostream & s, const Packet4i & v) +{ + Packet vt; + vt.v4i = v; + s << vt.i[0] << ", " << vt.i[1] << ", " << vt.i[2] << ", " << vt.i[3]; + return s; +} + +inline std::ostream & operator <<(std::ostream & s, const Packet4ui & v) +{ + Packet vt; + vt.v4ui = v; + s << vt.ui[0] << ", " << vt.ui[1] << ", " << vt.ui[2] << ", " << vt.ui[3]; + return s; +} + +inline std::ostream & operator <<(std::ostream & s, const Packet2l & v) +{ + Packet vt; + vt.v2l = v; + s << vt.l[0] << ", " << vt.l[1]; + return s; +} + +inline std::ostream & operator <<(std::ostream & s, const Packet2ul & v) +{ + Packet vt; + vt.v2ul = v; + s << vt.ul[0] << ", " << vt.ul[1] ; + return s; +} + +inline std::ostream & operator <<(std::ostream & s, const Packet2d & v) +{ + Packet vt; + vt.v2d = v; + s << vt.d[0] << ", " << vt.d[1]; + return s; +} + +#if !defined(__ARCH__) || (defined(__ARCH__) && __ARCH__ >= 12) +inline std::ostream & operator <<(std::ostream & s, const Packet4f & v) +{ + Packet vt; + vt.v4f = v; + s << vt.f[0] << ", " << vt.f[1] << ", " << vt.f[2] << ", " << vt.f[3]; + return s; +} +#endif + +template<> EIGEN_STRONG_INLINE Packet4i pload(const int* from) +{ + // FIXME: No intrinsic yet + EIGEN_DEBUG_ALIGNED_LOAD + Packet *vfrom; + vfrom = (Packet *) from; + return vfrom->v4i; +} + +template<> EIGEN_STRONG_INLINE Packet2d pload(const double* from) +{ + // FIXME: No intrinsic yet + EIGEN_DEBUG_ALIGNED_LOAD + Packet *vfrom; + vfrom = (Packet *) from; + return vfrom->v2d; +} + +template<> EIGEN_STRONG_INLINE void pstore(int* to, const Packet4i& from) +{ + // FIXME: No intrinsic yet + EIGEN_DEBUG_ALIGNED_STORE + Packet *vto; + vto = (Packet *) to; + vto->v4i = from; +} + +template<> EIGEN_STRONG_INLINE void pstore(double* to, const Packet2d& from) +{ + // FIXME: No intrinsic yet + EIGEN_DEBUG_ALIGNED_STORE + Packet *vto; + vto = (Packet *) to; + vto->v2d = from; +} + +template<> EIGEN_STRONG_INLINE Packet4i pset1(const int& from) +{ + return vec_splats(from); +} +template<> EIGEN_STRONG_INLINE Packet2d pset1(const double& from) { + return vec_splats(from); +} + +template<> EIGEN_STRONG_INLINE void +pbroadcast4(const int *a, + Packet4i& a0, Packet4i& a1, Packet4i& a2, Packet4i& a3) +{ + a3 = pload(a); + a0 = vec_splat(a3, 0); + a1 = vec_splat(a3, 1); + a2 = vec_splat(a3, 2); + a3 = vec_splat(a3, 3); +} + +template<> EIGEN_STRONG_INLINE void +pbroadcast4(const double *a, + Packet2d& a0, Packet2d& a1, Packet2d& a2, Packet2d& a3) +{ + a1 = pload(a); + a0 = vec_splat(a1, 0); + a1 = vec_splat(a1, 1); + a3 = pload(a+2); + a2 = vec_splat(a3, 0); + a3 = vec_splat(a3, 1); +} + +template<> EIGEN_DEVICE_FUNC inline Packet4i pgather(const int* from, Index stride) +{ + int EIGEN_ALIGN16 ai[4]; + ai[0] = from[0*stride]; + ai[1] = from[1*stride]; + ai[2] = from[2*stride]; + ai[3] = from[3*stride]; + return pload(ai); +} + +template<> EIGEN_DEVICE_FUNC inline Packet2d pgather(const double* from, Index stride) +{ + double EIGEN_ALIGN16 af[2]; + af[0] = from[0*stride]; + af[1] = from[1*stride]; + return pload(af); +} + +template<> EIGEN_DEVICE_FUNC inline void pscatter(int* to, const Packet4i& from, Index stride) +{ + int EIGEN_ALIGN16 ai[4]; + pstore((int *)ai, from); + to[0*stride] = ai[0]; + to[1*stride] = ai[1]; + to[2*stride] = ai[2]; + to[3*stride] = ai[3]; +} + +template<> EIGEN_DEVICE_FUNC inline void pscatter(double* to, const Packet2d& from, Index stride) +{ + double EIGEN_ALIGN16 af[2]; + pstore(af, from); + to[0*stride] = af[0]; + to[1*stride] = af[1]; +} + +template<> EIGEN_STRONG_INLINE Packet4i padd(const Packet4i& a, const Packet4i& b) { return (a + b); } +template<> EIGEN_STRONG_INLINE Packet2d padd(const Packet2d& a, const Packet2d& b) { return (a + b); } + +template<> EIGEN_STRONG_INLINE Packet4i psub(const Packet4i& a, const Packet4i& b) { return (a - b); } +template<> EIGEN_STRONG_INLINE Packet2d psub(const Packet2d& a, const Packet2d& b) { return (a - b); } + +template<> EIGEN_STRONG_INLINE Packet4i pmul(const Packet4i& a, const Packet4i& b) { return (a * b); } +template<> EIGEN_STRONG_INLINE Packet2d pmul(const Packet2d& a, const Packet2d& b) { return (a * b); } + +template<> EIGEN_STRONG_INLINE Packet4i pdiv(const Packet4i& a, const Packet4i& b) { return (a / b); } +template<> EIGEN_STRONG_INLINE Packet2d pdiv(const Packet2d& a, const Packet2d& b) { return (a / b); } + +template<> EIGEN_STRONG_INLINE Packet4i pnegate(const Packet4i& a) { return (-a); } +template<> EIGEN_STRONG_INLINE Packet2d pnegate(const Packet2d& a) { return (-a); } + +template<> EIGEN_STRONG_INLINE Packet4i pconj(const Packet4i& a) { return a; } +template<> EIGEN_STRONG_INLINE Packet2d pconj(const Packet2d& a) { return a; } + +template<> EIGEN_STRONG_INLINE Packet4i pmadd(const Packet4i& a, const Packet4i& b, const Packet4i& c) { return padd(pmul(a, b), c); } +template<> EIGEN_STRONG_INLINE Packet2d pmadd(const Packet2d& a, const Packet2d& b, const Packet2d& c) { return vec_madd(a, b, c); } + +template<> EIGEN_STRONG_INLINE Packet4i plset(const int& a) { return padd(pset1(a), p4i_COUNTDOWN); } +template<> EIGEN_STRONG_INLINE Packet2d plset(const double& a) { return padd(pset1(a), p2d_COUNTDOWN); } + +template<> EIGEN_STRONG_INLINE Packet4i pmin(const Packet4i& a, const Packet4i& b) { return vec_min(a, b); } +template<> EIGEN_STRONG_INLINE Packet2d pmin(const Packet2d& a, const Packet2d& b) { return vec_min(a, b); } + +template<> EIGEN_STRONG_INLINE Packet4i pmax(const Packet4i& a, const Packet4i& b) { return vec_max(a, b); } +template<> EIGEN_STRONG_INLINE Packet2d pmax(const Packet2d& a, const Packet2d& b) { return vec_max(a, b); } + +template<> EIGEN_STRONG_INLINE Packet4i pand(const Packet4i& a, const Packet4i& b) { return vec_and(a, b); } +template<> EIGEN_STRONG_INLINE Packet2d pand(const Packet2d& a, const Packet2d& b) { return vec_and(a, b); } + +template<> EIGEN_STRONG_INLINE Packet4i por(const Packet4i& a, const Packet4i& b) { return vec_or(a, b); } +template<> EIGEN_STRONG_INLINE Packet2d por(const Packet2d& a, const Packet2d& b) { return vec_or(a, b); } + +template<> EIGEN_STRONG_INLINE Packet4i pxor(const Packet4i& a, const Packet4i& b) { return vec_xor(a, b); } +template<> EIGEN_STRONG_INLINE Packet2d pxor(const Packet2d& a, const Packet2d& b) { return vec_xor(a, b); } + +template<> EIGEN_STRONG_INLINE Packet4i pandnot(const Packet4i& a, const Packet4i& b) { return pand(a, vec_nor(b, b)); } +template<> EIGEN_STRONG_INLINE Packet2d pandnot(const Packet2d& a, const Packet2d& b) { return vec_and(a, vec_nor(b, b)); } + +template<> EIGEN_STRONG_INLINE Packet2d pround(const Packet2d& a) { return vec_round(a); } +template<> EIGEN_STRONG_INLINE Packet2d pceil(const Packet2d& a) { return vec_ceil(a); } +template<> EIGEN_STRONG_INLINE Packet2d pfloor(const Packet2d& a) { return vec_floor(a); } + +template<> EIGEN_STRONG_INLINE Packet4i ploadu(const int* from) { return pload(from); } +template<> EIGEN_STRONG_INLINE Packet2d ploadu(const double* from) { return pload(from); } + + +template<> EIGEN_STRONG_INLINE Packet4i ploaddup(const int* from) +{ + Packet4i p = pload(from); + return vec_perm(p, p, p16uc_DUPLICATE32_HI); +} + +template<> EIGEN_STRONG_INLINE Packet2d ploaddup(const double* from) +{ + Packet2d p = pload(from); + return vec_perm(p, p, p16uc_PSET64_HI); +} + +template<> EIGEN_STRONG_INLINE void pstoreu(int* to, const Packet4i& from) { pstore(to, from); } +template<> EIGEN_STRONG_INLINE void pstoreu(double* to, const Packet2d& from) { pstore(to, from); } + +template<> EIGEN_STRONG_INLINE void prefetch(const int* addr) { EIGEN_ZVECTOR_PREFETCH(addr); } +template<> EIGEN_STRONG_INLINE void prefetch(const double* addr) { EIGEN_ZVECTOR_PREFETCH(addr); } + +template<> EIGEN_STRONG_INLINE int pfirst(const Packet4i& a) { int EIGEN_ALIGN16 x[4]; pstore(x, a); return x[0]; } +template<> EIGEN_STRONG_INLINE double pfirst(const Packet2d& a) { double EIGEN_ALIGN16 x[2]; pstore(x, a); return x[0]; } + +template<> EIGEN_STRONG_INLINE Packet4i preverse(const Packet4i& a) +{ + return reinterpret_cast(vec_perm(reinterpret_cast(a), reinterpret_cast(a), p16uc_REVERSE32)); +} + +template<> EIGEN_STRONG_INLINE Packet2d preverse(const Packet2d& a) +{ + return reinterpret_cast(vec_perm(reinterpret_cast(a), reinterpret_cast(a), p16uc_REVERSE64)); +} + +template<> EIGEN_STRONG_INLINE Packet4i pabs(const Packet4i& a) { return vec_abs(a); } +template<> EIGEN_STRONG_INLINE Packet2d pabs(const Packet2d& a) { return vec_abs(a); } + +template<> EIGEN_STRONG_INLINE int predux(const Packet4i& a) +{ + Packet4i b, sum; + b = vec_sld(a, a, 8); + sum = padd(a, b); + b = vec_sld(sum, sum, 4); + sum = padd(sum, b); + return pfirst(sum); +} + +template<> EIGEN_STRONG_INLINE double predux(const Packet2d& a) +{ + Packet2d b, sum; + b = reinterpret_cast(vec_sld(reinterpret_cast(a), reinterpret_cast(a), 8)); + sum = padd(a, b); + return pfirst(sum); +} + +// Other reduction functions: +// mul +template<> EIGEN_STRONG_INLINE int predux_mul(const Packet4i& a) +{ + EIGEN_ALIGN16 int aux[4]; + pstore(aux, a); + return aux[0] * aux[1] * aux[2] * aux[3]; +} + +template<> EIGEN_STRONG_INLINE double predux_mul(const Packet2d& a) +{ + return pfirst(pmul(a, reinterpret_cast(vec_sld(reinterpret_cast(a), reinterpret_cast(a), 8)))); +} + +// min +template<> EIGEN_STRONG_INLINE int predux_min(const Packet4i& a) +{ + Packet4i b, res; + b = pmin(a, vec_sld(a, a, 8)); + res = pmin(b, vec_sld(b, b, 4)); + return pfirst(res); +} + +template<> EIGEN_STRONG_INLINE double predux_min(const Packet2d& a) +{ + return pfirst(pmin(a, reinterpret_cast(vec_sld(reinterpret_cast(a), reinterpret_cast(a), 8)))); +} + +// max +template<> EIGEN_STRONG_INLINE int predux_max(const Packet4i& a) +{ + Packet4i b, res; + b = pmax(a, vec_sld(a, a, 8)); + res = pmax(b, vec_sld(b, b, 4)); + return pfirst(res); +} + +// max +template<> EIGEN_STRONG_INLINE double predux_max(const Packet2d& a) +{ + return pfirst(pmax(a, reinterpret_cast(vec_sld(reinterpret_cast(a), reinterpret_cast(a), 8)))); +} + +EIGEN_DEVICE_FUNC inline void +ptranspose(PacketBlock& kernel) { + Packet4i t0 = vec_mergeh(kernel.packet[0], kernel.packet[2]); + Packet4i t1 = vec_mergel(kernel.packet[0], kernel.packet[2]); + Packet4i t2 = vec_mergeh(kernel.packet[1], kernel.packet[3]); + Packet4i t3 = vec_mergel(kernel.packet[1], kernel.packet[3]); + kernel.packet[0] = vec_mergeh(t0, t2); + kernel.packet[1] = vec_mergel(t0, t2); + kernel.packet[2] = vec_mergeh(t1, t3); + kernel.packet[3] = vec_mergel(t1, t3); +} + +EIGEN_DEVICE_FUNC inline void +ptranspose(PacketBlock& kernel) { + Packet2d t0 = vec_perm(kernel.packet[0], kernel.packet[1], p16uc_TRANSPOSE64_HI); + Packet2d t1 = vec_perm(kernel.packet[0], kernel.packet[1], p16uc_TRANSPOSE64_LO); + kernel.packet[0] = t0; + kernel.packet[1] = t1; +} + +template<> EIGEN_STRONG_INLINE Packet4i pblend(const Selector<4>& ifPacket, const Packet4i& thenPacket, const Packet4i& elsePacket) { + Packet4ui select = { ifPacket.select[0], ifPacket.select[1], ifPacket.select[2], ifPacket.select[3] }; + Packet4ui mask = vec_cmpeq(select, reinterpret_cast(p4i_ONE)); + return vec_sel(elsePacket, thenPacket, mask); +} + + +template<> EIGEN_STRONG_INLINE Packet2d pblend(const Selector<2>& ifPacket, const Packet2d& thenPacket, const Packet2d& elsePacket) { + Packet2ul select = { ifPacket.select[0], ifPacket.select[1] }; + Packet2ul mask = vec_cmpeq(select, reinterpret_cast(p2l_ONE)); + return vec_sel(elsePacket, thenPacket, mask); +} + +/* z13 has no vector float support so we emulate that with double + z14 has proper vector float support. +*/ +#if !defined(__ARCH__) || (defined(__ARCH__) && __ARCH__ < 12) +/* Helper function to simulate a vec_splat_packet4f + */ +template EIGEN_STRONG_INLINE Packet4f vec_splat_packet4f(const Packet4f& from) +{ + Packet4f splat; + switch (element) { + case 0: + splat.v4f[0] = vec_splat(from.v4f[0], 0); + splat.v4f[1] = splat.v4f[0]; + break; + case 1: + splat.v4f[0] = vec_splat(from.v4f[0], 1); + splat.v4f[1] = splat.v4f[0]; + break; + case 2: + splat.v4f[0] = vec_splat(from.v4f[1], 0); + splat.v4f[1] = splat.v4f[0]; + break; + case 3: + splat.v4f[0] = vec_splat(from.v4f[1], 1); + splat.v4f[1] = splat.v4f[0]; + break; + } + return splat; +} + +template<> EIGEN_STRONG_INLINE Packet4f pload(const float* from) +{ + // FIXME: No intrinsic yet + EIGEN_DEBUG_ALIGNED_LOAD + Packet4f vfrom; + vfrom.v4f[0] = vec_ld2f(&from[0]); + vfrom.v4f[1] = vec_ld2f(&from[2]); + return vfrom; +} + +template<> EIGEN_STRONG_INLINE void pstore(float* to, const Packet4f& from) +{ + // FIXME: No intrinsic yet + EIGEN_DEBUG_ALIGNED_STORE + vec_st2f(from.v4f[0], &to[0]); + vec_st2f(from.v4f[1], &to[2]); +} + +template<> EIGEN_STRONG_INLINE Packet4f pset1(const float& from) +{ + Packet4f to; + to.v4f[0] = pset1(static_cast(from)); + to.v4f[1] = to.v4f[0]; + return to; +} + +template<> EIGEN_STRONG_INLINE void +pbroadcast4(const float *a, + Packet4f& a0, Packet4f& a1, Packet4f& a2, Packet4f& a3) +{ + a3 = pload(a); + a0 = vec_splat_packet4f<0>(a3); + a1 = vec_splat_packet4f<1>(a3); + a2 = vec_splat_packet4f<2>(a3); + a3 = vec_splat_packet4f<3>(a3); +} + +template<> EIGEN_DEVICE_FUNC inline Packet4f pgather(const float* from, Index stride) +{ + float EIGEN_ALIGN16 ai[4]; + ai[0] = from[0*stride]; + ai[1] = from[1*stride]; + ai[2] = from[2*stride]; + ai[3] = from[3*stride]; + return pload(ai); +} + +template<> EIGEN_DEVICE_FUNC inline void pscatter(float* to, const Packet4f& from, Index stride) +{ + float EIGEN_ALIGN16 ai[4]; + pstore((float *)ai, from); + to[0*stride] = ai[0]; + to[1*stride] = ai[1]; + to[2*stride] = ai[2]; + to[3*stride] = ai[3]; +} + +template<> EIGEN_STRONG_INLINE Packet4f padd(const Packet4f& a, const Packet4f& b) +{ + Packet4f c; + c.v4f[0] = a.v4f[0] + b.v4f[0]; + c.v4f[1] = a.v4f[1] + b.v4f[1]; + return c; +} + +template<> EIGEN_STRONG_INLINE Packet4f psub(const Packet4f& a, const Packet4f& b) +{ + Packet4f c; + c.v4f[0] = a.v4f[0] - b.v4f[0]; + c.v4f[1] = a.v4f[1] - b.v4f[1]; + return c; +} + +template<> EIGEN_STRONG_INLINE Packet4f pmul(const Packet4f& a, const Packet4f& b) +{ + Packet4f c; + c.v4f[0] = a.v4f[0] * b.v4f[0]; + c.v4f[1] = a.v4f[1] * b.v4f[1]; + return c; +} + +template<> EIGEN_STRONG_INLINE Packet4f pdiv(const Packet4f& a, const Packet4f& b) +{ + Packet4f c; + c.v4f[0] = a.v4f[0] / b.v4f[0]; + c.v4f[1] = a.v4f[1] / b.v4f[1]; + return c; +} + +template<> EIGEN_STRONG_INLINE Packet4f pnegate(const Packet4f& a) +{ + Packet4f c; + c.v4f[0] = -a.v4f[0]; + c.v4f[1] = -a.v4f[1]; + return c; +} + +template<> EIGEN_STRONG_INLINE Packet4f pmadd(const Packet4f& a, const Packet4f& b, const Packet4f& c) +{ + Packet4f res; + res.v4f[0] = vec_madd(a.v4f[0], b.v4f[0], c.v4f[0]); + res.v4f[1] = vec_madd(a.v4f[1], b.v4f[1], c.v4f[1]); + return res; +} + +template<> EIGEN_STRONG_INLINE Packet4f pmin(const Packet4f& a, const Packet4f& b) +{ + Packet4f res; + res.v4f[0] = pmin(a.v4f[0], b.v4f[0]); + res.v4f[1] = pmin(a.v4f[1], b.v4f[1]); + return res; +} + +template<> EIGEN_STRONG_INLINE Packet4f pmax(const Packet4f& a, const Packet4f& b) +{ + Packet4f res; + res.v4f[0] = pmax(a.v4f[0], b.v4f[0]); + res.v4f[1] = pmax(a.v4f[1], b.v4f[1]); + return res; +} + +template<> EIGEN_STRONG_INLINE Packet4f pand(const Packet4f& a, const Packet4f& b) +{ + Packet4f res; + res.v4f[0] = pand(a.v4f[0], b.v4f[0]); + res.v4f[1] = pand(a.v4f[1], b.v4f[1]); + return res; +} + +template<> EIGEN_STRONG_INLINE Packet4f por(const Packet4f& a, const Packet4f& b) +{ + Packet4f res; + res.v4f[0] = por(a.v4f[0], b.v4f[0]); + res.v4f[1] = por(a.v4f[1], b.v4f[1]); + return res; +} + +template<> EIGEN_STRONG_INLINE Packet4f pxor(const Packet4f& a, const Packet4f& b) +{ + Packet4f res; + res.v4f[0] = pxor(a.v4f[0], b.v4f[0]); + res.v4f[1] = pxor(a.v4f[1], b.v4f[1]); + return res; +} + +template<> EIGEN_STRONG_INLINE Packet4f pandnot(const Packet4f& a, const Packet4f& b) +{ + Packet4f res; + res.v4f[0] = pandnot(a.v4f[0], b.v4f[0]); + res.v4f[1] = pandnot(a.v4f[1], b.v4f[1]); + return res; +} + +template<> EIGEN_STRONG_INLINE Packet4f pround(const Packet4f& a) +{ + Packet4f res; + res.v4f[0] = vec_round(a.v4f[0]); + res.v4f[1] = vec_round(a.v4f[1]); + return res; +} + +template<> EIGEN_STRONG_INLINE Packet4f pceil(const Packet4f& a) +{ + Packet4f res; + res.v4f[0] = vec_ceil(a.v4f[0]); + res.v4f[1] = vec_ceil(a.v4f[1]); + return res; +} + +template<> EIGEN_STRONG_INLINE Packet4f pfloor(const Packet4f& a) +{ + Packet4f res; + res.v4f[0] = vec_floor(a.v4f[0]); + res.v4f[1] = vec_floor(a.v4f[1]); + return res; +} + +template<> EIGEN_STRONG_INLINE Packet4f ploaddup(const float* from) +{ + Packet4f p = pload(from); + p.v4f[1] = vec_splat(p.v4f[0], 1); + p.v4f[0] = vec_splat(p.v4f[0], 0); + return p; +} + +template<> EIGEN_STRONG_INLINE float pfirst(const Packet4f& a) { float EIGEN_ALIGN16 x[2]; vec_st2f(a.v4f[0], &x[0]); return x[0]; } + +template<> EIGEN_STRONG_INLINE Packet4f preverse(const Packet4f& a) +{ + Packet4f rev; + rev.v4f[0] = preverse(a.v4f[1]); + rev.v4f[1] = preverse(a.v4f[0]); + return rev; +} + +template<> EIGEN_STRONG_INLINE Packet4f pabs(const Packet4f& a) +{ + Packet4f res; + res.v4f[0] = pabs(a.v4f[0]); + res.v4f[1] = pabs(a.v4f[1]); + return res; +} + +template<> EIGEN_STRONG_INLINE float predux(const Packet4f& a) +{ + Packet2d sum; + sum = padd(a.v4f[0], a.v4f[1]); + double first = predux(sum); + return static_cast(first); +} + +template<> EIGEN_STRONG_INLINE float predux_mul(const Packet4f& a) +{ + // Return predux_mul of the subvectors product + return static_cast(pfirst(predux_mul(pmul(a.v4f[0], a.v4f[1])))); +} + +template<> EIGEN_STRONG_INLINE float predux_min(const Packet4f& a) +{ + Packet2d b, res; + b = pmin(a.v4f[0], a.v4f[1]); + res = pmin(b, reinterpret_cast(vec_sld(reinterpret_cast(b), reinterpret_cast(b), 8))); + return static_cast(pfirst(res)); +} + +template<> EIGEN_STRONG_INLINE float predux_max(const Packet4f& a) +{ + Packet2d b, res; + b = pmax(a.v4f[0], a.v4f[1]); + res = pmax(b, reinterpret_cast(vec_sld(reinterpret_cast(b), reinterpret_cast(b), 8))); + return static_cast(pfirst(res)); +} + +/* Split the Packet4f PacketBlock into 4 Packet2d PacketBlocks and transpose each one + */ +EIGEN_DEVICE_FUNC inline void +ptranspose(PacketBlock& kernel) { + PacketBlock t0,t1,t2,t3; + // copy top-left 2x2 Packet2d block + t0.packet[0] = kernel.packet[0].v4f[0]; + t0.packet[1] = kernel.packet[1].v4f[0]; + + // copy top-right 2x2 Packet2d block + t1.packet[0] = kernel.packet[0].v4f[1]; + t1.packet[1] = kernel.packet[1].v4f[1]; + + // copy bottom-left 2x2 Packet2d block + t2.packet[0] = kernel.packet[2].v4f[0]; + t2.packet[1] = kernel.packet[3].v4f[0]; + + // copy bottom-right 2x2 Packet2d block + t3.packet[0] = kernel.packet[2].v4f[1]; + t3.packet[1] = kernel.packet[3].v4f[1]; + + // Transpose all 2x2 blocks + ptranspose(t0); + ptranspose(t1); + ptranspose(t2); + ptranspose(t3); + + // Copy back transposed blocks, but exchange t1 and t2 due to transposition + kernel.packet[0].v4f[0] = t0.packet[0]; + kernel.packet[0].v4f[1] = t2.packet[0]; + kernel.packet[1].v4f[0] = t0.packet[1]; + kernel.packet[1].v4f[1] = t2.packet[1]; + kernel.packet[2].v4f[0] = t1.packet[0]; + kernel.packet[2].v4f[1] = t3.packet[0]; + kernel.packet[3].v4f[0] = t1.packet[1]; + kernel.packet[3].v4f[1] = t3.packet[1]; +} + +template<> EIGEN_STRONG_INLINE Packet4f pblend(const Selector<4>& ifPacket, const Packet4f& thenPacket, const Packet4f& elsePacket) { + Packet2ul select_hi = { ifPacket.select[0], ifPacket.select[1] }; + Packet2ul select_lo = { ifPacket.select[2], ifPacket.select[3] }; + Packet2ul mask_hi = vec_cmpeq(select_hi, reinterpret_cast(p2l_ONE)); + Packet2ul mask_lo = vec_cmpeq(select_lo, reinterpret_cast(p2l_ONE)); + Packet4f result; + result.v4f[0] = vec_sel(elsePacket.v4f[0], thenPacket.v4f[0], mask_hi); + result.v4f[1] = vec_sel(elsePacket.v4f[1], thenPacket.v4f[1], mask_lo); + return result; +} + +template<> Packet4f EIGEN_STRONG_INLINE pcmp_le(const Packet4f& a, const Packet4f& b) +{ + Packet4f res; + res.v4f[0] = pcmp_le(a.v4f[0], b.v4f[0]); + res.v4f[1] = pcmp_le(a.v4f[1], b.v4f[1]); + return res; +} + +template<> Packet4f EIGEN_STRONG_INLINE pcmp_lt(const Packet4f& a, const Packet4f& b) +{ + Packet4f res; + res.v4f[0] = pcmp_lt(a.v4f[0], b.v4f[0]); + res.v4f[1] = pcmp_lt(a.v4f[1], b.v4f[1]); + return res; +} + +template<> Packet4f EIGEN_STRONG_INLINE pcmp_eq(const Packet4f& a, const Packet4f& b) +{ + Packet4f res; + res.v4f[0] = pcmp_eq(a.v4f[0], b.v4f[0]); + res.v4f[1] = pcmp_eq(a.v4f[1], b.v4f[1]); + return res; +} + +#else +template<> EIGEN_STRONG_INLINE Packet4f pload(const float* from) +{ + // FIXME: No intrinsic yet + EIGEN_DEBUG_ALIGNED_LOAD + Packet *vfrom; + vfrom = (Packet *) from; + return vfrom->v4f; +} + +template<> EIGEN_STRONG_INLINE void pstore(float* to, const Packet4f& from) +{ + // FIXME: No intrinsic yet + EIGEN_DEBUG_ALIGNED_STORE + Packet *vto; + vto = (Packet *) to; + vto->v4f = from; +} + +template<> EIGEN_STRONG_INLINE Packet4f pset1(const float& from) +{ + return vec_splats(from); +} + +template<> EIGEN_STRONG_INLINE void +pbroadcast4(const float *a, + Packet4f& a0, Packet4f& a1, Packet4f& a2, Packet4f& a3) +{ + a3 = pload(a); + a0 = vec_splat(a3, 0); + a1 = vec_splat(a3, 1); + a2 = vec_splat(a3, 2); + a3 = vec_splat(a3, 3); +} + +template<> EIGEN_DEVICE_FUNC inline Packet4f pgather(const float* from, Index stride) +{ + float EIGEN_ALIGN16 af[4]; + af[0] = from[0*stride]; + af[1] = from[1*stride]; + af[2] = from[2*stride]; + af[3] = from[3*stride]; + return pload(af); +} + +template<> EIGEN_DEVICE_FUNC inline void pscatter(float* to, const Packet4f& from, Index stride) +{ + float EIGEN_ALIGN16 af[4]; + pstore((float*)af, from); + to[0*stride] = af[0]; + to[1*stride] = af[1]; + to[2*stride] = af[2]; + to[3*stride] = af[3]; +} + +template<> EIGEN_STRONG_INLINE Packet4f padd(const Packet4f& a, const Packet4f& b) { return (a + b); } +template<> EIGEN_STRONG_INLINE Packet4f psub(const Packet4f& a, const Packet4f& b) { return (a - b); } +template<> EIGEN_STRONG_INLINE Packet4f pmul(const Packet4f& a, const Packet4f& b) { return (a * b); } +template<> EIGEN_STRONG_INLINE Packet4f pdiv(const Packet4f& a, const Packet4f& b) { return (a / b); } +template<> EIGEN_STRONG_INLINE Packet4f pnegate(const Packet4f& a) { return (-a); } +template<> EIGEN_STRONG_INLINE Packet4f pconj (const Packet4f& a) { return a; } +template<> EIGEN_STRONG_INLINE Packet4f pmadd (const Packet4f& a, const Packet4f& b, const Packet4f& c) { return vec_madd(a, b, c); } +template<> EIGEN_STRONG_INLINE Packet4f pmin (const Packet4f& a, const Packet4f& b) { return vec_min(a, b); } +template<> EIGEN_STRONG_INLINE Packet4f pmax (const Packet4f& a, const Packet4f& b) { return vec_max(a, b); } +template<> EIGEN_STRONG_INLINE Packet4f pand (const Packet4f& a, const Packet4f& b) { return vec_and(a, b); } +template<> EIGEN_STRONG_INLINE Packet4f por (const Packet4f& a, const Packet4f& b) { return vec_or(a, b); } +template<> EIGEN_STRONG_INLINE Packet4f pxor (const Packet4f& a, const Packet4f& b) { return vec_xor(a, b); } +template<> EIGEN_STRONG_INLINE Packet4f pandnot(const Packet4f& a, const Packet4f& b) { return vec_and(a, vec_nor(b, b)); } +template<> EIGEN_STRONG_INLINE Packet4f pround (const Packet4f& a) { return vec_round(a); } +template<> EIGEN_STRONG_INLINE Packet4f pceil (const Packet4f& a) { return vec_ceil(a); } +template<> EIGEN_STRONG_INLINE Packet4f pfloor (const Packet4f& a) { return vec_floor(a); } +template<> EIGEN_STRONG_INLINE Packet4f pabs (const Packet4f& a) { return vec_abs(a); } +template<> EIGEN_STRONG_INLINE float pfirst(const Packet4f& a) { float EIGEN_ALIGN16 x[4]; pstore(x, a); return x[0]; } + +template<> EIGEN_STRONG_INLINE Packet4f ploaddup(const float* from) +{ + Packet4f p = pload(from); + return vec_perm(p, p, p16uc_DUPLICATE32_HI); +} + +template<> EIGEN_STRONG_INLINE Packet4f preverse(const Packet4f& a) +{ + return reinterpret_cast(vec_perm(reinterpret_cast(a), reinterpret_cast(a), p16uc_REVERSE32)); +} + +template<> EIGEN_STRONG_INLINE float predux(const Packet4f& a) +{ + Packet4f b, sum; + b = vec_sld(a, a, 8); + sum = padd(a, b); + b = vec_sld(sum, sum, 4); + sum = padd(sum, b); + return pfirst(sum); +} + +// Other reduction functions: +// mul +template<> EIGEN_STRONG_INLINE float predux_mul(const Packet4f& a) +{ + Packet4f prod; + prod = pmul(a, vec_sld(a, a, 8)); + return pfirst(pmul(prod, vec_sld(prod, prod, 4))); +} + +// min +template<> EIGEN_STRONG_INLINE float predux_min(const Packet4f& a) +{ + Packet4f b, res; + b = pmin(a, vec_sld(a, a, 8)); + res = pmin(b, vec_sld(b, b, 4)); + return pfirst(res); +} + +// max +template<> EIGEN_STRONG_INLINE float predux_max(const Packet4f& a) +{ + Packet4f b, res; + b = pmax(a, vec_sld(a, a, 8)); + res = pmax(b, vec_sld(b, b, 4)); + return pfirst(res); +} + +EIGEN_DEVICE_FUNC inline void +ptranspose(PacketBlock& kernel) { + Packet4f t0 = vec_mergeh(kernel.packet[0], kernel.packet[2]); + Packet4f t1 = vec_mergel(kernel.packet[0], kernel.packet[2]); + Packet4f t2 = vec_mergeh(kernel.packet[1], kernel.packet[3]); + Packet4f t3 = vec_mergel(kernel.packet[1], kernel.packet[3]); + kernel.packet[0] = vec_mergeh(t0, t2); + kernel.packet[1] = vec_mergel(t0, t2); + kernel.packet[2] = vec_mergeh(t1, t3); + kernel.packet[3] = vec_mergel(t1, t3); +} + +template<> EIGEN_STRONG_INLINE Packet4f pblend(const Selector<4>& ifPacket, const Packet4f& thenPacket, const Packet4f& elsePacket) { + Packet4ui select = { ifPacket.select[0], ifPacket.select[1], ifPacket.select[2], ifPacket.select[3] }; + Packet4ui mask = vec_cmpeq(select, reinterpret_cast(p4i_ONE)); + return vec_sel(elsePacket, thenPacket, mask); +} + +#endif + +template<> EIGEN_STRONG_INLINE void prefetch(const float* addr) { EIGEN_ZVECTOR_PREFETCH(addr); } +template<> EIGEN_STRONG_INLINE Packet4f ploadu (const float* from) { return pload(from); } +template<> EIGEN_STRONG_INLINE void pstoreu(float* to, const Packet4f& from) { pstore(to, from); } +template<> EIGEN_STRONG_INLINE Packet4f plset (const float& a) { return padd(pset1(a), p4f_COUNTDOWN); } + +} // end namespace internal + +} // end namespace Eigen + +#endif // EIGEN_PACKET_MATH_ZVECTOR_H -- cgit v1.2.3-70-g09d2