Быстрое вычисление квадрата

Чтобы ускорить мои бигнум деления мне нужно ускорить операцию y = x^2 для bigints, которые представлены как динамические массивы беззнаковых DWORD. Чтобы было ясно:

DWORD x[n+1] = { LSW, ......, MSW };
  • где n + 1 — количество используемых DWORD
  • так что значение числа x = x[0]+x[1]<<32 + ... x[N]<<32*(n)

Вопрос в том: Как мне вычислить y = x^2 максимально быстро без потери точности?
— С помощью C ++ и с целочисленной арифметикой (32 бита с Carry) в распоряжении.

Мой текущий подход заключается в применении умножения y = x*x и избежать многократного умножения.

Например:

x = x[0] + x[1]<<32 + ... x[n]<<32*(n)

Для простоты позвольте мне переписать это:

x = x0+ x1 + x2 + ... + xn

где индекс представляет адрес внутри массива, поэтому:

y = x*x
y = (x0 + x1 + x2 + ...xn)*(x0 + x1 + x2 + ...xn)
y = x0*(x0 + x1 + x2 + ...xn) + x1*(x0 + x1 + x2 + ...xn) + x2*(x0 + x1 + x2 + ...xn) + ...xn*(x0 + x1 + x2 + ...xn)

y0     = x0*x0
y1     = x1*x0 + x0*x1
y2     = x2*x0 + x1*x1 + x0*x2
y3     = x3*x0 + x2*x1 + x1*x2
...
y(2n-3) = xn(n-2)*x(n  ) + x(n-1)*x(n-1) + x(n  )*x(n-2)
y(2n-2) = xn(n-1)*x(n  ) + x(n  )*x(n-1)
y(2n-1) = xn(n  )*x(n  )

При ближайшем рассмотрении становится ясно, что почти все xi*xj появляется дважды (не первый и не последний), что означает, что N*N умножения могут быть заменены (N+1)*(N/2) умножения. Постскриптум 32bit*32bit = 64bit так что результат каждого mul+add операция обрабатывается как 64+1 bit,

Есть ли лучший способ вычислить это быстро? Все, что я нашел во время поисков, было алгоритмами sqrts, а не sqr …

Fast sqr

!!! Помните, что все числа в моем коде — сначала MSW, а не как в тесте, приведенном выше (сначала LSW для простоты уравнений, иначе это был бы беспорядок в индексе).

Текущая функциональная реализация fsqr

void arbnum::sqr(const arbnum &x)
{
// O((N+1)*N/2)
arbnum c;
DWORD h, l;
int N, nx, nc, i, i0, i1, k;
c._alloc(x.siz + x.siz + 1);
nx = x.siz - 1;
nc = c.siz - 1;
N = nx + nx;
for (i=0; i<=nc; i++)
c.dat[i]=0;
for (i=1; i<N; i++)
for (i0=0; (i0<=nx) && (i0<=i); i0++)
{
i1 = i - i0;
if (i0 >= i1)
break;
if (i1 > nx)
continue;
h = x.dat[nx-i0];
if (!h)
continue;
l = x.dat[nx-i1];
if (!l)
continue;
alu.mul(h, l, h, l);
k = nc - i;
if (k >= 0)
alu.add(c.dat[k], c.dat[k], l);
k--;
if (k>=0)
alu.adc(c.dat[k], c.dat[k],h);
k--;
for (; (alu.cy) && (k>=0); k--)
alu.inc(c.dat[k]);
}
c.shl(1);
for (i = 0; i <= N; i += 2)
{
i0 = i>>1;
h = x.dat[nx-i0];
if (!h)
continue;
alu.mul(h, l, h, h);
k = nc - i;
if (k >= 0)
alu.add(c.dat[k], c.dat[k],l);
k--;
if (k>=0)
alu.adc(c.dat[k], c.dat[k], h);
k--;
for (; (alu.cy) && (k >= 0); k--)
alu.inc(c.dat[k]);
}
c.bits = c.siz<<5;
c.exp = x.exp + x.exp + ((c.siz - x.siz - x.siz)<<5) + 1;
c.sig = sig;
*this = c;
}

Использование умножения Карацубы

(спасибо Калпису)

Я реализовал умножение Карацубы, но результаты значительно медленнее, чем при использовании простых O(N^2) умножение, вероятно, из-за той ужасной рекурсии, которую я не вижу способа избежать. Его компромисс должен быть в действительно больших числах (больше, чем сотни цифр) … но даже в этом случае происходит много передач памяти. Есть ли способ избежать рекурсивных вызовов (нерекурсивный вариант, … Почти все рекурсивные алгоритмы могут быть выполнены таким образом). Тем не менее, я постараюсь изменить ситуацию и посмотреть, что произойдет (избегайте нормализаций и т. Д., Также это может быть глупой ошибкой в ​​коде). Во всяком случае, после решения Карацуба на случай x*x прирост производительности невелик.

Оптимизировано умножение Карацубы

Тест производительности для y = x^2 looped 1000x times, 0.9 < x < 1 ~ 32*98 bits:

x = 0.98765588997654321000000009876... | 98*32 bits
sqr [ 213.989 ms ] ... O((N+1)*N/2) fast sqr
mul1[ 363.472 ms ] ... O(N^2) classic multiplication
mul2[ 349.384 ms ] ... O(3*(N^log2(3))) optimized Karatsuba multiplication
mul3[ 9345.127 ms] ... O(3*(N^log2(3))) unoptimized Karatsuba multiplication

x = 0.98765588997654321000... | 195*32 bits
sqr [ 883.01 ms ]
mul1[ 1427.02 ms ]
mul2[ 1089.84 ms ]

x = 0.98765588997654321000... | 389*32 bits
sqr [ 3189.19 ms ]
mul1[ 5553.23 ms ]
mul2[ 3159.07 ms ]

После оптимизации для Karatsuba код стал значительно быстрее, чем раньше. Тем не менее, для меньших чисел это чуть меньше половины скорости моего O(N^2) умножение. Для больших чисел это быстрее с коэффициентом, определяемым сложностями умножения Бута. Пороговое значение для умножения составляет около 32 * 98 битов, а для sqr — около 32 * 389 битов, поэтому, если сумма входных битов пересекает этот порог, то умножение Карацубы будет использовано для ускорения умножения, что также будет аналогичным для sqr.

Кстати, оптимизации включены:

  • Минимизируйте кучи, используя слишком большой аргумент рекурсии
  • Вместо этого используется предотвращение любых 32-битных ALU с арифметикой bignum (+, -).
  • игнорирование 0*y или же x*0 или же 0*0 случаи
  • Переформатирование ввода x,y количество чисел к степени два, чтобы избежать перераспределения
  • Реализовать умножение по модулю для z1 = (x0 + x1)*(y0 + y1) минимизировать рекурсию

Модифицированное умножение Шёнхаге-Штрассена для реализации sqr

Я проверил использование FFT а также NTT преобразовывает, чтобы ускорить вычисление sqr. Результаты таковы:

  1. FFT

    Потерять точность и, следовательно, нужны высокоточные комплексные числа. Это на самом деле значительно замедляет процесс, поэтому ускорение отсутствует. Результат не является точным (может быть ошибочно округлен), поэтому FFT непригоден (пока)

  2. NTT

    NTT конечное поле ДПФ и поэтому не происходит потеря точности. Нужна модульная арифметика на целых числах без знака: modpow, modmul, modadd а также modsub,

    я использую DWORD (32-разрядные целые числа без знака). NTT Размер вектора ввода / вывода ограничен из-за проблем переполнения !!! Для 32-битной модульной арифметики N ограничено (2^32)/(max(input[])^2) так bigint должен быть разделен на более мелкие куски (я использую BYTES поэтому максимальный размер bigint обработано

    (2^32)/((2^8)^2) = 2^16 bytes = 2^14 DWORDs = 16384 DWORDs)
    

    sqr использует только 1xNTT + 1xINTT вместо 2xNTT + 1xINTT для умножения, но NTT использование слишком медленное и размер порогового числа слишком велик для практического использования в моей реализации (для mul а также для sqr).

    Возможно даже превышение предела переполнения, поэтому следует использовать 64-битную модульную арифметику, которая может еще больше замедлить работу. Так NTT для моих целей тоже непригодна.

Некоторые измерения:

a = 0.98765588997654321000 | 389*32 bits
looped 1x times
sqr1[ 3.177 ms ] fast sqr
sqr2[ 720.419 ms ] NTT sqr
mul1[ 5.588 ms ] simpe mul
mul2[ 3.172 ms ] karatsuba mul
mul3[ 1053.382 ms ] NTT mul

Моя реализация:

void arbnum::sqr_NTT(const arbnum &x)
{
// O(N*log(N)*(log(log(N)))) - 1x NTT
// Schönhage-Strassen sqr
// To prevent NTT overflow: n <= 48K * 8 bit -> result siz <= 12K * 32 bit -> x.siz + y.siz <= 12K!!!
int i, j, k, n;
int s = x.sig*x.sig, exp0 = x.exp + x.exp - ((x.siz+x.siz)<<5) + 2;
i = x.siz;
for (n = 1; n < i; n<<=1)
;
if (n + n > 0x3000) {
_error(_arbnum_error_TooBigNumber);
zero();
return;
}
n <<= 3;
DWORD *xx, *yy, q, qq;
xx = new DWORD[n+n];
#ifdef _mmap_h
if (xx)
mmap_new(xx, (n+n) << 2);
#endif
if (xx==NULL) {
_error(_arbnum_error_NotEnoughMemory);
zero();
return;
}
yy = xx + n;

// Zero padding (and split DWORDs to BYTEs)
for (i--, k=0; i >= 0; i--)
{
q = x.dat[i];
xx[k] = q&0xFF; k++; q>>=8;
xx[k] = q&0xFF; k++; q>>=8;
xx[k] = q&0xFF; k++; q>>=8;
xx[k] = q&0xFF; k++;
}
for (;k<n;k++)
xx[k] = 0;

//NTT
fourier_NTT ntt;

ntt.NTT(yy,xx,n);    // init NTT for n

// Convolution
for (i=0; i<n; i++)
yy[i] = modmul(yy[i], yy[i], ntt.p);

//INTT
ntt.INTT(xx, yy);

//suma
q=0;
for (i = 0, j = 0; i<n; i++) {
qq = xx[i];
q += qq&0xFF;
yy[n-i-1] = q&0xFF;
q>>=8;
qq>>=8;
q+=qq;
}

// Merge WORDs to DWORDs and copy them to result
_alloc(n>>2);
for (i = 0, j = 0; i<siz; i++)
{
q  =(yy[j]<<24)&0xFF000000; j++;
q |=(yy[j]<<16)&0x00FF0000; j++;
q |=(yy[j]<< 8)&0x0000FF00; j++;
q |=(yy[j]    )&0x000000FF; j++;
dat[i] = q;
}

#ifdef _mmap_h
if (xx)
mmap_del(xx);
#endif
delete xx;
bits = siz<<5;
sig = s;
exp = exp0 + (siz<<5) - 1;
// _normalize();
}

Заключение

Для меньших номеров это лучший вариант мой быстрый sqr подход, а после
порог Карацуба умножение лучше. Но я все еще думаю, что должно быть что-то тривиальное, что мы упустили из виду. У кого-нибудь есть другие идеи?

NTT оптимизация

После чрезвычайно интенсивных оптимизаций (в основном NTT): Вопрос переполнения стека Модульная арифметика и NTT (конечное поле DFT) оптимизации.

Некоторые значения изменились:

a = 0.98765588997654321000 | 1553*32bits
looped 10x times
mul2[ 28.585 ms ] Karatsuba mul
mul3[ 26.311 ms ] NTT mul

А сейчас NTT умножение, наконец, быстрее, чем Карацуба после 1500 * 32-битного порога.

Некоторые измерения и ошибка обнаружены

a = 0.99991970486 | 1553*32 bits
looped: 10x
sqr1[  58.656 ms ] fast sqr
sqr2[  13.447 ms ] NTT sqr
mul1[ 102.563 ms ] simpe mul
mul2[  28.916 ms ] Karatsuba mul Error
mul3[  19.470 ms ] NTT mul

Я узнал, что мой Карацуба (больше / меньше) течет LSB каждого DWORD сегмент бигнум. Когда я исследую, я обновлю код …

Кроме того, после NTT оптимизации пороги изменились, поэтому для NTT sqr это 310*32 bits = 9920 bits из операнд, и для NTT mul это 1396*32 bits = 44672 bits из результат (сумма битов операндов).

Код Карацубы исправлен благодаря @greybeard

//---------------------------------------------------------------------------
void arbnum::_mul_karatsuba(DWORD *z, DWORD *x, DWORD *y, int n)
{
// Recursion for Karatsuba
// z[2n] = x[n]*y[n];
// n=2^m
int i;
for (i=0; i<n; i++)
if (x[i]) {
i=-1;
break;
} // x==0 ?

if (i < 0)
for (i = 0; i<n; i++)
if (y[i]) {
i = -1;
break;
} // y==0 ?

if (i >= 0) {
for (i = 0; i < n + n; i++)
z[i]=0;
return;
} // 0.? = 0

if (n == 1) {
alu.mul(z[0], z[1], x[0], y[0]);
return;
}

if (n< 1)
return;
int n2 = n>>1;
_mul_karatsuba(z+n, x+n2, y+n2, n2);                         // z0 = x0.y0
_mul_karatsuba(z  , x   , y   , n2);                         // z2 = x1.y1
DWORD *q = new DWORD[n<<1], *q0, *q1, *qq;
BYTE cx,cy;
if (q == NULL) {
_error(_arbnum_error_NotEnoughMemory);
return;
}
#define _add { alu.add(qq[i], q0[i], q1[i]); for (i--; i>=0; i--) alu.adc(qq[i], q0[i], q1[i]); } // qq = q0 + q1 ...[i..0]
#define _sub { alu.sub(qq[i], q0[i], q1[i]); for (i--; i>=0; i--) alu.sbc(qq[i], q0[i], q1[i]); } // qq = q0 - q1 ...[i..0]
qq = q;
q0 = x + n2;
q1 = x;
i = n2 - 1;
_add;
cx = alu.cy; // =x0+x1

qq = q + n2;
q0 = y + n2;
q1 = y;
i = n2 - 1;
_add;
cy = alu.cy; // =y0+y1

_mul_karatsuba(q + n, q + n2, q, n2);                       // =(x0+x1)(y0+y1) mod ((2^N)-1)

if (cx) {
qq = q + n;
q0 = qq;
q1 = q + n2;
i = n2 - 1;
_add;
cx = alu.cy;
}// += cx*(y0 + y1) << n2

if (cy) {
qq = q + n;
q0 = qq;
q1 = q;
i = n2 -1;
_add;
cy = alu.cy;
}// +=cy*(x0+x1)<<n2

qq = q + n;  q0 = qq; q1 = z + n; i = n - 1; _sub;  // -=z0
qq = q + n;  q0 = qq; q1 = z;     i = n - 1; _sub;  // -=z2
qq = z + n2; q0 = qq; q1 = q + n; i = n - 1; _add;  // z1=(x0+x1)(y0+y1)-z0-z2

DWORD ccc=0;

if (alu.cy)
ccc++;    // Handle carry from last operation
if (cx || cy)
ccc++;    // Handle carry from before last operation
if (ccc)
{
i = n2 - 1;
alu.add(z[i], z[i], ccc);
for (i--; i>=0; i--)
if (alu.cy)
alu.inc(z[i]);
else
break;
}

delete[] q;
#undef _add
#undef _sub
}

//---------------------------------------------------------------------------
void arbnum::mul_karatsuba(const arbnum &x, const arbnum &y)
{
// O(3*(N)^log2(3)) ~ O(3*(N^1.585))
// Karatsuba multiplication
//
int s = x.sig*y.sig;
arbnum a, b;
a = x;
b = y;
a.sig = +1;
b.sig = +1;
int i, n;
for (n = 1; (n < a.siz) || (n < b.siz); n <<= 1)
;
a._realloc(n);
b._realloc(n);
_alloc(n + n);
for (i=0; i < siz; i++)
dat[i]=0;
_mul_karatsuba(dat, a.dat, b.dat, n);
bits = siz << 5;
sig = s;
exp = a.exp + b.exp + ((siz-a.siz-b.siz)<<5) + 1;
//    _normalize();
}
//---------------------------------------------------------------------------

мой arbnum представление чисел:

// dat is MSDW first ... LSDW last
DWORD *dat; int siz,exp,sig,bits;
  • dat[siz] это богомол LSDW означает наименее значимое DWORD.
  • exp является показателем MSB dat[0]
  • Первый ненулевой бит присутствует в мантиссе !!!

    // |-----|---------------------------|---------------|------|
    // | sig | MSB      mantisa      LSB |   exponent    | bits |
    // |-----|---------------------------|---------------|------|
    // | +1  | 0.(0      ...          0) | 2^0           |   0  | +zero
    // | -1  | 0.(0      ...          0) | 2^0           |   0  | -zero
    // |-----|---------------------------|---------------|------|
    // | +1  | 1.(dat[0] ... dat[siz-1]) | 2^exp         |   n  | +number
    // | -1  | 1.(dat[0] ... dat[siz-1]) | 2^exp         |   n  | -number
    // |-----|---------------------------|---------------|------|
    // | +1  | 1.0                       | 2^+0x7FFFFFFE |   1  | +infinity
    // | -1  | 1.0                       | 2^+0x7FFFFFFE |   1  | -infinity
    // |-----|---------------------------|---------------|------|
    

12

Решение

Если я правильно понимаю ваш алгоритм, кажется O(n^2) где n количество цифр

Вы смотрели на Алгоритм Карацубы?
Это ускоряет умножение, используя подход «разделяй и властвуй». Возможно, стоит взглянуть.

2

Другие решения

Если вы хотите написать новый лучший показатель, вам, возможно, придется написать его в сборке. Это код от Голанга.

https://code.google.com/p/go/source/browse/src/pkg/math/exp_amd64.s

0

По вопросам рекламы ammmcru@yandex.ru
Adblock
detector