Intel HEXL
Intel Homomorphic Encryption Acceleration Library, accelerating the modular arithmetic operations used in homomorphic encryption.
All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Pages
msvc.hpp
Go to the documentation of this file.
1 // Copyright (C) 2020-2021 Intel Corporation
2 // SPDX-License-Identifier: Apache-2.0
3 
4 #pragma once
5 
6 #ifdef HEXL_USE_MSVC
7 
8 #define NOMINMAX // Avoid errors with std::min/std::max
9 #undef min
10 #undef max
11 
12 #include <immintrin.h>
13 #include <intrin.h>
14 #include <stdint.h>
15 
16 #include <cmath>
17 #include <iostream>
18 
19 #include "hexl/util/check.hpp"
20 
21 #pragma intrinsic(_addcarry_u64, _BitScanReverse64, _subborrow_u64, _udiv128, \
22  _umul128)
23 
24 #undef TRUE
25 #undef FALSE
26 
27 namespace intel {
28 namespace hexl {
29 
30 inline uint64_t BarrettReduce128(uint64_t input_hi, uint64_t input_lo,
31  uint64_t modulus) {
32  HEXL_CHECK(modulus != 0, "modulus == 0")
33  uint64_t remainder;
34  _udiv128(input_hi, input_lo, modulus, &remainder);
35 
36  return remainder;
37 }
38 
39 // Multiplies x * y as 128-bit integer.
40 // @param prod_hi Stores high 64 bits of product
41 // @param prod_lo Stores low 64 bits of product
42 inline void MultiplyUInt64(uint64_t x, uint64_t y, uint64_t* prod_hi,
43  uint64_t* prod_lo) {
44  *prod_lo = _umul128(x, y, prod_hi);
45 }
46 
47 // Return the high 128 minus BitShift bits of the 128-bit product x * y
48 template <int BitShift>
49 inline uint64_t MultiplyUInt64Hi(uint64_t x, uint64_t y) {
50  HEXL_CHECK(BitShift == 52 || BitShift == 64,
51  "Invalid BitShift " << BitShift << "; expected 52 or 64");
52  uint64_t prod_hi;
53  uint64_t prod_lo = _umul128(x, y, &prod_hi);
54  uint64_t result_hi;
55  uint64_t result_lo;
56  RightShift128(&result_hi, &result_lo, prod_hi, prod_lo, BitShift);
57  return result_lo;
58 }
59 
65 inline void LeftShift128(uint64_t* result_hi, uint64_t* result_lo,
66  const uint64_t op_hi, const uint64_t op_lo,
67  const uint64_t shift_value) {
68  HEXL_CHECK(result_hi != nullptr, "Require result_hi != nullptr");
69  HEXL_CHECK(result_lo != nullptr, "Require result_lo != nullptr");
70  HEXL_CHECK(shift_value <= 128,
71  "shift_value cannot be greater than 128 " << shift_value);
72 
73  if (shift_value == 0) {
74  *result_hi = op_hi;
75  *result_lo = op_lo;
76  } else if (shift_value == 64) {
77  *result_hi = op_lo;
78  *result_lo = 0ULL;
79  } else if (shift_value == 128) {
80  *result_hi = 0ULL;
81  *result_lo = 0ULL;
82  } else if (shift_value >= 1 && shift_value <= 63) {
83  *result_hi = (op_hi << shift_value) | (op_lo >> (64 - shift_value));
84  *result_lo = op_lo << shift_value;
85  } else if (shift_value >= 65 && shift_value < 128) {
86  *result_hi = op_lo << (shift_value - 64);
87  *result_lo = 0ULL;
88  }
89 }
90 
96 inline void RightShift128(uint64_t* result_hi, uint64_t* result_lo,
97  const uint64_t op_hi, const uint64_t op_lo,
98  const uint64_t shift_value) {
99  HEXL_CHECK(result_hi != nullptr, "Require result_hi != nullptr");
100  HEXL_CHECK(result_lo != nullptr, "Require result_lo != nullptr");
101  HEXL_CHECK(shift_value <= 128,
102  "shift_value cannot be greater than 128 " << shift_value);
103 
104  if (shift_value == 0) {
105  *result_hi = op_hi;
106  *result_lo = op_lo;
107  } else if (shift_value == 64) {
108  *result_hi = 0ULL;
109  *result_lo = op_hi;
110  } else if (shift_value == 128) {
111  *result_hi = 0ULL;
112  *result_lo = 0ULL;
113  } else if (shift_value >= 1 && shift_value <= 63) {
114  *result_hi = op_hi >> shift_value;
115  *result_lo = (op_hi << (64 - shift_value)) | (op_lo >> shift_value);
116  } else if (shift_value >= 65 && shift_value < 128) {
117  *result_hi = 0ULL;
118  *result_lo = op_hi >> (shift_value - 64);
119  }
120 }
121 
125 inline void AddWithCarry128(uint64_t* result_hi, uint64_t* result_lo,
126  const uint64_t op1_hi, const uint64_t op1_lo,
127  const uint64_t op2_hi, const uint64_t op2_lo) {
128  HEXL_CHECK(result_hi != nullptr, "Require result_hi != nullptr");
129  HEXL_CHECK(result_lo != nullptr, "Require result_lo != nullptr");
130 
131  // first 64bit block
132  *result_lo = op1_lo + op2_lo;
133  unsigned char carry = static_cast<unsigned char>(*result_lo < op1_lo);
134 
135  // second 64bit block
136  _addcarry_u64(carry, op1_hi, op2_hi, result_hi);
137 }
138 
142 inline void SubWithCarry128(uint64_t* result_hi, uint64_t* result_lo,
143  const uint64_t op1_hi, const uint64_t op1_lo,
144  const uint64_t op2_hi, const uint64_t op2_lo) {
145  HEXL_CHECK(result_hi != nullptr, "Require result_hi != nullptr");
146  HEXL_CHECK(result_lo != nullptr, "Require result_lo != nullptr");
147 
148  unsigned char borrow;
149 
150  // first 64bit block
151  *result_lo = op1_lo - op2_lo;
152  borrow = static_cast<unsigned char>(op2_lo > op1_lo);
153 
154  // second 64bit block
155  _subborrow_u64(borrow, op1_hi, op2_hi, result_hi);
156 }
157 
160 inline uint64_t SignificantBitLength(const uint64_t* value) {
161  HEXL_CHECK(value != nullptr, "Require value != nullptr");
162 
163  unsigned long count = 0; // NOLINT(runtime/int)
164 
165  // second 64bit block
166  _BitScanReverse64(&count, *(value + 1));
167  if (count >= 0 && *(value + 1) > 0) {
168  return static_cast<uint64_t>(count) + 1 + 64;
169  }
170 
171  // first 64bit block
172  _BitScanReverse64(&count, *value);
173  if (count >= 0 && *(value) > 0) {
174  return static_cast<uint64_t>(count) + 1;
175  }
176  return 0;
177 }
178 
181 inline bool CheckSign(const uint64_t* input) {
182  HEXL_CHECK(input != nullptr, "Require input != nullptr");
183 
184  uint64_t input_temp[2]{0, 0};
185  RightShift128(&input_temp[1], &input_temp[0], input[1], input[0], 127);
186  return (input_temp[0] == 1);
187 }
188 
193 inline void DivideUInt128UInt64(uint64_t* quotient, const uint64_t* numerator,
194  const uint64_t denominator) {
195  HEXL_CHECK(quotient != nullptr, "Require quotient != nullptr");
196  HEXL_CHECK(numerator != nullptr, "Require numerator != nullptr");
197  HEXL_CHECK(denominator != 0, "denominator cannot be 0 " << denominator);
198 
199  // get bit count of divisor
200  uint64_t numerator_bits = SignificantBitLength(numerator);
201  const uint64_t numerator_bits_const = numerator_bits;
202  const uint64_t uint_128_bit = 128ULL;
203 
204  uint64_t MASK[2]{0x0000000000000001, 0x0000000000000000};
205  uint64_t remainder[2]{0, 0};
206  uint64_t quotient_temp[2]{0, 0};
207  uint64_t denominator_temp[2]{denominator, 0};
208 
209  quotient[0] = numerator[0];
210  quotient[1] = numerator[1];
211 
212  // align numerator
213  LeftShift128(&quotient[1], &quotient[0], quotient[1], quotient[0],
214  (uint_128_bit - numerator_bits_const));
215 
216  while (numerator_bits) {
217  // if remainder is negative
218  if (CheckSign(remainder)) {
219  LeftShift128(&remainder[1], &remainder[0], remainder[1], remainder[0], 1);
220  RightShift128(&quotient_temp[1], &quotient_temp[0], quotient[1],
221  quotient[0], (uint_128_bit - 1));
222  remainder[0] = remainder[0] | quotient_temp[0];
223  LeftShift128(&quotient[1], &quotient[0], quotient[1], quotient[0], 1);
224  // remainder=remainder+denominator_temp
225  AddWithCarry128(&remainder[1], &remainder[0], remainder[1], remainder[0],
226  denominator_temp[1], denominator_temp[0]);
227  } else { // if remainder is positive
228  LeftShift128(&remainder[1], &remainder[0], remainder[1], remainder[0], 1);
229  RightShift128(&quotient_temp[1], &quotient_temp[0], quotient[1],
230  quotient[0], (uint_128_bit - 1));
231  remainder[0] = remainder[0] | quotient_temp[0];
232  LeftShift128(&quotient[1], &quotient[0], quotient[1], quotient[0], 1);
233  // remainder=remainder-denominator_temp
234  SubWithCarry128(&remainder[1], &remainder[0], remainder[1], remainder[0],
235  denominator_temp[1], denominator_temp[0]);
236  }
237 
238  // if remainder is positive set MSB of quotient[0]=1
239  if (!CheckSign(remainder)) {
240  MASK[0] = 0x0000000000000001;
241  MASK[1] = 0x0000000000000000;
242  LeftShift128(&MASK[1], &MASK[0], MASK[1], MASK[0],
243  (uint_128_bit - numerator_bits_const));
244  quotient[0] = quotient[0] | MASK[0];
245  quotient[1] = quotient[1] | MASK[1];
246  }
247  quotient_temp[0] = 0;
248  quotient_temp[1] = 0;
249  numerator_bits--;
250  }
251 
252  if (CheckSign(remainder)) {
253  // remainder=remainder+denominator_temp
254  AddWithCarry128(&remainder[1], &remainder[0], remainder[1], remainder[0],
255  denominator_temp[1], denominator_temp[0]);
256  }
257  RightShift128(&quotient[1], &quotient[0], quotient[1], quotient[0],
258  (uint_128_bit - numerator_bits_const));
259 }
260 
265 inline uint64_t DivideUInt128UInt64Lo(const uint64_t numerator_hi,
266  const uint64_t numerator_lo,
267  const uint64_t denominator) {
268  uint64_t numerator[2]{numerator_lo, numerator_hi};
269  uint64_t quotient[2]{0, 0};
270 
271  DivideUInt128UInt64(quotient, numerator, denominator);
272  return quotient[0];
273 }
274 
275 // Returns most-significant bit of the input
276 inline uint64_t MSB(uint64_t input) {
277  unsigned long index{0}; // NOLINT(runtime/int)
278  _BitScanReverse64(&index, input);
279  return index;
280 }
281 
282 #define HEXL_LOOP_UNROLL_4 \
283  {}
284 #define HEXL_LOOP_UNROLL_8 \
285  {}
286 
287 #endif
288 
289 } // namespace hexl
290 } // namespace intel
#define HEXL_CHECK(cond, expr)
Definition: check.hpp:39
Definition: eltwise-add-mod.hpp:8