diff options
| author | Yong He <yonghe@outlook.com> | 2024-08-23 21:45:59 -0700 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2024-08-23 21:45:59 -0700 |
| commit | b2ca2d5a4efeae807d3c3f48f60235e47413b559 (patch) | |
| tree | 643d2bab5776e5f8f7cfa722975af9e826d77c9d /source/slang/slang-ir-specialize.cpp | |
| parent | e4088cd602bd4d5a72fea67a787b1319acfc044d (diff) | |
Make variadic generics work with interfaces and forward autodiff. (#4905)
Diffstat (limited to 'source/slang/slang-ir-specialize.cpp')
| -rw-r--r-- | source/slang/slang-ir-specialize.cpp | 352 |
1 files changed, 203 insertions, 149 deletions
diff --git a/source/slang/slang-ir-specialize.cpp b/source/slang/slang-ir-specialize.cpp index c9e94352e..a56dae025 100644 --- a/source/slang/slang-ir-specialize.cpp +++ b/source/slang/slang-ir-specialize.cpp @@ -8,6 +8,7 @@ #include "slang-ir-lower-witness-lookup.h" #include "slang-ir-dce.h" #include "slang-ir-sccp.h" +#include "slang-ir-util.h" #include "../core/slang-performance-profiler.h" namespace Slang @@ -85,6 +86,7 @@ struct SpecializationContext { case kIROp_GlobalGenericParam: case kIROp_LookupWitness: + case kIROp_GetTupleElement: return false; case kIROp_Specialize: // The `specialize` instruction is a bit sepcial, @@ -589,9 +591,6 @@ struct SpecializationContext case kIROp_Expand: return maybeSpecializeExpand(as<IRExpand>(inst)); - case kIROp_ExpandTypeOrVal: - return maybeSpecializeExpandTypeOrVal(as<IRExpandType>(inst)); - case kIROp_GetTupleElement: return maybeSpecializeFoldableInst(inst); @@ -605,6 +604,15 @@ struct SpecializationContext case kIROp_CountOf: return maybeSpecializeCountOf(inst); + + case kIROp_Func: + + if (tryExpandParameterPack(as<IRFunc>(inst))) + { + addUsersToWorkList(inst); + return true; + } + return false; } } @@ -1010,6 +1018,9 @@ struct SpecializationContext workList.removeLast(); workListSet.remove(inst); + if (!inst->getParent() && inst->getOp() != kIROp_Module) + continue; + // For each instruction we process, we want to perform // a few steps. // @@ -1182,11 +1193,8 @@ struct SpecializationContext auto newWrapExistential = builder.emitWrapExistential( resultType, newCall, slotOperandCount, slotOperands.getArrayView().getBuffer()); inst->replaceUsesWith(newWrapExistential); - workList.remove(inst); inst->removeAndDeallocate(); addUsersToWorkList(newWrapExistential); - - workList.remove(wrapExistential); SLANG_ASSERT(!wrapExistential->hasUses()); wrapExistential->removeAndDeallocate(); return true; @@ -1209,6 +1217,14 @@ struct SpecializationContext if (maybeSpecializeBufferLoadCall(inst)) return false; + // If any arguments are value packs, we need to flatten them. + bool isCalleeFullyExpanded = false; + tryExpandParameterPack(as<IRFunc>(inst->getCallee()), &isCalleeFullyExpanded); + if (isCalleeFullyExpanded) + { + inst = tryExpandArgPack((IRCall*)inst); + } + // We can only specialize a call when the callee function is known. // auto calleeFunc = as<IRFunc>(inst->getCallee()); @@ -2402,13 +2418,9 @@ struct SpecializationContext break; } } - auto type = clonePatternVal(*subEnv, builder, childInst->getFullType(), index); - for (UInt i = 0; i < childInst->getOperandCount(); i++) - { - clonePatternVal(*subEnv, builder, childInst->getOperand(i), index); - } auto newInst = cloneInst(subEnv, builder, childInst); - newInst = builder->replaceOperand(&newInst->typeUse, type); + if (newInst != childInst) + addToWorkList(newInst); subEnv->mapOldValToNew[childInst] = newInst; IRBuilder subBuilder(*builder); subBuilder.setInsertInto(newInst); @@ -2419,6 +2431,32 @@ struct SpecializationContext return newInst; } + // A helper function to emit a MakeWitnessPack, MakeTypePack or MakeValuePack inst from + // a collection of elements, dependending on `type`. + // + IRInst* makeSpecializedPack(IRBuilder& builder, IRType* type, ArrayView<IRInst*> elements) + { + IRInst* resultPack = nullptr; + if (as<IRWitnessTableType>(type)) + { + List<IRType*> types; + for (auto element : elements) + types.add(element->getDataType()); + auto newTypePack = builder.getTypePack(elements.getCount(), types.getBuffer()); + resultPack = builder.emitMakeWitnessPack(newTypePack, elements); + } + else if (as<IRTypeKind>(type) || as<IRTypeType>(type)) + { + auto newTypePack = builder.getTypePack(elements.getCount(), (IRType* const*)elements.getBuffer()); + resultPack = newTypePack; + } + else + { + resultPack = builder.emitMakeValuePack((UInt)elements.getCount(), elements.getBuffer()); + } + return resultPack; + } + bool maybeSpecializeExpand(IRExpand* expandInst) { if (expandInst->getCaptureCount() == 0) @@ -2440,44 +2478,57 @@ struct SpecializationContext } if (elementCount == 0) { - auto resultValuePack = builder.emitMakeValuePack(0, (IRInst*const*)nullptr); - expandInst->replaceUsesWith(resultValuePack); + auto resultPack = makeSpecializedPack(builder, expandInst->getDataType(), elements.getArrayView()); + expandInst->replaceUsesWith(resultPack); expandInst->removeAndDeallocate(); - addUsersToWorkList(resultValuePack); + addUsersToWorkList(resultPack); return true; } + + bool isMultiBlock = as<IRYield>(expandInst->getFirstBlock()->getTerminator()) == nullptr; for (UInt i = 0; i < elementCount; i++) { IRCloneEnv cloneEnv; - IRBlock* firstBlock = nullptr; IRBuilder subBuilder = builder; - for (auto childBlock : expandInst->getBlocks()) + IRBlock* mergeBlock = nullptr; + if (isMultiBlock) { - auto newBlock = subBuilder.emitBlock(); - if (!firstBlock) - firstBlock = newBlock; - cloneEnv.mapOldValToNew[childBlock] = newBlock; + IRBlock* firstBlock = nullptr; + for (auto childBlock : expandInst->getBlocks()) + { + auto newBlock = subBuilder.emitBlock(); + if (!firstBlock) + firstBlock = newBlock; + cloneEnv.mapOldValToNew[childBlock] = newBlock; + } + + builder.emitBranch(firstBlock); + + mergeBlock = subBuilder.emitBlock(); + builder.setInsertInto(mergeBlock); } + auto indexParam = expandInst->getFirstBlock()->getFirstParam(); SLANG_ASSERT(indexParam); cloneEnv.mapOldValToNew[indexParam] = subBuilder.getIntValue(subBuilder.getIntType(), i); - builder.emitBranch(firstBlock); - - IRBlock* mergeBlock = subBuilder.emitBlock(); - builder.setInsertInto(mergeBlock); - for (auto childBlock : expandInst->getBlocks()) { - auto newBlock = cloneEnv.mapOldValToNew[childBlock]; - subBuilder.setInsertInto(newBlock); + if (isMultiBlock) + { + auto newBlock = cloneEnv.mapOldValToNew[childBlock]; + subBuilder.setInsertInto(newBlock); + } for (auto child : childBlock->getChildren()) { if (as<IRYield>(child)) { - elements.add(cloneEnv.mapOldValToNew[child->getOperand(0)]); - subBuilder.emitBranch(mergeBlock); + auto currentResult = child->getOperand(0); + currentResult = findCloneForOperand(&cloneEnv, currentResult); + elements.add(currentResult); + if (isMultiBlock) + subBuilder.emitBranch(mergeBlock); continue; } specializeExpandChildInst(cloneEnv, &subBuilder, child, i); @@ -2486,129 +2537,22 @@ struct SpecializationContext } } - auto resultValuePack = builder.emitMakeValuePack((UInt)elements.getCount(), elements.getBuffer()); - auto currentBlock = builder.getBlock(); - for (auto nextInst = expandInst->next; nextInst;) - { - auto next = nextInst->next; - nextInst->insertAtEnd(currentBlock); - nextInst = next; - } - addUsersToWorkList(expandInst); - expandInst->replaceUsesWith(resultValuePack); - expandInst->removeAndDeallocate(); - return true; - } - IRInst* clonePatternValImpl(IRCloneEnv& cloneEnv, IRBuilder* builder, IRInst* val, UInt indexInPack) - { - if (!val) - return val; - - switch (val->getOp()) - { - case kIROp_ExpandTypeOrVal: - return val; - case kIROp_Each: + IRInst* resultPack = makeSpecializedPack(builder, expandInst->getDataType(), elements.getArrayView()); + if (isMultiBlock) { - auto eachInst = as<IREach>(val); - auto packInst = eachInst->getElement(); - if (auto typePack = as<IRTypePack>(packInst)) - { - SLANG_RELEASE_ASSERT(indexInPack < typePack->getOperandCount()); - return typePack->getOperand(indexInPack); - } - else if (auto makeValuePack = as<IRMakeValuePack>(packInst)) - { - SLANG_RELEASE_ASSERT(indexInPack < makeValuePack->getOperandCount()); - return makeValuePack->getOperand(indexInPack); - } - else if (!as<IRTypeKind>(packInst->getDataType())) + auto currentBlock = builder.getBlock(); + for (auto nextInst = expandInst->next; nextInst;) { - auto type = clonePatternVal(cloneEnv, builder, val, indexInPack); - return builder->emitGetTupleElement((IRType*)type, packInst, indexInPack); + auto next = nextInst->next; + nextInst->insertAtEnd(currentBlock); + nextInst = next; } - return val; - } - default: - break; - } - bool anyChange = false; - ShortList<IRInst*> operands; - for (UInt i = 0; i < val->getOperandCount(); i++) - { - auto newOperand = clonePatternVal(cloneEnv, builder, val->getOperand(i), indexInPack); - if (newOperand != val->getOperand(i)) - anyChange = true; - operands.add(newOperand); - } - auto newType = clonePatternVal(cloneEnv, builder, val->getFullType(), indexInPack); - if (newType != val->getFullType()) - anyChange = true; - if (!anyChange) - return val; - - auto newVal = builder->emitIntrinsicInst((IRType*)newType, val->getOp(), operands.getCount(), operands.getArrayView().getBuffer()); - if (newVal != val) - { - cloneInstDecorationsAndChildren(&cloneEnv, module, val, newVal); - } - return newVal; - } - - IRInst* clonePatternVal(IRCloneEnv& cloneEnv, IRBuilder* builder, IRInst* val, UInt indexInPack) - { - if (auto clonedVal = cloneEnv.mapOldValToNew.tryGetValue(val)) - return *clonedVal; - cloneEnv.mapOldValToNew[val] = val; - auto result = clonePatternValImpl(cloneEnv, builder, val, indexInPack); - cloneEnv.mapOldValToNew[val] = result; - return result; - } - - bool maybeSpecializeExpandTypeOrVal(IRExpandType* expandInst) - { - if (expandInst->getCaptureCount() == 0) - return false; - - for (UInt i = 0; i < expandInst->getCaptureCount(); i++) - { - if (!as<IRTypePack>(expandInst->getCaptureType(i))) - return false; - } - IRBuilder builder(expandInst); - builder.setInsertBefore(expandInst); - List<IRInst*> elements; - UInt elementCount = 0; - if (auto firstTypePack = as<IRTypePack>(expandInst->getCaptureType(0))) - { - elementCount = firstTypePack->getOperandCount(); - } - for (UInt i = 0; i < elementCount; i++) - { - IRCloneEnv cloneEnv; - auto element = clonePatternVal(cloneEnv, &builder, expandInst->getPatternType(), i); - elements.add(element); } addUsersToWorkList(expandInst); - if (as<IRWitnessTableType>(expandInst->getDataType())) - { - List<IRType*> types; - for (auto element : elements) - types.add(element->getDataType()); - auto newTypePack = builder.getTypePack(elements.getCount(), types.getBuffer()); - auto result = builder.emitMakeWitnessPack(newTypePack, elements.getArrayView()); - expandInst->replaceUsesWith(result); - expandInst->removeAndDeallocate(); - return true; - } - else - { - auto newTypePack = builder.getTypePack(elements.getCount(), (IRType*const*)elements.getBuffer()); - expandInst->replaceUsesWith(newTypePack); - expandInst->removeAndDeallocate(); - return true; - } + expandInst->replaceUsesWith(resultPack); + expandInst->removeAndDeallocate(); + return true; } // The handling of specialization for global generic type @@ -2680,6 +2624,108 @@ struct SpecializationContext } } } + + + // If `func` has any parameters whose types are `IRTypePack`, then we will expand them + // into multiple parameters, so that the function has no parameters of type `IRTypePack`. + // returns true if changes are made. + // For example, this function turns `int f(TypePack<int, float> v)` into + // ``` + // int f(int v0, float v1) + // { + // v = MakeValuePack(v0,. v1); + // ... + // } + // ``` + // + bool tryExpandParameterPack(IRFunc* func, bool* outIsFullyExpanded = nullptr) + { + if (!func) + return false; + if (outIsFullyExpanded) + *outIsFullyExpanded = true; + ShortList<IRInst*> params; + for (auto param : func->getParams()) + { + if (as<IRTypePack>(param->getDataType())) + params.add(param); + if (as<IRExpand>(param->getDataType())) + { + if (outIsFullyExpanded) + *outIsFullyExpanded = false; + return false; + } + } + if (params.getCount() == 0) + return false; + + IRBuilder builder(func); + for (auto param : params) + { + builder.setInsertBefore(param); + auto typePack = as<IRTypePack>(param->getDataType()); + ShortList<IRInst*> newParams; + for (UInt i = 0; i < typePack->getOperandCount(); i++) + { + auto newParam = builder.createParam((IRType*)typePack->getOperand(i)); + newParam->insertBefore(param); + newParams.add(newParam); + } + setInsertBeforeOrdinaryInst(&builder, param); + auto val = builder.emitMakeValuePack(typePack, (UInt)newParams.getCount(), newParams.getArrayView().getBuffer()); + param->replaceUsesWith(val); + param->removeAndDeallocate(); + addUsersToWorkList(val); + } + + fixUpFuncType(func); + return true; + } + + // If any arguments in a call is a value pack, we will expand them into the argument list, + // so that the call has no arguments of type `IRTypePack`. + // For example, we will turn `f(MakeValuePack(a, b))` into `f(a, b)`. + // + IRCall* tryExpandArgPack(IRCall* call) + { + bool anyArgPack = false; + for (UInt i = 0; i < call->getArgCount(); i++) + { + auto arg = call->getArg(i); + if (as<IRTypePack>(arg->getDataType())) + { + anyArgPack = true; + break; + } + } + if (!anyArgPack) + return call; + IRBuilder builder(call); + builder.setInsertBefore(call); + List<IRInst*> newArgs; + for (UInt i = 0; i < call->getArgCount(); i++) + { + auto arg = call->getArg(i); + if (auto typePack = as<IRTypePack>(arg->getDataType())) + { + for (UInt elementIndex = 0; elementIndex < typePack->getOperandCount(); elementIndex++) + { + auto newArg = builder.emitGetTupleElement((IRType*)typePack->getOperand(elementIndex), arg, elementIndex); + newArgs.add(newArg); + } + } + else + { + newArgs.add(arg); + } + } + auto newCall = builder.emitCallInst(call->getFullType(), call->getCallee(), newArgs.getArrayView()); + call->replaceUsesWith(newCall); + call->transferDecorationsTo(newCall); + call->removeAndDeallocate(); + return newCall; + } + }; bool specializeModule( @@ -2785,6 +2831,13 @@ IRInst* specializeGenericImpl( IRBuilder* builder = &builderStorage; builder->setInsertBefore(genericVal); + List<IRInst*> pendingWorkList; + SLANG_DEFER + ( + for (Index ii = pendingWorkList.getCount() - 1; ii >= 0; ii--) + context->addToWorkList(pendingWorkList[ii]); + ); + // Now we will run through the body of the generic and // clone each of its instructions into the global scope, // until we reach a `return` instruction. @@ -2825,10 +2878,11 @@ IRInst* specializeGenericImpl( { if (auto func = as<IRFunc>(specializedVal)) { + context->tryExpandParameterPack(func); simplifyFunc(context->targetProgram, func, IRSimplificationOptions::getFast(context->targetProgram)); } } - + pendingWorkList.add(specializedVal); return specializedVal; } @@ -2848,7 +2902,7 @@ IRInst* specializeGenericImpl( // if (context) { - context->addToWorkList(clonedInst); + pendingWorkList.add(clonedInst); } } } |
