From 7e2bc8e06f61d554bae9bbebc1db0302eb3f1d8a Mon Sep 17 00:00:00 2001 From: Yong He Date: Fri, 26 Jul 2024 19:42:15 -0700 Subject: Allow passing sized array to unsized array parameter. (#4744) --- source/slang/slang-ir-specialize-arrays.cpp | 47 +++++++++++++++++++++++++++-- 1 file changed, 45 insertions(+), 2 deletions(-) (limited to 'source/slang/slang-ir-specialize-arrays.cpp') diff --git a/source/slang/slang-ir-specialize-arrays.cpp b/source/slang/slang-ir-specialize-arrays.cpp index 3f42fb4b0..9bc0042db 100644 --- a/source/slang/slang-ir-specialize-arrays.cpp +++ b/source/slang/slang-ir-specialize-arrays.cpp @@ -13,8 +13,11 @@ struct ArrayParameterSpecializationCondition : FunctionCallSpecializeCondition // This pass is intended to specialize functions // with struct parameters that has array fields // to avoid performance problems for GLSL targets. - // Returns true if `type` is an `IRStructType` with array-typed fields. + // It will also specialize functions with unsized array parameters into + // sized arrays, if the function is called with an argument that has a + // sized array type. + // bool isStructTypeWithArray(IRType* type) { if (auto structType = as(type)) @@ -38,8 +41,47 @@ struct ArrayParameterSpecializationCondition : FunctionCallSpecializeCondition bool doesParamWantSpecialization(IRParam* param, IRInst* arg) { SLANG_UNUSED(arg); - return isStructTypeWithArray(param->getDataType()); + if (isKhronosTarget(codeGenContext->getTargetReq())) + return isStructTypeWithArray(param->getDataType()); + return false; } + + bool doesParamTypeWantSpecialization(IRParam* param, IRInst* arg) + { + auto paramType = param->getDataType(); + auto argType = arg->getDataType(); + if (auto outTypeBase = as(paramType)) + { + paramType = outTypeBase->getValueType(); + SLANG_ASSERT(as(argType)); + argType = as(argType)->getValueType(); + } + else if (auto refType = as(paramType)) + { + paramType = refType->getValueType(); + SLANG_ASSERT(as(argType)); + argType = as(argType)->getValueType(); + } + else if (auto constRefType = as(paramType)) + { + paramType = constRefType->getValueType(); + SLANG_ASSERT(as(argType)); + argType = as(argType)->getValueType(); + } + auto arrayType = as(paramType); + if (!arrayType) + return false; + auto argArrayType = as(argType); + if (!argArrayType) + return false; + if (as(argArrayType->getElementCount())) + { + return true; + } + return false; + } + + CodeGenContext* codeGenContext = nullptr; }; void specializeArrayParameters( @@ -47,6 +89,7 @@ void specializeArrayParameters( IRModule* module) { ArrayParameterSpecializationCondition condition; + condition.codeGenContext = codeGenContext; specializeFunctionCalls(codeGenContext, module, &condition); } -- cgit v1.2.3