diff --git a/tools/clang/unittests/HLSLExec/LongVectorOps.def b/tools/clang/unittests/HLSLExec/LongVectorOps.def index 568c166641..f3908ff055 100644 --- a/tools/clang/unittests/HLSLExec/LongVectorOps.def +++ b/tools/clang/unittests/HLSLExec/LongVectorOps.def @@ -198,7 +198,14 @@ OP_LOAD_AND_STORE_SB(LoadAndStore_RD_SB_SRV, "RootDescriptor_SRV") OP_DEFAULT(Wave, WaveActiveSum, 1, "WaveActiveSum", "") OP_DEFAULT_DEFINES(Wave, WaveActiveMin, 1, "TestWaveActiveMin", "", " -DFUNC_WAVE_ACTIVE_MIN=1") OP_DEFAULT_DEFINES(Wave, WaveActiveMax, 1, "TestWaveActiveMax", "", " -DFUNC_WAVE_ACTIVE_MAX=1") -OP(Wave, WaveActiveProduct, 1, "TestWaveActiveProduct", "", " -DFUNC_WAVE_ACTIVE_PRODUCT=1", "LongVectorOp", - AllOnes, Default2, Default3) +OP(Wave, WaveActiveProduct, 1, "TestWaveActiveProduct", "", " -DFUNC_WAVE_ACTIVE_PRODUCT=1", "LongVectorOp", AllOnes, Default2, Default3) +OP_DEFAULT_DEFINES(Wave, WaveActiveBitAnd, 1, "TestWaveActiveBitAnd", "", " -DFUNC_WAVE_ACTIVE_BIT_AND=1") +OP_DEFAULT_DEFINES(Wave, WaveActiveBitOr, 1, "TestWaveActiveBitOr", "", " -DFUNC_WAVE_ACTIVE_BIT_OR=1") +OP_DEFAULT_DEFINES(Wave, WaveActiveBitXor, 1, "TestWaveActiveBitXor", "", " -DFUNC_WAVE_ACTIVE_BIT_XOR=1") +OP_DEFAULT_DEFINES(Wave, WaveActiveAllEqual, 1, "TestWaveActiveAllEqual", "", " -DFUNC_WAVE_ACTIVE_ALL_EQUAL=1") +OP_DEFAULT_DEFINES(Wave, WaveReadLaneAt, 1, "TestWaveReadLaneAt", "", " -DFUNC_WAVE_READ_LANE_AT=1") +OP_DEFAULT_DEFINES(Wave, WaveReadLaneFirst, 1, "TestWaveReadLaneFirst", "", " -DFUNC_WAVE_READ_LANE_FIRST=1") +OP_DEFAULT_DEFINES(Wave, WavePrefixSum, 1, "TestWavePrefixSum", "", " -DFUNC_WAVE_PREFIX_SUM=1 -DIS_WAVE_PREFIX_OP=1") +OP_DEFAULT_DEFINES(Wave, WavePrefixProduct, 1, "TestWavePrefixProduct", "", " -DFUNC_WAVE_PREFIX_PRODUCT=1 -DIS_WAVE_PREFIX_OP=1") #undef OP diff --git a/tools/clang/unittests/HLSLExec/LongVectors.cpp b/tools/clang/unittests/HLSLExec/LongVectors.cpp index 88b531bb46..5ab16b75a8 100644 --- a/tools/clang/unittests/HLSLExec/LongVectors.cpp +++ b/tools/clang/unittests/HLSLExec/LongVectors.cpp @@ -1300,7 +1300,7 @@ template struct ExpectedBuilder { // Wave Ops // -#define WAVE_ACTIVE_OP(OP, IMPL) \ +#define WAVE_OP(OP, IMPL) \ template struct Op : DefaultValidation { \ T operator()(T A, UINT WaveSize) { return IMPL; } \ }; @@ -1310,7 +1310,7 @@ template T waveActiveSum(T A, UINT WaveSize) { return A * WaveSizeT; } -WAVE_ACTIVE_OP(OpType::WaveActiveSum, (waveActiveSum(A, WaveSize))); +WAVE_OP(OpType::WaveActiveSum, (waveActiveSum(A, WaveSize))); template T waveActiveMin(T A, UINT WaveSize) { std::vector Values; @@ -1320,7 +1320,7 @@ template T waveActiveMin(T A, UINT WaveSize) { return *std::min_element(Values.begin(), Values.end()); } -WAVE_ACTIVE_OP(OpType::WaveActiveMin, (waveActiveMin(A, WaveSize))); +WAVE_OP(OpType::WaveActiveMin, (waveActiveMin(A, WaveSize))); template T waveActiveMax(T A, UINT WaveSize) { std::vector Values; @@ -1330,7 +1330,7 @@ template T waveActiveMax(T A, UINT WaveSize) { return *std::max_element(Values.begin(), Values.end()); } -WAVE_ACTIVE_OP(OpType::WaveActiveMax, (waveActiveMax(A, WaveSize))); +WAVE_OP(OpType::WaveActiveMax, (waveActiveMax(A, WaveSize))); template T waveActiveProduct(T A, UINT WaveSize) { // We want to avoid overflow of a large product. So, the WaveActiveProdFn has @@ -1339,9 +1339,100 @@ template T waveActiveProduct(T A, UINT WaveSize) { return A * static_cast(WaveSize - 1); } -WAVE_ACTIVE_OP(OpType::WaveActiveProduct, (waveActiveProduct(A, WaveSize))); +WAVE_OP(OpType::WaveActiveProduct, (waveActiveProduct(A, WaveSize))); -#undef WAVE_ACTIVE_OP +template T waveActiveBitAnd(T A, UINT) { + // We set the LSB to 0 in one of the lanes. + return static_cast(A & ~static_cast(1)); +} + +WAVE_OP(OpType::WaveActiveBitAnd, (waveActiveBitAnd(A, WaveSize))); + +template T waveActiveBitOr(T A, UINT) { + // We set the LSB to 0 in one of the lanes. + return static_cast(A | static_cast(1)); +} + +WAVE_OP(OpType::WaveActiveBitOr, (waveActiveBitOr(A, WaveSize))); + +template T waveActiveBitXor(T A, UINT) { + // We clear the LSB in every lane except the last lane which sets it to 1. + return static_cast(A | static_cast(1)); +} + +WAVE_OP(OpType::WaveActiveBitXor, (waveActiveBitXor(A, WaveSize))); + +template +struct Op : StrictValidation {}; + +template struct ExpectedBuilder { + static std::vector + buildExpected(Op &, + const InputSets &Inputs, UINT) { + DXASSERT_NOMSG(Inputs.size() == 1); + + std::vector Expected; + const size_t VectorSize = Inputs[0].size(); + Expected.assign(VectorSize - 1, static_cast(true)); + // We set the last element to a different value on a single lane. + Expected[VectorSize - 1] = static_cast(false); + + return Expected; + } +}; + +template +struct Op : StrictValidation {}; + +template struct ExpectedBuilder { + static std::vector buildExpected(Op &, + const InputSets &Inputs, UINT) { + DXASSERT_NOMSG(Inputs.size() == 1); + + std::vector Expected; + const size_t VectorSize = Inputs[0].size(); + // Simple test, on the lane that we read we also fill the vector with the + // value of the first element. + Expected.assign(VectorSize, Inputs[0][0]); + + return Expected; + } +}; + +template +struct Op : StrictValidation {}; + +template struct ExpectedBuilder { + static std::vector buildExpected(Op &, + const InputSets &Inputs, UINT) { + DXASSERT_NOMSG(Inputs.size() == 1); + + std::vector Expected; + const size_t VectorSize = Inputs[0].size(); + // Simple test, on the lane that we read we also fill the vector with the + // value of the first element. + Expected.assign(VectorSize, Inputs[0][0]); + + return Expected; + } +}; + +WAVE_OP(OpType::WavePrefixSum, (wavePrefixSum(A, WaveSize))); + +template T wavePrefixSum(T A, UINT WaveSize) { + // We test the prefix sume in the 'middle' lane. This choice is arbitrary. + return static_cast(A * static_cast(WaveSize / 2)); +} + +WAVE_OP(OpType::WavePrefixProduct, (wavePrefixProduct(A, WaveSize))); + +template T wavePrefixProduct(T A, UINT) { + // We test the the prefix product in the 3rd lane to avoid overflow issues. + // So the result is A * A. + return static_cast(A * A); +} + +#undef WAVE_OP // // dispatchTest @@ -1384,9 +1475,6 @@ template struct ExpectedBuilder { return Expected; } -}; - -template struct WaveOpExpectedBuilder { static auto buildExpected(Op Op, const InputSets &Inputs, UINT WaveSize) { @@ -1466,8 +1554,7 @@ void dispatchWaveOpTest(ID3D12Device *D3DDevice, bool VerboseLogging, std::vector> Inputs = buildTestInputs(VectorSize, Operation.InputSets, Operation.Arity); - auto Expected = - WaveOpExpectedBuilder::buildExpected(Op, Inputs, WaveSize); + auto Expected = ExpectedBuilder::buildExpected(Op, Inputs, WaveSize); runAndVerify(D3DDevice, VerboseLogging, Operation, Inputs, Expected, Op.ValidationConfig, AdditionalCompilerOptions); @@ -2243,44 +2330,100 @@ class DxilConf_SM69_Vectorized { HLK_TEST(LoadAndStore_RD_SB_SRV, double); HLK_TEST(LoadAndStore_RD_SB_UAV, double); + HLK_WAVEOP_TEST(WaveActiveAllEqual, HLSLBool_t); + HLK_WAVEOP_TEST(WaveReadLaneAt, HLSLBool_t); + HLK_WAVEOP_TEST(WaveReadLaneFirst, HLSLBool_t); + HLK_WAVEOP_TEST(WaveActiveSum, int16_t); HLK_WAVEOP_TEST(WaveActiveMin, int16_t); HLK_WAVEOP_TEST(WaveActiveMax, int16_t); HLK_WAVEOP_TEST(WaveActiveProduct, int16_t); + HLK_WAVEOP_TEST(WaveActiveAllEqual, int16_t); + HLK_WAVEOP_TEST(WaveReadLaneAt, int16_t); + HLK_WAVEOP_TEST(WaveReadLaneFirst, int16_t); + HLK_WAVEOP_TEST(WavePrefixSum, int16_t); + HLK_WAVEOP_TEST(WavePrefixProduct, int16_t); HLK_WAVEOP_TEST(WaveActiveSum, int32_t); HLK_WAVEOP_TEST(WaveActiveMin, int32_t); HLK_WAVEOP_TEST(WaveActiveMax, int32_t); HLK_WAVEOP_TEST(WaveActiveProduct, int32_t); + HLK_WAVEOP_TEST(WaveActiveAllEqual, int32_t); + HLK_WAVEOP_TEST(WaveReadLaneAt, int32_t); + HLK_WAVEOP_TEST(WaveReadLaneFirst, int32_t); + HLK_WAVEOP_TEST(WavePrefixSum, int32_t); + HLK_WAVEOP_TEST(WavePrefixProduct, int32_t); HLK_WAVEOP_TEST(WaveActiveSum, int64_t); HLK_WAVEOP_TEST(WaveActiveMin, int64_t); HLK_WAVEOP_TEST(WaveActiveMax, int64_t); HLK_WAVEOP_TEST(WaveActiveProduct, int64_t); + HLK_WAVEOP_TEST(WaveActiveAllEqual, int64_t); + HLK_WAVEOP_TEST(WaveReadLaneAt, int64_t); + HLK_WAVEOP_TEST(WaveReadLaneFirst, int64_t); + HLK_WAVEOP_TEST(WavePrefixSum, int64_t); + HLK_WAVEOP_TEST(WavePrefixProduct, int64_t); HLK_WAVEOP_TEST(WaveActiveSum, uint16_t); HLK_WAVEOP_TEST(WaveActiveMin, uint16_t); HLK_WAVEOP_TEST(WaveActiveMax, uint16_t); HLK_WAVEOP_TEST(WaveActiveProduct, uint16_t); + HLK_WAVEOP_TEST(WaveActiveAllEqual, uint16_t); + HLK_WAVEOP_TEST(WaveReadLaneAt, uint16_t); + HLK_WAVEOP_TEST(WaveReadLaneFirst, uint16_t); + HLK_WAVEOP_TEST(WavePrefixSum, uint16_t); + HLK_WAVEOP_TEST(WavePrefixProduct, uint16_t); HLK_WAVEOP_TEST(WaveActiveSum, uint32_t); HLK_WAVEOP_TEST(WaveActiveMin, uint32_t); HLK_WAVEOP_TEST(WaveActiveMax, uint32_t); HLK_WAVEOP_TEST(WaveActiveProduct, uint32_t); + // Note: WaveActiveBit* ops don't support uint16_t in HLSL + HLK_WAVEOP_TEST(WaveActiveBitAnd, uint32_t); + HLK_WAVEOP_TEST(WaveActiveBitOr, uint32_t); + HLK_WAVEOP_TEST(WaveActiveBitXor, uint32_t); + HLK_WAVEOP_TEST(WaveActiveAllEqual, uint32_t); + HLK_WAVEOP_TEST(WaveReadLaneAt, uint32_t); + HLK_WAVEOP_TEST(WaveReadLaneFirst, uint32_t); + HLK_WAVEOP_TEST(WavePrefixSum, uint32_t); + HLK_WAVEOP_TEST(WavePrefixProduct, uint32_t); HLK_WAVEOP_TEST(WaveActiveSum, uint64_t); HLK_WAVEOP_TEST(WaveActiveMin, uint64_t); HLK_WAVEOP_TEST(WaveActiveMax, uint64_t); HLK_WAVEOP_TEST(WaveActiveProduct, uint64_t); + HLK_WAVEOP_TEST(WaveActiveBitAnd, uint64_t); + HLK_WAVEOP_TEST(WaveActiveBitOr, uint64_t); + HLK_WAVEOP_TEST(WaveActiveBitXor, uint64_t); + HLK_WAVEOP_TEST(WaveActiveAllEqual, uint64_t); + HLK_WAVEOP_TEST(WaveReadLaneAt, uint64_t); + HLK_WAVEOP_TEST(WaveReadLaneFirst, uint64_t); + HLK_WAVEOP_TEST(WavePrefixSum, uint64_t); + HLK_WAVEOP_TEST(WavePrefixProduct, uint64_t); HLK_WAVEOP_TEST(WaveActiveSum, HLSLHalf_t); HLK_WAVEOP_TEST(WaveActiveMin, HLSLHalf_t); HLK_WAVEOP_TEST(WaveActiveMax, HLSLHalf_t); HLK_WAVEOP_TEST(WaveActiveProduct, HLSLHalf_t); + HLK_WAVEOP_TEST(WaveActiveAllEqual, HLSLHalf_t); + HLK_WAVEOP_TEST(WaveReadLaneAt, HLSLHalf_t); + HLK_WAVEOP_TEST(WaveReadLaneFirst, HLSLHalf_t); + HLK_WAVEOP_TEST(WavePrefixSum, HLSLHalf_t); + HLK_WAVEOP_TEST(WavePrefixProduct, HLSLHalf_t); HLK_WAVEOP_TEST(WaveActiveSum, float); HLK_WAVEOP_TEST(WaveActiveMin, float); HLK_WAVEOP_TEST(WaveActiveMax, float); HLK_WAVEOP_TEST(WaveActiveProduct, float); + HLK_WAVEOP_TEST(WaveActiveAllEqual, float); + HLK_WAVEOP_TEST(WaveReadLaneAt, float); + HLK_WAVEOP_TEST(WaveReadLaneFirst, float); + HLK_WAVEOP_TEST(WavePrefixSum, float); + HLK_WAVEOP_TEST(WavePrefixProduct, float); HLK_WAVEOP_TEST(WaveActiveSum, double); HLK_WAVEOP_TEST(WaveActiveMin, double); HLK_WAVEOP_TEST(WaveActiveMax, double); HLK_WAVEOP_TEST(WaveActiveProduct, double); + HLK_WAVEOP_TEST(WaveActiveAllEqual, double); + HLK_WAVEOP_TEST(WaveReadLaneAt, double); + HLK_WAVEOP_TEST(WaveReadLaneFirst, double); + HLK_WAVEOP_TEST(WavePrefixSum, double); + HLK_WAVEOP_TEST(WavePrefixProduct, double); private: bool Initialized = false; diff --git a/tools/clang/unittests/HLSLExec/ShaderOpArith.xml b/tools/clang/unittests/HLSLExec/ShaderOpArith.xml index b138011686..5bc9f7118e 100644 --- a/tools/clang/unittests/HLSLExec/ShaderOpArith.xml +++ b/tools/clang/unittests/HLSLExec/ShaderOpArith.xml @@ -4111,6 +4111,110 @@ void MSMain(uint GID : SV_GroupIndex, } #endif + #ifdef FUNC_WAVE_ACTIVE_BIT_AND + vector TestWaveActiveBitAnd(vector Vector) + { + if(WaveGetLaneIndex() == (WaveGetLaneCount() - 1)) + { + // Clear the LSB on the last lane only. + Vector = Vector & ~((OUT_TYPE)1); + } + return WaveActiveBitAnd(Vector); + } + #endif + + #ifdef FUNC_WAVE_ACTIVE_BIT_OR + vector TestWaveActiveBitOr(vector Vector) + { + if(WaveGetLaneIndex() == (WaveGetLaneCount() - 1)) + { + // Set the LSB on the last lane only. + Vector = Vector | ((OUT_TYPE)1); + } + return WaveActiveBitOr(Vector); + } + #endif + + #ifdef FUNC_WAVE_ACTIVE_BIT_XOR + vector TestWaveActiveBitXor(vector Vector) + { + const uint isChosen = (WaveGetLaneIndex() == 0) ? 1 : 0; + // Clear the LSB for all lanes except lane 0, which sets it to 1. + Vector = (Vector & ~((OUT_TYPE)1)) | (OUT_TYPE)isChosen; + + return WaveActiveBitOr(Vector); + } + #endif + + #ifdef FUNC_WAVE_ACTIVE_ALL_EQUAL + vector TestWaveActiveAllEqual(vector Vector) + { + if(WaveGetLaneIndex() == (WaveGetLaneCount() - 1)) + { + Vector[NUM - 1] = (TYPE)1337; + } + + return WaveActiveAllEqual(Vector); + } + #endif + + #ifdef FUNC_WAVE_READ_LANE_AT + vector TestWaveReadLaneAt(vector Vector) + { + // Keep it simple and just read the last lane. + const uint LaneToRead = WaveGetLaneCount() - 1; + if(WaveGetLaneIndex() == LaneToRead) + { + [unroll] + for(uint i = 1; i < NUM; ++i) + { + Vector[i] = Vector[0]; + } + } + return WaveReadLaneAt(Vector, LaneToRead); + } + #endif + + #ifdef FUNC_WAVE_READ_LANE_FIRST + vector TestWaveReadLaneFirst(vector Vector) + { + if(WaveGetLaneIndex() == 0) + { + [unroll] + for(uint i = 1; i < NUM; ++i) + { + Vector[i] = Vector[0]; + } + } + return WaveReadLaneFirst(Vector); + } + #endif + + #ifdef FUNC_WAVE_PREFIX_SUM + void TestWavePrefixSum(vector Vector) + { + const uint LaneCount = WaveGetLaneCount(); + const uint MidLane = LaneCount/2; + + Vector = WavePrefixSum(Vector); + if(WaveGetLaneIndex() == MidLane) + { + g_OutputVector.Store< vector >(0, Vector); + } + } + #endif + + #ifdef FUNC_WAVE_PREFIX_PRODUCT + void TestWavePrefixProduct(vector Vector) + { + Vector = WavePrefixProduct(Vector); + if(WaveGetLaneIndex() == 2) + { + g_OutputVector.Store< vector >(0, Vector); + } + } + #endif + #ifdef FUNC_TEST_SELECT vector TestSelect(vector Vector1, vector Vector2, @@ -4184,16 +4288,19 @@ void MSMain(uint GID : SV_GroupIndex, const uint32_t OutNum = NUM; #endif - #if IS_UNARY_OP - vector OutputVector = FUNC(Input1); + vector OutputVector; + #ifdef IS_WAVE_PREFIX_OP + // Wave prefix ops store the output on a specific lane only. + FUNC(Input1); + return; + #elif IS_UNARY_OP + OutputVector = FUNC(Input1); #elif IS_BINARY_OP - vector OutputVector = FUNC(Input1 OPERATOR - Input2); + OutputVector = FUNC(Input1 OPERATOR Input2); #elif IS_TERNARY_OP // Ternary ops don't bother expanding OPERATOR because its // always going to be comma for these test cases. - vector OutputVector = FUNC(Input1, Input2, - Input3); + OutputVector = FUNC(Input1, Input2, Input3); #endif g_OutputVector.Store< vector >(0, OutputVector);