summaryrefslogtreecommitdiffstats
path: root/source/slang/slang-ir-specialize-function-call.cpp
diff options
context:
space:
mode:
authorYong He <yonghe@outlook.com>2024-07-26 19:42:15 -0700
committerGitHub <noreply@github.com>2024-07-26 19:42:15 -0700
commit7e2bc8e06f61d554bae9bbebc1db0302eb3f1d8a (patch)
tree0f10e4a45cb81af2908da61743a4518de27748e2 /source/slang/slang-ir-specialize-function-call.cpp
parentc0bff66541302309ff4833e8d4ae2eba1561498a (diff)
Allow passing sized array to unsized array parameter. (#4744)
Diffstat (limited to 'source/slang/slang-ir-specialize-function-call.cpp')
-rw-r--r--source/slang/slang-ir-specialize-function-call.cpp75
1 files changed, 72 insertions, 3 deletions
diff --git a/source/slang/slang-ir-specialize-function-call.cpp b/source/slang/slang-ir-specialize-function-call.cpp
index c4928a230..a41ca1e99 100644
--- a/source/slang/slang-ir-specialize-function-call.cpp
+++ b/source/slang/slang-ir-specialize-function-call.cpp
@@ -67,6 +67,13 @@ bool FunctionCallSpecializeCondition::isParamSuitableForSpecialization(IRParam*
}
}
+bool FunctionCallSpecializeCondition::doesParamTypeWantSpecialization(IRParam* param, IRInst* arg)
+{
+ SLANG_UNUSED(param);
+ SLANG_UNUSED(arg);
+ return false;
+}
+
struct FunctionParameterSpecializationContext
{
// This type implements a pass to specialize functions
@@ -209,7 +216,9 @@ struct FunctionParameterSpecializationContext
// If neither the parameter nor the argument wants specialization,
// then we need to keep looking.
//
- if(!doesParamWantSpecialization(param, arg))
+ auto paramWantSpecialization = doesParamWantSpecialization(param, arg);
+ auto paramTypeWantSpecialization = doesParamTypeWantSpecialization(param, arg);
+ if(!paramWantSpecialization && !paramTypeWantSpecialization)
continue;
// If we have run into a `param` or `arg` that wants specialization,
@@ -222,7 +231,7 @@ struct FunctionParameterSpecializationContext
// can bail out immediately because our second condition
// cannot be met.
//
- if(!isParamSuitableForSpecialization(param, arg))
+ if(paramWantSpecialization && !isParamSuitableForSpecialization(param, arg))
return false;
}
@@ -242,6 +251,11 @@ struct FunctionParameterSpecializationContext
return condition->doesParamWantSpecialization(param, arg);
}
+ bool doesParamTypeWantSpecialization(IRParam* param, IRInst* arg)
+ {
+ return condition->doesParamTypeWantSpecialization(param, arg);
+ }
+
bool isParamSuitableForSpecialization(IRParam* param, IRInst* arg)
{
return condition->isParamSuitableForSpecialization(param, arg);
@@ -474,6 +488,11 @@ struct FunctionParameterSpecializationContext
// specialized callee based on this paramter.
//
ioInfo.newArgs.add(oldArg);
+
+ if (doesParamTypeWantSpecialization(oldParam, oldArg))
+ {
+ ioInfo.key.vals.add(oldArg->getDataType());
+ }
}
else
{
@@ -587,6 +606,30 @@ struct FunctionParameterSpecializationContext
}
}
+ // Wrap `argType` with a parameter direction type if `oldParam` has such a parameter direction type.
+ IRType* maybeWrapParameterDirectionType(IRParam* oldParam, IRType* argType)
+ {
+ IRType* paramType = oldParam->getDataType();
+ IRType* resultType = argType;
+ switch (paramType->getOp())
+ {
+ case kIROp_InOutType:
+ case kIROp_OutType:
+ case kIROp_RefType:
+ case kIROp_ConstRefType:
+ argType = as<IRPtrTypeBase>(argType)->getValueType();
+ resultType = getBuilder()->getPtrType(paramType->getOp(), argType, as<IRPtrTypeBase>(paramType)->getAddressSpace());
+ break;
+ }
+ if (auto rate = paramType->getRate())
+ {
+ IRBuilder builder(oldParam);
+ builder.setInsertAfter(resultType);
+ resultType = builder.getRateQualifiedType(rate, resultType);
+ }
+ return resultType;
+ }
+
IRInst* getSpecializedValueForParam(
FuncSpecializationInfo& ioInfo,
IRParam* oldParam,
@@ -601,7 +644,16 @@ struct FunctionParameterSpecializationContext
// that fills the same role as the old one, so we
// create it here.
//
- auto newParam = getBuilder()->createParam(oldParam->getFullType());
+ IRType* paramType = nullptr;
+ if (doesParamTypeWantSpecialization(oldParam, oldArg))
+ {
+ paramType = maybeWrapParameterDirectionType(oldParam, oldArg->getDataType());
+ }
+ else
+ {
+ paramType = oldParam->getFullType();
+ }
+ auto newParam = getBuilder()->createParam(paramType);
ioInfo.newParams.add(newParam);
// The new parameter will be used as the replacement
@@ -891,6 +943,23 @@ struct FunctionParameterSpecializationContext
//
addCallsToWorkListRec(newFunc);
+ // If one of the new parameters has a more specialized type,
+ // we need to update the type of load instructions from that
+ // parameter, if there are any.
+ for (auto newParam : funcInfo.newParams)
+ {
+ if (!as<IRParam>(newParam))
+ continue;
+ for (auto use = newParam->firstUse; use; use = use->nextUse)
+ {
+ auto user = use->getUser();
+ if (auto load = as<IRLoad>(user))
+ {
+ load->setFullType(as<IRPtrTypeBase>(newParam->getDataType())->getValueType());
+ }
+ }
+ }
+
simplifyFunc(codeGenContext->getTargetProgram(), newFunc, IRSimplificationOptions::getFast(codeGenContext->getTargetProgram()));
return newFunc;