diff options
| -rw-r--r-- | source/slang/slang-ir-inst-defs.h | 1 | ||||
| -rw-r--r-- | source/slang/slang-ir-specialize-function-call.cpp | 93 | ||||
| -rw-r--r-- | tests/compute/nonuniformres-as-function-parameter.slang | 141 |
3 files changed, 233 insertions, 2 deletions
diff --git a/source/slang/slang-ir-inst-defs.h b/source/slang/slang-ir-inst-defs.h index 84ee634a1..39de083f0 100644 --- a/source/slang/slang-ir-inst-defs.h +++ b/source/slang/slang-ir-inst-defs.h @@ -1173,6 +1173,7 @@ INST_RANGE(Layout, VarLayout, EntryPointLayout) INST(UNormAttr, unorm, 0, HOISTABLE) INST(SNormAttr, snorm, 0, HOISTABLE) INST(NoDiffAttr, no_diff, 0, HOISTABLE) + INST(NonUniformAttr, nonuniform, 0, HOISTABLE) /* SemanticAttr */ INST(UserSemanticAttr, userSemantic, 2, HOISTABLE) diff --git a/source/slang/slang-ir-specialize-function-call.cpp b/source/slang/slang-ir-specialize-function-call.cpp index a41ca1e99..98aba0fae 100644 --- a/source/slang/slang-ir-specialize-function-call.cpp +++ b/source/slang/slang-ir-specialize-function-call.cpp @@ -5,6 +5,7 @@ #include "slang-ir-clone.h" #include "slang-ir-insts.h" #include "slang-ir-ssa-simplification.h" +#include "slang-ir-util.h" namespace Slang { @@ -363,7 +364,7 @@ struct FunctionParameterSpecializationContext // a new callee function based on the original // function and the information we gathered. // - newFunc = generateSpecializedFunc(oldFunc, funcInfo); + newFunc = generateSpecializedFunc(oldFunc, funcInfo, callInfo); specializedFuncs.add(callInfo.key, newFunc); } @@ -381,6 +382,7 @@ struct FunctionParameterSpecializationContext newCall->insertBefore(oldCall); oldCall->replaceUsesWith(newCall); oldCall->removeAndDeallocate(); + } // Before diving into the details on how we gather information @@ -559,6 +561,21 @@ struct FunctionParameterSpecializationContext // the arguments at the new call site, and // don't add anything to the specialization key. // + // We should also add 2 more things such that our specialization + // can handle the corner cases that if the oldBase is a nonuniform + // resource and also the data type of oldIndex will be handled correctly. + // By doing so, we form an IRAttributedType to include both information + // and add it to the key of call info. + + List<IRAttr*> irAttrs; + if (findNonuniformIndexInst(oldIndex)) + { + IRAttr* attr = getBuilder()->getAttr(kIROp_NonUniformAttr); + irAttrs.add(attr); + } + auto irType = getBuilder()->getAttributedType(oldIndex->getDataType(), irAttrs); + ioInfo.key.vals.add(irType); + ioInfo.newArgs.add(oldIndex); } else if (oldArg->getOp() == kIROp_Load) @@ -577,6 +594,27 @@ struct FunctionParameterSpecializationContext } } + IRInst* findNonuniformIndexInst(IRInst* inst) + { + while(1) + { + if (inst == nullptr) + return nullptr; + + if (inst->getOp() == kIROp_NonUniformResourceIndex) + return inst; + + if (inst->getOp() == kIROp_IntCast) + { + inst = inst->getOperand(0); + } + else + { + return nullptr; + } + } + } + // The remaining information we've discussed is only // gathered once we decide we want to generate a // specialized function, but it follows much the same flow. @@ -803,7 +841,8 @@ struct FunctionParameterSpecializationContext // IRFunc* generateSpecializedFunc( IRFunc* oldFunc, - FuncSpecializationInfo const& funcInfo) + FuncSpecializationInfo const& funcInfo, + CallSpecializationInfo const& callInfo) { // We will make use of the infrastructure for cloning // IR code, that is defined in `ir-clone.{h,cpp}`. @@ -933,6 +972,18 @@ struct FunctionParameterSpecializationContext newBodyInst->insertBefore(newFirstOrdinary); } + // We need to handle a corner case where the new argument of + // the callee of this specialized function could be a use of + // NonUniformResourceIndex(), in such case, any indexing operation + // on the global buffer by using this new argument should be + // decorated with NonUniformDecoration. However, inside the new + // specialized function, we don't have that information anymore. + // Therefore, we will need to scan the new argument list to find out + // this case, and insert the NonUniformResourceIndex() instruction + // on the corresponding parameter of the new specialized function. + maybeInsertNonUniformResourceIndex(newFunc, funcInfo, callInfo); + + // At this point we've created a new specialized function, // and as such it may contain call sites that were not // covered when we built our initial work list. @@ -964,6 +1015,44 @@ struct FunctionParameterSpecializationContext return newFunc; } + + void maybeInsertNonUniformResourceIndex( + IRFunc* newFunc, + FuncSpecializationInfo const& funcInfo, + CallSpecializationInfo const& callInfo) + { + auto builder = getBuilder(); + uint32_t paramIndex = 0; + + SLANG_ASSERT(callInfo.newArgs.getCount() == funcInfo.newParams.getCount()); + + // Iterate over the new arguments, new parameters pair, and + // find out if there is any use of NonUniformResourceIndex() + // in the new arguments. + for (auto newArg : callInfo.newArgs) + { + if (auto nonuniformIdxInst = findNonuniformIndexInst(newArg)) + { + auto firstOrdinary = newFunc->getFirstOrdinaryInst(); + + IRCloneEnv cloneEnv; + auto newParam = funcInfo.newParams[paramIndex]; + + // Clone the NonUniformResourceIndex call and insert it at beginning + // of the function. Then replace every use of the parameter with the + // NonUniformResourceIndex. + auto clonedInst = cloneInstAndOperands(&cloneEnv, builder, nonuniformIdxInst); + clonedInst->insertBefore(firstOrdinary); + newParam->replaceUsesWith(clonedInst); + + // At last, set the operand of the NonUniformResourceIndex to the new parameter + // because we haven't done it yet during inst clone. + clonedInst->setOperand(0, newParam); + } + paramIndex++; + } + + } }; // The top-level function for invoking the specialization pass diff --git a/tests/compute/nonuniformres-as-function-parameter.slang b/tests/compute/nonuniformres-as-function-parameter.slang new file mode 100644 index 000000000..daffc848e --- /dev/null +++ b/tests/compute/nonuniformres-as-function-parameter.slang @@ -0,0 +1,141 @@ +//TEST:SIMPLE(filecheck=CHECK_SPV):-target spirv -entry main -stage compute +//TEST:SIMPLE(filecheck=CHECK_GLSL_SPV):-target spirv -entry main -stage compute -emit-spirv-via-glsl +//TEST:SIMPLE(filecheck=CHECK_GLSL):-target glsl -entry main -stage compute +//TEST:SIMPLE(filecheck=CHECK_HLSL):-target hlsl -entry main -stage compute +RWStructuredBuffer<uint> globalBuffer[] : register(t0, space0); +RWStructuredBuffer<uint3> outputBuffer; + +struct MyStruct +{ + uint a; + uint b; + uint c; +}; + + +MyStruct func(RWStructuredBuffer<uint> buffer) +{ + MyStruct a; + + // CHECK_GLSL: globalBuffer_0[nonuniformEXT({{.*}})] + // CHECK_GLSL: globalBuffer_0[nonuniformEXT({{.*}})] + + // For the last test case 3 that the callee passes globalBuffer[bufferIdx3] to the function, + // we should not see nonuniformEXT here. + + // CHECK_GLSL: globalBuffer_0[_{{.*}})] + // CHECK_GLSL: globalBuffer_0[_{{.*}})] + a.a = buffer[0]; + a.b = a.a + 1; + a.c = a.a + a.b + 1; + + return a; +} + +[shader("compute")] +[numthreads(1, 1, 1)] +void main(uint2 pixelIndex : SV_DispatchThreadID) +{ + + // CHECK_SPV: OpDecorate %[[VAR1:[a-zA-Z0-9_]+]] NonUniform + // CHECK_SPV: OpDecorate %[[VAR2:[a-zA-Z0-9_]+]] NonUniform + // CHECK_SPV: OpDecorate %[[VAR3:[a-zA-Z0-9_]+]] NonUniform + + + // CHECK_GLSL_SPV: OpDecorate %[[VAR1:[a-zA-Z0-9_]+]] NonUniform + // CHECK_GLSL_SPV: OpDecorate %[[VAR2:[a-zA-Z0-9_]+]] NonUniform + // CHECK_GLSL_SPV: OpDecorate %[[VAR3:[a-zA-Z0-9_]+]] NonUniform + // CHECK_GLSL_SPV: OpDecorate %[[VAR4:[a-zA-Z0-9_]+]] NonUniform + // CHECK_GLSL_SPV: OpDecorate %[[VAR5:[a-zA-Z0-9_]+]] NonUniform + // CHECK_GLSL_SPV: OpDecorate %[[VAR6:[a-zA-Z0-9_]+]] NonUniform + // CHECK_GLSL_SPV: OpDecorate %[[VAR7:[a-zA-Z0-9_]+]] NonUniform + // CHECK_GLSL_SPV: OpDecorate %[[VAR8:[a-zA-Z0-9_]+]] NonUniform + // CHECK_GLSL_SPV: OpDecorate %[[VAR9:[a-zA-Z0-9_]+]] NonUniform + // CHECK_GLSL_SPV: OpDecorate %[[VAR10:[a-zA-Z0-9_]+]] NonUniform + // CHECK_GLSL_SPV: OpDecorate %[[VAR11:[a-zA-Z0-9_]+]] NonUniform + // CHECK_GLSL_SPV: OpDecorate %[[VAR12:[a-zA-Z0-9_]+]] NonUniform + // CHECK_GLSL_SPV: OpDecorate %[[VAR13:[a-zA-Z0-9_]+]] NonUniform + // CHECK_GLSL_SPV: OpDecorate %[[VAR14:[a-zA-Z0-9_]+]] NonUniform + // CHECK_GLSL_SPV: OpDecorate %[[VAR15:[a-zA-Z0-9_]+]] NonUniform + // CHECK_GLSL_SPV: OpDecorate %[[VAR16:[a-zA-Z0-9_]+]] NonUniform + // CHECK_GLSL_SPV: OpDecorate %[[VAR17:[a-zA-Z0-9_]+]] NonUniform + // CHECK_GLSL_SPV: OpDecorate %[[VAR18:[a-zA-Z0-9_]+]] NonUniform + + + // Test case 1: slang will specialize the func call to 'MyStruct func(uint)' + uint bufferIdx = pixelIndex.x; + uint nonUniformIdx = NonUniformResourceIndex(bufferIdx); + RWStructuredBuffer<uint> buffer = globalBuffer[nonUniformIdx]; + + // CHECK_SPV: %[[VAR1]] = OpAccessChain %_ptr_StorageBuffer_RWStructuredBuffer{{.*}} %{{.*}} %bufferIdx + + // CHECK_GLSL_SPV: %[[VAR1]] = OpCopyObject %uint %{{.*}} + + // CHECK_GLSL_SPV: %[[VAR4]] = OpCopyObject %uint %[[VAR1]] + // CHECK_GLSL_SPV: %[[VAR5]] = OpAccessChain %_ptr_Uniform_uint %globalBuffer_0 %[[VAR4]] %int_0 %int_0 + // CHECK_GLSL_SPV: %[[VAR6]] = OpLoad %uint %[[VAR5]] + + // CHECK_GLSL_SPV: %[[VAR7]] = OpCopyObject %uint %[[VAR1]] + // CHECK_GLSL_SPV: %[[VAR8]] = OpAccessChain %_ptr_Uniform_uint %globalBuffer_0 %[[VAR7]] %int_0 %int_0 + // CHECK_GLSL_SPV: %[[VAR9]] = OpLoad %uint %[[VAR8]] + + // CHECK_GLSL: func_0({{.*}}nonuniformEXT({{.*}})) + // CHECK_HLSL: func_0(globalBuffer_0[NonUniformResourceIndex({{.*}})]) + MyStruct myStruct = func(buffer); + + int bufferIdx2 = pixelIndex.y; + + // Test case 2: Make sure we cover the case for the different data type of the index. + // In this case, slang will specialize the function to 'MyStruct func(int)' + // CHECK_SPV: %[[VAR2]] = OpAccessChain %_ptr_StorageBuffer_RWStructuredBuffer{{.*}} %{{.*}} %bufferIdx2 + + + // CHECK_GLSL_SPV: %[[VAR2]] = OpCopyObject %int %{{.*}} + + // CHECK_GLSL-SPV: %[[VAR10]] = OpCopyObject %int %[[VAR2]] + // CHECK_GLSL-SPV: %[[VAR11]] = OpAccessChain %_ptr_Uniform_uint %globalBuffer_0 %[[VAR10]] %int_0 %int_0 + // CHECK_GLSL-SPV: %[[VAR12]] = OpLoad %uint %[[VAR11]] + + // CHECK_GLSL-SPV: %[[VAR13]] = OpCopyObject %int %[[VAR2]] + // CHECK_GLSL-SPV: %[[VAR14]] = OpAccessChain %_ptr_Uniform_uint %globalBuffer_0 %[[VAR13]] %int_0 %int_0 + // CHECK_GLSL-SPV: %[[VAR15]] = OpLoad %uint %[[VAR14]] + RWStructuredBuffer<uint> buffer2 = globalBuffer[NonUniformResourceIndex(bufferIdx2)]; + + // CHECK_GLSL: func_1({{.*}}nonuniformEXT({{.*}})) + // CHECK_HLSL: func_0(globalBuffer_0[NonUniformResourceIndex({{.*}})]) + MyStruct myStruct2 = func(buffer2); + + // Test case 3: Test the case that we handle the uniformity correctly, the NonUniformResourceIndex will not propagate + // to the function, so there should no NonUniform decoration appeared. + int bufferIdx3 = pixelIndex.y; + RWStructuredBuffer<uint> buffer3 = globalBuffer[bufferIdx3]; + + // CHECK_SPV: %[[VAR4:[a-zA-Z0-9_]+]] = OpAccessChain %_ptr_StorageBuffer_RWStructuredBuffer{{.*}} %{{.*}} %bufferIdx2 + + // Test to make sure this command is not decorated with NonUniform: + // CHECK_SPV-NOT: OpDecorate %[[VAR4]] NonUniform + MyStruct myStruct3 = func(buffer3); + + + // Test case 4: Test to make sure we correctly cover the case that intCast or uintCast of a NonUniformResourceIndex + // is still a NonUniformResourceIndex. + + // CHECK_SPV: %[[VAR5:[a-zA-Z0-9_]+]] = OpBitcast %uint %{{.*}} + // CHECK_SPV: %[[VAR3]] = OpAccessChain %_ptr_StorageBuffer_RWStructuredBuffer{{.*}} %{{.*}} %[[VAR5]] + + // CHECK_GLSL-SPV: %[[VAR19:[a-zA-Z0-9_]+]] = OpBitcast %int %[[VAR3]] + // CHECK_GLSL-SPV: %[[VAR16]] = OpCopyObject %int %[[VAR19]] + // CHECK_GLSL-SPV: %[[VAR17]] = OpAccessChain %_ptr_Uniform_uint %globalBuffer_0 %[[VAR16]] %int_0 %int_0 + // CHECK_GLSL-SPV: %[[VAR18]] = OpLoad %uint %[[VAR17]] + // + // Since after the nested cast, the index data type is 'uint' now, make sure it calls the same function as the test case 1. + // CHECK_GLSL: func_0({{.*}}nonuniformEXT({{.*}})) + RWStructuredBuffer<uint> buffer4 = globalBuffer[(uint)((int)NonUniformResourceIndex(bufferIdx))]; + MyStruct myStruct4 = func(buffer4); + + outputBuffer[0] = uint3(myStruct.a, myStruct.b, myStruct.c); + outputBuffer[1] = uint3(myStruct2.a, myStruct2.b, myStruct2.c); + outputBuffer[2] = uint3(myStruct3.a, myStruct3.b, myStruct3.c); + outputBuffer[3] = uint3(myStruct4.a, myStruct4.b, myStruct4.c); +} + |
