AmpGen 2.1
Loading...
Searching...
No Matches
avx512d_types.h
Go to the documentation of this file.
1#ifndef AMPGEN_AVXd_TYPES
2#define AMPGEN_AVXd_TYPES 1
3
4#include <immintrin.h>
5#include <array>
6#include <iostream>
7#include <complex>
8#include <omp.h>
9#include <cmath>
10
11namespace AmpGen {
12 namespace AVX512d {
13 #define stl_fallback( x ) \
14 inline real_v x( const real_v& v ){ auto a = v.to_array(); return real_v( std::x(a[0]), std::x(a[1]), std::x(a[2]), std::x(a[3]), std::x(a[4]), std::x(a[5]), std::x(a[6]), std::x(a[7]) ) ; }
15
16 struct real_v {
17 __m512d data;
18 static constexpr unsigned size = 8;
19 typedef double scalar_type;
20 real_v() = default;
21 real_v(__m512d data ) : data(data) {}
22 real_v(const double& f ) : data( _mm512_set1_pd( f )) {}
24 const double& x0, const double& x1, const double& x2, const double& x3,
25 const double& x4, const double& x5, const double& x6, const double& x7)
26 {
27 double tmp[8] = {x0,x1,x2,x3,x4,x5,x6,x7};
28 data = _mm512_loadu_pd(tmp);
29 }
30 real_v(const double* f ) : data( _mm512_loadu_pd( f ) ) {}
31 void store( double* ptr ) const { _mm512_storeu_pd( ptr, data ); }
32 std::array<double, 8> to_array() const { std::array<double, 8> b; store( &b[0] ); return b; }
33 double at(const unsigned i) const { return to_array()[i] ; }
34 operator __m512d() const { return data ; }
35 };
36
37 inline real_v operator+( const real_v& lhs, const real_v& rhs ) { return _mm512_add_pd(lhs, rhs); }
38 inline real_v operator-( const real_v& lhs, const real_v& rhs ) { return _mm512_sub_pd(lhs, rhs); }
39 inline real_v operator*( const real_v& lhs, const real_v& rhs ) { return _mm512_mul_pd(lhs, rhs); }
40 inline real_v operator/( const real_v& lhs, const real_v& rhs ) { return _mm512_div_pd(lhs, rhs); }
41 inline real_v operator-( const real_v& x ) { return -1.f * x; }
42 inline real_v operator&( const real_v& lhs, const real_v& rhs ) { return _mm512_and_pd( lhs, rhs ); }
43 inline real_v operator|( const real_v& lhs, const real_v& rhs ) { return _mm512_or_pd( lhs, rhs ); }
44 inline real_v operator^( const real_v& lhs, const real_v& rhs ) { return _mm512_xor_pd( lhs, rhs ); }
45 inline real_v operator+=(real_v& lhs, const real_v& rhs ){ lhs = lhs + rhs; return lhs; }
46 inline real_v operator-=(real_v& lhs, const real_v& rhs ){ lhs = lhs - rhs; return lhs; }
47 inline real_v operator*=(real_v& lhs, const real_v& rhs ){ lhs = lhs * rhs; return lhs; }
48 inline real_v operator/=(real_v& lhs, const real_v& rhs ){ lhs = lhs / rhs; return lhs; }
49 inline real_v operator&&( const real_v& lhs, const real_v& rhs ) { return _mm512_and_pd( lhs, rhs ); }
50 inline real_v operator||( const real_v& lhs, const real_v& rhs ) { return _mm512_or_pd( lhs, rhs ); }
51 inline real_v operator!( const real_v& x ) { return x ^ _mm512_castsi512_pd( _mm512_set1_epi32( -1 ) ); }
52 inline __mmask8 operator<( const real_v& lhs, const real_v& rhs ) { return _mm512_cmp_pd_mask( lhs, rhs, _CMP_LT_OS ); }
53 inline __mmask8 operator>( const real_v& lhs, const real_v& rhs ) { return _mm512_cmp_pd_mask( lhs, rhs, _CMP_GT_OS ); }
54 inline __mmask8 operator==( const real_v& lhs, const real_v& rhs ){ return _mm512_cmp_pd_mask( lhs, rhs, _CMP_EQ_OS ); }
55 inline real_v sqrt( const real_v& v ) { return _mm512_sqrt_pd(v); }
56 inline real_v abs ( const real_v& v ) { return _mm512_andnot_pd(_mm512_set1_pd(-0.), v); }
57 // inline real_v sin( const real_v& v ) { return sin512_pd(v) ; }
58 // inline real_v cos( const real_v& v ) { return cos512_pd(v) ; }
59 // inline real_v tan( const real_v& v ) { real_v s; real_v c; sincos512_pd(v, (__m512*)&s, (__m512*)&c) ; return s/c; }
60 // inline real_v exp( const real_v& v ) { return exp512_ps(v) ; }
61 inline real_v select(const __mmask8& mask, const real_v& a, const real_v& b ) { return _mm512_mask_mov_pd( b, mask, a ); }
62 inline real_v select(const bool& mask , const real_v& a, const real_v& b ) { return mask ? a : b; }
63 inline real_v sign ( const real_v& v){ return select( v > 0., +1., -1. ); }
64 inline real_v atan2( const real_v& y, const real_v& x ){
65 std::array<double, 8> bx{x.to_array()}, by{y.to_array()};
66 return real_v (
67 std::atan2(by[0], bx[0]) , std::atan2( by[1], bx[1]), std::atan2( by[2], bx[2]), std::atan2( by[3], bx[3]) ,
68 std::atan2(by[4], bx[4]) , std::atan2( by[5], bx[5]), std::atan2( by[6], bx[6]), std::atan2( by[7], bx[7]) );
69 }
70 inline __m512i double_to_int( const real_v& x )
71 {
72 auto xr = _mm512_roundscale_pd(x, _MM_FROUND_TO_ZERO);
73 // based on: https://stackoverflow.com/questions/41144668/how-to-efficiently-perform-double-int64-conversions-with-sse-avx
74 return _mm512_sub_epi64(_mm512_castpd_si512(_mm512_add_pd(xr, _mm512_set1_pd(0x0018000000000000))),
75 _mm512_castpd_si512(_mm512_set1_pd(0x0018000000000000)));
76 }
77 inline real_v gather( const double* base_addr, const real_v& offsets)
78 {
79 return _mm512_i64gather_pd(double_to_int(offsets), base_addr, sizeof(double));
80 }
81
82 inline void frexp(const real_v& value, real_v& mant, real_v& exponent)
83 {
84 auto arg_as_int = _mm512_castpd_si512(value);
85 static const real_v offset(4503599627370496.0 + 1022.0); // 2^52 + 1022.0
86 static const __m512i pow2_52_i = _mm512_set1_epi64(0x4330000000000000); // *reinterpret_cast<const uint64_t*>(&pow2_52_d);
87 auto b = _mm512_srl_epi64(arg_as_int, _mm_cvtsi32_si128(52));
88 auto c = _mm512_or_si512( b , pow2_52_i);
89 exponent = real_v( _mm512_castsi512_pd(c) ) - offset;
90 mant = _mm512_castsi512_pd(_mm512_or_si512(_mm512_and_si512 (arg_as_int, _mm512_set1_epi64(0x000FFFFFFFFFFFFFll) ), _mm512_set1_epi64(0x3FE0000000000000ll)));
91 }
92
93 inline real_v fmadd( const real_v& a, const real_v& b, const real_v& c )
94 {
95 return _mm512_fmadd_pd(a, b, c);
96 }
97 inline real_v log(const real_v& arg)
98 {
99 static const real_v corr = 0.693147180559945286226764;
100 static const real_v CL15 = 0.148197055177935105296783;
101 static const real_v CL13 = 0.153108178020442575739679;
102 static const real_v CL11 = 0.181837339521549679055568;
103 static const real_v CL9 = 0.22222194152736701733275;
104 static const real_v CL7 = 0.285714288030134544449368;
105 static const real_v CL5 = 0.399999999989941956712869;
106 static const real_v CL3 = 0.666666666666685503450651;
107 static const real_v CL1 = 2.0;
108 real_v mant, exponent;
109 frexp(arg, mant, exponent);
110 auto x = (mant - 1.) / (mant + 1.);
111 auto x2 = x * x;
112 auto p = fmadd(CL15, x2, CL13);
113 p = fmadd(p, x2, CL11);
114 p = fmadd(p, x2, CL9);
115 p = fmadd(p, x2, CL7);
116 p = fmadd(p, x2, CL5);
117 p = fmadd(p, x2, CL3);
118 p = fmadd(p, x2, CL1);
119 p = fmadd(p, x, corr * exponent);
120 return p;
121 }
126 inline real_v remainder( const real_v& a, const real_v& b ){ return a - real_v(_mm512_roundscale_pd(a/b, _MM_FROUND_TO_NEG_INF)) * b; }
127 inline real_v fmod( const real_v& a, const real_v& b )
128 {
129 auto r = remainder( abs(a), abs(b) );
130 return select( a > 0., r, -r );
131 }
132
133 inline std::ostream& operator<<( std::ostream& os, const real_v& obj ) {
134 auto buffer = obj.to_array();
135 for( unsigned i = 0 ; i != 8; ++i ) os << buffer[i] << " ";
136 return os;
137 }
138
139 using complex_v = std::complex<real_v>;
140 inline complex_v operator+( const complex_v& lhs, const real_v& rhs ) { return complex_v(lhs.real() + rhs, lhs.imag()); }
141 inline complex_v operator-( const complex_v& lhs, const real_v& rhs ) { return complex_v(lhs.real() - rhs, lhs.imag()); }
142 inline complex_v operator*( const complex_v& lhs, const real_v& rhs ) { return complex_v(lhs.real()*rhs, lhs.imag()*rhs); }
143 inline complex_v operator/( const complex_v& lhs, const real_v& rhs ) { return complex_v(lhs.real()/rhs, lhs.imag()/rhs); }
144 inline complex_v operator+( const real_v& lhs, const complex_v& rhs ) { return complex_v(lhs + rhs.real(), rhs.imag()); }
145 inline complex_v operator-( const real_v& lhs, const complex_v& rhs ) { return complex_v(lhs - rhs.real(), - rhs.imag()); }
146 inline complex_v operator*( const real_v& lhs, const complex_v& rhs ) { return complex_v(lhs*rhs.real(), lhs*rhs.imag()); }
147 inline complex_v operator/( const real_v& lhs, const complex_v& rhs ) { return complex_v( lhs * rhs.real() , -lhs *rhs.imag()) / (rhs.real() * rhs.real() + rhs.imag() * rhs.imag() ); }
148 inline real_v abs( const complex_v& v ) { return sqrt( v.real() * v.real() + v.imag() * v.imag() ) ; }
149 inline real_v norm( const complex_v& v ) { return ( v.real() * v.real() + v.imag() * v.imag() ) ; }
150 inline complex_v select(const __mmask8& mask, const complex_v& a, const complex_v& b ) { return complex_v( select(mask, a.real(), b.real()), select(mask, a.imag(), b.imag() ) ) ; }
151 inline complex_v select(const __mmask8& mask, const real_v& a, const complex_v& b ) { return complex_v( select(mask, a , b.real()), select(mask, 0.f, b.imag()) ); }
152 inline complex_v select(const __mmask8& mask, const complex_v& a, const real_v& b ) { return complex_v( select(mask, a.real(), b ) , select(mask, a.imag(), 0.f) ); }
153 inline complex_v select(const bool& mask , const complex_v& a, const complex_v& b ) { return mask ? a : b; }
154 inline complex_v exp( const complex_v& v ){
155 auto [s,c] = sincos( v.imag());
156 return exp(v.real()) * complex_v(c, s);
157 }
158 inline complex_v sqrt( const complex_v& v )
159 {
160 auto r = abs(v);
161 return complex_v ( sqrt( 0.5 * (r + v.real()) ), sign(v.imag()) * sqrt( 0.5*( r - v.real() ) ) );
162 }
163 inline complex_v log( const complex_v& v )
164 {
165 return complex_v( 0.5 * log( norm(v) ) , atan2(v.imag(), v.real()) );
166 }
167
168 inline std::ostream& operator<<( std::ostream& os, const complex_v& obj ) { return os << "( "<< obj.real() << ") (" << obj.imag() << ")"; }
169 #pragma omp declare reduction(+: real_v: \
170 omp_out = omp_out + omp_in)
171 #pragma omp declare reduction(+: complex_v: \
172 omp_out = omp_out + omp_in)
173
174 }
175}
176
177#endif
#define stl_fallback(x)
real_v abs(const real_v &v)
real_v operator||(const real_v &lhs, const real_v &rhs)
real_v sqrt(const real_v &v)
real_v atan2(const real_v &y, const real_v &x)
real_v gather(const double *base_addr, const real_v &offsets)
real_v fmadd(const real_v &a, const real_v &b, const real_v &c)
real_v tan(const real_v &v)
real_v operator-(const real_v &lhs, const real_v &rhs)
real_v cos(const real_v &v)
__mmask8 operator>(const real_v &lhs, const real_v &rhs)
real_v operator|(const real_v &lhs, const real_v &rhs)
real_v fmod(const real_v &a, const real_v &b)
real_v log(const real_v &arg)
real_v operator-=(real_v &lhs, const real_v &rhs)
__m512i double_to_int(const real_v &x)
real_v select(const __mmask8 &mask, const real_v &a, const real_v &b)
real_v operator!(const real_v &x)
void frexp(const real_v &value, real_v &mant, real_v &exponent)
std::complex< real_v > complex_v
__mmask8 operator<(const real_v &lhs, const real_v &rhs)
real_v operator^(const real_v &lhs, const real_v &rhs)
real_v operator/(const real_v &lhs, const real_v &rhs)
real_v sign(const real_v &v)
std::ostream & operator<<(std::ostream &os, const real_v &obj)
real_v sin(const real_v &v)
real_v operator&&(const real_v &lhs, const real_v &rhs)
real_v operator&(const real_v &lhs, const real_v &rhs)
real_v remainder(const real_v &a, const real_v &b)
__mmask8 operator==(const real_v &lhs, const real_v &rhs)
real_v operator*=(real_v &lhs, const real_v &rhs)
real_v operator/=(real_v &lhs, const real_v &rhs)
real_v operator+=(real_v &lhs, const real_v &rhs)
real_v operator+(const real_v &lhs, const real_v &rhs)
real_v norm(const complex_v &v)
real_v exp(const real_v &v)
real_v operator*(const real_v &lhs, const real_v &rhs)
AVX::real_v real_v
Definition utils.h:46
real_v(const double &f)
std::array< double, 8 > to_array() const
void store(double *ptr) const
double at(const unsigned i) const
real_v(const double *f)
real_v(const double &x0, const double &x1, const double &x2, const double &x3, const double &x4, const double &x5, const double &x6, const double &x7)
static constexpr unsigned size