summaryrefslogtreecommitdiffstats
path: root/source/slang/slang-ir-specialize-function-call.cpp
diff options
context:
space:
mode:
authorEllie Hermaszewska <ellieh@nvidia.com>2024-10-29 14:49:26 +0800
committerGitHub <noreply@github.com>2024-10-29 14:49:26 +0800
commitf65d756bff8d4c5cbc15bd0322a2ae8e6b896a21 (patch)
treeea1d61342cd29368e19135000ec2948813096205 /source/slang/slang-ir-specialize-function-call.cpp
parenta729c15e9dce9f5116a38afc66329ab2ca4cea54 (diff)
format
* format * Minor test fixes * enable checking cpp format in ci
Diffstat (limited to 'source/slang/slang-ir-specialize-function-call.cpp')
-rw-r--r--source/slang/slang-ir-specialize-function-call.cpp167
1 files changed, 80 insertions, 87 deletions
diff --git a/source/slang/slang-ir-specialize-function-call.cpp b/source/slang/slang-ir-specialize-function-call.cpp
index d260f8105..7a9fc5f6f 100644
--- a/source/slang/slang-ir-specialize-function-call.cpp
+++ b/source/slang/slang-ir-specialize-function-call.cpp
@@ -1,16 +1,18 @@
// slang-ir-specialize-function-call.cpp
#include "slang-ir-specialize-function-call.h"
-#include "slang-ir.h"
#include "slang-ir-clone.h"
#include "slang-ir-insts.h"
#include "slang-ir-ssa-simplification.h"
#include "slang-ir-util.h"
+#include "slang-ir.h"
namespace Slang
{
-bool FunctionCallSpecializeCondition::isParamSuitableForSpecialization(IRParam* param, IRInst* inArg)
+bool FunctionCallSpecializeCondition::isParamSuitableForSpecialization(
+ IRParam* param,
+ IRInst* inArg)
{
SLANG_UNUSED(param);
@@ -31,10 +33,12 @@ bool FunctionCallSpecializeCondition::isParamSuitableForSpecialization(IRParam*
// specialize a callee to refer to the same
// global parameter directly.
//
- if (as<IRGlobalParam>(arg)) return true;
+ if (as<IRGlobalParam>(arg))
+ return true;
// Similarly for these global values
- if (as<IRGlobalValueWithCode>(arg)) return true;
+ if (as<IRGlobalValueWithCode>(arg))
+ return true;
// As we will see later, we can also
// specialize a call when the argument
@@ -87,7 +91,7 @@ struct FunctionParameterSpecializationContext
// `specializeFunctionParameters` function.
//
CodeGenContext* codeGenContext;
- IRModule* module;
+ IRModule* module;
// The condition on which parameters to specialize.
FunctionCallSpecializeCondition* condition;
@@ -98,9 +102,9 @@ struct FunctionParameterSpecializationContext
// of call sites in the program that may be worth
// considering for specialization.
//
- List<IRCall*> workList;
+ List<IRCall*> workList;
- IRBuilder builderStorage;
+ IRBuilder builderStorage;
IRBuilder* getBuilder() { return &builderStorage; }
// With the basic state out of the way, let's walk
@@ -122,7 +126,7 @@ struct FunctionParameterSpecializationContext
// We will process the work list until it goes dry,
// treating it like a stack of work items.
//
- while( workList.getCount() )
+ while (workList.getCount())
{
auto call = workList.getLast();
workList.removeLast();
@@ -134,7 +138,7 @@ struct FunctionParameterSpecializationContext
// become candidates for specialization, so
// our work list may grow along the way.
//
- if( canSpecializeCall(call) )
+ if (canSpecializeCall(call))
{
specializeCall(call);
changed = true;
@@ -149,7 +153,7 @@ struct FunctionParameterSpecializationContext
{
// If we have a call site, then add it to the list.
//
- if( auto call = as<IRCall>(inst) )
+ if (auto call = as<IRCall>(inst))
{
workList.add(call);
}
@@ -157,7 +161,7 @@ struct FunctionParameterSpecializationContext
// Recursively walk through any children, to
// see if we uncover more call sites.
//
- for( auto child : inst->getChildren() )
+ for (auto child : inst->getChildren())
{
addCallsToWorkListRec(child);
}
@@ -175,13 +179,17 @@ struct FunctionParameterSpecializationContext
// way to generate a specialized callee function.
//
auto func = as<IRFunc>(call->getCallee());
- if(!func)
+ if (!func)
return false;
- if(!func->isDefinition())
+ if (!func->isDefinition())
return false;
UnownedStringSlice def;
IRInst* intrinsicInst;
- if (findTargetIntrinsicDefinition(func, codeGenContext->getTargetReq()->getTargetCaps(), def, intrinsicInst))
+ if (findTargetIntrinsicDefinition(
+ func,
+ codeGenContext->getTargetReq()->getTargetCaps(),
+ def,
+ intrinsicInst))
return false;
// With the basic checks out of the way, there are
// two conditions we care about:
@@ -208,7 +216,7 @@ struct FunctionParameterSpecializationContext
//
bool anySpecializableParam = false;
UInt argCounter = 0;
- for( auto param : func->getParams() )
+ for (auto param : func->getParams())
{
UInt argIndex = argCounter++;
SLANG_ASSERT(argIndex < call->getArgCount());
@@ -219,7 +227,7 @@ struct FunctionParameterSpecializationContext
//
auto paramWantSpecialization = doesParamWantSpecialization(param, arg);
auto paramTypeWantSpecialization = doesParamTypeWantSpecialization(param, arg);
- if(!paramWantSpecialization && !paramTypeWantSpecialization)
+ if (!paramWantSpecialization && !paramTypeWantSpecialization)
continue;
// If we have run into a `param` or `arg` that wants specialization,
@@ -232,7 +240,7 @@ struct FunctionParameterSpecializationContext
// can bail out immediately because our second condition
// cannot be met.
//
- if(paramWantSpecialization && !isParamSuitableForSpecialization(param, arg))
+ if (paramWantSpecialization && !isParamSuitableForSpecialization(param, arg))
return false;
}
@@ -291,8 +299,8 @@ struct FunctionParameterSpecializationContext
//
struct CallSpecializationInfo
{
- Key key;
- List<IRInst*> newArgs;
+ Key key;
+ List<IRInst*> newArgs;
};
// Once we've collected the information about a call site
@@ -311,9 +319,9 @@ struct FunctionParameterSpecializationContext
//
struct FuncSpecializationInfo
{
- List<IRParam*> newParams;
- List<IRInst*> newBodyInsts;
- List<IRInst*> replacementsForOldParameters;
+ List<IRParam*> newParams;
+ List<IRInst*> newBodyInsts;
+ List<IRInst*> replacementsForOldParameters;
};
// Before diving into how the different passes collect
@@ -347,7 +355,7 @@ struct FunctionParameterSpecializationContext
// that is suitable to this call site.
//
IRFunc* newFunc = nullptr;
- if( !specializedFuncs.tryGetValue(callInfo.key, newFunc) )
+ if (!specializedFuncs.tryGetValue(callInfo.key, newFunc))
{
// If we didn't find a pre-existing specialized
// function, then we will go ahead and create one.
@@ -382,7 +390,6 @@ struct FunctionParameterSpecializationContext
newCall->insertBefore(oldCall);
oldCall->replaceUsesWith(newCall);
oldCall->removeAndDeallocate();
-
}
// Before diving into the details on how we gather information
@@ -450,10 +457,7 @@ struct FunctionParameterSpecializationContext
// argument list, and the "key" information that distinguishes
// what specialized callee we want/need.
//
- void gatherCallInfo(
- IRCall* oldCall,
- IRFunc* oldFunc,
- CallSpecializationInfo& callInfo)
+ void gatherCallInfo(IRCall* oldCall, IRFunc* oldFunc, CallSpecializationInfo& callInfo)
{
// The specialized callee key always needs to include
// the original function, since different functions
@@ -465,7 +469,7 @@ struct FunctionParameterSpecializationContext
// at parameter and argument pairs.
//
UInt oldArgCounter = 0;
- for( auto oldParam : oldFunc->getParams() )
+ for (auto oldParam : oldFunc->getParams())
{
UInt oldArgIndex = oldArgCounter++;
auto oldArg = oldCall->getArg(oldArgIndex);
@@ -474,15 +478,12 @@ struct FunctionParameterSpecializationContext
}
}
- void getCallInfoForParam(
- CallSpecializationInfo& ioInfo,
- IRParam* oldParam,
- IRInst* oldArg)
+ void getCallInfoForParam(CallSpecializationInfo& ioInfo, IRParam* oldParam, IRInst* oldArg)
{
// We know that the case where the parameter
// and argument don't want specialization is easy.
//
- if( !doesParamWantSpecialization(oldParam, oldArg) )
+ if (!doesParamWantSpecialization(oldParam, oldArg))
{
// The new call site will use the same argument
// value as the old one, and we don't need
@@ -513,14 +514,12 @@ struct FunctionParameterSpecializationContext
}
}
- void getCallInfoForArg(
- CallSpecializationInfo& ioInfo,
- IRInst* oldArg)
+ void getCallInfoForArg(CallSpecializationInfo& ioInfo, IRInst* oldArg)
{
// The base case we care about is when the original
// argument is a global shader parameter.
//
- if( auto oldGlobalParam = as<IRGlobalParam>(oldArg) )
+ if (auto oldGlobalParam = as<IRGlobalParam>(oldArg))
{
// In this case we don't need to pass anything
// as an argument at the new call site (the
@@ -532,17 +531,17 @@ struct FunctionParameterSpecializationContext
//
ioInfo.key.vals.add(oldGlobalParam);
}
- else if( auto globalConstant = as<IRGlobalValueWithCode>(oldArg) )
+ else if (auto globalConstant = as<IRGlobalValueWithCode>(oldArg))
{
// Similarly for other global constants
ioInfo.key.vals.add(globalConstant);
}
- else if( oldArg->getOp() == kIROp_GetElement )
+ else if (oldArg->getOp() == kIROp_GetElement)
{
// This is the case where the `oldArg` is
// in the form `oldBase[oldIndex]`
//
- auto oldBase = oldArg->getOperand(0);
+ auto oldBase = oldArg->getOperand(0);
auto oldIndex = oldArg->getOperand(1);
// Effectively, we act as if `oldBase` and
@@ -596,7 +595,7 @@ struct FunctionParameterSpecializationContext
IRInst* findNonuniformIndexInst(IRInst* inst)
{
- for(;;)
+ for (;;)
{
if (inst == nullptr)
return nullptr;
@@ -619,13 +618,10 @@ struct FunctionParameterSpecializationContext
// gathered once we decide we want to generate a
// specialized function, but it follows much the same flow.
//
- void gatherFuncInfo(
- IRCall* oldCall,
- IRFunc* oldFunc,
- FuncSpecializationInfo& funcInfo)
+ void gatherFuncInfo(IRCall* oldCall, IRFunc* oldFunc, FuncSpecializationInfo& funcInfo)
{
UInt oldArgCounter = 0;
- for( auto oldParam : oldFunc->getParams() )
+ for (auto oldParam : oldFunc->getParams())
{
UInt oldArgIndex = oldArgCounter++;
auto oldArg = oldCall->getArg(oldArgIndex);
@@ -644,7 +640,8 @@ struct FunctionParameterSpecializationContext
}
}
- // Wrap `argType` with a parameter direction type if `oldParam` has such a parameter direction type.
+ // Wrap `argType` with a parameter direction type if `oldParam` has such a parameter direction
+ // type.
IRType* maybeWrapParameterDirectionType(IRParam* oldParam, IRType* argType)
{
IRType* paramType = oldParam->getDataType();
@@ -656,7 +653,10 @@ struct FunctionParameterSpecializationContext
case kIROp_RefType:
case kIROp_ConstRefType:
argType = as<IRPtrTypeBase>(argType)->getValueType();
- resultType = getBuilder()->getPtrType(paramType->getOp(), argType, as<IRPtrTypeBase>(paramType)->getAddressSpace());
+ resultType = getBuilder()->getPtrType(
+ paramType->getOp(),
+ argType,
+ as<IRPtrTypeBase>(paramType)->getAddressSpace());
break;
}
if (auto rate = paramType->getRate())
@@ -670,13 +670,13 @@ struct FunctionParameterSpecializationContext
IRInst* getSpecializedValueForParam(
FuncSpecializationInfo& ioInfo,
- IRParam* oldParam,
- IRInst* oldArg)
+ IRParam* oldParam,
+ IRInst* oldArg)
{
// As always, the easy case is when the parameter of
// the original function doesn't need specialization.
//
- if( !doesParamWantSpecialization(oldParam, oldArg) )
+ if (!doesParamWantSpecialization(oldParam, oldArg))
{
// The specialized callee will need a new parameter
// that fills the same role as the old one, so we
@@ -708,9 +708,7 @@ struct FunctionParameterSpecializationContext
}
}
- IRInst* getSpecializedValueForArg(
- FuncSpecializationInfo& ioInfo,
- IRInst* oldArg)
+ IRInst* getSpecializedValueForArg(FuncSpecializationInfo& ioInfo, IRInst* oldArg)
{
// The logic here parallels `gatherCallInfoForArg`,
// and only differs in what information it is gathering.
@@ -718,7 +716,7 @@ struct FunctionParameterSpecializationContext
// As before, the base case is when we have a global
// shader parameter.
//
- if( auto globalParam = as<IRGlobalParam>(oldArg) )
+ if (auto globalParam = as<IRGlobalParam>(oldArg))
{
// The specialized function will not need any
// parameter in this case, and the global itself
@@ -727,18 +725,18 @@ struct FunctionParameterSpecializationContext
//
return globalParam;
}
- if( auto globalFunc = as<IRGlobalValueWithCode>(oldArg) )
+ if (auto globalFunc = as<IRGlobalValueWithCode>(oldArg))
{
// As above, the identity of the specialized function is sufficient
// to resolve the uses
return globalFunc;
}
- else if( oldArg->getOp() == kIROp_GetElement )
+ else if (oldArg->getOp() == kIROp_GetElement)
{
// This is the case where the argument is
// in the form `oldBase[oldIndex]`.
//
- auto oldBase = oldArg->getOperand(0);
+ auto oldBase = oldArg->getOperand(0);
auto oldIndex = oldArg->getOperand(1);
// In `gatherCallInfoForArg` this case was
@@ -793,10 +791,7 @@ struct FunctionParameterSpecializationContext
// of things, and then inserted to a more permanent location later.
//
builder->setInsertLoc(IRInsertLoc());
- auto newVal = builder->emitElementExtract(
- oldArg->getFullType(),
- newBase,
- newIndex);
+ auto newVal = builder->emitElementExtract(oldArg->getFullType(), newBase, newIndex);
// Because our new instruction wasn't
// actually inserted anywhere, we need to
@@ -815,9 +810,7 @@ struct FunctionParameterSpecializationContext
auto builder = getBuilder();
builder->setInsertLoc(IRInsertLoc());
- auto newVal = builder->emitLoad(
- oldArg->getFullType(),
- newPtr);
+ auto newVal = builder->emitLoad(oldArg->getFullType(), newPtr);
ioInfo.newBodyInsts.add(newVal);
return newVal;
@@ -840,9 +833,9 @@ struct FunctionParameterSpecializationContext
// the information we have gathered.
//
IRFunc* generateSpecializedFunc(
- IRFunc* oldFunc,
- FuncSpecializationInfo const& funcInfo,
- CallSpecializationInfo const& callInfo)
+ IRFunc* oldFunc,
+ FuncSpecializationInfo const& funcInfo,
+ CallSpecializationInfo const& callInfo)
{
// We will make use of the infrastructure for cloning
// IR code, that is defined in `ir-clone.{h,cpp}`.
@@ -859,7 +852,7 @@ struct FunctionParameterSpecializationContext
// already gathered.
//
UInt paramCounter = 0;
- for( auto oldParam : oldFunc->getParams() )
+ for (auto oldParam : oldFunc->getParams())
{
UInt paramIndex = paramCounter++;
auto newVal = funcInfo.replacementsForOldParameters[paramIndex];
@@ -876,7 +869,7 @@ struct FunctionParameterSpecializationContext
// their types.
//
List<IRType*> paramTypes;
- for( auto param : funcInfo.newParams )
+ for (auto param : funcInfo.newParams)
{
paramTypes.add(param->getFullType());
}
@@ -898,11 +891,7 @@ struct FunctionParameterSpecializationContext
// to perform the second phase of cloning, which will recursively
// clone any nested decorations, blocks, and instructions.
//
- cloneInstDecorationsAndChildren(
- &cloneEnv,
- builder->getModule(),
- oldFunc,
- newFunc);
+ cloneInstDecorationsAndChildren(&cloneEnv, builder->getModule(), oldFunc, newFunc);
// If we have added an Linkage decoration, we want to remove and destroy it,
// because the linkage should only be on the original function and
@@ -918,7 +907,7 @@ struct FunctionParameterSpecializationContext
const auto end = decorationList.end();
auto cur = decorationList.begin();
- while(cur != end)
+ while (cur != end)
{
IRDecoration* decoration = *cur;
@@ -930,10 +919,12 @@ struct FunctionParameterSpecializationContext
{
decoration->removeAndDeallocate();
}
- else if (as<IRReadNoneDecoration>(decoration) || as<IRNoSideEffectDecoration>(decoration))
+ else if (
+ as<IRReadNoneDecoration>(decoration) ||
+ as<IRNoSideEffectDecoration>(decoration))
{
// After specialization, the function may no longer be side effect free
- // because the parameter we substituted in maybe a global param.
+ // because the parameter we substituted in maybe a global param.
decoration->removeAndDeallocate();
}
}
@@ -963,11 +954,11 @@ struct FunctionParameterSpecializationContext
// which has the effect or arranging them in the output
// in the order they are enumerated here.
//
- for( auto newParam : funcInfo.newParams )
+ for (auto newParam : funcInfo.newParams)
{
newParam->insertBefore(newFirstOrdinary);
}
- for( auto newBodyInst : funcInfo.newBodyInsts )
+ for (auto newBodyInst : funcInfo.newBodyInsts)
{
newBodyInst->insertBefore(newFirstOrdinary);
}
@@ -1011,15 +1002,18 @@ struct FunctionParameterSpecializationContext
}
}
- simplifyFunc(codeGenContext->getTargetProgram(), newFunc, IRSimplificationOptions::getFast(codeGenContext->getTargetProgram()));
+ simplifyFunc(
+ codeGenContext->getTargetProgram(),
+ newFunc,
+ IRSimplificationOptions::getFast(codeGenContext->getTargetProgram()));
return newFunc;
}
void maybeInsertNonUniformResourceIndex(
IRFunc* newFunc,
- FuncSpecializationInfo const& funcInfo,
- CallSpecializationInfo const& callInfo)
+ FuncSpecializationInfo const& funcInfo,
+ CallSpecializationInfo const& callInfo)
{
auto builder = getBuilder();
uint32_t paramIndex = 0;
@@ -1051,7 +1045,6 @@ struct FunctionParameterSpecializationContext
}
paramIndex++;
}
-
}
};
@@ -1061,7 +1054,7 @@ struct FunctionParameterSpecializationContext
//
bool specializeFunctionCalls(
CodeGenContext* codeGenContext,
- IRModule* module,
+ IRModule* module,
FunctionCallSpecializeCondition* condition)
{
FunctionParameterSpecializationContext context;
@@ -1072,4 +1065,4 @@ bool specializeFunctionCalls(
return context.processModule();
}
-} // namesapce Slang
+} // namespace Slang