8 #define NOMINMAX // Avoid errors with std::min/std::max 12 #include <immintrin.h> 21 #pragma intrinsic(_addcarry_u64, _BitScanReverse64, _subborrow_u64, _udiv128, \ 30 inline uint64_t BarrettReduce128(uint64_t input_hi, uint64_t input_lo,
34 _udiv128(input_hi, input_lo, modulus, &remainder);
42 inline
void MultiplyUInt64(uint64_t x, uint64_t y, uint64_t* prod_hi,
44 *prod_lo = _umul128(x, y, prod_hi);
48 template <
int BitShift>
49 inline uint64_t MultiplyUInt64Hi(uint64_t x, uint64_t y) {
51 "Invalid BitShift " << BitShift <<
"; expected 52 or 64");
53 uint64_t prod_lo = _umul128(x, y, &prod_hi);
56 RightShift128(&result_hi, &result_lo, prod_hi, prod_lo, BitShift);
65 inline void LeftShift128(uint64_t* result_hi, uint64_t* result_lo,
66 const uint64_t op_hi,
const uint64_t op_lo,
67 const uint64_t shift_value) {
68 HEXL_CHECK(result_hi !=
nullptr,
"Require result_hi != nullptr");
69 HEXL_CHECK(result_lo !=
nullptr,
"Require result_lo != nullptr");
71 "shift_value cannot be greater than 128 " << shift_value);
73 if (shift_value == 0) {
76 }
else if (shift_value == 64) {
79 }
else if (shift_value == 128) {
82 }
else if (shift_value >= 1 && shift_value <= 63) {
83 *result_hi = (op_hi << shift_value) | (op_lo >> (64 - shift_value));
84 *result_lo = op_lo << shift_value;
85 }
else if (shift_value >= 65 && shift_value < 128) {
86 *result_hi = op_lo << (shift_value - 64);
96 inline void RightShift128(uint64_t* result_hi, uint64_t* result_lo,
97 const uint64_t op_hi,
const uint64_t op_lo,
98 const uint64_t shift_value) {
99 HEXL_CHECK(result_hi !=
nullptr,
"Require result_hi != nullptr");
100 HEXL_CHECK(result_lo !=
nullptr,
"Require result_lo != nullptr");
102 "shift_value cannot be greater than 128 " << shift_value);
104 if (shift_value == 0) {
107 }
else if (shift_value == 64) {
110 }
else if (shift_value == 128) {
113 }
else if (shift_value >= 1 && shift_value <= 63) {
114 *result_hi = op_hi >> shift_value;
115 *result_lo = (op_hi << (64 - shift_value)) | (op_lo >> shift_value);
116 }
else if (shift_value >= 65 && shift_value < 128) {
118 *result_lo = op_hi >> (shift_value - 64);
125 inline void AddWithCarry128(uint64_t* result_hi, uint64_t* result_lo,
126 const uint64_t op1_hi,
const uint64_t op1_lo,
127 const uint64_t op2_hi,
const uint64_t op2_lo) {
128 HEXL_CHECK(result_hi !=
nullptr,
"Require result_hi != nullptr");
129 HEXL_CHECK(result_lo !=
nullptr,
"Require result_lo != nullptr");
132 *result_lo = op1_lo + op2_lo;
133 unsigned char carry =
static_cast<unsigned char>(*result_lo < op1_lo);
136 _addcarry_u64(carry, op1_hi, op2_hi, result_hi);
142 inline void SubWithCarry128(uint64_t* result_hi, uint64_t* result_lo,
143 const uint64_t op1_hi,
const uint64_t op1_lo,
144 const uint64_t op2_hi,
const uint64_t op2_lo) {
145 HEXL_CHECK(result_hi !=
nullptr,
"Require result_hi != nullptr");
146 HEXL_CHECK(result_lo !=
nullptr,
"Require result_lo != nullptr");
148 unsigned char borrow;
151 *result_lo = op1_lo - op2_lo;
152 borrow =
static_cast<unsigned char>(op2_lo > op1_lo);
155 _subborrow_u64(borrow, op1_hi, op2_hi, result_hi);
160 inline uint64_t SignificantBitLength(
const uint64_t* value) {
161 HEXL_CHECK(value !=
nullptr,
"Require value != nullptr");
163 unsigned long count = 0;
166 _BitScanReverse64(&count, *(value + 1));
167 if (count >= 0 && *(value + 1) > 0) {
168 return static_cast<uint64_t
>(count) + 1 + 64;
172 _BitScanReverse64(&count, *value);
173 if (count >= 0 && *(value) > 0) {
174 return static_cast<uint64_t
>(count) + 1;
181 inline bool CheckSign(
const uint64_t* input) {
182 HEXL_CHECK(input !=
nullptr,
"Require input != nullptr");
184 uint64_t input_temp[2]{0, 0};
185 RightShift128(&input_temp[1], &input_temp[0], input[1], input[0], 127);
186 return (input_temp[0] == 1);
193 inline void DivideUInt128UInt64(uint64_t* quotient,
const uint64_t* numerator,
194 const uint64_t denominator) {
195 HEXL_CHECK(quotient !=
nullptr,
"Require quotient != nullptr");
196 HEXL_CHECK(numerator !=
nullptr,
"Require numerator != nullptr");
197 HEXL_CHECK(denominator != 0,
"denominator cannot be 0 " << denominator);
200 uint64_t numerator_bits = SignificantBitLength(numerator);
201 const uint64_t numerator_bits_const = numerator_bits;
202 const uint64_t uint_128_bit = 128ULL;
204 uint64_t MASK[2]{0x0000000000000001, 0x0000000000000000};
205 uint64_t remainder[2]{0, 0};
206 uint64_t quotient_temp[2]{0, 0};
207 uint64_t denominator_temp[2]{denominator, 0};
209 quotient[0] = numerator[0];
210 quotient[1] = numerator[1];
213 LeftShift128("ient[1], "ient[0], quotient[1], quotient[0],
214 (uint_128_bit - numerator_bits_const));
216 while (numerator_bits) {
218 if (CheckSign(remainder)) {
219 LeftShift128(&remainder[1], &remainder[0], remainder[1], remainder[0], 1);
220 RightShift128("ient_temp[1], "ient_temp[0], quotient[1],
221 quotient[0], (uint_128_bit - 1));
222 remainder[0] = remainder[0] | quotient_temp[0];
223 LeftShift128("ient[1], "ient[0], quotient[1], quotient[0], 1);
225 AddWithCarry128(&remainder[1], &remainder[0], remainder[1], remainder[0],
226 denominator_temp[1], denominator_temp[0]);
228 LeftShift128(&remainder[1], &remainder[0], remainder[1], remainder[0], 1);
229 RightShift128("ient_temp[1], "ient_temp[0], quotient[1],
230 quotient[0], (uint_128_bit - 1));
231 remainder[0] = remainder[0] | quotient_temp[0];
232 LeftShift128("ient[1], "ient[0], quotient[1], quotient[0], 1);
234 SubWithCarry128(&remainder[1], &remainder[0], remainder[1], remainder[0],
235 denominator_temp[1], denominator_temp[0]);
239 if (!CheckSign(remainder)) {
240 MASK[0] = 0x0000000000000001;
241 MASK[1] = 0x0000000000000000;
242 LeftShift128(&MASK[1], &MASK[0], MASK[1], MASK[0],
243 (uint_128_bit - numerator_bits_const));
244 quotient[0] = quotient[0] | MASK[0];
245 quotient[1] = quotient[1] | MASK[1];
247 quotient_temp[0] = 0;
248 quotient_temp[1] = 0;
252 if (CheckSign(remainder)) {
254 AddWithCarry128(&remainder[1], &remainder[0], remainder[1], remainder[0],
255 denominator_temp[1], denominator_temp[0]);
257 RightShift128("ient[1], "ient[0], quotient[1], quotient[0],
258 (uint_128_bit - numerator_bits_const));
265 inline uint64_t DivideUInt128UInt64Lo(
const uint64_t numerator_hi,
266 const uint64_t numerator_lo,
267 const uint64_t denominator) {
268 uint64_t numerator[2]{numerator_lo, numerator_hi};
269 uint64_t quotient[2]{0, 0};
271 DivideUInt128UInt64(quotient, numerator, denominator);
276 inline uint64_t MSB(uint64_t input) {
277 unsigned long index{0};
278 _BitScanReverse64(&index, input);
282 #define HEXL_LOOP_UNROLL_4 \ 284 #define HEXL_LOOP_UNROLL_8 \
#define HEXL_CHECK(cond, expr)
Definition: check.hpp:39
Definition: eltwise-add-mod.hpp:8