Computação quadrada de bignum rápido

Para acelerar minhas divisões bignum eu preciso acelerar a operação y = x^2 para bigints que são representados como arrays dynamics de DWORDs não assinados. Para ser claro:

 DWORD x[n+1] = { LSW, ......, MSW }; 
  • onde n + 1 é o número de DWORDs usados
  • então valor do número x = x[0]+x[1]<<32 + ... x[N]<<32*(n)

A questão é: Como calcular y = x^2 mais rápido possível sem perda de precisão? – Usando C ++ e com aritmética inteira (32 bits com Carry) à disposição.

Minha abordagem atual é aplicar multiplicação y = x*x e evitar multiplicações múltiplas.

Por exemplo:

 x = x[0] + x[1]<<32 + ... x[n]<<32*(n) 

Para simplificar, deixe-me reescrevê-lo:

 x = x0+ x1 + x2 + ... + xn 

onde index representa o endereço dentro da matriz, então:

 y = x*x y = (x0 + x1 + x2 + ...xn)*(x0 + x1 + x2 + ...xn) y = x0*(x0 + x1 + x2 + ...xn) + x1*(x0 + x1 + x2 + ...xn) + x2*(x0 + x1 + x2 + ...xn) + ...xn*(x0 + x1 + x2 + ...xn) y0 = x0*x0 y1 = x1*x0 + x0*x1 y2 = x2*x0 + x1*x1 + x0*x2 y3 = x3*x0 + x2*x1 + x1*x2 ... y(2n-3) = xn(n-2)*x(n ) + x(n-1)*x(n-1) + x(n )*x(n-2) y(2n-2) = xn(n-1)*x(n ) + x(n )*x(n-1) y(2n-1) = xn(n )*x(n ) 

Depois de um olhar mais atento, fica claro que quase todos os xi*xj aparecem duas vezes (não o primeiro e o último), o que significa que as multiplicações N*N podem ser substituídas por multiplicações (N+1)*(N/2) . PS 32bit*32bit = 64bit então o resultado de cada operação mul+add é tratado como 64+1 bit .

Existe uma maneira melhor de calcular isso rapidamente? Tudo que eu encontrei durante buscas foram algoritmos sqrts, não sqr …

Sqr rápido

!!! Cuidado que todos os números no meu código são MSW primeiro, … não como no teste acima (há LSW primeiro para simplicidade de equações, caso contrário, seria uma bagunça de índice).

Implementação funcional atual do fsqr

 void arbnum::sqr(const arbnum &x) { // O((N+1)*N/2) arbnum c; DWORD h, l; int N, nx, nc, i, i0, i1, k; c._alloc(x.siz + x.siz + 1); nx = x.siz - 1; nc = c.siz - 1; N = nx + nx; for (i=0; i<=nc; i++) c.dat[i]=0; for (i=1; i<N; i++) for (i0=0; (i0<=nx) && (i0= i1) break; if (i1 > nx) continue; h = x.dat[nx-i0]; if (!h) continue; l = x.dat[nx-i1]; if (!l) continue; alu.mul(h, l, h, l); k = nc - i; if (k >= 0) alu.add(c.dat[k], c.dat[k], l); k--; if (k>=0) alu.adc(c.dat[k], c.dat[k],h); k--; for (; (alu.cy) && (k>=0); k--) alu.inc(c.dat[k]); } c.shl(1); for (i = 0; i >1; h = x.dat[nx-i0]; if (!h) continue; alu.mul(h, l, h, h); k = nc - i; if (k >= 0) alu.add(c.dat[k], c.dat[k],l); k--; if (k>=0) alu.adc(c.dat[k], c.dat[k], h); k--; for (; (alu.cy) && (k >= 0); k--) alu.inc(c.dat[k]); } c.bits = c.siz<<5; c.exp = x.exp + x.exp + ((c.siz - x.siz - x.siz)<<5) + 1; c.sig = sig; *this = c; } 

Uso da multiplicação de Karatsuba

(graças a Calpis)

Eu implementei a multiplicação de Karatsuba, mas os resultados são maciçamente mais lentos até do que pelo uso da multiplicação simples de O(N^2) , provavelmente por causa daquela recursion horrível que não vejo como evitar. O trade-off deve ser em números muito grandes (maiores que centenas de dígitos) … mas mesmo assim há muitas transferências de memory. Existe uma maneira de evitar chamadas de recursion (variante não recursiva, … Quase todos os algoritmos recursivos podem ser feitos dessa maneira). Ainda assim, vou tentar ajustar as coisas e ver o que acontece (evitar normalizações, etc …, também pode ser algum erro bobo no código). De qualquer forma, depois de resolver Karatsuba para o caso x*x não há muito ganho de desempenho.

Multiplicação otimizada de Karatsuba

Teste de desempenho para y = x^2 looped 1000x times, 0.9 < x < 1 ~ 32*98 bits :

 x = 0.98765588997654321000000009876... | 98*32 bits sqr [ 213.989 ms ] ... O((N+1)*N/2) fast sqr mul1[ 363.472 ms ] ... O(N^2) classic multiplication mul2[ 349.384 ms ] ... O(3*(N^log2(3))) optimized Karatsuba multiplication mul3[ 9345.127 ms] ... O(3*(N^log2(3))) unoptimized Karatsuba multiplication x = 0.98765588997654321000... | 195*32 bits sqr [ 883.01 ms ] mul1[ 1427.02 ms ] mul2[ 1089.84 ms ] x = 0.98765588997654321000... | 389*32 bits sqr [ 3189.19 ms ] mul1[ 5553.23 ms ] mul2[ 3159.07 ms ] 

Após as otimizações para a Karatsuba, o código é massivamente mais rápido do que antes. Ainda assim, para números menores, é ligeiramente menor que a metade da velocidade da minha multiplicação O(N^2) . Para números maiores, é mais rápido com a proporção dada pelas complexidades das multiplicações de Booth. O limite para multiplicação é em torno de 32 * 98 bits e para sqr em torno de 32 * 389 bits, portanto, se a sum dos bits de input ultrapassar esse limite, a multiplicação de Karatsuba será usada para acelerar a multiplicação e também para sqr.

BTW, otimizações incluídas:

  • Minimizar o lixo da pilha por um argumento de recursion muito grande
  • Evitar qualquer aritmética de bignum (+, -) 32-bit ALU com carry é usado no lugar.
  • Ignorando 0*y ou x*0 ou 0*0 casos
  • Reformatar os tamanhos dos números x,y para a potência de dois para evitar a realocação
  • Implemente multiplicação de módulo para z1 = (x0 + x1)*(y0 + y1) para minimizar a recursion

Multiplicação de Schönhage-Strassen modificada para implementação de sqr

Eu testei o uso de transformações FFT e NTT para acelerar a computação do sqr. Os resultados são estes:

  1. FFT

    Perder precisão e, portanto, precisa de números complexos de alta precisão. Isso realmente diminui consideravelmente as coisas, portanto, não há aceleração. O resultado não é preciso (pode ser erroneamente arredondado), então a FFT é inutilizável (por enquanto)

  2. NTT

    O NTT é um DFT de campo finito e, portanto, nenhuma perda de precisão ocorre. Ele precisa de aritmética modular em inteiros sem sinal: modpow, modmul, modadd e modsub .

    Eu uso DWORD (32bit números inteiros sem sinal). O tamanho do vetor de input / saída da NTT é limitado devido a problemas de estouro !!! Para aritmética modular de 32 bits, N é limitado a (2^32)/(max(input[])^2) portanto bigint deve ser dividido em partes menores (eu uso BYTES para que o tamanho máximo do bigint processado seja

     (2^32)/((2^8)^2) = 2^16 bytes = 2^14 DWORDs = 16384 DWORDs) 

    O sqr usa apenas 1xNTT + 1xINTT vez de 2xNTT + 1xINTT para multiplicação, mas o uso da NTT é muito lento e o tamanho do número de limite é muito grande para uso prático na minha implementação (para mul e também para sqr ).

    É possível que seja até mesmo acima do limite de estouro, então aritmética modular de 64 bits deve ser usada, o que pode retardar ainda mais as coisas. Então NTT é para os meus propósitos também inutilizável também.

Algumas medições:

 a = 0.98765588997654321000 | 389*32 bits looped 1x times sqr1[ 3.177 ms ] fast sqr sqr2[ 720.419 ms ] NTT sqr mul1[ 5.588 ms ] simpe mul mul2[ 3.172 ms ] karatsuba mul mul3[ 1053.382 ms ] NTT mul 

Minha implementação:

 void arbnum::sqr_NTT(const arbnum &x) { // O(N*log(N)*(log(log(N)))) - 1x NTT // Schönhage-Strassen sqr // To prevent NTT overflow: n  result siz  x.siz + y.siz <= 12K!!! int i, j, k, n; int s = x.sig*x.sig, exp0 = x.exp + x.exp - ((x.siz+x.siz)<<5) + 2; i = x.siz; for (n = 1; n < i; n< 0x3000) { _error(_arbnum_error_TooBigNumber); zero(); return; } n <<= 3; DWORD *xx, *yy, q, qq; xx = new DWORD[n+n]; #ifdef _mmap_h if (xx) mmap_new(xx, (n+n) <= 0; i--) { q = x.dat[i]; xx[k] = q&0xFF; k++; q>>=8; xx[k] = q&0xFF; k++; q>>=8; xx[k] = q&0xFF; k++; q>>=8; xx[k] = q&0xFF; k++; } for (;k<n;k++) xx[k] = 0; //NTT fourier_NTT ntt; ntt.NTT(yy,xx,n); // init NTT for n // Convolution for (i=0; i<n; i++) yy[i] = modmul(yy[i], yy[i], ntt.p); //INTT ntt.INTT(xx, yy); //suma q=0; for (i = 0, j = 0; i>=8; qq>>=8; q+=qq; } // Merge WORDs to DWORDs and copy them to result _alloc(n>>2); for (i = 0, j = 0; i<siz; i++) { q =(yy[j]<<24)&0xFF000000; j++; q |=(yy[j]<<16)&0x00FF0000; j++; q |=(yy[j]<< 8)&0x0000FF00; j++; q |=(yy[j] )&0x000000FF; j++; dat[i] = q; } #ifdef _mmap_h if (xx) mmap_del(xx); #endif delete xx; bits = siz<<5; sig = s; exp = exp0 + (siz<<5) - 1; // _normalize(); } 

Conclusão

Para números menores, é a melhor opção minha abordagem rápida de sqr , e após a multiplicação de limiar de Karatsuba é melhor. Mas ainda acho que deveria haver algo trivial que negligenciamos. Tem outras ideias?

Otimização de NTT

Após otimizações extremamente intensas (principalmente NTT ): Perguntas sobre o Stack Overflow Aritmética modular e otimizações de NTT (campo finito DFT) .

Alguns valores foram alterados:

 a = 0.98765588997654321000 | 1553*32bits looped 10x times mul2[ 28.585 ms ] Karatsuba mul mul3[ 26.311 ms ] NTT mul 

Então, agora a multiplicação da NTT é finalmente mais rápida que a da Karatsuba após um limite de 1500 * 32 bits.

Algumas medições e erros detectados

 a = 0.99991970486 | 1553*32 bits looped: 10x sqr1[ 58.656 ms ] fast sqr sqr2[ 13.447 ms ] NTT sqr mul1[ 102.563 ms ] simpe mul mul2[ 28.916 ms ] Karatsuba mul Error mul3[ 19.470 ms ] NTT mul 

Descobri que meu Karatsuba (acima / abaixo) flui o LSB de cada segmento DWORD de bignum. Quando eu tiver pesquisado, atualizarei o código …

Além disso, após novas otimizações de NTT, os limites foram alterados. Assim, para NTR sqr , são 310*32 bits = 9920 bits de operando e, para NTT mul , são 1396*32 bits = 44672 bits de resultado (sum de bits de operandos).

Código Karatsuba reparado graças a @greybeard

 //--------------------------------------------------------------------------- void arbnum::_mul_karatsuba(DWORD *z, DWORD *x, DWORD *y, int n) { // Recursion for Karatsuba // z[2n] = x[n]*y[n]; // n=2^m int i; for (i=0; i<n; i++) if (x[i]) { i=-1; break; } // x==0 ? if (i < 0) for (i = 0; i= 0) { for (i = 0; i < n + n; i++) z[i]=0; return; } // 0.? = 0 if (n == 1) { alu.mul(z[0], z[1], x[0], y[0]); return; } if (n>1; _mul_karatsuba(z+n, x+n2, y+n2, n2); // z0 = x0.y0 _mul_karatsuba(z , x , y , n2); // z2 = x1.y1 DWORD *q = new DWORD[n<=0; i--) alu.adc(qq[i], q0[i], q1[i]); } // qq = q0 + q1 ...[i..0] #define _sub { alu.sub(qq[i], q0[i], q1[i]); for (i--; i>=0; i--) alu.sbc(qq[i], q0[i], q1[i]); } // qq = q0 - q1 ...[i..0] qq = q; q0 = x + n2; q1 = x; i = n2 - 1; _add; cx = alu.cy; // =x0+x1 qq = q + n2; q0 = y + n2; q1 = y; i = n2 - 1; _add; cy = alu.cy; // =y0+y1 _mul_karatsuba(q + n, q + n2, q, n2); // =(x0+x1)(y0+y1) mod ((2^N)-1) if (cx) { qq = q + n; q0 = qq; q1 = q + n2; i = n2 - 1; _add; cx = alu.cy; }// += cx*(y0 + y1) << n2 if (cy) { qq = q + n; q0 = qq; q1 = q; i = n2 -1; _add; cy = alu.cy; }// +=cy*(x0+x1)<=0; i--) if (alu.cy) alu.inc(z[i]); else break; } delete[] q; #undef _add #undef _sub } //--------------------------------------------------------------------------- void arbnum::mul_karatsuba(const arbnum &x, const arbnum &y) { // O(3*(N)^log2(3)) ~ O(3*(N^1.585)) // Karatsuba multiplication // int s = x.sig*y.sig; arbnum a, b; a = x; b = y; a.sig = +1; b.sig = +1; int i, n; for (n = 1; (n < a.siz) || (n < b.siz); n <<= 1) ; a._realloc(n); b._realloc(n); _alloc(n + n); for (i=0; i < siz; i++) dat[i]=0; _mul_karatsuba(dat, a.dat, b.dat, n); bits = siz << 5; sig = s; exp = a.exp + b.exp + ((siz-a.siz-b.siz)<<5) + 1; // _normalize(); } //--------------------------------------------------------------------------- 

Minha representação numérica do arbnum :

 // dat is MSDW first ... LSDW last DWORD *dat; int siz,exp,sig,bits; 
  • dat[siz] é o mantisa. LSDW significa DWORD menos significativo.
  • exp é o expoente do MSB de dat[0]
  • O primeiro bit não nulo está presente na mantissa !!!

     // |-----|---------------------------|---------------|------| // | sig | MSB mantisa LSB | exponent | bits | // |-----|---------------------------|---------------|------| // | +1 | 0.(0 ... 0) | 2^0 | 0 | +zero // | -1 | 0.(0 ... 0) | 2^0 | 0 | -zero // |-----|---------------------------|---------------|------| // | +1 | 1.(dat[0] ... dat[siz-1]) | 2^exp | n | +number // | -1 | 1.(dat[0] ... dat[siz-1]) | 2^exp | n | -number // |-----|---------------------------|---------------|------| // | +1 | 1.0 | 2^+0x7FFFFFFE | 1 | +infinity // | -1 | 1.0 | 2^+0x7FFFFFFE | 1 | -infinity // |-----|---------------------------|---------------|------| 

   

Se eu entendi seu algoritmo corretamente, parece O(n^2) onde n é o número de dígitos.

Você já olhou para o Algoritmo Karatsuba ? Acelera a multiplicação usando a abordagem de dividir e conquistar. Pode valer a pena dar uma olhada.

Se você estiver procurando escrever um novo expoente melhor, talvez seja necessário escrevê-lo na assembly. Este é o código do golang.

https://code.google.com/p/go/source/browse/src/pkg/math/exp_amd64.s