Intel HEXL
Intel Homomorphic Encryption Acceleration Library, accelerating the modular arithmetic operations used in homomorphic encryption.
All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Pages
ntt.hpp
Go to the documentation of this file.
1 // Copyright (C) 2020-2021 Intel Corporation
2 // SPDX-License-Identifier: Apache-2.0
3 
4 #pragma once
5 
6 #include <stdint.h>
7 
8 #include <memory>
9 #include <vector>
10 
12 #include "hexl/util/allocator.hpp"
13 
14 namespace intel {
15 namespace hexl {
16 
22 class NTT {
23  public:
25  template <class Adaptee, class... Args>
27  : public AllocatorInterface<AllocatorAdapter<Adaptee, Args...>> {
28  explicit AllocatorAdapter(Adaptee&& _a, Args&&... args);
29  AllocatorAdapter(const Adaptee& _a, Args&... args);
30 
31  // interface implementation
32  void* allocate_impl(size_t bytes_count);
33  void deallocate_impl(void* p, size_t n);
34 
35  private:
36  Adaptee alloc;
37  };
38 
40  NTT() = default;
41 
43  ~NTT() = default;
44 
54  NTT(uint64_t degree, uint64_t q,
55  std::shared_ptr<AllocatorBase> alloc_ptr = {});
56 
57  template <class Allocator, class... AllocatorArgs>
58  NTT(uint64_t degree, uint64_t q, Allocator&& a, AllocatorArgs&&... args)
59  : NTT(degree, q,
60  std::static_pointer_cast<AllocatorBase>(
61  std::make_shared<AllocatorAdapter<Allocator, AllocatorArgs...>>(
62  std::move(a), std::forward<AllocatorArgs>(args)...))) {}
63 
75  NTT(uint64_t degree, uint64_t q, uint64_t root_of_unity,
76  std::shared_ptr<AllocatorBase> alloc_ptr = {});
77 
78  template <class Allocator, class... AllocatorArgs>
79  NTT(uint64_t degree, uint64_t q, uint64_t root_of_unity, Allocator&& a,
80  AllocatorArgs&&... args)
81  : NTT(degree, q, root_of_unity,
82  std::static_pointer_cast<AllocatorBase>(
83  std::make_shared<AllocatorAdapter<Allocator, AllocatorArgs...>>(
84  std::move(a), std::forward<AllocatorArgs>(args)...))) {}
85 
90  static bool CheckArguments(uint64_t degree, uint64_t modulus);
91 
99  void ComputeForward(uint64_t* result, const uint64_t* operand,
100  uint64_t input_mod_factor, uint64_t output_mod_factor);
101 
109  void ComputeInverse(uint64_t* result, const uint64_t* operand,
110  uint64_t input_mod_factor, uint64_t output_mod_factor);
111 
113  uint64_t GetMinimalRootOfUnity() const { return m_w; }
114 
116  uint64_t GetDegree() const { return m_degree; }
117 
119  uint64_t GetModulus() const { return m_q; }
120 
123  return m_root_of_unity_powers;
124  }
125 
127  uint64_t GetRootOfUnityPower(size_t i) { return GetRootOfUnityPowers()[i]; }
128 
132  return m_precon32_root_of_unity_powers;
133  }
134 
138  return m_precon64_root_of_unity_powers;
139  }
140 
144  return m_avx512_root_of_unity_powers;
145  }
146 
150  return m_avx512_precon32_root_of_unity_powers;
151  }
152 
156  return m_avx512_precon52_root_of_unity_powers;
157  }
158 
162  return m_avx512_precon64_root_of_unity_powers;
163  }
164 
167  return m_inv_root_of_unity_powers;
168  }
169 
171  uint64_t GetInvRootOfUnityPower(size_t i) {
172  return GetInvRootOfUnityPowers()[i];
173  }
174 
177  // powers for the modulus and root of unity.
179  return m_precon32_inv_root_of_unity_powers;
180  }
181 
184  // powers for the modulus and root of unity.
186  return m_precon52_inv_root_of_unity_powers;
187  }
188 
191  // powers for the modulus and root of unity.
193  return m_precon64_inv_root_of_unity_powers;
194  }
195 
197  static size_t MaxDegreeBits() { return 20; }
198 
200  static size_t MaxModulusBits() { return 62; }
201 
203  static const size_t s_default_shift_bits{64};
204 
207  static const size_t s_ifma_shift_bits{52};
208 
211  static const size_t s_max_fwd_32_modulus{1ULL << (32 - 2)};
212 
215  static const size_t s_max_inv_32_modulus{1ULL << (32 - 1)};
216 
219  static const size_t s_max_fwd_ifma_modulus{1ULL << (s_ifma_shift_bits - 2)};
220 
223  static const size_t s_max_inv_ifma_modulus{1ULL << (s_ifma_shift_bits - 1)};
224 
225  private:
226  void ComputeRootOfUnityPowers();
227 
228  uint64_t m_degree; // N: size of NTT transform, should be power of 2
229  uint64_t m_q; // prime modulus. Must satisfy q == 1 mod 2n
230 
231  uint64_t m_degree_bits; // log_2(m_degree)
232 
233  uint64_t m_w_inv; // Inverse of minimal root of unity
234  uint64_t m_w; // A 2N'th root of unity
235 
236  std::shared_ptr<AllocatorBase> m_alloc;
237 
238  AlignedAllocator<uint64_t, 64> m_aligned_alloc;
239 
240  // powers of the minimal root of unity
241  AlignedVector64<uint64_t> m_root_of_unity_powers;
242  // vector of floor(W * 2**32 / m_q), with W the root of unity powers
243  AlignedVector64<uint64_t> m_precon32_root_of_unity_powers;
244  // vector of floor(W * 2**64 / m_q), with W the root of unity powers
245  AlignedVector64<uint64_t> m_precon64_root_of_unity_powers;
246 
247  // powers of the minimal root of unity adjusted for use in AVX512
248  // implementations
249  AlignedVector64<uint64_t> m_avx512_root_of_unity_powers;
250  // vector of floor(W * 2**32 / m_q), with W the AVX512 root of unity powers
251  AlignedVector64<uint64_t> m_avx512_precon32_root_of_unity_powers;
252  // vector of floor(W * 2**52 / m_q), with W the AVX512 root of unity powers
253  AlignedVector64<uint64_t> m_avx512_precon52_root_of_unity_powers;
254  // vector of floor(W * 2**64 / m_q), with W the AVX512 root of unity powers
255  AlignedVector64<uint64_t> m_avx512_precon64_root_of_unity_powers;
256 
257  // vector of floor(W * 2**32 / m_q), with W the inverse root of unity powers
258  AlignedVector64<uint64_t> m_precon32_inv_root_of_unity_powers;
259  // vector of floor(W * 2**52 / m_q), with W the inverse root of unity powers
260  AlignedVector64<uint64_t> m_precon52_inv_root_of_unity_powers;
261  // vector of floor(W * 2**64 / m_q), with W the inverse root of unity powers
262  AlignedVector64<uint64_t> m_precon64_inv_root_of_unity_powers;
263 
264  AlignedVector64<uint64_t> m_inv_root_of_unity_powers;
265 };
266 
267 } // namespace hexl
268 } // namespace intel
const AlignedVector64< uint64_t > & GetInvRootOfUnityPowers() const
Returns the inverse root of unity powers in bit-reversed order.
Definition: ntt.hpp:166
const AlignedVector64< uint64_t > & GetPrecon52InvRootOfUnityPowers() const
Returns the vector of 52-bit pre-conditioned pre-computed root of unity.
Definition: ntt.hpp:185
NTT(uint64_t degree, uint64_t q, uint64_t root_of_unity, Allocator &&a, AllocatorArgs &&... args)
Definition: ntt.hpp:79
uint64_t GetRootOfUnityPower(size_t i)
Returns the root of unity power at bit-reversed index i.
Definition: ntt.hpp:127
Base class for custom memory allocator.
Definition: allocator.hpp:12
Performs negacyclic forward and inverse number-theoretic transform (NTT), commonly used in RLWE crypt...
Definition: ntt.hpp:22
NTT(uint64_t degree, uint64_t q, Allocator &&a, AllocatorArgs &&... args)
Definition: ntt.hpp:58
const AlignedVector64< uint64_t > & GetPrecon32InvRootOfUnityPowers() const
Returns the vector of 32-bit pre-conditioned pre-computed root of unity.
Definition: ntt.hpp:178
uint64_t GetInvRootOfUnityPower(size_t i)
Returns the inverse root of unity power at bit-reversed index i.
Definition: ntt.hpp:171
std::vector< T, AlignedAllocator< T, 64 > > AlignedVector64
64-byte aligned memory allocator
Definition: aligned-allocator.hpp:107
AllocatorAdapter(Adaptee &&_a, Args &&... args)
static const size_t s_ifma_shift_bits
Bit shift used in Barrett precomputation when AVX512-IFMA acceleration is enabled.
Definition: ntt.hpp:207
static size_t MaxDegreeBits()
Maximum power of 2 in degree.
Definition: ntt.hpp:197
const AlignedVector64< uint64_t > & GetPrecon64RootOfUnityPowers() const
Returns 64-bit pre-conditioned root of unity powers in bit-reversed order.
Definition: ntt.hpp:137
uint64_t GetDegree() const
Returns the degree N.
Definition: ntt.hpp:116
const AlignedVector64< uint64_t > & GetAVX512Precon52RootOfUnityPowers() const
Returns 52-bit pre-conditioned AVX512 root of unity powers in bit-reversed order. ...
Definition: ntt.hpp:155
void ComputeInverse(uint64_t *result, const uint64_t *operand, uint64_t input_mod_factor, uint64_t output_mod_factor)
void deallocate_impl(void *p, size_t n)
static const size_t s_max_inv_32_modulus
Maximum modulus to use 32-bit AVX512-DQ acceleration for the inverse transform.
Definition: ntt.hpp:215
const AlignedVector64< uint64_t > & GetPrecon32RootOfUnityPowers() const
Returns 32-bit pre-conditioned root of unity powers in bit-reversed order.
Definition: ntt.hpp:131
const AlignedVector64< uint64_t > & GetPrecon64InvRootOfUnityPowers() const
Returns the vector of 64-bit pre-conditioned pre-computed root of unity.
Definition: ntt.hpp:192
const AlignedVector64< uint64_t > & GetAVX512Precon64RootOfUnityPowers() const
Returns 64-bit pre-conditioned AVX512 root of unity powers in bit-reversed order. ...
Definition: ntt.hpp:161
Definition: eltwise-add-mod.hpp:8
uint64_t GetMinimalRootOfUnity() const
Returns the minimal 2N&#39;th root of unity.
Definition: ntt.hpp:113
NTT()=default
Initializes an empty NTT object.
const AlignedVector64< uint64_t > & GetRootOfUnityPowers() const
Returns the root of unity powers in bit-reversed order.
Definition: ntt.hpp:122
static const size_t s_max_fwd_32_modulus
Maximum modulus to use 32-bit AVX512-DQ acceleration for the forward transform.
Definition: ntt.hpp:211
uint64_t GetModulus() const
Returns the word-sized prime modulus.
Definition: ntt.hpp:119
static const size_t s_max_fwd_ifma_modulus
Maximum modulus to use AVX512-IFMA acceleration for the forward transform.
Definition: ntt.hpp:219
~NTT()=default
Destructs the NTT object.
static const size_t s_default_shift_bits
Default bit shift used in Barrett precomputation.
Definition: ntt.hpp:203
Helper class for custom memory allocation.
Definition: ntt.hpp:26
static const size_t s_max_inv_ifma_modulus
Maximum modulus to use AVX512-IFMA acceleration for the inverse transform.
Definition: ntt.hpp:223
void * allocate_impl(size_t bytes_count)
const AlignedVector64< uint64_t > & GetAVX512Precon32RootOfUnityPowers() const
Returns 32-bit pre-conditioned AVX512 root of unity powers in bit-reversed order. ...
Definition: ntt.hpp:149
void ComputeForward(uint64_t *result, const uint64_t *operand, uint64_t input_mod_factor, uint64_t output_mod_factor)
Compute forward NTT. Results are bit-reversed.
static bool CheckArguments(uint64_t degree, uint64_t modulus)
Returns true if arguments satisfy constraints for negacyclic NTT.
const AlignedVector64< uint64_t > & GetAVX512RootOfUnityPowers() const
Returns the root of unity powers in bit-reversed order with modifications for use by AVX512 implement...
Definition: ntt.hpp:143
static size_t MaxModulusBits()
Maximum number of bits in modulus;.
Definition: ntt.hpp:200
Helper memory allocation struct which delegates implementation to AllocatorImpl.
Definition: allocator.hpp:29