-rw-r--r-- 9016 libmceliece-20240726/crypto_kem/460896/avx/pk_gen.c raw
/*
This file is for public-key generation
*/
// 20240508 djb: switch to crypto_sort_int32
// 20240508 djb: switch to crypto_sort_int64
// 20230105 tony: use merge exchange in sort_rows(); fewer lines for minmax_rows()
// 20221231 djb: more 0 initialization to clarify data flow; tnx thom wiggers
// 20221230 djb: add linker lines
// 20221230 djb: move nBlocks_I to macro
// linker define pk_gen
// linker use fft vec256_inv vec256_mul_asm
#include "pk_gen.h"
#include "controlbits.h"
#include "crypto_sort_int64.h"
#include "crypto_sort_int32.h"
#include "params.h"
#include "util.h"
#include "fft.h"
#include "crypto_declassify.h"
#include "crypto_uint64.h"
static crypto_uint64 uint64_is_equal_declassify(uint64_t t,uint64_t u)
{
crypto_uint64 mask = crypto_uint64_equal_mask(t,u);
crypto_declassify(&mask,sizeof mask);
return mask;
}
static crypto_uint64 uint64_is_zero_declassify(uint64_t t)
{
crypto_uint64 mask = crypto_uint64_zero_mask(t);
crypto_declassify(&mask,sizeof mask);
return mask;
}
#include <stdint.h>
#define nBlocks_I ((PK_NROWS + 255) / 256)
#define par_width 7
/* set m to 11...1 if the i-th bit of x is 0 and the i-th bit of y is 1 */
/* set m to 00...0 otherwise */
static inline void extract_01_masks(uint16_t *m, uint64_t *x, uint64_t *y, int i)
{
*m = (((~x[ i>>6 ]) & y[ i>>6 ]) >> (i&63)) & 1;
*m = -(*m);
}
/* return a 128-bit vector of which each bits is set to the i-th bit of v */
static inline vec256 extract_mask256(uint64_t v[], int i)
{
uint32_t mask;
mask = (v[ i>>6 ] >> (i&63)) & 1;
mask = -mask;
return vec256_set1_32b(mask);
}
// swap x and y if m = 11...1
static inline void uint16_cswap(uint16_t *x, uint16_t *y, uint16_t m)
{
uint16_t d;
d = *x ^ *y;
d &= m;
*x ^= d;
*y ^= d;
}
// swap x and y if m = 11...1
static inline void vec256_cswap(vec256 *x, vec256 *y, vec256 m)
{
vec256 d;
d = *x ^ *y;
d &= m;
*x ^= d;
*y ^= d;
}
/* swap x[i0] and x[i1] if x[i1] > x[i0] */
/* swap mat[i0] and mat[i1] if x[i1] > x[i0] */
static inline void minmax_rows(uint16_t *x, vec256 (*mat)[par_width], int i0, int i1)
{
int i;
uint16_t m;
vec256 mm;
m = x[i1] - x[i0]; m >>= 15; m = -m;
mm = vec256_set1_16b(m);
uint16_cswap(&x[i0], &x[i1], m);
for (i = 0; i < par_width; i++)
vec256_cswap(&mat[i0][i], &mat[i1][i], mm);
}
static void sort_rows(int n, uint16_t *x, vec256 (*mat)[par_width])
{
int t = 1;
while ((1 << t)*2 < n) t++;
for (int j = t-1; j >= 0; j--)
{
int p = 1 << j, q = 1 << (t-1), r = 0, d = p;
while (1)
{
for (int i = 0; i < n-d; i++)
if ((i & p) == r)
minmax_rows(x, mat, i, i+d);
if (q != p) { d = q - p; q = q / 2; r = p; }
else break;
}
}
}
/* extract numbers represented in bitsliced form */
static void de_bitslicing(uint64_t * out, const vec256 in[][GFBITS])
{
int i, j, r;
uint64_t u = 0;
for (i = 0; i < (1 << GFBITS); i++)
out[i] = 0 ;
for (i = 0; i < 32; i++)
for (j = GFBITS-1; j >= 0; j--) {
u = vec256_extract(in[i][j], 0);
for (r = 0; r < 64; r++)
{
out[i*256 + 0*64 + r] <<= 1;
out[i*256 + 0*64 + r] |= (u >> r) & 1;
}
u = vec256_extract(in[i][j], 1);
for (r = 0; r < 64; r++)
{
out[i*256 + 1*64 + r] <<= 1;
out[i*256 + 1*64 + r] |= (u >> r) & 1;
}
u = vec256_extract(in[i][j], 2);
for (r = 0; r < 64; r++)
{
out[i*256 + 2*64 + r] <<= 1;
out[i*256 + 2*64 + r] |= (u >> r) & 1;
}
u = vec256_extract(in[i][j], 3);
for (r = 0; r < 64; r++)
{
out[i*256 + 3*64 + r] <<= 1;
out[i*256 + 3*64 + r] |= (u >> r) & 1;
}
}
}
/* convert numbers into bitsliced form */
static void to_bitslicing_2x(vec256 out0[][GFBITS], vec256 out1[][GFBITS], const uint64_t * in)
{
int i, j, k, r;
uint64_t u[2][4];
for (j = 0;j < 2;++j) for (k = 0;k < 4;++k) u[j][k] = 0;
for (i = 0; i < 32; i++)
for (j = GFBITS-1; j >= 0; j--)
{
for (k = 0; k < 4; k++)
for (r = 63; r >= 0; r--)
{
u[0][k] <<= 1;
u[0][k] |= (in[i*256 + k*64 + r] >> (GFBITS-1-j)) & 1;
u[1][k] <<= 1;
u[1][k] |= (in[i*256 + k*64 + r] >> (j + GFBITS)) & 1;
}
out0[i][j] = vec256_set4x(u[0][0], u[0][1], u[0][2], u[0][3]);
out1[i][j] = vec256_set4x(u[1][0], u[1][1], u[1][2], u[1][3]);
}
}
/* y[pi[i]] = x[i] */
/* requires pi to be a permutation */
static void composeinv(int n, uint16_t y[n], uint16_t x[n], uint16_t pi[n])
{
int i;
int32_t t[n];
for (i = 0;i < n;++i)
{
t[i] = pi[i];
t[i] <<= 16;
t[i] |= x[i];
}
crypto_sort_int32(t,n);
for (i = 0;i < n;++i)
y[i] = t[i] & 0xFFFF;
}
/* input: irr, an irreducible polynomial */
/* perm, a permutation represented as an array of 32-bit numbers */
/* pi, same permutation represented as an array of 16-bit numbers */
/* output: pk, the public key*/
/* return: 0 if pk is successfully generated, -1 otherwise */
int pk_gen(unsigned char * pk, const unsigned char * irr, uint32_t * perm, int16_t * pi)
{
int i, j, k, b;
int row, c;
union
{
uint64_t w[ PK_NROWS ][ nBlocks_I*4 ];
vec256 v[ PK_NROWS ][ nBlocks_I ];
} mat;
union
{
uint64_t w[ PK_NROWS ][ par_width*4 ];
vec256 v[ PK_NROWS ][ par_width ];
} par;
uint64_t one = 1;
uint64_t t;
uint16_t m;
vec256 mm;
vec128 sk_int[ GFBITS ];
vec256 consts[ 32 ][ GFBITS ];
vec256 eval[ 32 ][ GFBITS ];
vec256 prod[ 32 ][ GFBITS ];
vec256 tmp[ GFBITS ];
uint64_t list[1 << GFBITS];
uint16_t ind[ PK_NROWS ];
uint16_t ind_inv[ PK_NROWS ];
// compute the inverses
irr_load(sk_int, irr);
fft(eval, sk_int);
vec256_copy(prod[0], eval[0]);
for (i = 1; i < 32; i++)
vec256_mul(prod[i], prod[i-1], eval[i]);
vec256_inv(tmp, prod[31]);
for (i = 30; i >= 0; i--)
{
vec256_mul(prod[i+1], prod[i], tmp);
vec256_mul(tmp, tmp, eval[i+1]);
}
vec256_copy(prod[0], tmp);
// fill matrix
de_bitslicing(list, prod);
for (i = 0; i < (1 << GFBITS); i++)
{
list[i] <<= GFBITS;
list[i] |= i;
list[i] |= ((uint64_t) perm[i]) << 31;
}
crypto_sort_int64(list, 1 << GFBITS);
for (i = 1; i < (1 << GFBITS); i++)
if (uint64_is_equal_declassify(list[i-1] >> 31,list[i] >> 31))
return -1;
to_bitslicing_2x(consts, prod, list);
for (i = 0; i < (1 << GFBITS); i++)
pi[i] = list[i] & GFMASK;
for (j = 0; j < nBlocks_I; j++)
for (k = 0; k < GFBITS; k++)
mat.v[ k ][ j ] = prod[ j ][ k ];
for (i = 1; i < SYS_T; i++)
for (j = 0; j < nBlocks_I; j++)
{
vec256_mul(prod[j], prod[j], consts[j]);
for (k = 0; k < GFBITS; k++)
mat.v[ i*GFBITS + k ][ j ] = prod[ j ][ k ];
}
// gaussian elimination to obtain L, U, and P such that LP M = U
// L and U are stored in the space of M
// P is stored in ind
for (i = 0; i < PK_NROWS; i++)
ind_inv[i] = ind[ i ] = i;
for (row = 0; row < PK_NROWS; row++)
{
i = row >> 6;
j = row & 63;
for (k = row + 1; k < PK_NROWS; k++)
{
extract_01_masks(&m, mat.w[ row ], mat.w[ k ], row);
uint16_cswap(&ind[row], &ind[k], m);
mm = vec256_set1_16b(m);
for (c = 0; c < nBlocks_I; c++)
vec256_cswap(&mat.v[ row ][ c ], &mat.v[ k ][ c ], mm);
}
if ( uint64_is_zero_declassify((mat.w[ row ][ i ] >> j) & 1) ) // return if not systematic
{
return -1;
}
for (k = row+1; k < PK_NROWS; k++)
{
t = mat.w[ k ][ i ] & (one << j);
mm = extract_mask256(mat.w[k], row);
for (c = 0; c < nBlocks_I; c++)
mat.v[ k ][ c ] ^= mat.v[ row ][ c ] & mm;
mat.w[ k ][ i ] |= t;
}
}
// apply M^-1 to the remaining columns
composeinv(PK_NROWS, ind_inv, ind_inv, ind);
for (j = 4; j < 18; j += par_width)
{
for (k = 0; k < GFBITS; k++)
for (b = 0; b < par_width; b++)
par.v[ k ][ b ] = prod[ j+b ][ k ];
for (i = 1; i < SYS_T; i++)
{
for (b = 0; b < par_width; b++)
vec256_mul(prod[j+b], prod[j+b], consts[j+b]);
for (k = 0; k < GFBITS; k++)
for (b = 0; b < par_width; b++)
par.v[ i*GFBITS + k ][ b ] = prod[ j+b ][ k ];
}
// apply P
for (i = 0; i < PK_NROWS; i++)
ind[i] = ind_inv[i];
sort_rows(PK_NROWS, ind, par.v);
// apply L
for (row = PK_NROWS-1; row >= 0; row--)
for (i = row-1; i >= 0; i--)
{
mm = extract_mask256(mat.w[row], i);
for (k = 0; k < par_width; k++)
par.v[ row ][ k ] ^= par.v[ i ][ k ] & mm;
}
// apply U^-1
if (j == 4)
{
for (i = 0; i < PK_NROWS; i++)
par.v[i][0] = mat.v[i][nBlocks_I-1];
}
for (row = PK_NROWS-1; row >= 0; row--)
for (i = PK_NROWS-1; i > row; i--)
{
mm = extract_mask256(mat.w[row], i);
for (k = 0; k < par_width; k++)
par.v[ row ][ k ] ^= par.v[ i ][ k ] & mm;
}
if (j == 4)
{
for (row = 0; row < PK_NROWS; row++)
{
k = 3;
store_i(pk + PK_ROW_BYTES * row + j*32 + k*8 + 4 - 13*12, par.w[row][k] >> 32, 4);
for (k = 4; k < 28; k++)
store8(pk + PK_ROW_BYTES * row + j*32 + k*8 - 13*12, par.w[row][k]);
}
}
else
{
for (row = 0; row < PK_NROWS; row++)
for (k = 0; k < 28; k++)
store8(pk + PK_ROW_BYTES * row + j*32 + k*8 - 13*12, par.w[row][k]);
}
}
//
return 0;
}