summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--source/slang/slang-ir-glsl-legalize.cpp155
-rw-r--r--tests/vkray/multipleinout.slang36
2 files changed, 188 insertions, 3 deletions
diff --git a/source/slang/slang-ir-glsl-legalize.cpp b/source/slang/slang-ir-glsl-legalize.cpp
index 1123e1f2a..455c924ca 100644
--- a/source/slang/slang-ir-glsl-legalize.cpp
+++ b/source/slang/slang-ir-glsl-legalize.cpp
@@ -2390,7 +2390,7 @@ IRInst* materializeValue(IRBuilder* builder, ScalarizedVal const& val)
}
}
-void legalizeRayTracingEntryPointParameterForGLSL(
+void handleSingleParam(
GLSLLegalizationContext* context,
IRFunc* func,
IRParam* pp,
@@ -2442,6 +2442,136 @@ void legalizeRayTracingEntryPointParameterForGLSL(
builder->addDependsOnDecoration(func, globalParam);
}
+static void consolidateParameters(GLSLLegalizationContext* context, List<IRParam*>& params)
+{
+ auto builder = context->getBuilder();
+
+ // Create a struct type to hold all parameters
+ IRInst* consolidatedVar = nullptr;
+ auto structType = builder->createStructType();
+
+ // Inside the structure, add fields for each parameter
+ for (auto _param : params)
+ {
+ auto _paramType = _param->getDataType();
+ IRType* valueType = _paramType;
+
+ if (as<IROutTypeBase>(_paramType))
+ valueType = as<IROutTypeBase>(_paramType)->getValueType();
+
+ auto key = builder->createStructKey();
+ if (auto nameDecor = _param->findDecoration<IRNameHintDecoration>())
+ builder->addNameHintDecoration(key, nameDecor->getName());
+ auto field = builder->createStructField(structType, key, valueType);
+ field->removeFromParent();
+ field->insertAtEnd(structType);
+ }
+
+ // Create a global variable to hold the consolidated struct
+ consolidatedVar = builder->createGlobalVar(structType);
+ auto ptrType = builder->getPtrType(kIROp_PtrType, structType, AddressSpace::IncomingRayPayload);
+ consolidatedVar->setFullType(ptrType);
+ consolidatedVar->moveToEnd();
+
+ // Add the ray payload decoration and assign location 0.
+ builder->addVulkanRayPayloadDecoration(consolidatedVar, 0);
+
+ // Replace each parameter with a field in the consolidated struct
+ for (Index i = 0; i < params.getCount(); ++i)
+ {
+ auto _param = params[i];
+
+ // Find the i-th field
+ IRStructField* targetField = nullptr;
+ Index fieldIndex = 0;
+ for (auto field : structType->getFields())
+ {
+ if (fieldIndex == i)
+ {
+ targetField = field;
+ break;
+ }
+ fieldIndex++;
+ }
+ SLANG_ASSERT(targetField);
+
+ // Create the field address with the correct type
+ auto _paramType = _param->getDataType();
+ auto fieldType = targetField->getFieldType();
+
+ // If the parameter is an out/inout type, we need to create a pointer type
+ IRType* fieldPtrType = nullptr;
+ if (as<IROutType>(_paramType))
+ {
+ fieldPtrType = builder->getPtrType(kIROp_OutType, fieldType);
+ }
+ else if (as<IRInOutType>(_paramType))
+ {
+ fieldPtrType = builder->getPtrType(kIROp_InOutType, fieldType);
+ }
+
+ auto fieldAddr =
+ builder->emitFieldAddress(fieldPtrType, consolidatedVar, targetField->getKey());
+
+ // Replace parameter uses with field address
+ _param->replaceUsesWith(fieldAddr);
+ }
+}
+
+// Consolidate ray tracing parameters for an entry point function
+void consolidateRayTracingParameters(GLSLLegalizationContext* context, IRFunc* func)
+{
+ auto builder = context->getBuilder();
+ auto firstBlock = func->getFirstBlock();
+ if (!firstBlock)
+ return;
+
+ // Collect all out/inout parameters that need to be consolidated
+ List<IRParam*> outParams;
+ List<IRParam*> params;
+
+ for (auto param = firstBlock->getFirstParam(); param; param = param->getNextParam())
+ {
+ builder->setInsertBefore(firstBlock->getFirstOrdinaryInst());
+ if (as<IROutType>(param->getDataType()) || as<IRInOutType>(param->getDataType()))
+ {
+ outParams.add(param);
+ }
+ params.add(param);
+ }
+
+ // We don't need consolidation here.
+ if (outParams.getCount() <= 1)
+ {
+ for (auto param : params)
+ {
+ auto paramLayoutDecoration = param->findDecoration<IRLayoutDecoration>();
+ SLANG_ASSERT(paramLayoutDecoration);
+ auto paramLayout = as<IRVarLayout>(paramLayoutDecoration->getLayout());
+ handleSingleParam(context, func, param, paramLayout);
+ }
+ return;
+ }
+ else
+ {
+ // We need consolidation here, but before that, handle parameters other than inout/out.
+ for (auto param : params)
+ {
+ if (outParams.contains(param))
+ {
+ continue;
+ }
+ auto paramLayoutDecoration = param->findDecoration<IRLayoutDecoration>();
+ SLANG_ASSERT(paramLayoutDecoration);
+ auto paramLayout = as<IRVarLayout>(paramLayoutDecoration->getLayout());
+ handleSingleParam(context, func, param, paramLayout);
+ }
+
+ // Now, consolidate the inout/out parameters
+ consolidateParameters(context, outParams);
+ }
+}
+
static void legalizeMeshPayloadInputParam(
GLSLLegalizationContext* context,
CodeGenContext* codeGenContext,
@@ -3129,7 +3259,6 @@ void legalizeEntryPointParameterForGLSL(
}
}
-
// We need to create a global variable that will replace the parameter.
// It seems superficially obvious that the variable should have
// the same type as the parameter.
@@ -3286,7 +3415,6 @@ void legalizeEntryPointParameterForGLSL(
case Stage::Intersection:
case Stage::Miss:
case Stage::RayGeneration:
- legalizeRayTracingEntryPointParameterForGLSL(context, func, pp, paramLayout);
return;
}
@@ -3916,12 +4044,33 @@ void legalizeEntryPointForGLSL(
invokePathConstantFuncInHullShader(&context, codeGenContext, scalarizedGlobalOutput);
}
+ // Special handling for ray tracing shaders
+ bool isRayTracingShader = false;
+ switch (stage)
+ {
+ case Stage::AnyHit:
+ case Stage::Callable:
+ case Stage::ClosestHit:
+ case Stage::Intersection:
+ case Stage::Miss:
+ case Stage::RayGeneration:
+ isRayTracingShader = true;
+ consolidateRayTracingParameters(&context, func);
+ break;
+ default:
+ break;
+ }
+
// Next we will walk through any parameters of the entry-point function,
// and turn them into global variables.
if (auto firstBlock = func->getFirstBlock())
{
for (auto pp = firstBlock->getFirstParam(); pp; pp = pp->getNextParam())
{
+ if (isRayTracingShader)
+ {
+ continue;
+ }
// Any initialization code we insert for parameters needs
// to be at the start of the "ordinary" instructions in the block:
builder.setInsertBefore(firstBlock->getFirstOrdinaryInst());
diff --git a/tests/vkray/multipleinout.slang b/tests/vkray/multipleinout.slang
new file mode 100644
index 000000000..52e1758b0
--- /dev/null
+++ b/tests/vkray/multipleinout.slang
@@ -0,0 +1,36 @@
+//TEST:SIMPLE(filecheck=CHECK): -stage closesthit -entry main -target spirv -emit-spirv-directly
+
+// This test checks whether the spirv generated when there are multiple inout or out variables, they
+// all get consolidated into one IncomingRayPayloadKHR.
+
+struct ReflectionRay
+{
+ float4 color;
+};
+
+StructuredBuffer<float4> colors;
+
+[shader("closesthit")]
+void main(
+ BuiltInTriangleIntersectionAttributes attributes,
+ inout ReflectionRay ioPayload,
+ out float3 dummy)
+{
+ uint materialID = (InstanceIndex() << 1)
+ + InstanceID()
+ + PrimitiveIndex()
+ + HitKind();
+
+ ioPayload.color = colors[materialID];
+ dummy = HitTriangleVertexPosition(0);
+}
+
+// CHECK: OpEntryPoint ClosestHitKHR %main "main" %{{.*}} %{{.*}} %gl_PrimitiveID %{{.*}} %gl_InstanceID %colors %{{.*}}
+// CHECK: %_struct_{{.*}} = OpTypeStruct %ReflectionRay %v3float
+// CHECK: %_ptr_IncomingRayPayloadKHR__struct_{{.*}} = OpTypePointer IncomingRayPayloadKHR %_struct_{{.*}}
+// CHECK: %main = OpFunction %void None %{{.*}}
+// CHECK: %materialID = OpIAdd %uint %{{.*}} %{{.*}}
+// CHECK: %{{.*}} = OpAccessChain %_ptr_StorageBuffer_v4float %colors %int_0 %materialID
+// CHECK: %{{.*}} = OpLoad %v4float %{{.*}}
+// CHECK: %{{.*}} = OpAccessChain %_ptr_Input_v3float %{{.*}} %uint_0
+// CHECK: %{{.*}} = OpLoad %v3float %{{.*}}