summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--source/slang/slang-ir-inst-defs.h1
-rw-r--r--source/slang/slang-ir-specialize-function-call.cpp93
-rw-r--r--tests/compute/nonuniformres-as-function-parameter.slang141
3 files changed, 233 insertions, 2 deletions
diff --git a/source/slang/slang-ir-inst-defs.h b/source/slang/slang-ir-inst-defs.h
index 84ee634a1..39de083f0 100644
--- a/source/slang/slang-ir-inst-defs.h
+++ b/source/slang/slang-ir-inst-defs.h
@@ -1173,6 +1173,7 @@ INST_RANGE(Layout, VarLayout, EntryPointLayout)
INST(UNormAttr, unorm, 0, HOISTABLE)
INST(SNormAttr, snorm, 0, HOISTABLE)
INST(NoDiffAttr, no_diff, 0, HOISTABLE)
+ INST(NonUniformAttr, nonuniform, 0, HOISTABLE)
/* SemanticAttr */
INST(UserSemanticAttr, userSemantic, 2, HOISTABLE)
diff --git a/source/slang/slang-ir-specialize-function-call.cpp b/source/slang/slang-ir-specialize-function-call.cpp
index a41ca1e99..98aba0fae 100644
--- a/source/slang/slang-ir-specialize-function-call.cpp
+++ b/source/slang/slang-ir-specialize-function-call.cpp
@@ -5,6 +5,7 @@
#include "slang-ir-clone.h"
#include "slang-ir-insts.h"
#include "slang-ir-ssa-simplification.h"
+#include "slang-ir-util.h"
namespace Slang
{
@@ -363,7 +364,7 @@ struct FunctionParameterSpecializationContext
// a new callee function based on the original
// function and the information we gathered.
//
- newFunc = generateSpecializedFunc(oldFunc, funcInfo);
+ newFunc = generateSpecializedFunc(oldFunc, funcInfo, callInfo);
specializedFuncs.add(callInfo.key, newFunc);
}
@@ -381,6 +382,7 @@ struct FunctionParameterSpecializationContext
newCall->insertBefore(oldCall);
oldCall->replaceUsesWith(newCall);
oldCall->removeAndDeallocate();
+
}
// Before diving into the details on how we gather information
@@ -559,6 +561,21 @@ struct FunctionParameterSpecializationContext
// the arguments at the new call site, and
// don't add anything to the specialization key.
//
+ // We should also add 2 more things such that our specialization
+ // can handle the corner cases that if the oldBase is a nonuniform
+ // resource and also the data type of oldIndex will be handled correctly.
+ // By doing so, we form an IRAttributedType to include both information
+ // and add it to the key of call info.
+
+ List<IRAttr*> irAttrs;
+ if (findNonuniformIndexInst(oldIndex))
+ {
+ IRAttr* attr = getBuilder()->getAttr(kIROp_NonUniformAttr);
+ irAttrs.add(attr);
+ }
+ auto irType = getBuilder()->getAttributedType(oldIndex->getDataType(), irAttrs);
+ ioInfo.key.vals.add(irType);
+
ioInfo.newArgs.add(oldIndex);
}
else if (oldArg->getOp() == kIROp_Load)
@@ -577,6 +594,27 @@ struct FunctionParameterSpecializationContext
}
}
+ IRInst* findNonuniformIndexInst(IRInst* inst)
+ {
+ while(1)
+ {
+ if (inst == nullptr)
+ return nullptr;
+
+ if (inst->getOp() == kIROp_NonUniformResourceIndex)
+ return inst;
+
+ if (inst->getOp() == kIROp_IntCast)
+ {
+ inst = inst->getOperand(0);
+ }
+ else
+ {
+ return nullptr;
+ }
+ }
+ }
+
// The remaining information we've discussed is only
// gathered once we decide we want to generate a
// specialized function, but it follows much the same flow.
@@ -803,7 +841,8 @@ struct FunctionParameterSpecializationContext
//
IRFunc* generateSpecializedFunc(
IRFunc* oldFunc,
- FuncSpecializationInfo const& funcInfo)
+ FuncSpecializationInfo const& funcInfo,
+ CallSpecializationInfo const& callInfo)
{
// We will make use of the infrastructure for cloning
// IR code, that is defined in `ir-clone.{h,cpp}`.
@@ -933,6 +972,18 @@ struct FunctionParameterSpecializationContext
newBodyInst->insertBefore(newFirstOrdinary);
}
+ // We need to handle a corner case where the new argument of
+ // the callee of this specialized function could be a use of
+ // NonUniformResourceIndex(), in such case, any indexing operation
+ // on the global buffer by using this new argument should be
+ // decorated with NonUniformDecoration. However, inside the new
+ // specialized function, we don't have that information anymore.
+ // Therefore, we will need to scan the new argument list to find out
+ // this case, and insert the NonUniformResourceIndex() instruction
+ // on the corresponding parameter of the new specialized function.
+ maybeInsertNonUniformResourceIndex(newFunc, funcInfo, callInfo);
+
+
// At this point we've created a new specialized function,
// and as such it may contain call sites that were not
// covered when we built our initial work list.
@@ -964,6 +1015,44 @@ struct FunctionParameterSpecializationContext
return newFunc;
}
+
+ void maybeInsertNonUniformResourceIndex(
+ IRFunc* newFunc,
+ FuncSpecializationInfo const& funcInfo,
+ CallSpecializationInfo const& callInfo)
+ {
+ auto builder = getBuilder();
+ uint32_t paramIndex = 0;
+
+ SLANG_ASSERT(callInfo.newArgs.getCount() == funcInfo.newParams.getCount());
+
+ // Iterate over the new arguments, new parameters pair, and
+ // find out if there is any use of NonUniformResourceIndex()
+ // in the new arguments.
+ for (auto newArg : callInfo.newArgs)
+ {
+ if (auto nonuniformIdxInst = findNonuniformIndexInst(newArg))
+ {
+ auto firstOrdinary = newFunc->getFirstOrdinaryInst();
+
+ IRCloneEnv cloneEnv;
+ auto newParam = funcInfo.newParams[paramIndex];
+
+ // Clone the NonUniformResourceIndex call and insert it at beginning
+ // of the function. Then replace every use of the parameter with the
+ // NonUniformResourceIndex.
+ auto clonedInst = cloneInstAndOperands(&cloneEnv, builder, nonuniformIdxInst);
+ clonedInst->insertBefore(firstOrdinary);
+ newParam->replaceUsesWith(clonedInst);
+
+ // At last, set the operand of the NonUniformResourceIndex to the new parameter
+ // because we haven't done it yet during inst clone.
+ clonedInst->setOperand(0, newParam);
+ }
+ paramIndex++;
+ }
+
+ }
};
// The top-level function for invoking the specialization pass
diff --git a/tests/compute/nonuniformres-as-function-parameter.slang b/tests/compute/nonuniformres-as-function-parameter.slang
new file mode 100644
index 000000000..daffc848e
--- /dev/null
+++ b/tests/compute/nonuniformres-as-function-parameter.slang
@@ -0,0 +1,141 @@
+//TEST:SIMPLE(filecheck=CHECK_SPV):-target spirv -entry main -stage compute
+//TEST:SIMPLE(filecheck=CHECK_GLSL_SPV):-target spirv -entry main -stage compute -emit-spirv-via-glsl
+//TEST:SIMPLE(filecheck=CHECK_GLSL):-target glsl -entry main -stage compute
+//TEST:SIMPLE(filecheck=CHECK_HLSL):-target hlsl -entry main -stage compute
+RWStructuredBuffer<uint> globalBuffer[] : register(t0, space0);
+RWStructuredBuffer<uint3> outputBuffer;
+
+struct MyStruct
+{
+ uint a;
+ uint b;
+ uint c;
+};
+
+
+MyStruct func(RWStructuredBuffer<uint> buffer)
+{
+ MyStruct a;
+
+ // CHECK_GLSL: globalBuffer_0[nonuniformEXT({{.*}})]
+ // CHECK_GLSL: globalBuffer_0[nonuniformEXT({{.*}})]
+
+ // For the last test case 3 that the callee passes globalBuffer[bufferIdx3] to the function,
+ // we should not see nonuniformEXT here.
+
+ // CHECK_GLSL: globalBuffer_0[_{{.*}})]
+ // CHECK_GLSL: globalBuffer_0[_{{.*}})]
+ a.a = buffer[0];
+ a.b = a.a + 1;
+ a.c = a.a + a.b + 1;
+
+ return a;
+}
+
+[shader("compute")]
+[numthreads(1, 1, 1)]
+void main(uint2 pixelIndex : SV_DispatchThreadID)
+{
+
+ // CHECK_SPV: OpDecorate %[[VAR1:[a-zA-Z0-9_]+]] NonUniform
+ // CHECK_SPV: OpDecorate %[[VAR2:[a-zA-Z0-9_]+]] NonUniform
+ // CHECK_SPV: OpDecorate %[[VAR3:[a-zA-Z0-9_]+]] NonUniform
+
+
+ // CHECK_GLSL_SPV: OpDecorate %[[VAR1:[a-zA-Z0-9_]+]] NonUniform
+ // CHECK_GLSL_SPV: OpDecorate %[[VAR2:[a-zA-Z0-9_]+]] NonUniform
+ // CHECK_GLSL_SPV: OpDecorate %[[VAR3:[a-zA-Z0-9_]+]] NonUniform
+ // CHECK_GLSL_SPV: OpDecorate %[[VAR4:[a-zA-Z0-9_]+]] NonUniform
+ // CHECK_GLSL_SPV: OpDecorate %[[VAR5:[a-zA-Z0-9_]+]] NonUniform
+ // CHECK_GLSL_SPV: OpDecorate %[[VAR6:[a-zA-Z0-9_]+]] NonUniform
+ // CHECK_GLSL_SPV: OpDecorate %[[VAR7:[a-zA-Z0-9_]+]] NonUniform
+ // CHECK_GLSL_SPV: OpDecorate %[[VAR8:[a-zA-Z0-9_]+]] NonUniform
+ // CHECK_GLSL_SPV: OpDecorate %[[VAR9:[a-zA-Z0-9_]+]] NonUniform
+ // CHECK_GLSL_SPV: OpDecorate %[[VAR10:[a-zA-Z0-9_]+]] NonUniform
+ // CHECK_GLSL_SPV: OpDecorate %[[VAR11:[a-zA-Z0-9_]+]] NonUniform
+ // CHECK_GLSL_SPV: OpDecorate %[[VAR12:[a-zA-Z0-9_]+]] NonUniform
+ // CHECK_GLSL_SPV: OpDecorate %[[VAR13:[a-zA-Z0-9_]+]] NonUniform
+ // CHECK_GLSL_SPV: OpDecorate %[[VAR14:[a-zA-Z0-9_]+]] NonUniform
+ // CHECK_GLSL_SPV: OpDecorate %[[VAR15:[a-zA-Z0-9_]+]] NonUniform
+ // CHECK_GLSL_SPV: OpDecorate %[[VAR16:[a-zA-Z0-9_]+]] NonUniform
+ // CHECK_GLSL_SPV: OpDecorate %[[VAR17:[a-zA-Z0-9_]+]] NonUniform
+ // CHECK_GLSL_SPV: OpDecorate %[[VAR18:[a-zA-Z0-9_]+]] NonUniform
+
+
+ // Test case 1: slang will specialize the func call to 'MyStruct func(uint)'
+ uint bufferIdx = pixelIndex.x;
+ uint nonUniformIdx = NonUniformResourceIndex(bufferIdx);
+ RWStructuredBuffer<uint> buffer = globalBuffer[nonUniformIdx];
+
+ // CHECK_SPV: %[[VAR1]] = OpAccessChain %_ptr_StorageBuffer_RWStructuredBuffer{{.*}} %{{.*}} %bufferIdx
+
+ // CHECK_GLSL_SPV: %[[VAR1]] = OpCopyObject %uint %{{.*}}
+
+ // CHECK_GLSL_SPV: %[[VAR4]] = OpCopyObject %uint %[[VAR1]]
+ // CHECK_GLSL_SPV: %[[VAR5]] = OpAccessChain %_ptr_Uniform_uint %globalBuffer_0 %[[VAR4]] %int_0 %int_0
+ // CHECK_GLSL_SPV: %[[VAR6]] = OpLoad %uint %[[VAR5]]
+
+ // CHECK_GLSL_SPV: %[[VAR7]] = OpCopyObject %uint %[[VAR1]]
+ // CHECK_GLSL_SPV: %[[VAR8]] = OpAccessChain %_ptr_Uniform_uint %globalBuffer_0 %[[VAR7]] %int_0 %int_0
+ // CHECK_GLSL_SPV: %[[VAR9]] = OpLoad %uint %[[VAR8]]
+
+ // CHECK_GLSL: func_0({{.*}}nonuniformEXT({{.*}}))
+ // CHECK_HLSL: func_0(globalBuffer_0[NonUniformResourceIndex({{.*}})])
+ MyStruct myStruct = func(buffer);
+
+ int bufferIdx2 = pixelIndex.y;
+
+ // Test case 2: Make sure we cover the case for the different data type of the index.
+ // In this case, slang will specialize the function to 'MyStruct func(int)'
+ // CHECK_SPV: %[[VAR2]] = OpAccessChain %_ptr_StorageBuffer_RWStructuredBuffer{{.*}} %{{.*}} %bufferIdx2
+
+
+ // CHECK_GLSL_SPV: %[[VAR2]] = OpCopyObject %int %{{.*}}
+
+ // CHECK_GLSL-SPV: %[[VAR10]] = OpCopyObject %int %[[VAR2]]
+ // CHECK_GLSL-SPV: %[[VAR11]] = OpAccessChain %_ptr_Uniform_uint %globalBuffer_0 %[[VAR10]] %int_0 %int_0
+ // CHECK_GLSL-SPV: %[[VAR12]] = OpLoad %uint %[[VAR11]]
+
+ // CHECK_GLSL-SPV: %[[VAR13]] = OpCopyObject %int %[[VAR2]]
+ // CHECK_GLSL-SPV: %[[VAR14]] = OpAccessChain %_ptr_Uniform_uint %globalBuffer_0 %[[VAR13]] %int_0 %int_0
+ // CHECK_GLSL-SPV: %[[VAR15]] = OpLoad %uint %[[VAR14]]
+ RWStructuredBuffer<uint> buffer2 = globalBuffer[NonUniformResourceIndex(bufferIdx2)];
+
+ // CHECK_GLSL: func_1({{.*}}nonuniformEXT({{.*}}))
+ // CHECK_HLSL: func_0(globalBuffer_0[NonUniformResourceIndex({{.*}})])
+ MyStruct myStruct2 = func(buffer2);
+
+ // Test case 3: Test the case that we handle the uniformity correctly, the NonUniformResourceIndex will not propagate
+ // to the function, so there should no NonUniform decoration appeared.
+ int bufferIdx3 = pixelIndex.y;
+ RWStructuredBuffer<uint> buffer3 = globalBuffer[bufferIdx3];
+
+ // CHECK_SPV: %[[VAR4:[a-zA-Z0-9_]+]] = OpAccessChain %_ptr_StorageBuffer_RWStructuredBuffer{{.*}} %{{.*}} %bufferIdx2
+
+ // Test to make sure this command is not decorated with NonUniform:
+ // CHECK_SPV-NOT: OpDecorate %[[VAR4]] NonUniform
+ MyStruct myStruct3 = func(buffer3);
+
+
+ // Test case 4: Test to make sure we correctly cover the case that intCast or uintCast of a NonUniformResourceIndex
+ // is still a NonUniformResourceIndex.
+
+ // CHECK_SPV: %[[VAR5:[a-zA-Z0-9_]+]] = OpBitcast %uint %{{.*}}
+ // CHECK_SPV: %[[VAR3]] = OpAccessChain %_ptr_StorageBuffer_RWStructuredBuffer{{.*}} %{{.*}} %[[VAR5]]
+
+ // CHECK_GLSL-SPV: %[[VAR19:[a-zA-Z0-9_]+]] = OpBitcast %int %[[VAR3]]
+ // CHECK_GLSL-SPV: %[[VAR16]] = OpCopyObject %int %[[VAR19]]
+ // CHECK_GLSL-SPV: %[[VAR17]] = OpAccessChain %_ptr_Uniform_uint %globalBuffer_0 %[[VAR16]] %int_0 %int_0
+ // CHECK_GLSL-SPV: %[[VAR18]] = OpLoad %uint %[[VAR17]]
+ //
+ // Since after the nested cast, the index data type is 'uint' now, make sure it calls the same function as the test case 1.
+ // CHECK_GLSL: func_0({{.*}}nonuniformEXT({{.*}}))
+ RWStructuredBuffer<uint> buffer4 = globalBuffer[(uint)((int)NonUniformResourceIndex(bufferIdx))];
+ MyStruct myStruct4 = func(buffer4);
+
+ outputBuffer[0] = uint3(myStruct.a, myStruct.b, myStruct.c);
+ outputBuffer[1] = uint3(myStruct2.a, myStruct2.b, myStruct2.c);
+ outputBuffer[2] = uint3(myStruct3.a, myStruct3.b, myStruct3.c);
+ outputBuffer[3] = uint3(myStruct4.a, myStruct4.b, myStruct4.c);
+}
+