Intel HEXL for FPGA
Intel Homomorphic Encryption FPGA Acceleration Library, accelerating the modular arithmetic operations used in homomorphic encryption.
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Pages
ntt.hpp
Go to the documentation of this file.
1 // Copyright (C) 2020-2021 Intel Corporation
2 // SPDX-License-Identifier: Apache-2.0
3 
4 #pragma once
5 
6 #include "utils-test.hpp"
7 
8 namespace hetest {
9 namespace utils {
10 
11 // Stores an integer on which modular multiplication can be performed more
12 // efficiently, at the cost of some precomputation.
14 public:
15  MultiplyFactor() = default;
16 
17  // Computes and stores the Barrett factor (operand << bit_shift) / modulus
18  MultiplyFactor(uint64_t operand, uint64_t bit_shift, uint64_t modulus)
19  : m_operand(operand) {
21  operand <= modulus,
22  "operand " << operand << " must be less than modulus " << modulus);
23  UTILS_CHECK(bit_shift == 64 || bit_shift == 52,
24  "Unsupported BitShift " << bit_shift);
25  uint64_t op_hi{0};
26  uint64_t op_lo{0};
27 
28  if (bit_shift == 64) {
29  op_hi = operand;
30  op_lo = 0;
31  } else if (bit_shift == 52) {
32  op_hi = operand >> 12;
33  op_lo = operand << 52;
34  }
35  m_barrett_factor = DivideUInt128UInt64Lo(op_hi, op_lo, modulus);
36  }
37 
38  inline uint64_t BarrettFactor() const { return m_barrett_factor; }
39  inline uint64_t Operand() const { return m_operand; }
40 
41 private:
42  uint64_t m_operand;
43  uint64_t m_barrett_factor;
44 };
45 
46 // Reverses the bits
47 uint64_t ReverseBitsUInt(uint64_t x, uint64_t bits);
48 
49 // Returns a^{-1} mod modulus
50 uint64_t InverseUIntMod(uint64_t a, uint64_t modulus);
51 
54 uint64_t MultiplyUIntMod(uint64_t x, uint64_t y, uint64_t modulus);
55 
56 // Returns (x * y) mod modulus
57 // @param y_precon floor(2**64 / modulus)
58 uint64_t MultiplyMod(uint64_t x, uint64_t y, uint64_t y_precon,
59  uint64_t modulus);
60 
61 // Returns (x + y) mod modulus
62 // Assumes x, y < modulus
63 uint64_t AddUIntMod(uint64_t x, uint64_t y, uint64_t modulus);
64 
65 // Returns (x - y) mod modulus
66 // Assumes x, y < modulus
67 uint64_t SubUIntMod(uint64_t x, uint64_t y, uint64_t modulus);
68 
69 // Returns base^exp mod modulus
70 uint64_t PowMod(uint64_t base, uint64_t exp, uint64_t modulus);
71 
72 // Returns true whether root is a degree-th root of unity
73 // degree must be a power of two.
74 bool IsPrimitiveRoot(uint64_t root, uint64_t degree, uint64_t modulus);
75 
76 // Tries to return a primtiive degree-th root of unity
77 // Returns -1 if no root is found
78 uint64_t GeneratePrimitiveRoot(uint64_t degree, uint64_t modulus);
79 
80 // Returns true whether root is a degree-th root of unity
81 // degree must be a power of two.
82 uint64_t MinimalPrimitiveRoot(uint64_t degree, uint64_t modulus);
83 
84 // Computes (x * y) mod modulus, except that the output is in [0, 2 * modulus]
85 // @param modulus_precon Pre-computed Barrett reduction factor
86 template <int BitShift>
87 inline uint64_t MultiplyUIntModLazy(uint64_t x, uint64_t y_operand,
88  uint64_t y_barrett_factor,
89  uint64_t modulus) {
91  y_operand < modulus,
92  "y_operand " << y_operand << " must be less than modulus " << modulus);
94  modulus <= MaximumValue(BitShift),
95  "Modulus " << modulus << " exceeds bound " << MaximumValue(BitShift));
96  UTILS_CHECK(x <= MaximumValue(BitShift),
97  "Operand " << x << " exceeds bound " << MaximumValue(BitShift));
98 
99  uint64_t Q = MultiplyUInt64Hi<BitShift>(x, y_barrett_factor);
100  return y_operand * x - Q * modulus;
101 }
102 
103 // Computes (x * y) mod modulus, except that the output is in [0, 2 * modulus]
104 template <int BitShift>
105 inline uint64_t MultiplyUIntModLazy(uint64_t x, uint64_t y, uint64_t modulus) {
106  UTILS_CHECK(BitShift == 64 || BitShift == 52,
107  "Unsupported BitShift " << BitShift);
108  UTILS_CHECK(x <= MaximumValue(BitShift),
109  "Operand " << x << " exceeds bound " << MaximumValue(BitShift));
110  UTILS_CHECK(y < modulus,
111  "y " << y << " must be less than modulus " << modulus);
112  UTILS_CHECK(
113  modulus <= MaximumValue(BitShift),
114  "Modulus " << modulus << " exceeds bound " << MaximumValue(BitShift));
115  uint64_t y_hi{0};
116  uint64_t y_lo{0};
117  if (BitShift == 64) {
118  y_hi = y;
119  y_lo = 0;
120  } else if (BitShift == 52) {
121  y_hi = y >> 12;
122  y_lo = y << 52;
123  }
124  uint64_t y_barrett = DivideUInt128UInt64Lo(y_hi, y_lo, modulus);
125  return MultiplyUIntModLazy<BitShift>(x, y, y_barrett, modulus);
126 }
127 
128 // Adds two unsigned 64-bit integers
129 // @param operand1 Number to add
130 // @param operand2 Number to add
131 // @param result Stores the sum
132 // @return The carry bit
133 inline unsigned char AddUInt64(uint64_t operand1, uint64_t operand2,
134  uint64_t* result) {
135  *result = operand1 + operand2;
136  return static_cast<unsigned char>(*result < operand1);
137 }
138 
139 // Returns whether or not the input is prime
140 bool IsPrime(uint64_t n);
141 
142 // Generates a list of num_primes primes in the range [2^(bit_size,
143 // 2^(bit_size+1)]. Ensures each prime q satisfies
144 // q % (2*ntt_size+1)) == 1
145 // @param num_primes Number of primes to generate
146 // @param bit_size Bit size of each prime
147 // @param ntt_size N such that each prime q satisfies q % (2N) == 1. N must be
148 // a power of two
149 std::vector<uint64_t> GeneratePrimes(size_t num_primes, size_t bit_size,
150  size_t ntt_size = 1);
151 
152 // returns input mod modulus, computed via Barrett reduction
153 // @param q_barr floor(2^64 / p)
154 uint64_t BarrettReduce64(uint64_t input, uint64_t modulus, uint64_t q_barr);
155 
156 template <int InputModFactor>
157 uint64_t ReduceMod(uint64_t x, uint64_t modulus,
158  const uint64_t* twice_modulus = nullptr,
159  const uint64_t* four_times_modulus = nullptr) {
160  UTILS_CHECK(InputModFactor == 1 || InputModFactor == 2 ||
161  InputModFactor == 4 || InputModFactor == 8,
162  "InputModFactor should be 1, 2, 4, or 8");
163  if (InputModFactor == 1) {
164  return x;
165  }
166  if (InputModFactor == 2) {
167  if (x >= modulus) {
168  x -= modulus;
169  }
170  return x;
171  }
172  if (InputModFactor == 4) {
173  UTILS_CHECK(twice_modulus != nullptr,
174  "twice_modulus should not be nullptr");
175  if (x >= *twice_modulus) {
176  x -= *twice_modulus;
177  }
178  if (x >= modulus) {
179  x -= modulus;
180  }
181  return x;
182  }
183  if (InputModFactor == 8) {
184  UTILS_CHECK(twice_modulus != nullptr,
185  "twice_modulus should not be nullptr");
186  UTILS_CHECK(four_times_modulus != nullptr,
187  "four_times_modulus should not be nullptr");
188 
189  if (x >= *four_times_modulus) {
190  x -= *four_times_modulus;
191  }
192  if (x >= *twice_modulus) {
193  x -= *twice_modulus;
194  }
195  if (x >= modulus) {
196  x -= modulus;
197  }
198  return x;
199  }
200  UTILS_CHECK(false, "Should be unreachable");
201  return x;
202 }
203 
204 } // namespace utils
205 } // namespace hetest
206 
207 namespace hetest {
208 namespace utils {
209 
215 class NTT {
216 public:
217  template <class Adaptee, class... Args>
219  : public AllocatorInterface<AllocatorAdapter<Adaptee, Args...>> {
220  explicit AllocatorAdapter(Adaptee&& _a, Args&&... args);
221  AllocatorAdapter(const Adaptee& _a, Args&... args);
222 
223  // interface implementation
224  void* allocate_impl(size_t bytes_count);
225  void deallocate_impl(void* p, size_t n);
226 
227  private:
228  Adaptee alloc;
229  };
230 
232  NTT();
233 
235  ~NTT();
236 
246  NTT(uint64_t degree, uint64_t q,
247  std::shared_ptr<AllocatorBase> alloc_ptr = {});
248 
249  template <class Allocator, class... AllocatorArgs>
250  NTT(uint64_t degree, uint64_t q, Allocator&& a, AllocatorArgs&&... args)
251  : NTT(degree, q,
252  std::static_pointer_cast<AllocatorBase>(
253  std::make_shared<
254  AllocatorAdapter<Allocator, AllocatorArgs...>>(
255  std::move(a), std::forward<AllocatorArgs>(args)...))) {}
256 
268  NTT(uint64_t degree, uint64_t q, uint64_t root_of_unity,
269  std::shared_ptr<AllocatorBase> alloc_ptr = {});
270 
271  template <class Allocator, class... AllocatorArgs>
272  NTT(uint64_t degree, uint64_t q, uint64_t root_of_unity, Allocator&& a,
273  AllocatorArgs&&... args)
274  : NTT(degree, q, root_of_unity,
275  std::static_pointer_cast<AllocatorBase>(
276  std::make_shared<
277  AllocatorAdapter<Allocator, AllocatorArgs...>>(
278  std::move(a), std::forward<AllocatorArgs>(args)...))) {}
279 
287  void ComputeForward(uint64_t* result, const uint64_t* operand,
288  uint64_t input_mod_factor, uint64_t output_mod_factor);
289 
297  void ComputeInverse(uint64_t* result, const uint64_t* operand,
298  uint64_t input_mod_factor, uint64_t output_mod_factor);
299 
300  class NTTImpl;
301 
302 public:
303  std::shared_ptr<NTTImpl> m_impl;
304 };
305 
306 } // namespace utils
307 } // namespace hetest
308 
309 namespace hetest {
310 namespace utils {
311 
313 public:
314  NTTImpl(uint64_t degree, uint64_t q, uint64_t root_of_unity,
315  std::shared_ptr<AllocatorBase> alloc_ptr = {});
316  NTTImpl(uint64_t degree, uint64_t q,
317  std::shared_ptr<AllocatorBase> alloc_ptr = {});
318 
319  ~NTTImpl();
320 
321  uint64_t GetMinimalRootOfUnity() const { return m_w; }
322 
323  uint64_t GetDegree() const { return m_degree; }
324 
325  uint64_t GetModulus() const { return m_q; }
326 
328  return m_precon64_root_of_unity_powers;
329  }
330 
332  return GetPrecon64RootOfUnityPowers().data();
333  }
334 
336  return m_precon52_root_of_unity_powers;
337  }
338 
340  return GetPrecon52RootOfUnityPowers().data();
341  }
342 
344  return GetRootOfUnityPowers().data();
345  }
346 
347  // Returns the vector of pre-computed root of unity powers for the modulus
348  // and root of unity.
350  return m_root_of_unity_powers;
351  }
352 
353  // Returns the root of unity at index i.
354  uint64_t GetRootOfUnityPower(size_t i) { return GetRootOfUnityPowers()[i]; }
355 
356  // Returns the vector of 64-bit pre-conditioned pre-computed root of unity
357  // powers for the modulus and root of unity.
359  return m_precon64_inv_root_of_unity_powers;
360  }
361 
363  return GetPrecon64InvRootOfUnityPowers().data();
364  }
365 
366  // Returns the vector of 52-bit pre-conditioned pre-computed root of unity
367  // powers for the modulus and root of unity.
369  return m_precon52_inv_root_of_unity_powers;
370  }
371 
373  return GetPrecon52InvRootOfUnityPowers().data();
374  }
375 
377  return m_inv_root_of_unity_powers;
378  }
379 
381  return GetInvRootOfUnityPowers().data();
382  }
383 
384  uint64_t GetInvRootOfUnityPower(size_t i) {
385  return GetInvRootOfUnityPowers()[i];
386  }
387 
388  void ComputeForward(uint64_t* result, const uint64_t* operand,
389  uint64_t input_mod_factor, uint64_t output_mod_factor);
390 
391  void ComputeInverse(uint64_t* result, const uint64_t* operand,
392  uint64_t input_mod_factor, uint64_t output_mod_factor);
393 
394  static const size_t s_max_degree_bits{20}; // Maximum power of 2 in degree
395 
396  // Maximum number of bits in modulus;
397  static const size_t s_max_modulus_bits{62};
398 
399  // Default bit shift used in Barrett precomputation
400  static const size_t s_default_shift_bits{64};
401 
402  // Bit shift used in Barrett precomputation when IFMA acceleration is
403  // enabled
404  static const size_t s_ifma_shift_bits{52};
405 
406  // Maximum number of bits in modulus to use IFMA acceleration for the
407  // forward transform
408  static const size_t s_max_fwd_ifma_modulus{1ULL << (s_ifma_shift_bits - 2)};
409 
410  // Maximum number of bits in modulus to use IFMA acceleration for the
411  // inverse transform
412  static const size_t s_max_inv_ifma_modulus{1ULL << (s_ifma_shift_bits - 1)};
413 
414 private:
415  void ComputeRootOfUnityPowers();
416  uint64_t m_degree; // N: size of NTT transform, should be power of 2
417  uint64_t m_q; // prime modulus. Must satisfy q == 1 mod 2n
418 
419  uint64_t m_degree_bits; // log_2(m_degree)
420  // Bit shift to use in computing Barrett reduction for forward transform
421 
422  uint64_t m_winv; // Inverse of minimal root of unity
423  uint64_t m_w; // A 2N'th root of unity
424 
425  std::shared_ptr<AllocatorBase> alloc;
426 
427  // vector of floor(W * 2**52 / m_q), with W the root of unity powers
428  AlignedVector64<uint64_t> m_precon52_root_of_unity_powers;
429  // vector of floor(W * 2**64 / m_q), with W the root of unity powers
430  AlignedVector64<uint64_t> m_precon64_root_of_unity_powers;
431  // powers of the minimal root of unity
432  AlignedVector64<uint64_t> m_root_of_unity_powers;
433 
434  // vector of floor(W * 2**52 / m_q), with W the inverse root of unity powers
435  AlignedVector64<uint64_t> m_precon52_inv_root_of_unity_powers;
436  // vector of floor(W * 2**64 / m_q), with W the inverse root of unity powers
437  AlignedVector64<uint64_t> m_precon64_inv_root_of_unity_powers;
438 
439  AlignedVector64<uint64_t> m_inv_root_of_unity_powers;
440 };
441 
442 void ForwardTransformToBitReverse64(uint64_t* operand, uint64_t n,
443  uint64_t modulus,
444  const uint64_t* root_of_unity_powers,
445  const uint64_t* precon_root_of_unity_powers,
446  uint64_t input_mod_factor = 1,
447  uint64_t output_mod_factor = 1);
448 
457  uint64_t* operand, uint64_t n, uint64_t modulus,
458  const uint64_t* root_of_unity_powers);
459 
461  uint64_t* operand, uint64_t n, uint64_t modulus,
462  const uint64_t* inv_root_of_unity_powers,
463  const uint64_t* precon_inv_root_of_unity_powers,
464  uint64_t input_mod_factor = 1, uint64_t output_mod_factor = 1);
465 
466 // Returns true if arguments satisfy constraints for negacyclic NTT
467 bool CheckNTTArguments(uint64_t degree, uint64_t modulus);
468 
469 } // namespace utils
470 } // namespace hetest
void InverseTransformFromBitReverse64(uint64_t *operand, uint64_t n, uint64_t modulus, const uint64_t *inv_root_of_unity_powers, const uint64_t *precon_inv_root_of_unity_powers, uint64_t input_mod_factor, uint64_t output_mod_factor)
Definition: ntt.cpp:580
Definition: utils-test.hpp:202
uint64_t * GetInvRootOfUnityPowersPtr()
Definition: ntt.hpp:380
AlignedVector64< uint64_t > & GetPrecon52RootOfUnityPowers()
Definition: ntt.hpp:335
uint64_t InverseUIntMod(uint64_t input, uint64_t modulus)
Definition: ntt.cpp:14
uint64_t PowMod(uint64_t base, uint64_t exp, uint64_t modulus)
Definition: ntt.cpp:84
uint64_t DivideUInt128UInt64Lo(uint64_t x1, uint64_t x0, uint64_t y)
Definition: utils-test.hpp:108
void ComputeForward(uint64_t *result, const uint64_t *operand, uint64_t input_mod_factor, uint64_t output_mod_factor)
Compute forward NTT. Results are bit-reversed.
Definition: ntt.cpp:443
NTT(uint64_t degree, uint64_t q, Allocator &&a, AllocatorArgs &&...args)
Definition: ntt.hpp:250
std::vector< uint64_t > GeneratePrimes(size_t num_primes, size_t bit_size, size_t ntt_size)
Definition: ntt.cpp:221
AlignedVector64< uint64_t > & GetInvRootOfUnityPowers()
Definition: ntt.hpp:376
static const size_t s_max_modulus_bits
Definition: ntt.hpp:397
uint64_t GetMinimalRootOfUnity() const
Definition: ntt.hpp:321
uint64_t MinimalPrimitiveRoot(uint64_t degree, uint64_t modulus)
Definition: ntt.cpp:137
bool IsPrime(uint64_t n)
Definition: ntt.cpp:173
static const size_t s_max_inv_ifma_modulus
Definition: ntt.hpp:412
uint64_t * GetPrecon52InvRootOfUnityPowersPtr()
Definition: ntt.hpp:372
uint64_t SubUIntMod(uint64_t x, uint64_t y, uint64_t modulus)
Definition: ntt.cpp:76
uint64_t GetRootOfUnityPower(size_t i)
Definition: ntt.hpp:354
uint64_t ReverseBitsUInt(uint64_t x, uint64_t bit_width)
Definition: ntt.cpp:160
Performs negacyclic forward and inverse number-theoretic transform (NTT), commonly used in RLWE crypt...
Definition: ntt.hpp:215
uint64_t GetDegree() const
Definition: ntt.hpp:323
Definition: ntt.hpp:13
uint64_t GetModulus() const
Definition: ntt.hpp:325
AllocatorAdapter(Adaptee &&_a, Args &&...args)
std::vector< T, AlignedAllocator< T, 64 > > AlignedVector64
Definition: utils-test.hpp:352
~NTT()
Destructs the NTT object.
uint64_t MultiplyUIntMod(uint64_t x, uint64_t y, uint64_t modulus)
Definition: ntt.cpp:52
uint64_t MultiplyUIntModLazy(uint64_t x, uint64_t y_operand, uint64_t y_barrett_factor, uint64_t modulus)
Definition: ntt.hpp:87
uint64_t MultiplyMod(uint64_t x, uint64_t y, uint64_t y_precon, uint64_t modulus)
Definition: ntt.cpp:62
static const size_t s_max_fwd_ifma_modulus
Definition: ntt.hpp:408
uint64_t BarrettFactor() const
Definition: ntt.hpp:38
AlignedVector64< uint64_t > & GetRootOfUnityPowers()
Definition: ntt.hpp:349
static const size_t s_ifma_shift_bits
Definition: ntt.hpp:404
static const size_t s_max_degree_bits
Definition: ntt.hpp:394
bool CheckNTTArguments(uint64_t degree, uint64_t modulus)
Definition: ntt.cpp:661
uint64_t * GetPrecon52RootOfUnityPowersPtr()
Definition: ntt.hpp:339
void deallocate_impl(void *p, size_t n)
Definition: utils-test.hpp:195
NTT(uint64_t degree, uint64_t q, uint64_t root_of_unity, Allocator &&a, AllocatorArgs &&...args)
Definition: ntt.hpp:272
#define UTILS_CHECK(cond, expr)
Definition: utils-test.hpp:81
void ComputeInverse(uint64_t *result, const uint64_t *operand, uint64_t input_mod_factor, uint64_t output_mod_factor)
Definition: ntt.cpp:408
AlignedVector64< uint64_t > & GetPrecon52InvRootOfUnityPowers()
Definition: ntt.hpp:368
uint64_t AddUIntMod(uint64_t x, uint64_t y, uint64_t modulus)
Definition: ntt.cpp:69
uint64_t GeneratePrimitiveRoot(uint64_t degree, uint64_t modulus)
Definition: ntt.cpp:111
NTT()
Initializes an empty NTT object.
uint64_t * GetPrecon64RootOfUnityPowersPtr()
Definition: ntt.hpp:331
bool IsPrimitiveRoot(uint64_t root, uint64_t degree, uint64_t modulus)
Definition: ntt.cpp:99
Definition: ntt.hpp:312
uint64_t MaximumValue(uint64_t bits)
Definition: utils-test.hpp:182
AlignedVector64< uint64_t > & GetPrecon64InvRootOfUnityPowers()
Definition: ntt.hpp:358
void ForwardTransformToBitReverse64(uint64_t *operand, uint64_t n, uint64_t modulus, const uint64_t *root_of_unity_powers, const uint64_t *precon_root_of_unity_powers, uint64_t input_mod_factor, uint64_t output_mod_factor)
Definition: ntt.cpp:474
uint64_t GetInvRootOfUnityPower(size_t i)
Definition: ntt.hpp:384
std::shared_ptr< NTTImpl > m_impl
Class implementing the NTT.
Definition: ntt.hpp:300
unsigned char AddUInt64(uint64_t operand1, uint64_t operand2, uint64_t *result)
Definition: ntt.hpp:133
uint64_t * GetPrecon64InvRootOfUnityPowersPtr()
Definition: ntt.hpp:362
static const size_t s_default_shift_bits
Definition: ntt.hpp:400
MultiplyFactor(uint64_t operand, uint64_t bit_shift, uint64_t modulus)
Definition: ntt.hpp:18
void ComputeInverse(uint64_t *result, const uint64_t *operand, uint64_t input_mod_factor, uint64_t output_mod_factor)
Definition: ntt.cpp:458
NTTImpl(uint64_t degree, uint64_t q, uint64_t root_of_unity, std::shared_ptr< AllocatorBase > alloc_ptr={})
Definition: ntt.cpp:258
uint64_t ReduceMod(uint64_t x, uint64_t modulus, const uint64_t *twice_modulus=nullptr, const uint64_t *four_times_modulus=nullptr)
Definition: ntt.hpp:157
void * allocate_impl(size_t bytes_count)
void ComputeForward(uint64_t *result, const uint64_t *operand, uint64_t input_mod_factor, uint64_t output_mod_factor)
Definition: ntt.cpp:386
uint64_t * GetRootOfUnityPowersPtr()
Definition: ntt.hpp:343
void ReferenceForwardTransformToBitReverse(uint64_t *operand, uint64_t n, uint64_t modulus, const uint64_t *root_of_unity_powers)
Reference NTT which is written for clarity rather than performance.
Definition: ntt.cpp:550
AlignedVector64< uint64_t > & GetPrecon64RootOfUnityPowers()
Definition: ntt.hpp:327
uint64_t Operand() const
Definition: ntt.hpp:39
uint64_t BarrettReduce64(uint64_t input, uint64_t modulus, uint64_t q_barr)
Definition: ntt.cpp:45