-rw-r--r-- 2900 libmceliece-20240513/crypto_kem/348864/avx/xor_mat_vec256.c raw
// 20240508 djb: split out of encrypt.c; optimize
// linker define xor_mat_vec256
#include "xor_mat_vec256.h"
// input: M, size rows*rowbytes
// input: v, size ceil(rowbytes/32)
// input: s, size ceil(rows/8)
// output: s, size ceil(rows/8)
// operation: s_i ^= dotproduct(M_i,v)
// where s_i is ith bit of s
// and M_i is ith row of M
// and dotproduct is dot product mod 2
void xor_mat_vec256(unsigned char *s, const unsigned char *M, int64_t rows, int64_t rowbytes, const vec256 *v)
{
int64_t i, j, rowwords1;
vec256 vj, t_0, t_1, t_2, t_3, u_0, u_1;
vec128 x128;
uint64_t x64;
const unsigned char *Mi0, *Mi1, *Mi2, *Mi3;
if (rows <= 0) return;
if (rowbytes <= 0) return;
rowwords1 = (rowbytes-1)/32;
i = 0;
Mi0 = M;
Mi1 = Mi0 + rowbytes;
Mi2 = Mi1 + rowbytes;
Mi3 = Mi2 + rowbytes;
if (rows <= 1) Mi1 = Mi0;
if (rows <= 2) Mi2 = Mi0;
if (rows <= 3) Mi3 = Mi0;
mainloop:
// handle rows i, i+1, i+2, i+3 where i is a multiple of 4
// pointers to the M rows are Mi0, Mi1, Mi2, Mi3
// if i+1 or i+2 or i+3 is >=rows: Mi1 or Mi2 or Mi3 is actually Mi0
t_0 = vec256_setzero();
t_1 = vec256_setzero();
t_2 = vec256_setzero();
t_3 = vec256_setzero();
for (j = 0; j < rowwords1; j++) {
vj = v[j];
t_0 ^= vec256_load(Mi0+32*j) & vj;
t_1 ^= vec256_load(Mi1+32*j) & vj;
t_2 ^= vec256_load(Mi2+32*j) & vj;
t_3 ^= vec256_load(Mi3+32*j) & vj;
}
vj = v[j];
t_0 ^= vec256_load(Mi0+rowbytes-32) & vj;
t_1 ^= vec256_load(Mi1+rowbytes-32) & vj;
t_2 ^= vec256_load(Mi2+rowbytes-32) & vj;
t_3 ^= vec256_load(Mi3+rowbytes-32) & vj;
// want: horizontal bit sums of t_0, t_1, t_2, t_3
u_0 = vec256_unpack_low(t_0,t_2) ^ vec256_unpack_high(t_0,t_2);
u_1 = vec256_unpack_low(t_1,t_3) ^ vec256_unpack_high(t_1,t_3);
// want: horizontal bit sums of u_0 bot, u_1 bot, u_0 top, u_1 top
u_0 ^= vec256_8x_shr(u_0,1);
u_1 ^= vec256_8x_shl(u_1,1);
u_0 &= vec256_set1_32b(0x55555555);
u_1 &= vec256_set1_32b(0xaaaaaaaa);
u_0 |= u_1;
// want: sums of u_0 bot even bits, u_0 bot odd bits, u_0 top even bits, u_0 top odd bits
u_0 ^= vec256_2x_swap64(u_0);
u_0 ^= vec256_4x_shr(u_0,32);
u_0 ^= vec256_8x_shr(u_0,16);
u_0 ^= vec256_8x_shr(u_0,8);
u_0 ^= vec256_8x_shr(u_0,4);
u_0 ^= vec256_8x_shr(u_0,2);
u_0 &= vec256_set1_32b(3);
x128 = vec256_extract2x(u_0,1);
x128 = vec128_4x_shl(x128,2) | vec256_extractbot(u_0);
x64 = vec128_extract(x128,0);
s[ i/8 ] ^= (x64 & 15) << (i%8);
i += 4;
Mi0 += 4*rowbytes;
Mi1 += 4*rowbytes;
Mi2 += 4*rowbytes;
Mi3 += 4*rowbytes;
if (i + 4 <= rows) // normal case
goto mainloop;
if (i < rows) {
// some trailing rows
// at most 3 rows out of many, so prioritize code size
if (rows <= i+1) Mi1 = Mi0;
if (rows <= i+2) Mi2 = Mi0;
if (rows <= i+3) Mi3 = Mi0;
goto mainloop;
}
if (i > rows) {
// clean up after trailing rows
i -= 4;
s[ i/8 ] ^= (x64 & 15 & (15 << (rows - i))) << (i%8);
}
}