Вычитание целочисленного результата сравнения с SSE

Я изучаю кодирование SSE и теперь застрял на пару строк кода при переносе / векторизации exp(double) функция.

Я тоже начинающий с C / C ++, поэтому эти две строки сбивают с толку, и я не знаю, что там происходит:

fast_exp(double x) {
double  a = c1 * x + 0.5;
int     n = (int)a;  // truncate toward zero
ieee754 u;  // union of double and unsigned short[4]

n -= (a < 0); //(1)
...

u.s[3] = (unsigned short)((n << 4) & 0x7FF0); //(2)
return  u.d;
}

Что там делается?

(1) a вычитается из n если а имеет отрицательное значение?

(2)? сдвиг влево по битам, но что еще там происходит?

И как эти две строки кода пишутся с использованием встроенных функций SSE?


Вот то, что я до сих пор частично портировал на SSE. Линии, на которых я застрял, помечены ?? в коде ниже. Он не компилируется (пока).

typedef union {
double d;
unsigned short s[4];
} ieee754;

double exp_sse (double value){

double px; //, a;
ieee754 u;
__m128i n;
__m128d a;
__m128d  x = _mm_set1_pd (x);
__m128d c1 = _mm_set1_pd (1.4426950408889634073599);
__m128d c2 = _mm_set1_pd (6.93145751953125E-1);
__m128d c3 = _mm_set1_pd (1.42860682030941723212E-6);
__m128i c1023 = _mm_set1_epi32(1023);
__m128d c4 = _mm_set1_pd (0.5);

/* n = round(x / log 2) --------------------- */
// a = c1 * x + 0.5;
a = _mm_mul_pd (c1, x);
a = _mm_add_pd (a, c4);

// n = (int)a;
n = _mm_cvtpd_epi32(a);

n -= (a < 0); ??

/* x -= n * log2 ---------------------- */
//px = (double)n;
px = _mm_cvtepi32_pd(n);
//x -= px * c2;
x = _mm_sub_pd(x, _mm_mul_pd(px, c2));
//x -= px * c3;
x = _mm_sub_pd(x, _mm_mul_pd(px, c3));

// calc e^x -------------------
...
// ----------------------------

/* 2^n in double. */
n = _mm_add_(n, c1023);

u.s[3] = (unsigned short)((n << 4) & 0x7FF0); ??

return d[0] * u.d;
}

0

Решение

Задача ещё не решена.

Другие решения

Других решений пока нет …

По вопросам рекламы ammmcru@yandex.ru
Adblock
detector