summaryrefslogtreecommitdiffstats
path: root/source/slang/slang-ir-specialize.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'source/slang/slang-ir-specialize.cpp')
-rw-r--r--source/slang/slang-ir-specialize.cpp352
1 files changed, 203 insertions, 149 deletions
diff --git a/source/slang/slang-ir-specialize.cpp b/source/slang/slang-ir-specialize.cpp
index c9e94352e..a56dae025 100644
--- a/source/slang/slang-ir-specialize.cpp
+++ b/source/slang/slang-ir-specialize.cpp
@@ -8,6 +8,7 @@
#include "slang-ir-lower-witness-lookup.h"
#include "slang-ir-dce.h"
#include "slang-ir-sccp.h"
+#include "slang-ir-util.h"
#include "../core/slang-performance-profiler.h"
namespace Slang
@@ -85,6 +86,7 @@ struct SpecializationContext
{
case kIROp_GlobalGenericParam:
case kIROp_LookupWitness:
+ case kIROp_GetTupleElement:
return false;
case kIROp_Specialize:
// The `specialize` instruction is a bit sepcial,
@@ -589,9 +591,6 @@ struct SpecializationContext
case kIROp_Expand:
return maybeSpecializeExpand(as<IRExpand>(inst));
- case kIROp_ExpandTypeOrVal:
- return maybeSpecializeExpandTypeOrVal(as<IRExpandType>(inst));
-
case kIROp_GetTupleElement:
return maybeSpecializeFoldableInst(inst);
@@ -605,6 +604,15 @@ struct SpecializationContext
case kIROp_CountOf:
return maybeSpecializeCountOf(inst);
+
+ case kIROp_Func:
+
+ if (tryExpandParameterPack(as<IRFunc>(inst)))
+ {
+ addUsersToWorkList(inst);
+ return true;
+ }
+ return false;
}
}
@@ -1010,6 +1018,9 @@ struct SpecializationContext
workList.removeLast();
workListSet.remove(inst);
+ if (!inst->getParent() && inst->getOp() != kIROp_Module)
+ continue;
+
// For each instruction we process, we want to perform
// a few steps.
//
@@ -1182,11 +1193,8 @@ struct SpecializationContext
auto newWrapExistential = builder.emitWrapExistential(
resultType, newCall, slotOperandCount, slotOperands.getArrayView().getBuffer());
inst->replaceUsesWith(newWrapExistential);
- workList.remove(inst);
inst->removeAndDeallocate();
addUsersToWorkList(newWrapExistential);
-
- workList.remove(wrapExistential);
SLANG_ASSERT(!wrapExistential->hasUses());
wrapExistential->removeAndDeallocate();
return true;
@@ -1209,6 +1217,14 @@ struct SpecializationContext
if (maybeSpecializeBufferLoadCall(inst))
return false;
+ // If any arguments are value packs, we need to flatten them.
+ bool isCalleeFullyExpanded = false;
+ tryExpandParameterPack(as<IRFunc>(inst->getCallee()), &isCalleeFullyExpanded);
+ if (isCalleeFullyExpanded)
+ {
+ inst = tryExpandArgPack((IRCall*)inst);
+ }
+
// We can only specialize a call when the callee function is known.
//
auto calleeFunc = as<IRFunc>(inst->getCallee());
@@ -2402,13 +2418,9 @@ struct SpecializationContext
break;
}
}
- auto type = clonePatternVal(*subEnv, builder, childInst->getFullType(), index);
- for (UInt i = 0; i < childInst->getOperandCount(); i++)
- {
- clonePatternVal(*subEnv, builder, childInst->getOperand(i), index);
- }
auto newInst = cloneInst(subEnv, builder, childInst);
- newInst = builder->replaceOperand(&newInst->typeUse, type);
+ if (newInst != childInst)
+ addToWorkList(newInst);
subEnv->mapOldValToNew[childInst] = newInst;
IRBuilder subBuilder(*builder);
subBuilder.setInsertInto(newInst);
@@ -2419,6 +2431,32 @@ struct SpecializationContext
return newInst;
}
+ // A helper function to emit a MakeWitnessPack, MakeTypePack or MakeValuePack inst from
+ // a collection of elements, dependending on `type`.
+ //
+ IRInst* makeSpecializedPack(IRBuilder& builder, IRType* type, ArrayView<IRInst*> elements)
+ {
+ IRInst* resultPack = nullptr;
+ if (as<IRWitnessTableType>(type))
+ {
+ List<IRType*> types;
+ for (auto element : elements)
+ types.add(element->getDataType());
+ auto newTypePack = builder.getTypePack(elements.getCount(), types.getBuffer());
+ resultPack = builder.emitMakeWitnessPack(newTypePack, elements);
+ }
+ else if (as<IRTypeKind>(type) || as<IRTypeType>(type))
+ {
+ auto newTypePack = builder.getTypePack(elements.getCount(), (IRType* const*)elements.getBuffer());
+ resultPack = newTypePack;
+ }
+ else
+ {
+ resultPack = builder.emitMakeValuePack((UInt)elements.getCount(), elements.getBuffer());
+ }
+ return resultPack;
+ }
+
bool maybeSpecializeExpand(IRExpand* expandInst)
{
if (expandInst->getCaptureCount() == 0)
@@ -2440,44 +2478,57 @@ struct SpecializationContext
}
if (elementCount == 0)
{
- auto resultValuePack = builder.emitMakeValuePack(0, (IRInst*const*)nullptr);
- expandInst->replaceUsesWith(resultValuePack);
+ auto resultPack = makeSpecializedPack(builder, expandInst->getDataType(), elements.getArrayView());
+ expandInst->replaceUsesWith(resultPack);
expandInst->removeAndDeallocate();
- addUsersToWorkList(resultValuePack);
+ addUsersToWorkList(resultPack);
return true;
}
+
+ bool isMultiBlock = as<IRYield>(expandInst->getFirstBlock()->getTerminator()) == nullptr;
for (UInt i = 0; i < elementCount; i++)
{
IRCloneEnv cloneEnv;
- IRBlock* firstBlock = nullptr;
IRBuilder subBuilder = builder;
- for (auto childBlock : expandInst->getBlocks())
+ IRBlock* mergeBlock = nullptr;
+ if (isMultiBlock)
{
- auto newBlock = subBuilder.emitBlock();
- if (!firstBlock)
- firstBlock = newBlock;
- cloneEnv.mapOldValToNew[childBlock] = newBlock;
+ IRBlock* firstBlock = nullptr;
+ for (auto childBlock : expandInst->getBlocks())
+ {
+ auto newBlock = subBuilder.emitBlock();
+ if (!firstBlock)
+ firstBlock = newBlock;
+ cloneEnv.mapOldValToNew[childBlock] = newBlock;
+ }
+
+ builder.emitBranch(firstBlock);
+
+ mergeBlock = subBuilder.emitBlock();
+ builder.setInsertInto(mergeBlock);
}
+
auto indexParam = expandInst->getFirstBlock()->getFirstParam();
SLANG_ASSERT(indexParam);
cloneEnv.mapOldValToNew[indexParam] = subBuilder.getIntValue(subBuilder.getIntType(), i);
- builder.emitBranch(firstBlock);
-
- IRBlock* mergeBlock = subBuilder.emitBlock();
- builder.setInsertInto(mergeBlock);
-
for (auto childBlock : expandInst->getBlocks())
{
- auto newBlock = cloneEnv.mapOldValToNew[childBlock];
- subBuilder.setInsertInto(newBlock);
+ if (isMultiBlock)
+ {
+ auto newBlock = cloneEnv.mapOldValToNew[childBlock];
+ subBuilder.setInsertInto(newBlock);
+ }
for (auto child : childBlock->getChildren())
{
if (as<IRYield>(child))
{
- elements.add(cloneEnv.mapOldValToNew[child->getOperand(0)]);
- subBuilder.emitBranch(mergeBlock);
+ auto currentResult = child->getOperand(0);
+ currentResult = findCloneForOperand(&cloneEnv, currentResult);
+ elements.add(currentResult);
+ if (isMultiBlock)
+ subBuilder.emitBranch(mergeBlock);
continue;
}
specializeExpandChildInst(cloneEnv, &subBuilder, child, i);
@@ -2486,129 +2537,22 @@ struct SpecializationContext
}
}
- auto resultValuePack = builder.emitMakeValuePack((UInt)elements.getCount(), elements.getBuffer());
- auto currentBlock = builder.getBlock();
- for (auto nextInst = expandInst->next; nextInst;)
- {
- auto next = nextInst->next;
- nextInst->insertAtEnd(currentBlock);
- nextInst = next;
- }
- addUsersToWorkList(expandInst);
- expandInst->replaceUsesWith(resultValuePack);
- expandInst->removeAndDeallocate();
- return true;
- }
- IRInst* clonePatternValImpl(IRCloneEnv& cloneEnv, IRBuilder* builder, IRInst* val, UInt indexInPack)
- {
- if (!val)
- return val;
-
- switch (val->getOp())
- {
- case kIROp_ExpandTypeOrVal:
- return val;
- case kIROp_Each:
+ IRInst* resultPack = makeSpecializedPack(builder, expandInst->getDataType(), elements.getArrayView());
+ if (isMultiBlock)
{
- auto eachInst = as<IREach>(val);
- auto packInst = eachInst->getElement();
- if (auto typePack = as<IRTypePack>(packInst))
- {
- SLANG_RELEASE_ASSERT(indexInPack < typePack->getOperandCount());
- return typePack->getOperand(indexInPack);
- }
- else if (auto makeValuePack = as<IRMakeValuePack>(packInst))
- {
- SLANG_RELEASE_ASSERT(indexInPack < makeValuePack->getOperandCount());
- return makeValuePack->getOperand(indexInPack);
- }
- else if (!as<IRTypeKind>(packInst->getDataType()))
+ auto currentBlock = builder.getBlock();
+ for (auto nextInst = expandInst->next; nextInst;)
{
- auto type = clonePatternVal(cloneEnv, builder, val, indexInPack);
- return builder->emitGetTupleElement((IRType*)type, packInst, indexInPack);
+ auto next = nextInst->next;
+ nextInst->insertAtEnd(currentBlock);
+ nextInst = next;
}
- return val;
- }
- default:
- break;
- }
- bool anyChange = false;
- ShortList<IRInst*> operands;
- for (UInt i = 0; i < val->getOperandCount(); i++)
- {
- auto newOperand = clonePatternVal(cloneEnv, builder, val->getOperand(i), indexInPack);
- if (newOperand != val->getOperand(i))
- anyChange = true;
- operands.add(newOperand);
- }
- auto newType = clonePatternVal(cloneEnv, builder, val->getFullType(), indexInPack);
- if (newType != val->getFullType())
- anyChange = true;
- if (!anyChange)
- return val;
-
- auto newVal = builder->emitIntrinsicInst((IRType*)newType, val->getOp(), operands.getCount(), operands.getArrayView().getBuffer());
- if (newVal != val)
- {
- cloneInstDecorationsAndChildren(&cloneEnv, module, val, newVal);
- }
- return newVal;
- }
-
- IRInst* clonePatternVal(IRCloneEnv& cloneEnv, IRBuilder* builder, IRInst* val, UInt indexInPack)
- {
- if (auto clonedVal = cloneEnv.mapOldValToNew.tryGetValue(val))
- return *clonedVal;
- cloneEnv.mapOldValToNew[val] = val;
- auto result = clonePatternValImpl(cloneEnv, builder, val, indexInPack);
- cloneEnv.mapOldValToNew[val] = result;
- return result;
- }
-
- bool maybeSpecializeExpandTypeOrVal(IRExpandType* expandInst)
- {
- if (expandInst->getCaptureCount() == 0)
- return false;
-
- for (UInt i = 0; i < expandInst->getCaptureCount(); i++)
- {
- if (!as<IRTypePack>(expandInst->getCaptureType(i)))
- return false;
- }
- IRBuilder builder(expandInst);
- builder.setInsertBefore(expandInst);
- List<IRInst*> elements;
- UInt elementCount = 0;
- if (auto firstTypePack = as<IRTypePack>(expandInst->getCaptureType(0)))
- {
- elementCount = firstTypePack->getOperandCount();
- }
- for (UInt i = 0; i < elementCount; i++)
- {
- IRCloneEnv cloneEnv;
- auto element = clonePatternVal(cloneEnv, &builder, expandInst->getPatternType(), i);
- elements.add(element);
}
addUsersToWorkList(expandInst);
- if (as<IRWitnessTableType>(expandInst->getDataType()))
- {
- List<IRType*> types;
- for (auto element : elements)
- types.add(element->getDataType());
- auto newTypePack = builder.getTypePack(elements.getCount(), types.getBuffer());
- auto result = builder.emitMakeWitnessPack(newTypePack, elements.getArrayView());
- expandInst->replaceUsesWith(result);
- expandInst->removeAndDeallocate();
- return true;
- }
- else
- {
- auto newTypePack = builder.getTypePack(elements.getCount(), (IRType*const*)elements.getBuffer());
- expandInst->replaceUsesWith(newTypePack);
- expandInst->removeAndDeallocate();
- return true;
- }
+ expandInst->replaceUsesWith(resultPack);
+ expandInst->removeAndDeallocate();
+ return true;
}
// The handling of specialization for global generic type
@@ -2680,6 +2624,108 @@ struct SpecializationContext
}
}
}
+
+
+ // If `func` has any parameters whose types are `IRTypePack`, then we will expand them
+ // into multiple parameters, so that the function has no parameters of type `IRTypePack`.
+ // returns true if changes are made.
+ // For example, this function turns `int f(TypePack<int, float> v)` into
+ // ```
+ // int f(int v0, float v1)
+ // {
+ // v = MakeValuePack(v0,. v1);
+ // ...
+ // }
+ // ```
+ //
+ bool tryExpandParameterPack(IRFunc* func, bool* outIsFullyExpanded = nullptr)
+ {
+ if (!func)
+ return false;
+ if (outIsFullyExpanded)
+ *outIsFullyExpanded = true;
+ ShortList<IRInst*> params;
+ for (auto param : func->getParams())
+ {
+ if (as<IRTypePack>(param->getDataType()))
+ params.add(param);
+ if (as<IRExpand>(param->getDataType()))
+ {
+ if (outIsFullyExpanded)
+ *outIsFullyExpanded = false;
+ return false;
+ }
+ }
+ if (params.getCount() == 0)
+ return false;
+
+ IRBuilder builder(func);
+ for (auto param : params)
+ {
+ builder.setInsertBefore(param);
+ auto typePack = as<IRTypePack>(param->getDataType());
+ ShortList<IRInst*> newParams;
+ for (UInt i = 0; i < typePack->getOperandCount(); i++)
+ {
+ auto newParam = builder.createParam((IRType*)typePack->getOperand(i));
+ newParam->insertBefore(param);
+ newParams.add(newParam);
+ }
+ setInsertBeforeOrdinaryInst(&builder, param);
+ auto val = builder.emitMakeValuePack(typePack, (UInt)newParams.getCount(), newParams.getArrayView().getBuffer());
+ param->replaceUsesWith(val);
+ param->removeAndDeallocate();
+ addUsersToWorkList(val);
+ }
+
+ fixUpFuncType(func);
+ return true;
+ }
+
+ // If any arguments in a call is a value pack, we will expand them into the argument list,
+ // so that the call has no arguments of type `IRTypePack`.
+ // For example, we will turn `f(MakeValuePack(a, b))` into `f(a, b)`.
+ //
+ IRCall* tryExpandArgPack(IRCall* call)
+ {
+ bool anyArgPack = false;
+ for (UInt i = 0; i < call->getArgCount(); i++)
+ {
+ auto arg = call->getArg(i);
+ if (as<IRTypePack>(arg->getDataType()))
+ {
+ anyArgPack = true;
+ break;
+ }
+ }
+ if (!anyArgPack)
+ return call;
+ IRBuilder builder(call);
+ builder.setInsertBefore(call);
+ List<IRInst*> newArgs;
+ for (UInt i = 0; i < call->getArgCount(); i++)
+ {
+ auto arg = call->getArg(i);
+ if (auto typePack = as<IRTypePack>(arg->getDataType()))
+ {
+ for (UInt elementIndex = 0; elementIndex < typePack->getOperandCount(); elementIndex++)
+ {
+ auto newArg = builder.emitGetTupleElement((IRType*)typePack->getOperand(elementIndex), arg, elementIndex);
+ newArgs.add(newArg);
+ }
+ }
+ else
+ {
+ newArgs.add(arg);
+ }
+ }
+ auto newCall = builder.emitCallInst(call->getFullType(), call->getCallee(), newArgs.getArrayView());
+ call->replaceUsesWith(newCall);
+ call->transferDecorationsTo(newCall);
+ call->removeAndDeallocate();
+ return newCall;
+ }
+
};
bool specializeModule(
@@ -2785,6 +2831,13 @@ IRInst* specializeGenericImpl(
IRBuilder* builder = &builderStorage;
builder->setInsertBefore(genericVal);
+ List<IRInst*> pendingWorkList;
+ SLANG_DEFER
+ (
+ for (Index ii = pendingWorkList.getCount() - 1; ii >= 0; ii--)
+ context->addToWorkList(pendingWorkList[ii]);
+ );
+
// Now we will run through the body of the generic and
// clone each of its instructions into the global scope,
// until we reach a `return` instruction.
@@ -2825,10 +2878,11 @@ IRInst* specializeGenericImpl(
{
if (auto func = as<IRFunc>(specializedVal))
{
+ context->tryExpandParameterPack(func);
simplifyFunc(context->targetProgram, func, IRSimplificationOptions::getFast(context->targetProgram));
}
}
-
+ pendingWorkList.add(specializedVal);
return specializedVal;
}
@@ -2848,7 +2902,7 @@ IRInst* specializeGenericImpl(
//
if (context)
{
- context->addToWorkList(clonedInst);
+ pendingWorkList.add(clonedInst);
}
}
}