19 : m_operand(operand) {
22 "operand " << operand <<
" must be less than modulus " << modulus);
24 "Unsupported BitShift " << bit_shift);
28 if (bit_shift == 64) {
31 }
else if (bit_shift == 52) {
32 op_hi = operand >> 12;
33 op_lo = operand << 52;
39 inline uint64_t
Operand()
const {
return m_operand; }
43 uint64_t m_barrett_factor;
58 uint64_t
MultiplyMod(uint64_t x, uint64_t y, uint64_t y_precon,
63 uint64_t
AddUIntMod(uint64_t x, uint64_t y, uint64_t modulus);
67 uint64_t
SubUIntMod(uint64_t x, uint64_t y, uint64_t modulus);
70 uint64_t
PowMod(uint64_t base, uint64_t exp, uint64_t modulus);
86 template <
int BitShift>
88 uint64_t y_barrett_factor,
92 "y_operand " << y_operand <<
" must be less than modulus " << modulus);
95 "Modulus " << modulus <<
" exceeds bound " <<
MaximumValue(BitShift));
97 "Operand " << x <<
" exceeds bound " <<
MaximumValue(BitShift));
99 uint64_t Q = MultiplyUInt64Hi<BitShift>(x, y_barrett_factor);
100 return y_operand * x - Q * modulus;
104 template <
int BitShift>
107 "Unsupported BitShift " << BitShift);
109 "Operand " << x <<
" exceeds bound " <<
MaximumValue(BitShift));
111 "y " << y <<
" must be less than modulus " << modulus);
114 "Modulus " << modulus <<
" exceeds bound " <<
MaximumValue(BitShift));
117 if (BitShift == 64) {
120 }
else if (BitShift == 52) {
125 return MultiplyUIntModLazy<BitShift>(x, y, y_barrett, modulus);
133 inline unsigned char AddUInt64(uint64_t operand1, uint64_t operand2,
135 *result = operand1 + operand2;
136 return static_cast<unsigned char>(*result < operand1);
149 std::vector<uint64_t>
GeneratePrimes(
size_t num_primes,
size_t bit_size,
150 size_t ntt_size = 1);
154 uint64_t
BarrettReduce64(uint64_t input, uint64_t modulus, uint64_t q_barr);
156 template <
int InputModFactor>
158 const uint64_t* twice_modulus =
nullptr,
159 const uint64_t* four_times_modulus =
nullptr) {
160 UTILS_CHECK(InputModFactor == 1 || InputModFactor == 2 ||
161 InputModFactor == 4 || InputModFactor == 8,
162 "InputModFactor should be 1, 2, 4, or 8");
163 if (InputModFactor == 1) {
166 if (InputModFactor == 2) {
172 if (InputModFactor == 4) {
174 "twice_modulus should not be nullptr");
175 if (x >= *twice_modulus) {
183 if (InputModFactor == 8) {
185 "twice_modulus should not be nullptr");
187 "four_times_modulus should not be nullptr");
189 if (x >= *four_times_modulus) {
190 x -= *four_times_modulus;
192 if (x >= *twice_modulus) {
217 template <
class Adaptee,
class... Args>
246 NTT(uint64_t degree, uint64_t q,
247 std::shared_ptr<AllocatorBase> alloc_ptr = {});
249 template <
class Allocator,
class... AllocatorArgs>
250 NTT(uint64_t degree, uint64_t q, Allocator&& a, AllocatorArgs&&... args)
255 std::move(a), std::forward<AllocatorArgs>(args)...))) {}
268 NTT(uint64_t degree, uint64_t q, uint64_t root_of_unity,
269 std::shared_ptr<AllocatorBase> alloc_ptr = {});
271 template <
class Allocator,
class... AllocatorArgs>
272 NTT(uint64_t degree, uint64_t q, uint64_t root_of_unity, Allocator&& a,
273 AllocatorArgs&&... args)
274 :
NTT(degree, q, root_of_unity,
278 std::move(a), std::forward<AllocatorArgs>(args)...))) {}
288 uint64_t input_mod_factor, uint64_t output_mod_factor);
298 uint64_t input_mod_factor, uint64_t output_mod_factor);
303 std::shared_ptr<NTTImpl>
m_impl;
314 NTTImpl(uint64_t degree, uint64_t q, uint64_t root_of_unity,
315 std::shared_ptr<AllocatorBase> alloc_ptr = {});
316 NTTImpl(uint64_t degree, uint64_t q,
317 std::shared_ptr<AllocatorBase> alloc_ptr = {});
328 return m_precon64_root_of_unity_powers;
336 return m_precon52_root_of_unity_powers;
350 return m_root_of_unity_powers;
359 return m_precon64_inv_root_of_unity_powers;
369 return m_precon52_inv_root_of_unity_powers;
377 return m_inv_root_of_unity_powers;
389 uint64_t input_mod_factor, uint64_t output_mod_factor);
392 uint64_t input_mod_factor, uint64_t output_mod_factor);
415 void ComputeRootOfUnityPowers();
419 uint64_t m_degree_bits;
425 std::shared_ptr<AllocatorBase> alloc;
444 const uint64_t* root_of_unity_powers,
445 const uint64_t* precon_root_of_unity_powers,
446 uint64_t input_mod_factor = 1,
447 uint64_t output_mod_factor = 1);
457 uint64_t* operand, uint64_t n, uint64_t modulus,
458 const uint64_t* root_of_unity_powers);
461 uint64_t* operand, uint64_t n, uint64_t modulus,
462 const uint64_t* inv_root_of_unity_powers,
463 const uint64_t* precon_inv_root_of_unity_powers,
464 uint64_t input_mod_factor = 1, uint64_t output_mod_factor = 1);
void InverseTransformFromBitReverse64(uint64_t *operand, uint64_t n, uint64_t modulus, const uint64_t *inv_root_of_unity_powers, const uint64_t *precon_inv_root_of_unity_powers, uint64_t input_mod_factor, uint64_t output_mod_factor)
Definition: ntt.cpp:580
Definition: utils-test.hpp:202
uint64_t * GetInvRootOfUnityPowersPtr()
Definition: ntt.hpp:380
AlignedVector64< uint64_t > & GetPrecon52RootOfUnityPowers()
Definition: ntt.hpp:335
uint64_t InverseUIntMod(uint64_t input, uint64_t modulus)
Definition: ntt.cpp:14
uint64_t PowMod(uint64_t base, uint64_t exp, uint64_t modulus)
Definition: ntt.cpp:84
uint64_t DivideUInt128UInt64Lo(uint64_t x1, uint64_t x0, uint64_t y)
Definition: utils-test.hpp:108
void ComputeForward(uint64_t *result, const uint64_t *operand, uint64_t input_mod_factor, uint64_t output_mod_factor)
Compute forward NTT. Results are bit-reversed.
Definition: ntt.cpp:443
NTT(uint64_t degree, uint64_t q, Allocator &&a, AllocatorArgs &&...args)
Definition: ntt.hpp:250
std::vector< uint64_t > GeneratePrimes(size_t num_primes, size_t bit_size, size_t ntt_size)
Definition: ntt.cpp:221
AlignedVector64< uint64_t > & GetInvRootOfUnityPowers()
Definition: ntt.hpp:376
static const size_t s_max_modulus_bits
Definition: ntt.hpp:397
uint64_t GetMinimalRootOfUnity() const
Definition: ntt.hpp:321
uint64_t MinimalPrimitiveRoot(uint64_t degree, uint64_t modulus)
Definition: ntt.cpp:137
bool IsPrime(uint64_t n)
Definition: ntt.cpp:173
static const size_t s_max_inv_ifma_modulus
Definition: ntt.hpp:412
uint64_t * GetPrecon52InvRootOfUnityPowersPtr()
Definition: ntt.hpp:372
uint64_t SubUIntMod(uint64_t x, uint64_t y, uint64_t modulus)
Definition: ntt.cpp:76
uint64_t GetRootOfUnityPower(size_t i)
Definition: ntt.hpp:354
uint64_t ReverseBitsUInt(uint64_t x, uint64_t bit_width)
Definition: ntt.cpp:160
Performs negacyclic forward and inverse number-theoretic transform (NTT), commonly used in RLWE crypt...
Definition: ntt.hpp:215
uint64_t GetDegree() const
Definition: ntt.hpp:323
uint64_t GetModulus() const
Definition: ntt.hpp:325
AllocatorAdapter(Adaptee &&_a, Args &&...args)
std::vector< T, AlignedAllocator< T, 64 > > AlignedVector64
Definition: utils-test.hpp:352
~NTT()
Destructs the NTT object.
uint64_t MultiplyUIntMod(uint64_t x, uint64_t y, uint64_t modulus)
Definition: ntt.cpp:52
uint64_t MultiplyUIntModLazy(uint64_t x, uint64_t y_operand, uint64_t y_barrett_factor, uint64_t modulus)
Definition: ntt.hpp:87
uint64_t MultiplyMod(uint64_t x, uint64_t y, uint64_t y_precon, uint64_t modulus)
Definition: ntt.cpp:62
static const size_t s_max_fwd_ifma_modulus
Definition: ntt.hpp:408
uint64_t BarrettFactor() const
Definition: ntt.hpp:38
AlignedVector64< uint64_t > & GetRootOfUnityPowers()
Definition: ntt.hpp:349
static const size_t s_ifma_shift_bits
Definition: ntt.hpp:404
static const size_t s_max_degree_bits
Definition: ntt.hpp:394
bool CheckNTTArguments(uint64_t degree, uint64_t modulus)
Definition: ntt.cpp:661
uint64_t * GetPrecon52RootOfUnityPowersPtr()
Definition: ntt.hpp:339
void deallocate_impl(void *p, size_t n)
Definition: utils-test.hpp:195
NTT(uint64_t degree, uint64_t q, uint64_t root_of_unity, Allocator &&a, AllocatorArgs &&...args)
Definition: ntt.hpp:272
#define UTILS_CHECK(cond, expr)
Definition: utils-test.hpp:81
void ComputeInverse(uint64_t *result, const uint64_t *operand, uint64_t input_mod_factor, uint64_t output_mod_factor)
Definition: ntt.cpp:408
AlignedVector64< uint64_t > & GetPrecon52InvRootOfUnityPowers()
Definition: ntt.hpp:368
uint64_t AddUIntMod(uint64_t x, uint64_t y, uint64_t modulus)
Definition: ntt.cpp:69
uint64_t GeneratePrimitiveRoot(uint64_t degree, uint64_t modulus)
Definition: ntt.cpp:111
NTT()
Initializes an empty NTT object.
uint64_t * GetPrecon64RootOfUnityPowersPtr()
Definition: ntt.hpp:331
bool IsPrimitiveRoot(uint64_t root, uint64_t degree, uint64_t modulus)
Definition: ntt.cpp:99
uint64_t MaximumValue(uint64_t bits)
Definition: utils-test.hpp:182
AlignedVector64< uint64_t > & GetPrecon64InvRootOfUnityPowers()
Definition: ntt.hpp:358
void ForwardTransformToBitReverse64(uint64_t *operand, uint64_t n, uint64_t modulus, const uint64_t *root_of_unity_powers, const uint64_t *precon_root_of_unity_powers, uint64_t input_mod_factor, uint64_t output_mod_factor)
Definition: ntt.cpp:474
uint64_t GetInvRootOfUnityPower(size_t i)
Definition: ntt.hpp:384
std::shared_ptr< NTTImpl > m_impl
Class implementing the NTT.
Definition: ntt.hpp:300
unsigned char AddUInt64(uint64_t operand1, uint64_t operand2, uint64_t *result)
Definition: ntt.hpp:133
uint64_t * GetPrecon64InvRootOfUnityPowersPtr()
Definition: ntt.hpp:362
static const size_t s_default_shift_bits
Definition: ntt.hpp:400
MultiplyFactor(uint64_t operand, uint64_t bit_shift, uint64_t modulus)
Definition: ntt.hpp:18
void ComputeInverse(uint64_t *result, const uint64_t *operand, uint64_t input_mod_factor, uint64_t output_mod_factor)
Definition: ntt.cpp:458
NTTImpl(uint64_t degree, uint64_t q, uint64_t root_of_unity, std::shared_ptr< AllocatorBase > alloc_ptr={})
Definition: ntt.cpp:258
uint64_t ReduceMod(uint64_t x, uint64_t modulus, const uint64_t *twice_modulus=nullptr, const uint64_t *four_times_modulus=nullptr)
Definition: ntt.hpp:157
void * allocate_impl(size_t bytes_count)
void ComputeForward(uint64_t *result, const uint64_t *operand, uint64_t input_mod_factor, uint64_t output_mod_factor)
Definition: ntt.cpp:386
uint64_t * GetRootOfUnityPowersPtr()
Definition: ntt.hpp:343
void ReferenceForwardTransformToBitReverse(uint64_t *operand, uint64_t n, uint64_t modulus, const uint64_t *root_of_unity_powers)
Reference NTT which is written for clarity rather than performance.
Definition: ntt.cpp:550
AlignedVector64< uint64_t > & GetPrecon64RootOfUnityPowers()
Definition: ntt.hpp:327
uint64_t Operand() const
Definition: ntt.hpp:39
uint64_t BarrettReduce64(uint64_t input, uint64_t modulus, uint64_t q_barr)
Definition: ntt.cpp:45