Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add more multiplication primitives #107

Open
wants to merge 25 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 15 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 19 additions & 0 deletions inc/zoo/swar/SWAR.h
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,8 @@ constexpr __uint128_t lsbIndex(__uint128_t v) noexcept {
}
#endif



/// Core abstraction around SIMD Within A Register (SWAR). Specifies 'lanes'
/// of NBits width against a type T, and provides an abstraction for performing
/// SIMD operations against that primitive type T treated as a SIMD register.
Expand Down Expand Up @@ -108,6 +110,17 @@ struct SWAR {
return result;
}

constexpr static auto evenLaneMask() {
using S = SWAR<NBits, T>;
static_assert(0 == S::Lanes % 2, "Only even number of elements supported");
using D = SWAR<NBits * 2, T>;
return S{(D::LeastSignificantBit << S::NBits) - D::LeastSignificantBit};
}

constexpr static auto oddLaneMask() {
return SWAR<NBits, T>{static_cast<T>(~evenLaneMask().value())};
}

template <typename Range>
constexpr static auto from(const Range &values) noexcept {
using std::begin; using std::end;
Expand Down Expand Up @@ -245,6 +258,12 @@ constexpr auto horizontalEquality(SWAR<NBits, T> left, SWAR<NBits, T> right) {
return left.m_v == right.m_v;
}

template <int NBits, typename T>
constexpr static auto consumeMSB(SWAR<NBits, T> s) noexcept {
using S = SWAR<NBits, T>;
auto msbCleared = s & ~S{S::MostSignificantBit};
return S{static_cast<T>(msbCleared.value() << 1)};
}

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am not sold on promoting this to the main header of swar.
This really seems to be an artifact of the "regressive" direction of "associative iteration", it does not cohere enough to the SWAR library itself.


#if ZOO_USE_LEASTNBITSMASK
Expand Down
130 changes: 86 additions & 44 deletions inc/zoo/swar/associative_iteration.h
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
#ifndef ZOO_SWAR_ASSOCIATIVE_ITERATION_H
#define ZOO_SWAR_ASSOCIATIVE_ITERATION_H

#include "SWAR.h"
#include "zoo/swar/SWAR.h"
#include <assert.h>
#include <cstdint>

//#define ZOO_DEVELOPMENT_DEBUGGING
#ifdef ZOO_DEVELOPMENT_DEBUGGING
Expand Down Expand Up @@ -393,8 +396,7 @@ constexpr auto multiplication_OverflowUnsafe_SpecificBitCount(
};

auto halver = [](auto counts) {
auto msbCleared = counts & ~S{S::MostSignificantBit};
return S{msbCleared.value() << 1};
return swar::consumeMSB(counts);
};

auto shifted = S{multiplier.value() << (NB - ActualBits)};
Expand Down Expand Up @@ -426,38 +428,6 @@ constexpr auto multiplication_OverflowUnsafe_SpecificBitCount_deprecated(
return product;
}

// TODO(Jamie): Add tests from other PR.
template<int ActualBits, int NB, typename T>
constexpr auto exponentiation_OverflowUnsafe_SpecificBitCount(
SWAR<NB, T> x,
SWAR<NB, T> exponent
) {
using S = SWAR<NB, T>;

auto operation = [](auto left, auto right, auto counts) {
const auto mask = makeLaneMaskFromMSB(counts);
const auto product =
multiplication_OverflowUnsafe_SpecificBitCount<ActualBits>(left, right);
return (product & mask) | (left & ~mask);
};

// halver should work same as multiplication... i think...
auto halver = [](auto counts) {
auto msbCleared = counts & ~S{S::MostSignificantBit};
return S{static_cast<T>(msbCleared.value() << 1)};
};

exponent = S{static_cast<T>(exponent.value() << (NB - ActualBits))};
return associativeOperatorIterated_regressive(
x,
S{meta::BitmaskMaker<T, 1, NB>().value}, // neutral is lane wise..
exponent,
S{S::MostSignificantBit},
operation,
ActualBits,
halver
);
}

template<int NB, typename T>
constexpr auto multiplication_OverflowUnsafe(
Expand All @@ -475,14 +445,6 @@ struct SWAR_Pair{
SWAR<NB, T> even, odd;
};

template<int NB, typename T>
constexpr SWAR<NB, T> doublingMask() {
using S = SWAR<NB, T>;
static_assert(0 == S::Lanes % 2, "Only even number of elements supported");
using D = SWAR<NB * 2, T>;
return S{(D::LeastSignificantBit << NB) - D::LeastSignificantBit};
}

template<int NB, typename T>
constexpr auto doublePrecision(SWAR<NB, T> input) {
using S = SWAR<NB, T>;
Expand All @@ -491,7 +453,7 @@ constexpr auto doublePrecision(SWAR<NB, T> input) {
"Precision can only be doubled for SWARs of even element count"
);
using RV = SWAR<NB * 2, T>;
constexpr auto DM = doublingMask<NB, T>();
constexpr auto DM = SWAR<NB, T>::evenLaneMask();
return SWAR_Pair<NB * 2, T>{
RV{(input & DM).value()},
RV{(input.value() >> NB) & DM.value()}
Expand All @@ -503,13 +465,93 @@ constexpr auto halvePrecision(SWAR<NB, T> even, SWAR<NB, T> odd) {
using S = SWAR<NB, T>;
static_assert(0 == NB % 2, "Only even lane-bitcounts supported");
using RV = SWAR<NB/2, T>;
constexpr auto HalvingMask = doublingMask<NB/2, T>();
constexpr auto HalvingMask = SWAR<NB/2, T>::evenLaneMask();
auto
evenHalf = RV{even.value()} & HalvingMask,
oddHalf = RV{(RV{odd.value()} & HalvingMask).value() << NB/2};
return evenHalf | oddHalf;
}


template <int NB, typename T> struct MultiplicationResult {
SWAR<NB, T> lower;
SWAR<NB, T> upper;
Comment on lines +482 to +483
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

merge

};

template <int NB, typename T>
constexpr auto
doublingMultiplication(SWAR<NB, T> multiplicand, SWAR<NB, T> multiplier) {
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

doubling is confusing here. doublePrecisionMultiplication is fine, multiplicationByDoublingPrecision, ...

using S = SWAR<NB, T>; using D = SWAR<NB * 2, T>;
auto [l_even, l_odd] = doublePrecision(multiplicand);
auto [r_even, r_odd] = doublePrecision(multiplier);
auto
res_even = multiplication_OverflowUnsafe(l_even, r_even),
res_odd = multiplication_OverflowUnsafe(l_odd, r_odd);
return SWAR_Pair<NB * 2, T>{res_even, res_odd};
}

template <int NB, typename T>
constexpr MultiplicationResult<NB, T>
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why the explicit return type?

wideningMultiplication(SWAR<NB, T> multiplicand, SWAR<NB, T> multiplier) {
using S = SWAR<NB, T>; using D = SWAR<NB * 2, T>;
constexpr auto
HalfLane = S::NBits,
UpperHalfOfLanes = SWAR<S::NBits, T>::oddLaneMask().value();
auto [lower, upper] = doublingMultiplication(multiplicand, multiplier);
auto result = halvePrecision(lower, upper);
auto
over_even = D{(lower.value() & UpperHalfOfLanes) >> HalfLane},
over_odd = D{(upper.value() & UpperHalfOfLanes) >> HalfLane};
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

shift intra lane allows you to provide the mask.
Please use those primitives instead of deploying the pick-axe

auto upper_lanes_overflow = halvePrecision(over_even, over_odd);
return {result, upper_lanes_overflow};
}

template <int NB, typename T>
constexpr
auto saturatedMultiplication(SWAR<NB, T> multiplicand, SWAR<NB, T> multiplier) {
using S = SWAR<NB, T>;
constexpr auto One = S{S::LeastSignificantBit};
auto [result, overflow] = wideningMultiplication(multiplicand, multiplier);
auto did_overflow = zoo::swar::greaterEqual(overflow, One);
auto lane_mask = did_overflow.MSBtoLaneMask();
auto saturated = result | lane_mask;
return S{saturated};
}


// TODO(Jamie): Add tests from other PR.
template<int NB, typename T>
constexpr auto saturatingExponentiation(
SWAR<NB, T> x,
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Absolutely not.
We're not removing the non-saturating exponentiation and provide only the saturating exponentiation. Don't do that.
Always the general operation is pre-requisite for the more specific.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

good call

SWAR<NB, T> exponent
) {
using S = SWAR<NB, T>;
constexpr auto NumBitsPerLane = S::NBits;
constexpr auto
MSB = S{S::MostSignificantBit},
LSB = S{S::LeastSignificantBit};

auto operation = [](auto left, auto right, auto counts) {
auto mask = makeLaneMaskFromMSB(counts);
auto product = saturatedMultiplication(left, right);
return (product & mask) | (left & ~mask);
};

auto halver = [](auto counts) {
return swar::consumeMSB(counts);
};

return associativeOperatorIterated_regressive(
x,
LSB,
exponent,
MSB,
operation,
NumBitsPerLane,
halver
);
}

}

#endif
39 changes: 38 additions & 1 deletion test/swar/BasicOperations.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -211,7 +211,7 @@ static_assert(BooleanSWAR{Literals<4, u16>,
namespace Multiplication {

static_assert(~int64_t(0) == negate(S4_64{S4_64::LeastSignificantBit}).value());
static_assert(0x0F0F0F0F == doublingMask<4, uint32_t>().value());
static_assert(0x0F0F0F0F == SWAR<4, uint32_t>::evenLaneMask().value());

constexpr auto PrecisionFixtureTest = 0x89ABCDEF;
constexpr auto Doubled =
Expand Down Expand Up @@ -255,6 +255,43 @@ HE(3, u8, 0xFF, 0x7);
HE(2, u8, 0xAA, 0x2);
#undef HE

template<int NB, typename T>
constexpr auto testSaturatingMultiplication(T left, T right, T expected) {
using S = SWAR<NB, T>;
return saturatingExponentiation(S{left}, S{right}).value() == expected;
}
static_assert(
testSaturatingMultiplication<8, u32>(
0x09'40'03'01,
0x37'03'C0'01,
0xFF'FF'FF'01
));
static_assert(
testSaturatingMultiplication<8, u32>(
0x02'02'02'02,
0x02'02'02'02,
0x04'04'04'04
));
static_assert(
testSaturatingMultiplication<8, u32>(
0xFF'FF'FF'FF,
0x04'03'02'01,
0xFF'FF'FF'FF
));
static_assert(
testSaturatingMultiplication<8, u32>(
0x02'FF'FF'FF,
0x03'03'02'01,
0x08'FF'FF'FF
));
static_assert(
testSaturatingMultiplication<4, u32>(
0x1243'0003,
0x0002'0002,
0x1119'1119
));


TEST_CASE("Old multiply version", "[deprecated][swar]") {
SWAR<8, u32> Micand{0x5030201};
SWAR<8, u32> Mplier{0xA050301};
Expand Down
Loading