diff options
| author | Mukund Keshava <mkeshava@nvidia.com> | 2025-03-01 12:19:26 +0530 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2025-02-28 22:49:26 -0800 |
| commit | b86925c1929186c122536b9a7ed75131faceddb7 (patch) | |
| tree | 86f1078bffd1001bcafe6fbaa2d0da297389b997 /source/slang | |
| parent | dd9d24d29c4a9e05a4510eb9959fafa0ed36618b (diff) | |
Consolidate multiple inouts/outs into struct (#6435)
* Consolidate multiple inout/outs into struct
Fixes #5121
VUID-StandaloneSpirv-IncomingRayPayloadKHR-04700 requires that there be
only one IncomingRayPayloadKHR per entry point. This change does two
things:
1. If an entry point has the one inout or out, or has only ins, then
stay with current implementation.
2. If there are multiple outs/inouts, then create a new structure to
consolidate these fields and emit this structure.
These two code paths are split into two separate functions for clarity.
This patch also adds a new test: multipleinout.slang to test this.
* Address review comments
* Refactor code as per review comments
* format code
* fix failing tests
---------
Co-authored-by: slangbot <186143334+slangbot@users.noreply.github.com>
Co-authored-by: Yong He <yonghe@outlook.com>
Diffstat (limited to 'source/slang')
| -rw-r--r-- | source/slang/slang-ir-glsl-legalize.cpp | 155 |
1 files changed, 152 insertions, 3 deletions
diff --git a/source/slang/slang-ir-glsl-legalize.cpp b/source/slang/slang-ir-glsl-legalize.cpp index 1123e1f2a..455c924ca 100644 --- a/source/slang/slang-ir-glsl-legalize.cpp +++ b/source/slang/slang-ir-glsl-legalize.cpp @@ -2390,7 +2390,7 @@ IRInst* materializeValue(IRBuilder* builder, ScalarizedVal const& val) } } -void legalizeRayTracingEntryPointParameterForGLSL( +void handleSingleParam( GLSLLegalizationContext* context, IRFunc* func, IRParam* pp, @@ -2442,6 +2442,136 @@ void legalizeRayTracingEntryPointParameterForGLSL( builder->addDependsOnDecoration(func, globalParam); } +static void consolidateParameters(GLSLLegalizationContext* context, List<IRParam*>& params) +{ + auto builder = context->getBuilder(); + + // Create a struct type to hold all parameters + IRInst* consolidatedVar = nullptr; + auto structType = builder->createStructType(); + + // Inside the structure, add fields for each parameter + for (auto _param : params) + { + auto _paramType = _param->getDataType(); + IRType* valueType = _paramType; + + if (as<IROutTypeBase>(_paramType)) + valueType = as<IROutTypeBase>(_paramType)->getValueType(); + + auto key = builder->createStructKey(); + if (auto nameDecor = _param->findDecoration<IRNameHintDecoration>()) + builder->addNameHintDecoration(key, nameDecor->getName()); + auto field = builder->createStructField(structType, key, valueType); + field->removeFromParent(); + field->insertAtEnd(structType); + } + + // Create a global variable to hold the consolidated struct + consolidatedVar = builder->createGlobalVar(structType); + auto ptrType = builder->getPtrType(kIROp_PtrType, structType, AddressSpace::IncomingRayPayload); + consolidatedVar->setFullType(ptrType); + consolidatedVar->moveToEnd(); + + // Add the ray payload decoration and assign location 0. + builder->addVulkanRayPayloadDecoration(consolidatedVar, 0); + + // Replace each parameter with a field in the consolidated struct + for (Index i = 0; i < params.getCount(); ++i) + { + auto _param = params[i]; + + // Find the i-th field + IRStructField* targetField = nullptr; + Index fieldIndex = 0; + for (auto field : structType->getFields()) + { + if (fieldIndex == i) + { + targetField = field; + break; + } + fieldIndex++; + } + SLANG_ASSERT(targetField); + + // Create the field address with the correct type + auto _paramType = _param->getDataType(); + auto fieldType = targetField->getFieldType(); + + // If the parameter is an out/inout type, we need to create a pointer type + IRType* fieldPtrType = nullptr; + if (as<IROutType>(_paramType)) + { + fieldPtrType = builder->getPtrType(kIROp_OutType, fieldType); + } + else if (as<IRInOutType>(_paramType)) + { + fieldPtrType = builder->getPtrType(kIROp_InOutType, fieldType); + } + + auto fieldAddr = + builder->emitFieldAddress(fieldPtrType, consolidatedVar, targetField->getKey()); + + // Replace parameter uses with field address + _param->replaceUsesWith(fieldAddr); + } +} + +// Consolidate ray tracing parameters for an entry point function +void consolidateRayTracingParameters(GLSLLegalizationContext* context, IRFunc* func) +{ + auto builder = context->getBuilder(); + auto firstBlock = func->getFirstBlock(); + if (!firstBlock) + return; + + // Collect all out/inout parameters that need to be consolidated + List<IRParam*> outParams; + List<IRParam*> params; + + for (auto param = firstBlock->getFirstParam(); param; param = param->getNextParam()) + { + builder->setInsertBefore(firstBlock->getFirstOrdinaryInst()); + if (as<IROutType>(param->getDataType()) || as<IRInOutType>(param->getDataType())) + { + outParams.add(param); + } + params.add(param); + } + + // We don't need consolidation here. + if (outParams.getCount() <= 1) + { + for (auto param : params) + { + auto paramLayoutDecoration = param->findDecoration<IRLayoutDecoration>(); + SLANG_ASSERT(paramLayoutDecoration); + auto paramLayout = as<IRVarLayout>(paramLayoutDecoration->getLayout()); + handleSingleParam(context, func, param, paramLayout); + } + return; + } + else + { + // We need consolidation here, but before that, handle parameters other than inout/out. + for (auto param : params) + { + if (outParams.contains(param)) + { + continue; + } + auto paramLayoutDecoration = param->findDecoration<IRLayoutDecoration>(); + SLANG_ASSERT(paramLayoutDecoration); + auto paramLayout = as<IRVarLayout>(paramLayoutDecoration->getLayout()); + handleSingleParam(context, func, param, paramLayout); + } + + // Now, consolidate the inout/out parameters + consolidateParameters(context, outParams); + } +} + static void legalizeMeshPayloadInputParam( GLSLLegalizationContext* context, CodeGenContext* codeGenContext, @@ -3129,7 +3259,6 @@ void legalizeEntryPointParameterForGLSL( } } - // We need to create a global variable that will replace the parameter. // It seems superficially obvious that the variable should have // the same type as the parameter. @@ -3286,7 +3415,6 @@ void legalizeEntryPointParameterForGLSL( case Stage::Intersection: case Stage::Miss: case Stage::RayGeneration: - legalizeRayTracingEntryPointParameterForGLSL(context, func, pp, paramLayout); return; } @@ -3916,12 +4044,33 @@ void legalizeEntryPointForGLSL( invokePathConstantFuncInHullShader(&context, codeGenContext, scalarizedGlobalOutput); } + // Special handling for ray tracing shaders + bool isRayTracingShader = false; + switch (stage) + { + case Stage::AnyHit: + case Stage::Callable: + case Stage::ClosestHit: + case Stage::Intersection: + case Stage::Miss: + case Stage::RayGeneration: + isRayTracingShader = true; + consolidateRayTracingParameters(&context, func); + break; + default: + break; + } + // Next we will walk through any parameters of the entry-point function, // and turn them into global variables. if (auto firstBlock = func->getFirstBlock()) { for (auto pp = firstBlock->getFirstParam(); pp; pp = pp->getNextParam()) { + if (isRayTracingShader) + { + continue; + } // Any initialization code we insert for parameters needs // to be at the start of the "ordinary" instructions in the block: builder.setInsertBefore(firstBlock->getFirstOrdinaryInst()); |
