summaryrefslogtreecommitdiff
path: root/source/slang/slang-ir-specialize-arrays.cpp
diff options
context:
space:
mode:
authorYong He <yonghe@outlook.com>2024-07-26 19:42:15 -0700
committerGitHub <noreply@github.com>2024-07-26 19:42:15 -0700
commit7e2bc8e06f61d554bae9bbebc1db0302eb3f1d8a (patch)
tree0f10e4a45cb81af2908da61743a4518de27748e2 /source/slang/slang-ir-specialize-arrays.cpp
parentc0bff66541302309ff4833e8d4ae2eba1561498a (diff)
Allow passing sized array to unsized array parameter. (#4744)
Diffstat (limited to 'source/slang/slang-ir-specialize-arrays.cpp')
-rw-r--r--source/slang/slang-ir-specialize-arrays.cpp47
1 files changed, 45 insertions, 2 deletions
diff --git a/source/slang/slang-ir-specialize-arrays.cpp b/source/slang/slang-ir-specialize-arrays.cpp
index 3f42fb4b0..9bc0042db 100644
--- a/source/slang/slang-ir-specialize-arrays.cpp
+++ b/source/slang/slang-ir-specialize-arrays.cpp
@@ -13,8 +13,11 @@ struct ArrayParameterSpecializationCondition : FunctionCallSpecializeCondition
// This pass is intended to specialize functions
// with struct parameters that has array fields
// to avoid performance problems for GLSL targets.
-
// Returns true if `type` is an `IRStructType` with array-typed fields.
+ // It will also specialize functions with unsized array parameters into
+ // sized arrays, if the function is called with an argument that has a
+ // sized array type.
+ //
bool isStructTypeWithArray(IRType* type)
{
if (auto structType = as<IRStructType>(type))
@@ -38,8 +41,47 @@ struct ArrayParameterSpecializationCondition : FunctionCallSpecializeCondition
bool doesParamWantSpecialization(IRParam* param, IRInst* arg)
{
SLANG_UNUSED(arg);
- return isStructTypeWithArray(param->getDataType());
+ if (isKhronosTarget(codeGenContext->getTargetReq()))
+ return isStructTypeWithArray(param->getDataType());
+ return false;
}
+
+ bool doesParamTypeWantSpecialization(IRParam* param, IRInst* arg)
+ {
+ auto paramType = param->getDataType();
+ auto argType = arg->getDataType();
+ if (auto outTypeBase = as<IROutTypeBase>(paramType))
+ {
+ paramType = outTypeBase->getValueType();
+ SLANG_ASSERT(as<IRPtrTypeBase>(argType));
+ argType = as<IRPtrTypeBase>(argType)->getValueType();
+ }
+ else if (auto refType = as<IRRefType>(paramType))
+ {
+ paramType = refType->getValueType();
+ SLANG_ASSERT(as<IRPtrTypeBase>(argType));
+ argType = as<IRPtrTypeBase>(argType)->getValueType();
+ }
+ else if (auto constRefType = as<IRConstRefType>(paramType))
+ {
+ paramType = constRefType->getValueType();
+ SLANG_ASSERT(as<IRPtrTypeBase>(argType));
+ argType = as<IRPtrTypeBase>(argType)->getValueType();
+ }
+ auto arrayType = as<IRUnsizedArrayType>(paramType);
+ if (!arrayType)
+ return false;
+ auto argArrayType = as<IRArrayType>(argType);
+ if (!argArrayType)
+ return false;
+ if (as<IRIntLit>(argArrayType->getElementCount()))
+ {
+ return true;
+ }
+ return false;
+ }
+
+ CodeGenContext* codeGenContext = nullptr;
};
void specializeArrayParameters(
@@ -47,6 +89,7 @@ void specializeArrayParameters(
IRModule* module)
{
ArrayParameterSpecializationCondition condition;
+ condition.codeGenContext = codeGenContext;
specializeFunctionCalls(codeGenContext, module, &condition);
}