summaryrefslogtreecommitdiff
path: root/source/slang/slang-ir-specialize-arrays.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'source/slang/slang-ir-specialize-arrays.cpp')
-rw-r--r--source/slang/slang-ir-specialize-arrays.cpp47
1 files changed, 45 insertions, 2 deletions
diff --git a/source/slang/slang-ir-specialize-arrays.cpp b/source/slang/slang-ir-specialize-arrays.cpp
index 3f42fb4b0..9bc0042db 100644
--- a/source/slang/slang-ir-specialize-arrays.cpp
+++ b/source/slang/slang-ir-specialize-arrays.cpp
@@ -13,8 +13,11 @@ struct ArrayParameterSpecializationCondition : FunctionCallSpecializeCondition
// This pass is intended to specialize functions
// with struct parameters that has array fields
// to avoid performance problems for GLSL targets.
-
// Returns true if `type` is an `IRStructType` with array-typed fields.
+ // It will also specialize functions with unsized array parameters into
+ // sized arrays, if the function is called with an argument that has a
+ // sized array type.
+ //
bool isStructTypeWithArray(IRType* type)
{
if (auto structType = as<IRStructType>(type))
@@ -38,8 +41,47 @@ struct ArrayParameterSpecializationCondition : FunctionCallSpecializeCondition
bool doesParamWantSpecialization(IRParam* param, IRInst* arg)
{
SLANG_UNUSED(arg);
- return isStructTypeWithArray(param->getDataType());
+ if (isKhronosTarget(codeGenContext->getTargetReq()))
+ return isStructTypeWithArray(param->getDataType());
+ return false;
}
+
+ bool doesParamTypeWantSpecialization(IRParam* param, IRInst* arg)
+ {
+ auto paramType = param->getDataType();
+ auto argType = arg->getDataType();
+ if (auto outTypeBase = as<IROutTypeBase>(paramType))
+ {
+ paramType = outTypeBase->getValueType();
+ SLANG_ASSERT(as<IRPtrTypeBase>(argType));
+ argType = as<IRPtrTypeBase>(argType)->getValueType();
+ }
+ else if (auto refType = as<IRRefType>(paramType))
+ {
+ paramType = refType->getValueType();
+ SLANG_ASSERT(as<IRPtrTypeBase>(argType));
+ argType = as<IRPtrTypeBase>(argType)->getValueType();
+ }
+ else if (auto constRefType = as<IRConstRefType>(paramType))
+ {
+ paramType = constRefType->getValueType();
+ SLANG_ASSERT(as<IRPtrTypeBase>(argType));
+ argType = as<IRPtrTypeBase>(argType)->getValueType();
+ }
+ auto arrayType = as<IRUnsizedArrayType>(paramType);
+ if (!arrayType)
+ return false;
+ auto argArrayType = as<IRArrayType>(argType);
+ if (!argArrayType)
+ return false;
+ if (as<IRIntLit>(argArrayType->getElementCount()))
+ {
+ return true;
+ }
+ return false;
+ }
+
+ CodeGenContext* codeGenContext = nullptr;
};
void specializeArrayParameters(
@@ -47,6 +89,7 @@ void specializeArrayParameters(
IRModule* module)
{
ArrayParameterSpecializationCondition condition;
+ condition.codeGenContext = codeGenContext;
specializeFunctionCalls(codeGenContext, module, &condition);
}