summaryrefslogtreecommitdiff
path: root/source/slang/slang-ir-lower-expand-type.cpp
diff options
context:
space:
mode:
authorYong He <yonghe@outlook.com>2024-08-23 21:45:59 -0700
committerGitHub <noreply@github.com>2024-08-23 21:45:59 -0700
commitb2ca2d5a4efeae807d3c3f48f60235e47413b559 (patch)
tree643d2bab5776e5f8f7cfa722975af9e826d77c9d /source/slang/slang-ir-lower-expand-type.cpp
parente4088cd602bd4d5a72fea67a787b1319acfc044d (diff)
Make variadic generics work with interfaces and forward autodiff. (#4905)
Diffstat (limited to 'source/slang/slang-ir-lower-expand-type.cpp')
-rw-r--r--source/slang/slang-ir-lower-expand-type.cpp167
1 files changed, 167 insertions, 0 deletions
diff --git a/source/slang/slang-ir-lower-expand-type.cpp b/source/slang/slang-ir-lower-expand-type.cpp
new file mode 100644
index 000000000..8b68b1fc1
--- /dev/null
+++ b/source/slang/slang-ir-lower-expand-type.cpp
@@ -0,0 +1,167 @@
+#include "slang-ir-lower-expand-type.h"
+
+#include "slang-ir-insts.h"
+#include "slang-ir-util.h"
+#include "slang-ir-clone.h"
+
+namespace Slang
+{
+ IRInst* clonePatternVal(IRCloneEnv& cloneEnv, IRBuilder* builder, IRInst* val, IRInst* eachIndex);
+
+ IRInst* clonePatternValImpl(IRCloneEnv& cloneEnv, IRBuilder* builder, IRInst* val, IRInst* eachIndex)
+ {
+ if (!val)
+ return val;
+
+ switch (val->getOp())
+ {
+ case kIROp_ExpandTypeOrVal:
+ return val;
+ case kIROp_Each:
+ {
+ auto eachInst = as<IREach>(val);
+ auto packInst = eachInst->getElement();
+ packInst = clonePatternValImpl(cloneEnv, builder, packInst, eachIndex);
+ auto result = builder->emitGetTupleElement(val->getFullType(), packInst, eachIndex);
+ return result;
+ }
+ case kIROp_Specialize:
+ case kIROp_LookupWitness:
+ case kIROp_ExtractExistentialType:
+ case kIROp_ExtractExistentialWitnessTable:
+ break;
+ default:
+ // If the value is not a type, and it is not in a block, then it is some global inst
+ // that shouldn't be deep copied into current block, such as a IRFunc.
+ if (!as<IRType>(val) && getBlock(val->getParent()) == nullptr)
+ return val;
+ break;
+ }
+ bool anyChange = false;
+ ShortList<IRInst*> operands;
+ for (UInt i = 0; i < val->getOperandCount(); i++)
+ {
+ auto newOperand = clonePatternVal(cloneEnv, builder, val->getOperand(i), eachIndex);
+ if (newOperand != val->getOperand(i))
+ anyChange = true;
+ operands.add(newOperand);
+ }
+ auto newType = clonePatternVal(cloneEnv, builder, val->getFullType(), eachIndex);
+ if (newType != val->getFullType())
+ anyChange = true;
+ if (!anyChange)
+ return val;
+
+ auto newVal = builder->emitIntrinsicInst((IRType*)newType, val->getOp(), operands.getCount(), operands.getArrayView().getBuffer());
+ if (newVal != val)
+ {
+ cloneInstDecorationsAndChildren(&cloneEnv, builder->getModule(), val, newVal);
+ }
+ return newVal;
+ }
+
+ IRInst* clonePatternVal(IRCloneEnv& cloneEnv, IRBuilder* builder, IRInst* val, IRInst* eachIndex)
+ {
+ if (auto clonedVal = cloneEnv.mapOldValToNew.tryGetValue(val))
+ return *clonedVal;
+ cloneEnv.mapOldValToNew[val] = val;
+ auto result = clonePatternValImpl(cloneEnv, builder, val, eachIndex);
+ cloneEnv.mapOldValToNew[val] = result;
+ return result;
+ }
+
+ // Translate a `IRExpandType` into an `IRExpand` where the `PatternType` is defined
+ // inside the `IRExpand` body.
+ //
+ IRInst* lowerExpandTypeImpl(IRExpandType* expandType)
+ {
+ // Turn `IRExpandType` into an `IRExpand` instruction.
+ IRBuilder builder(expandType);
+ builder.setInsertBefore(expandType);
+ List<IRInst*> capturedArgs;
+ IRCloneEnv cloneEnv;
+ for (UInt i = 0; i < expandType->getCaptureCount(); i++)
+ {
+ auto capturedArg = expandType->getCaptureType(i);
+ capturedArgs.add(capturedArg);
+ }
+ auto result = builder.emitExpandInst(expandType->getFullType(), expandType->getCaptureCount(), capturedArgs.getBuffer());
+ builder.setInsertInto(result);
+ builder.emitBlock();
+ auto eachIndex = builder.emitParam(builder.getIntType());
+ auto newPatternType = clonePatternVal(cloneEnv, &builder, expandType->getPatternType(), eachIndex);
+ builder.emitYield(newPatternType);
+ return result;
+ }
+
+ // Process the body of an `IRExpand` instruction, and replace the type of children insts if it
+ // is an `IRExpandType`.
+ //
+ void processExpandVal(IRExpand* expandVal)
+ {
+ IRBuilder builder(expandVal);
+ IRCloneEnv cloneEnv;
+ auto eachIndex = expandVal->getFirstBlock()->getFirstParam();
+ for (auto block : expandVal->getBlocks())
+ {
+ for (auto inst : block->getModifiableChildren())
+ {
+ builder.setInsertBefore(inst);
+ auto newType = clonePatternVal(cloneEnv, &builder, inst->getFullType(), eachIndex);
+ if (newType != inst->getFullType())
+ {
+ inst = builder.replaceOperand(&inst->typeUse, newType);
+ }
+ for (UInt i = 0; i < inst->getOperandCount(); i++)
+ {
+ auto oldOperand = inst->getOperand(i);
+ if (!oldOperand)
+ continue;
+ if (isChildInstOf(oldOperand, expandVal))
+ continue;
+ auto newOperand = clonePatternVal(cloneEnv, &builder, oldOperand, eachIndex);
+ if (newOperand != inst->getOperand(i))
+ {
+ inst = builder.replaceOperand(inst->getOperands() + i, newOperand);
+ }
+ }
+ }
+ }
+ }
+
+ void lowerExpandType(IRModule* module)
+ {
+ // Use a work list to process all instructions in the module, and lower any `IRExpandType` we see
+ // along the way.
+
+ List<IRInst*> workList;
+ for (auto type : module->getGlobalInsts())
+ {
+ workList.add(type);
+ }
+
+ while (workList.getCount() != 0)
+ {
+ auto inst = workList.getLast();
+ workList.removeLast();
+
+ if (auto expandType = as<IRExpandType>(inst))
+ {
+ inst = lowerExpandTypeImpl(expandType);
+ if (inst != expandType)
+ {
+ expandType->replaceUsesWith(inst);
+ expandType->removeAndDeallocate();
+ }
+ }
+ else if (auto expandVal = as<IRExpand>(inst))
+ {
+ processExpandVal(expandVal);
+ }
+ for (auto child : inst->getChildren())
+ {
+ workList.add(child);
+ }
+ }
+ }
+}