diff options
Diffstat (limited to 'source/slang/slang-ir-specialize-function-call.cpp')
| -rw-r--r-- | source/slang/slang-ir-specialize-function-call.cpp | 75 |
1 files changed, 72 insertions, 3 deletions
diff --git a/source/slang/slang-ir-specialize-function-call.cpp b/source/slang/slang-ir-specialize-function-call.cpp index c4928a230..a41ca1e99 100644 --- a/source/slang/slang-ir-specialize-function-call.cpp +++ b/source/slang/slang-ir-specialize-function-call.cpp @@ -67,6 +67,13 @@ bool FunctionCallSpecializeCondition::isParamSuitableForSpecialization(IRParam* } } +bool FunctionCallSpecializeCondition::doesParamTypeWantSpecialization(IRParam* param, IRInst* arg) +{ + SLANG_UNUSED(param); + SLANG_UNUSED(arg); + return false; +} + struct FunctionParameterSpecializationContext { // This type implements a pass to specialize functions @@ -209,7 +216,9 @@ struct FunctionParameterSpecializationContext // If neither the parameter nor the argument wants specialization, // then we need to keep looking. // - if(!doesParamWantSpecialization(param, arg)) + auto paramWantSpecialization = doesParamWantSpecialization(param, arg); + auto paramTypeWantSpecialization = doesParamTypeWantSpecialization(param, arg); + if(!paramWantSpecialization && !paramTypeWantSpecialization) continue; // If we have run into a `param` or `arg` that wants specialization, @@ -222,7 +231,7 @@ struct FunctionParameterSpecializationContext // can bail out immediately because our second condition // cannot be met. // - if(!isParamSuitableForSpecialization(param, arg)) + if(paramWantSpecialization && !isParamSuitableForSpecialization(param, arg)) return false; } @@ -242,6 +251,11 @@ struct FunctionParameterSpecializationContext return condition->doesParamWantSpecialization(param, arg); } + bool doesParamTypeWantSpecialization(IRParam* param, IRInst* arg) + { + return condition->doesParamTypeWantSpecialization(param, arg); + } + bool isParamSuitableForSpecialization(IRParam* param, IRInst* arg) { return condition->isParamSuitableForSpecialization(param, arg); @@ -474,6 +488,11 @@ struct FunctionParameterSpecializationContext // specialized callee based on this paramter. // ioInfo.newArgs.add(oldArg); + + if (doesParamTypeWantSpecialization(oldParam, oldArg)) + { + ioInfo.key.vals.add(oldArg->getDataType()); + } } else { @@ -587,6 +606,30 @@ struct FunctionParameterSpecializationContext } } + // Wrap `argType` with a parameter direction type if `oldParam` has such a parameter direction type. + IRType* maybeWrapParameterDirectionType(IRParam* oldParam, IRType* argType) + { + IRType* paramType = oldParam->getDataType(); + IRType* resultType = argType; + switch (paramType->getOp()) + { + case kIROp_InOutType: + case kIROp_OutType: + case kIROp_RefType: + case kIROp_ConstRefType: + argType = as<IRPtrTypeBase>(argType)->getValueType(); + resultType = getBuilder()->getPtrType(paramType->getOp(), argType, as<IRPtrTypeBase>(paramType)->getAddressSpace()); + break; + } + if (auto rate = paramType->getRate()) + { + IRBuilder builder(oldParam); + builder.setInsertAfter(resultType); + resultType = builder.getRateQualifiedType(rate, resultType); + } + return resultType; + } + IRInst* getSpecializedValueForParam( FuncSpecializationInfo& ioInfo, IRParam* oldParam, @@ -601,7 +644,16 @@ struct FunctionParameterSpecializationContext // that fills the same role as the old one, so we // create it here. // - auto newParam = getBuilder()->createParam(oldParam->getFullType()); + IRType* paramType = nullptr; + if (doesParamTypeWantSpecialization(oldParam, oldArg)) + { + paramType = maybeWrapParameterDirectionType(oldParam, oldArg->getDataType()); + } + else + { + paramType = oldParam->getFullType(); + } + auto newParam = getBuilder()->createParam(paramType); ioInfo.newParams.add(newParam); // The new parameter will be used as the replacement @@ -891,6 +943,23 @@ struct FunctionParameterSpecializationContext // addCallsToWorkListRec(newFunc); + // If one of the new parameters has a more specialized type, + // we need to update the type of load instructions from that + // parameter, if there are any. + for (auto newParam : funcInfo.newParams) + { + if (!as<IRParam>(newParam)) + continue; + for (auto use = newParam->firstUse; use; use = use->nextUse) + { + auto user = use->getUser(); + if (auto load = as<IRLoad>(user)) + { + load->setFullType(as<IRPtrTypeBase>(newParam->getDataType())->getValueType()); + } + } + } + simplifyFunc(codeGenContext->getTargetProgram(), newFunc, IRSimplificationOptions::getFast(codeGenContext->getTargetProgram())); return newFunc; |
