Skip to content

Commit a699e7e

Browse files
authored
[SPIR-V] Support sign() intrinsics for unsigned integer. (#7845)
Fixes #7755
1 parent b3619bf commit a699e7e

File tree

3 files changed

+154
-0
lines changed

3 files changed

+154
-0
lines changed

tools/clang/lib/SPIRV/SpirvEmitter.cpp

Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9352,6 +9352,9 @@ SpirvEmitter::processIntrinsicCallExpr(const CallExpr *callExpr) {
93529352
case hlsl::IntrinsicOp::IOP_printf:
93539353
retVal = processIntrinsicPrintf(callExpr);
93549354
break;
9355+
case hlsl::IntrinsicOp::IOP_usign:
9356+
retVal = processIntrinsicSignUnsignedInt(callExpr);
9357+
break;
93559358
case hlsl::IntrinsicOp::IOP_sign: {
93569359
if (isFloatOrVecMatOfFloatType(callExpr->getArg(0)->getType()))
93579360
retVal = processIntrinsicFloatSign(callExpr);
@@ -12656,6 +12659,80 @@ SpirvEmitter::processIntrinsicSaturate(const CallExpr *callExpr) {
1265612659
return nullptr;
1265712660
}
1265812661

12662+
SpirvInstruction *
12663+
SpirvEmitter::processIntrinsicSignUnsignedInt(const CallExpr *callExpr) {
12664+
const auto srcLoc = callExpr->getExprLoc();
12665+
const auto srcRange = callExpr->getSourceRange();
12666+
12667+
const Expr *firstArg = callExpr->getArg(0);
12668+
const QualType firstArgType = firstArg->getType();
12669+
auto elemType = QualType{};
12670+
uint32_t numRows;
12671+
uint32_t numCols;
12672+
uint32_t count;
12673+
bool isScalar =
12674+
isScalarType(firstArgType, &elemType) ||
12675+
(isVectorType(firstArgType, &elemType, &count) && count == 1) ||
12676+
(isMxNMatrix(firstArgType, &elemType, &numRows, &numCols) &&
12677+
(numRows == 1 && numCols == 1));
12678+
12679+
auto *zero = getValueZero(astContext.IntTy);
12680+
auto *one = getValueOne(astContext.IntTy);
12681+
if (isScalar) {
12682+
auto *argVal = doExpr(callExpr->getArg(0));
12683+
auto *zeroUint = getValueZero(callExpr->getArg(0)->getType());
12684+
auto *cmp =
12685+
spvBuilder.createBinaryOp(spv::Op::OpUGreaterThan, astContext.BoolTy,
12686+
argVal, zeroUint, srcLoc, srcRange);
12687+
return spvBuilder.createSelect(astContext.IntTy, cmp, one, zero, srcLoc,
12688+
srcRange);
12689+
}
12690+
12691+
uint32_t size;
12692+
if (isVectorType(firstArgType)) {
12693+
size = count;
12694+
} else if (is1xNMatrix(firstArgType)) {
12695+
size = numCols;
12696+
} else if (isMx1Matrix(firstArgType)) {
12697+
size = numRows;
12698+
} else {
12699+
size = numRows;
12700+
}
12701+
12702+
const auto actOnEachVec = [this, srcLoc, srcRange, zero, one, elemType,
12703+
size](uint32_t index, QualType inType,
12704+
QualType outType, SpirvInstruction *curRow) {
12705+
auto zeroUint = getValueZero(elemType);
12706+
// Create `size` vector of uint zeros.
12707+
auto *zerosUint = spvBuilder.getConstantComposite(
12708+
astContext.getExtVectorType(elemType, size),
12709+
std::vector<clang::spirv::SpirvConstant *>(size, zeroUint));
12710+
// Compare if they are greater than zero.
12711+
auto *cmp = spvBuilder.createBinaryOp(
12712+
spv::Op::OpUGreaterThan,
12713+
astContext.getExtVectorType(astContext.BoolTy, size), curRow, zerosUint,
12714+
srcLoc, srcRange);
12715+
12716+
// Create a vector of int ones and zeros.
12717+
auto *zeros = spvBuilder.getConstantComposite(
12718+
astContext.getExtVectorType(astContext.IntTy, size),
12719+
std::vector<clang::spirv::SpirvConstant *>(size, zero));
12720+
auto *ones = spvBuilder.getConstantComposite(
12721+
astContext.getExtVectorType(astContext.IntTy, size),
12722+
std::vector<clang::spirv::SpirvConstant *>(size, one));
12723+
// Select between ones and zeros based on the comparison.
12724+
return spvBuilder.createSelect(
12725+
astContext.getExtVectorType(astContext.IntTy, size), cmp, ones, zeros,
12726+
srcLoc, srcRange);
12727+
};
12728+
12729+
if (isVectorType(firstArgType)) {
12730+
return actOnEachVec(0, firstArgType, callExpr->getType(), doExpr(firstArg));
12731+
}
12732+
return processEachVectorInMatrix(firstArg, doExpr(firstArg), actOnEachVec,
12733+
srcLoc, srcRange);
12734+
}
12735+
1265912736
SpirvInstruction *
1266012737
SpirvEmitter::processIntrinsicFloatSign(const CallExpr *callExpr) {
1266112738
// Import the GLSL.std.450 extended instruction set.

tools/clang/lib/SPIRV/SpirvEmitter.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -648,6 +648,9 @@ class SpirvEmitter : public ASTConsumer {
648648
/// Processes the 'ReadClock' intrinsic function.
649649
SpirvInstruction *processIntrinsicReadClock(const CallExpr *);
650650

651+
/// Processes the 'sign' intrinsic function for unsigned integer types.
652+
SpirvInstruction *processIntrinsicSignUnsignedInt(const CallExpr *callExpr);
653+
651654
/// Processes the 'sign' intrinsic function for float types.
652655
/// The FSign instruction in the GLSL instruction set returns a floating point
653656
/// result. The HLSL sign function, however, returns an integer. An extra
Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,74 @@
1+
// RUN: %dxc -T vs_6_0 -E main -fcgl -Vd %s -spirv | FileCheck %s
2+
3+
// CHECK-DAG: %int_0 = OpConstant %int 0
4+
// CHECK-DAG: %int_1 = OpConstant %int 1
5+
// CHECK-DAG: %uint_0 = OpConstant %uint 0
6+
// CHECK-DAG: %v3int = OpTypeVector %int 3
7+
// CHECK-DAG: %v3uint = OpTypeVector %uint 3
8+
// CHECK-DAG: [[zeros_uint3:%[0-9]+]] = OpConstantComposite %v3uint %uint_0 %uint_0 %uint_0
9+
// CHECK-DAG: [[zeros_int3:%[0-9]+]] = OpConstantComposite %v3int %int_0 %int_0 %int_0
10+
// CHECK-DAG: [[ones_int3:%[0-9]+]] = OpConstantComposite %v3int %int_1 %int_1 %int_1
11+
12+
void main() {
13+
int result;
14+
int3 result3;
15+
int3x3 result3x3;
16+
17+
// CHECK: [[a:%[0-9]+]] = OpLoad %uint %a
18+
// CHECK-NEXT: [[cmp_a:%[0-9]+]] = OpUGreaterThan %bool [[a]] %uint_0
19+
// CHECK-NEXT: [[select_a:%[0-9]+]] = OpSelect %int [[cmp_a]] %int_1 %int_0
20+
// CHECK-NEXT: OpStore %result [[select_a]]
21+
uint a;
22+
result = sign(a);
23+
24+
// CHECK: [[b:%[0-9]+]] = OpLoad %uint %b
25+
// CHECK-NEXT: [[cmp_b:%[0-9]+]] = OpUGreaterThan %bool [[b]] %uint_0
26+
// CHECK-NEXT: [[select_b:%[0-9]+]] = OpSelect %int [[cmp_b]] %int_1 %int_0
27+
// CHECK-NEXT: OpStore %result [[select_b]]
28+
uint1 b;
29+
result = sign(b);
30+
31+
// CHECK: [[c:%[0-9]+]] = OpLoad %v3uint %c
32+
// CHECK-NEXT: [[cmp_c:%[0-9]+]] = OpUGreaterThan %v3bool [[c]] [[zeros_uint3]]
33+
// CHECK-NEXT: [[select_c:%[0-9]+]] = OpSelect %v3int [[cmp_c]] [[ones_int3]] [[zeros_int3]]
34+
// CHECK-NEXT: OpStore %result3 [[select_c]]
35+
uint3 c;
36+
result3 = sign(c);
37+
38+
39+
// CHECK: [[d:%[0-9]+]] = OpLoad %uint %d
40+
// CHECK-NEXT: [[cmp_d:%[0-9]+]] = OpUGreaterThan %bool [[d]] %uint_0
41+
// CHECK-NEXT: [[select_d:%[0-9]+]] = OpSelect %int [[cmp_d]] %int_1 %int_0
42+
// CHECK-NEXT: OpStore %result [[select_d]]
43+
uint1x1 d;
44+
result = sign(d);
45+
46+
// CHECK: [[e:%[0-9]+]] = OpLoad %v3uint %e
47+
// CHECK-NEXT: [[cmp_e:%[0-9]+]] = OpUGreaterThan %v3bool [[e]] [[zeros_uint3]]
48+
// CHECK-NEXT: [[select_e:%[0-9]+]] = OpSelect %v3int [[cmp_e]] [[ones_int3]] [[zeros_int3]]
49+
// CHECK-NEXT: OpStore %result3 [[select_e]]
50+
uint1x3 e;
51+
result3 = sign(e);
52+
53+
// CHECK: [[f:%[0-9]+]] = OpLoad %v3uint %f
54+
// CHECK-NEXT: [[cmp_f:%[0-9]+]] = OpUGreaterThan %v3bool [[f]] [[zeros_uint3]]
55+
// CHECK-NEXT: [[select_f:%[0-9]+]] = OpSelect %v3int [[cmp_f]] [[ones_int3]] [[zeros_int3]]
56+
// CHECK-NEXT: OpStore %result3 [[select_f]]
57+
uint3x1 f;
58+
result3 = sign(f);
59+
60+
// CHECK: [[h:%[0-9]+]] = OpLoad %_arr_v3uint_uint_3 %h
61+
// CHECK-NEXT: [[h_row0:%[0-9]+]] = OpCompositeExtract %v3uint [[h]] 0
62+
// CHECK-NEXT: [[cmp_h_row0:%[0-9]+]] = OpUGreaterThan %v3bool [[h_row0]] [[zeros_uint3]]
63+
// CHECK-NEXT: [[select_h_row0:%[0-9]+]] = OpSelect %v3int [[cmp_h_row0]] [[ones_int3]] [[zeros_int3]]
64+
// CHECK-NEXT: [[h_row1:%[0-9]+]] = OpCompositeExtract %v3uint [[h]] 1
65+
// CHECK-NEXT: [[cmp_h_row1:%[0-9]+]] = OpUGreaterThan %v3bool [[h_row1]] [[zeros_uint3]]
66+
// CHECK-NEXT: [[select_h_row1:%[0-9]+]] = OpSelect %v3int [[cmp_h_row1]] [[ones_int3]] [[zeros_int3]]
67+
// CHECK-NEXT: [[h_row2:%[0-9]+]] = OpCompositeExtract %v3uint [[h]] 2
68+
// CHECK-NEXT: [[cmp_h_row2:%[0-9]+]] = OpUGreaterThan %v3bool [[h_row2]] [[zeros_uint3]]
69+
// CHECK-NEXT: [[select_h_row2:%[0-9]+]] = OpSelect %v3int [[cmp_h_row2]] [[ones_int3]] [[zeros_int3]]
70+
// CHECK-NEXT: [[select_h:%[0-9]+]] = OpCompositeConstruct %_arr_v3uint_uint_3 [[select_h_row0]] [[select_h_row1]] [[select_h_row2]]
71+
// CHECK-NEXT: OpStore %result3x3 [[select_h]]
72+
uint3x3 h;
73+
result3x3 = sign(h);
74+
}

0 commit comments

Comments
 (0)