Skip to content
Open
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions docs/DXIL.rst
Original file line number Diff line number Diff line change
Expand Up @@ -2426,6 +2426,8 @@ ID Name Description
309 VectorReduceAnd Bitwise AND reduction of the vector returning a scalar
310 VectorReduceOr Bitwise OR reduction of the vector returning a scalar
311 FDot computes the n-dimensional vector dot-product
312 GetGroupWaveIndex returns the index of the wave in the thread group
313 GetGroupWaveCount returns the number of waves in the thread group
=== ===================================================== =======================================================================================================================================================================================================================


Expand Down
3 changes: 3 additions & 0 deletions docs/ReleaseNotes.md
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,9 @@ The included licenses apply to the following files:
- Fixed regression: [#7508](https://github.com/microsoft/DirectXShaderCompiler/issues/7508) crash when calling `Load` with `status`.
- Header file `dxcpix.h` was added to the release package.
- Moved Linear Algebra (Cooperative Vector) DXIL Opcodes to experimental Shader Model 6.10
- Implement GetGroupWaveIndex and GetGroupWaveCount : [proposal](https://github.com/microsoft/hlsl-specs/blob/main/proposals/0048-group-wave-index.md)
- GetGroupWaveIndex: New intrinsic for Compute, Mesh, Amplification and Node shaders which returns the index of the wave within the thread group that the the thread is executing.
- GetGroupWaveCount: New intrinsic for Compute, Mesh, Amplification and Node shaders which returns the total number of waves executing within the thread group.

### Version 1.8.2505

Expand Down
16 changes: 12 additions & 4 deletions include/dxc/DXIL/DxilConstants.h
Original file line number Diff line number Diff line change
Expand Up @@ -656,6 +656,10 @@ enum class OpCode : unsigned {
// Graphics shader
ViewID = 138, // returns the view index

// Group Wave Ops
GetGroupWaveCount = 313, // returns the number of waves in the thread group
GetGroupWaveIndex = 312, // returns the index of the wave in the thread group

// Helper Lanes
IsHelperLane = 221, // returns true on helper lanes in pixel shaders

Expand Down Expand Up @@ -1087,9 +1091,9 @@ enum class OpCode : unsigned {
NumOpCodes_Dxil_1_6 = 222,
NumOpCodes_Dxil_1_7 = 226,
NumOpCodes_Dxil_1_8 = 258,
NumOpCodes_Dxil_1_9 = 312,
NumOpCodes_Dxil_1_9 = 314,

NumOpCodes = 312 // exclusive last value of enumeration
NumOpCodes = 314 // exclusive last value of enumeration
};
// OPCODE-ENUM:END

Expand Down Expand Up @@ -1194,6 +1198,10 @@ enum class OpCodeClass : unsigned {
// Graphics shader
ViewID,

// Group Wave Ops
GetGroupWaveCount,
GetGroupWaveIndex,

// Helper Lanes
IsHelperLane,

Expand Down Expand Up @@ -1423,9 +1431,9 @@ enum class OpCodeClass : unsigned {
NumOpClasses_Dxil_1_6 = 149,
NumOpClasses_Dxil_1_7 = 153,
NumOpClasses_Dxil_1_8 = 174,
NumOpClasses_Dxil_1_9 = 196,
NumOpClasses_Dxil_1_9 = 198,

NumOpClasses = 196 // exclusive last value of enumeration
NumOpClasses = 198 // exclusive last value of enumeration
};
// OPCODECLASS-ENUM:END

Expand Down
40 changes: 40 additions & 0 deletions include/dxc/DXIL/DxilInstructions.h
Original file line number Diff line number Diff line change
Expand Up @@ -10231,5 +10231,45 @@ struct DxilInst_FDot {
llvm::Value *get_b() const { return Instr->getOperand(2); }
void set_b(llvm::Value *val) { Instr->setOperand(2, val); }
};

/// This instruction returns the index of the wave in the thread group
struct DxilInst_GetGroupWaveIndex {
llvm::Instruction *Instr;
// Construction and identification
DxilInst_GetGroupWaveIndex(llvm::Instruction *pInstr) : Instr(pInstr) {}
operator bool() const {
return hlsl::OP::IsDxilOpFuncCallInst(Instr,
hlsl::OP::OpCode::GetGroupWaveIndex);
}
// Validation support
bool isAllowed() const { return true; }
bool isArgumentListValid() const {
if (1 != llvm::dyn_cast<llvm::CallInst>(Instr)->getNumArgOperands())
return false;
return true;
}
// Metadata
bool requiresUniformInputs() const { return false; }
};

/// This instruction returns the number of waves in the thread group
struct DxilInst_GetGroupWaveCount {
llvm::Instruction *Instr;
// Construction and identification
DxilInst_GetGroupWaveCount(llvm::Instruction *pInstr) : Instr(pInstr) {}
operator bool() const {
return hlsl::OP::IsDxilOpFuncCallInst(Instr,
hlsl::OP::OpCode::GetGroupWaveCount);
}
// Validation support
bool isAllowed() const { return true; }
bool isArgumentListValid() const {
if (1 != llvm::dyn_cast<llvm::CallInst>(Instr)->getNumArgOperands())
return false;
return true;
}
// Metadata
bool requiresUniformInputs() const { return false; }
};
// INSTR-HELPER:END
} // namespace hlsl
4 changes: 3 additions & 1 deletion include/dxc/HlslIntrinsicOp.h
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@ enum class IntrinsicOp {
IOP_EvaluateAttributeSnapped = 17,
IOP_GeometryIndex = 18,
IOP_GetAttributeAtVertex = 19,
IOP_GetGroupWaveCount = 395,
IOP_GetGroupWaveIndex = 396,
IOP_GetRemainingRecursionLevels = 20,
IOP_GetRenderTargetSampleCount = 21,
IOP_GetRenderTargetSamplePosition = 22,
Expand Down Expand Up @@ -401,7 +403,7 @@ enum class IntrinsicOp {
IOP_usign = 355,
MOP_InterlockedUMax = 356,
MOP_InterlockedUMin = 357,
Num_Intrinsics = 395,
Num_Intrinsics = 397,
};
inline bool HasUnsignedIntrinsicOpcode(IntrinsicOp opcode) {
switch (opcode) {
Expand Down
37 changes: 37 additions & 0 deletions lib/DXIL/DxilOperations.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2714,6 +2714,24 @@ const OP::OpCodeProperty OP::m_OpCodeProps[(unsigned)OP::OpCode::NumOpCodes] = {
1,
{{0x400}},
{{0x3}}}, // Overloads: <hf

// Group Wave Ops
{OC::GetGroupWaveIndex,
"GetGroupWaveIndex",
OCC::GetGroupWaveIndex,
"getGroupWaveIndex",
Attribute::ReadNone,
0,
{},
{}}, // Overloads: v
{OC::GetGroupWaveCount,
"GetGroupWaveCount",
OCC::GetGroupWaveCount,
"getGroupWaveCount",
Attribute::ReadNone,
0,
{},
{}}, // Overloads: v
};
// OPCODE-OLOADS:END

Expand Down Expand Up @@ -3542,6 +3560,13 @@ void OP::GetMinShaderModelAndMask(OpCode C, bool bWithTranslation,
minor = 10;
return;
}
// Instructions: GetGroupWaveIndex=312, GetGroupWaveCount=313
if ((312 <= op && op <= 313)) {
major = 6;
minor = 10;
mask = SFLAG(Compute) | SFLAG(Mesh) | SFLAG(Amplification) | SFLAG(Library);
return;
}
// OPCODE-SMMASK:END
}

Expand Down Expand Up @@ -6044,6 +6069,16 @@ Function *OP::GetOpFunc(OpCode opCode, Type *pOverloadType) {
A(pETy);
A(pETy);
break;

// Group Wave Ops
case OpCode::GetGroupWaveIndex:
A(pI32);
A(pI32);
break;
case OpCode::GetGroupWaveCount:
A(pI32);
A(pI32);
break;
// OPCODE-OLOAD-FUNCS:END
default:
DXASSERT(false, "otherwise unhandled case");
Expand Down Expand Up @@ -6335,6 +6370,8 @@ llvm::Type *OP::GetOverloadType(OpCode opCode, llvm::Function *F) {
case OpCode::ReservedC7:
case OpCode::ReservedC8:
case OpCode::ReservedC9:
case OpCode::GetGroupWaveIndex:
case OpCode::GetGroupWaveCount:
return Type::getVoidTy(Ctx);
case OpCode::CheckAccessFullyMapped:
case OpCode::SampleIndex:
Expand Down
14 changes: 14 additions & 0 deletions lib/DxilValidation/DxilValidation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -920,6 +920,20 @@ static void ValidateSignatureDxilOp(CallInst *CI, DXIL::OpCode Opcode,
{"DispatchMesh", "Amplification shader"});
}
} break;
case DXIL::OpCode::GetGroupWaveIndex:
case DXIL::OpCode::GetGroupWaveCount: {
bool IsCSLike = Props.shaderKind == DXIL::ShaderKind::Compute ||
Props.shaderKind == DXIL::ShaderKind::Mesh ||
Props.shaderKind == DXIL::ShaderKind::Amplification ||
Props.shaderKind == DXIL::ShaderKind::Node;
if (!IsCSLike) {
ValCtx.EmitInstrFormatError(CI, ValidationRule::SmOpcodeInInvalidFunction,
{Opcode == DXIL::OpCode::GetGroupWaveIndex
? "GetGroupWaveIndex"
: "GetGroupWaveCount",
"compute, mesh, or amplification shader"});
}
} break;
default:
break;
}
Expand Down
4 changes: 4 additions & 0 deletions lib/HLSL/HLOperationLower.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7513,6 +7513,10 @@ IntrinsicLower gLowerTable[] = {
{IntrinsicOp::IOP___builtin_VectorAccumulate, TranslateVectorAccumulate,
DXIL::OpCode::VectorAccumulate},
{IntrinsicOp::IOP_isnormal, TrivialIsSpecialFloat, DXIL::OpCode::IsNormal},
{IntrinsicOp::IOP_GetGroupWaveCount, TranslateWaveToVal,
DXIL::OpCode::GetGroupWaveCount},
{IntrinsicOp::IOP_GetGroupWaveIndex, TranslateWaveToVal,
DXIL::OpCode::GetGroupWaveIndex},
};
} // namespace
static_assert(
Expand Down
2 changes: 2 additions & 0 deletions tools/clang/lib/SPIRV/DeclResultIdMapper.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4231,6 +4231,8 @@ SpirvVariable *DeclResultIdMapper::getBuiltinVar(spv::BuiltIn builtIn,
case spv::BuiltIn::LocalInvocationIndex:
case spv::BuiltIn::RemainingRecursionLevelsAMDX:
case spv::BuiltIn::ShaderIndexAMDX:
case spv::BuiltIn::SubgroupId:
case spv::BuiltIn::NumSubgroups:
sc = spv::StorageClass::Input;
break;
case spv::BuiltIn::TaskCountNV:
Expand Down
16 changes: 16 additions & 0 deletions tools/clang/lib/SPIRV/SpirvEmitter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9454,6 +9454,22 @@ SpirvEmitter::processIntrinsicCallExpr(const CallExpr *callExpr) {
case hlsl::IntrinsicOp::IOP_WaveActiveCountBits:
retVal = processWaveCountBits(callExpr, spv::GroupOperation::Reduce);
break;
case hlsl::IntrinsicOp::IOP_GetGroupWaveIndex: {
featureManager.requestTargetEnv(SPV_ENV_VULKAN_1_1, "GetGroupWaveIndex",
srcLoc);
const QualType retType = callExpr->getCallReturnType(astContext);
auto *var =
declIdMapper.getBuiltinVar(spv::BuiltIn::SubgroupId, retType, srcLoc);
retVal = spvBuilder.createLoad(retType, var, srcLoc, srcRange);
} break;
case hlsl::IntrinsicOp::IOP_GetGroupWaveCount: {
featureManager.requestTargetEnv(SPV_ENV_VULKAN_1_1, "GetGroupWaveCount",
srcLoc);
const QualType retType = callExpr->getCallReturnType(astContext);
auto *var =
declIdMapper.getBuiltinVar(spv::BuiltIn::NumSubgroups, retType, srcLoc);
retVal = spvBuilder.createLoad(retType, var, srcLoc, srcRange);
} break;
case hlsl::IntrinsicOp::IOP_WaveActiveUSum:
case hlsl::IntrinsicOp::IOP_WaveActiveSum:
case hlsl::IntrinsicOp::IOP_WaveActiveUProduct:
Expand Down
11 changes: 11 additions & 0 deletions tools/clang/test/CodeGenSPIRV/sm6_10.group-wave-count.hlsl
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
// RUN: %dxc -T cs_6_10 -E main -fspv-target-env=vulkan1.1 -fcgl %s -spirv | FileCheck %s

// CHECK: ; Version: 1.3

RWStructuredBuffer<uint> output: register(u0);

[numthreads(64, 1, 1)]
void main(uint3 id: SV_DispatchThreadID) {

output[id.x] = GetGroupWaveCount();
}
11 changes: 11 additions & 0 deletions tools/clang/test/CodeGenSPIRV/sm6_10.group-wave-index.hlsl
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
// RUN: %dxc -T cs_6_10 -E main -fspv-target-env=vulkan1.1 -fcgl %s -spirv | FileCheck %s

// CHECK: ; Version: 1.3

RWStructuredBuffer<uint> output: register(u0);

[numthreads(64, 1, 1)]
void main(uint3 id: SV_DispatchThreadID) {

output[id.x] = GetGroupWaveIndex();
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
// RUN: %dxc -T cs_6_10 -E main %s | FileCheck %s

// CHECK: call i32 @dx.op.getGroupWaveIndex(i32 312
// CHECK: call i32 @dx.op.getGroupWaveCount(i32 313

RWStructuredBuffer<uint> output0 : register(u0);
RWStructuredBuffer<uint> output1 : register(u1);

[numthreads(64, 1, 1)]
void main(uint3 id: SV_DispatchThreadID) {
uint waveIdx = GetGroupWaveIndex();
uint waveCount = GetGroupWaveCount();

output0[id.x] = waveIdx;
output1[id.x] = waveCount;
}
2 changes: 2 additions & 0 deletions utils/hct/gen_intrin_main.txt
Original file line number Diff line number Diff line change
Expand Up @@ -303,6 +303,8 @@ $type1 [[]] QuadReadAcrossY(in numeric<> value);
$type1 [[]] QuadReadAcrossDiagonal(in numeric<> value);
bool [[]] QuadAny(in bool cond);
bool [[]] QuadAll(in bool cond);
uint [[rn]] GetGroupWaveIndex();
uint [[rn]] GetGroupWaveCount();

// Raytracing
void [[]] TraceRay(in acceleration_struct AccelerationStructure, in uint RayFlags, in uint InstanceInclusionMask, in uint RayContributionToHitGroupIndex, in uint MultiplierForGeometryContributionToHitGroupIndex, in uint MissShaderIndex, in ray_desc Ray, inout udt Payload);
Expand Down
26 changes: 24 additions & 2 deletions utils/hct/hctdb.py
Original file line number Diff line number Diff line change
Expand Up @@ -609,6 +609,10 @@ def populate_categories_and_models(self):
)
elif i.name.startswith("Bitcast"):
i.category = "Bitcasts with different sizes"
elif i.name.startswith("GetGroupWave"):
i.category = "Group Wave Ops"
i.shader_model = 6, 10
i.shader_stages = ("compute", "mesh", "amplification", "library")
for i in "ViewID,AttributeAtVertex".split(","):
self.name_idx[i].shader_model = 6, 1
for i in "RawBufferLoad,RawBufferStore".split(","):
Expand Down Expand Up @@ -5882,10 +5886,28 @@ def UFI(name, **mappings):
counters=("floats",),
)

# Group Wave Operations
self.add_dxil_op(
"GetGroupWaveIndex",
"GetGroupWaveIndex",
"returns the index of the wave in the thread group",
"v",
"rn",
[db_dxil_param(0, "i32", "", "operation result")],
)
self.add_dxil_op(
"GetGroupWaveCount",
"GetGroupWaveCount",
"returns the number of waves in the thread group",
"v",
"rn",
[db_dxil_param(0, "i32", "", "operation result")],
)

# End of DXIL 1.9 opcodes.
op_count = self.set_op_count_for_version(1, 9)
assert op_count == 312, (
"312 is expected next operation index but encountered %d and thus opcodes are broken"
assert op_count == 314, (
"314 is expected next operation index but encountered %d and thus opcodes are broken"
% op_count
)

Expand Down
6 changes: 4 additions & 2 deletions utils/hct/hlsl_intrinsic_opcodes.json
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
{
"IntrinsicOpCodes": {
"Num_Intrinsics": 395,
"Num_Intrinsics": 397,
"IOP_AcceptHitAndEndSearch": 0,
"IOP_AddUint64": 1,
"IOP_AllMemoryBarrier": 2,
Expand Down Expand Up @@ -395,6 +395,8 @@
"IOP___builtin_MatVecMulAdd": 391,
"IOP___builtin_OuterProductAccumulate": 392,
"IOP___builtin_VectorAccumulate": 393,
"IOP_isnormal": 394
"IOP_isnormal": 394,
"IOP_GetGroupWaveCount": 395,
"IOP_GetGroupWaveIndex": 396
}
}