31 : m_operand(operand) {
32 HEXL_CHECK(operand <= modulus,
"operand " << operand
33 <<
" must be less than modulus " 35 HEXL_CHECK(bit_shift == 32 || bit_shift == 52 || bit_shift == 64,
36 "Unsupported BitShift " << bit_shift);
37 uint64_t op_hi = operand >> (64 - bit_shift);
38 uint64_t op_lo = (bit_shift == 64) ? 0 : (operand << bit_shift);
40 m_barrett_factor = DivideUInt128UInt64Lo(op_hi, op_lo, modulus);
47 inline uint64_t
Operand()
const {
return m_operand; }
51 uint64_t m_barrett_factor;
55 inline bool IsPowerOfTwo(uint64_t num) {
return num && !(num & (num - 1)); }
58 inline uint64_t
Log2(uint64_t x) {
return MSB(x); }
66 HEXL_CHECK(bits <= 64,
"MaximumValue requires bits <= 64; got " << bits);
68 return (std::numeric_limits<uint64_t>::max)();
70 return (1ULL << bits) - 1;
77 uint64_t
ReverseBits(uint64_t x, uint64_t bit_width);
81 uint64_t
InverseMod(uint64_t x, uint64_t modulus);
85 uint64_t
MultiplyMod(uint64_t x, uint64_t y, uint64_t modulus);
92 uint64_t
MultiplyMod(uint64_t x, uint64_t y, uint64_t y_precon,
97 uint64_t
AddUIntMod(uint64_t x, uint64_t y, uint64_t modulus);
101 uint64_t
SubUIntMod(uint64_t x, uint64_t y, uint64_t modulus);
104 uint64_t
PowMod(uint64_t base, uint64_t exp, uint64_t modulus);
110 bool IsPrimitiveRoot(uint64_t root, uint64_t degree, uint64_t modulus);
128 template <
int BitShift>
130 uint64_t y_barrett_factor, uint64_t modulus) {
131 HEXL_CHECK(y_operand < modulus,
"y_operand " << y_operand
132 <<
" must be less than modulus " 136 "Modulus " << modulus <<
" exceeds bound " <<
MaximumValue(BitShift));
138 "Operand " << x <<
" exceeds bound " <<
MaximumValue(BitShift));
140 uint64_t Q = MultiplyUInt64Hi<BitShift>(x, y_barrett_factor);
141 return y_operand * x - Q * modulus;
149 template <
int BitShift>
152 "Unsupported BitShift " << BitShift);
154 "Operand " << x <<
" exceeds bound " <<
MaximumValue(BitShift));
156 "y " << y <<
" must be less than modulus " << modulus);
159 "Modulus " << modulus <<
" exceeds bound " <<
MaximumValue(BitShift));
161 uint64_t y_barrett =
MultiplyFactor(y, BitShift, modulus).BarrettFactor();
162 return MultiplyModLazy<BitShift>(x, y, y_barrett, modulus);
170 inline unsigned char AddUInt64(uint64_t operand1, uint64_t operand2,
172 *result = operand1 + operand2;
173 return static_cast<unsigned char>(*result < operand1);
188 std::vector<uint64_t>
GeneratePrimes(
size_t num_primes,
size_t bit_size,
189 bool prefer_small_primes,
190 size_t ntt_size = 1);
196 uint64_t
BarrettReduce64(uint64_t input, uint64_t modulus, uint64_t q_barr);
205 template <
int InputModFactor>
207 const uint64_t* twice_modulus =
nullptr,
208 const uint64_t* four_times_modulus =
nullptr) {
209 HEXL_CHECK(InputModFactor == 1 || InputModFactor == 2 ||
210 InputModFactor == 4 || InputModFactor == 8,
211 "InputModFactor should be 1, 2, 4, or 8");
212 if (InputModFactor == 1) {
215 if (InputModFactor == 2) {
221 if (InputModFactor == 4) {
222 HEXL_CHECK(twice_modulus !=
nullptr,
"twice_modulus should not be nullptr");
223 if (x >= *twice_modulus) {
231 if (InputModFactor == 8) {
232 HEXL_CHECK(twice_modulus !=
nullptr,
"twice_modulus should not be nullptr");
234 "four_times_modulus should not be nullptr");
236 if (x >= *four_times_modulus) {
237 x -= *four_times_modulus;
239 if (x >= *twice_modulus) {
uint64_t ReduceMod(uint64_t x, uint64_t modulus, const uint64_t *twice_modulus=nullptr, const uint64_t *four_times_modulus=nullptr)
Returns x mod modulus, assuming x < InputModFactor * modulus.
Definition: number-theory.hpp:206
uint64_t GeneratePrimitiveRoot(uint64_t degree, uint64_t modulus)
Tries to return a primitive degree-th root of unity.
uint64_t MultiplyMod(uint64_t x, uint64_t y, uint64_t modulus)
Returns (x * y) mod modulus.
uint64_t MaximumValue(uint64_t bits)
Returns the maximum value that can be represented using bits bits.
Definition: number-theory.hpp:65
std::vector< uint64_t > GeneratePrimes(size_t num_primes, size_t bit_size, bool prefer_small_primes, size_t ntt_size=1)
Generates a list of num_primes primes in the range [2^(bit_size),.
uint64_t MultiplyModLazy(uint64_t x, uint64_t y_operand, uint64_t y_barrett_factor, uint64_t modulus)
Computes (x * y) mod modulus, except that the output is in [0, 2 * modulus].
Definition: number-theory.hpp:129
MultiplyFactor(uint64_t operand, uint64_t bit_shift, uint64_t modulus)
Computes and stores the Barrett factor floor((operand << bit_shift) / modulus). This is useful when m...
Definition: number-theory.hpp:30
uint64_t SubUIntMod(uint64_t x, uint64_t y, uint64_t modulus)
Returns (x - y) mod modulus.
unsigned char AddUInt64(uint64_t operand1, uint64_t operand2, uint64_t *result)
Adds two unsigned 64-bit integers.
Definition: number-theory.hpp:170
uint64_t BarrettFactor() const
Returns the pre-computed Barrett factor.
Definition: number-theory.hpp:44
bool IsPrime(uint64_t n)
Returns whether or not the input is prime.
bool IsPowerOfTwo(uint64_t num)
Returns whether or not num is a power of two.
Definition: number-theory.hpp:55
#define HEXL_CHECK(cond, expr)
Definition: check.hpp:39
Pre-computes a Barrett factor with which modular multiplication can be performed more efficiently...
Definition: number-theory.hpp:20
Definition: eltwise-add-mod.hpp:8
uint64_t AddUIntMod(uint64_t x, uint64_t y, uint64_t modulus)
Returns (x + y) mod modulus.
uint64_t ReverseBits(uint64_t x, uint64_t bit_width)
Reverses the bits.
bool IsPowerOfFour(uint64_t num)
Definition: number-theory.hpp:60
uint64_t InverseMod(uint64_t x, uint64_t modulus)
Returns x^{-1} mod modulus.
bool IsPrimitiveRoot(uint64_t root, uint64_t degree, uint64_t modulus)
Returns whether or not root is a degree-th root of unity mod modulus.
uint64_t MinimalPrimitiveRoot(uint64_t degree, uint64_t modulus)
Returns whether or not root is a degree-th root of unity.
uint64_t PowMod(uint64_t base, uint64_t exp, uint64_t modulus)
Returns base^exp mod modulus.
uint64_t Operand() const
Returns the operand corresponding to the Barrett factor.
Definition: number-theory.hpp:47
uint64_t BarrettReduce64(uint64_t input, uint64_t modulus, uint64_t q_barr)
Returns input mod modulus, computed via 64-bit Barrett reduction.
uint64_t Log2(uint64_t x)
Returns floor(log2(x))
Definition: number-theory.hpp:58