summaryrefslogtreecommitdiffstats
path: root/source
diff options
context:
space:
mode:
authorYong He <yonghe@outlook.com>2022-12-19 11:47:19 -0800
committerGitHub <noreply@github.com>2022-12-19 11:47:19 -0800
commit216dfba0af66210a46ef0df18beb73d975fdf727 (patch)
treef397ea5bf8d47d7a5d90dc95edfb472f2e49d762 /source
parent36220da1e29c891972fef32c8575c15f868b9959 (diff)
Separate primal computations from unzipped function into an explicit function. (#2569)
Co-authored-by: Yong He <yhe@nvidia.com>
Diffstat (limited to 'source')
-rw-r--r--source/slang/slang-ir-autodiff-propagate.h6
-rw-r--r--source/slang/slang-ir-autodiff-rev.cpp105
-rw-r--r--source/slang/slang-ir-autodiff-unzip.cpp552
-rw-r--r--source/slang/slang-ir-autodiff-unzip.h7
-rw-r--r--source/slang/slang-ir-clone.cpp18
-rw-r--r--source/slang/slang-ir-clone.h3
-rw-r--r--source/slang/slang-ir-dce.cpp40
-rw-r--r--source/slang/slang-ir-dce.h4
-rw-r--r--source/slang/slang-ir-inst-defs.h4
-rw-r--r--source/slang/slang-ir-inst-pass-base.h11
-rw-r--r--source/slang/slang-ir-insts.h17
-rw-r--r--source/slang/slang-ir-peephole.cpp15
-rw-r--r--source/slang/slang-ir-peephole.h2
-rw-r--r--source/slang/slang-ir-sccp.cpp16
-rw-r--r--source/slang/slang-ir-sccp.h3
-rw-r--r--source/slang/slang-ir-ssa-simplification.cpp25
-rw-r--r--source/slang/slang-ir-ssa-simplification.h3
-rw-r--r--source/slang/slang-ir-ssa.cpp14
-rw-r--r--source/slang/slang-ir-ssa.h2
-rw-r--r--source/slang/slang-ir-util.h11
-rw-r--r--source/slang/slang-ir.cpp10
-rw-r--r--source/slang/slang-lower-to-ir.cpp19
22 files changed, 837 insertions, 50 deletions
diff --git a/source/slang/slang-ir-autodiff-propagate.h b/source/slang/slang-ir-autodiff-propagate.h
index 0d5686899..4edf20142 100644
--- a/source/slang/slang-ir-autodiff-propagate.h
+++ b/source/slang/slang-ir-autodiff-propagate.h
@@ -10,12 +10,12 @@
namespace Slang
{
-bool isDifferentialInst(IRInst* inst)
+inline bool isDifferentialInst(IRInst* inst)
{
return inst->findDecoration<IRDifferentialInstDecoration>();
}
-bool isMixedDifferentialInst(IRInst* inst)
+inline bool isMixedDifferentialInst(IRInst* inst)
{
return inst->findDecoration<IRMixedDifferentialInstDecoration>();
}
@@ -104,4 +104,4 @@ struct DiffPropagationPass : InstPassBase
}
};
-} \ No newline at end of file
+}
diff --git a/source/slang/slang-ir-autodiff-rev.cpp b/source/slang/slang-ir-autodiff-rev.cpp
index c7fbc415a..56002231a 100644
--- a/source/slang/slang-ir-autodiff-rev.cpp
+++ b/source/slang/slang-ir-autodiff-rev.cpp
@@ -491,6 +491,98 @@ struct BackwardDiffTranscriber
builder.emitBranch(firstBlock);
}
+ void cleanUpUnusedPrimalIntermediate(IRInst* func, IRInst* primalFunc, IRInst* intermediateType)
+ {
+ IRStructType* structType = as<IRStructType>(intermediateType);
+ if (!structType)
+ {
+ auto genType = as<IRGeneric>(intermediateType);
+ structType = as<IRStructType>(findGenericReturnVal(genType));
+ SLANG_RELEASE_ASSERT(structType);
+ }
+
+ // Collect fields that are never fetched by reverse func.
+ OrderedHashSet<IRStructKey*> fieldsToCleanup;
+ for (auto children : structType->getChildren())
+ {
+ if (auto field = as<IRStructField>(children))
+ {
+ auto structKey = field->getKey();
+ bool usedByRevFunc = false;
+ for (auto use = structKey->firstUse; use; use = use->nextUse)
+ {
+ if (isChildInstOf(use->getUser(), func))
+ {
+ usedByRevFunc = true;
+ break;
+ }
+ }
+ if (!usedByRevFunc)
+ {
+ List<IRInst*> users;
+ for (auto use = structKey->firstUse; use; use = use->nextUse)
+ {
+ users.add(use->getUser());
+ }
+ for (auto user : users)
+ {
+ if (!isChildInstOf(user, primalFunc))
+ continue;
+ if (auto addr = as<IRFieldAddress>(user))
+ {
+ if (addr->hasMoreThanOneUse())
+ continue;
+ if (addr->firstUse)
+ {
+ if (addr->firstUse->getUser()->getOp() == kIROp_Store)
+ {
+ addr->firstUse->getUser()->removeAndDeallocate();
+ }
+ addr->removeAndDeallocate();
+ }
+ }
+ }
+
+ bool hasNonTrivialUse = false;
+ for (auto use = structKey->firstUse; use; use = use->nextUse)
+ {
+ switch (use->getUser()->getOp())
+ {
+ case kIROp_PrimalValueStructKeyDecoration:
+ case kIROp_StructField:
+ continue;
+ default:
+ hasNonTrivialUse = true;
+ break;
+ }
+ }
+ if (!hasNonTrivialUse)
+ {
+ fieldsToCleanup.Add(structKey);
+ }
+ }
+ }
+ }
+
+ // Actually remove fields from struct.
+ for (auto children : structType->getChildren())
+ {
+ if (auto field = as<IRStructField>(children))
+ {
+ if (fieldsToCleanup.Contains(field->getKey()))
+ {
+ auto key = field->getKey();
+ List<IRInst*> keyUsers;
+ for (auto use = key->firstUse; use; use = use->nextUse)
+ keyUsers.add(use->getUser());
+ for (auto keyUser : keyUsers)
+ keyUser->removeAndDeallocate();
+ key->removeAndDeallocate();
+ }
+ }
+ }
+ }
+
// Transcribe a function definition.
InstPair transcribeFunc(IRBuilder* builder, IRFunc* primalFunc, IRFunc* diffFunc)
{
@@ -520,12 +612,9 @@ struct BackwardDiffTranscriber
// second block of the unzipped function.
//
IRFunc* unzippedFwdDiffFunc = diffUnzipPass->unzipDiffInsts(fwdDiffFunc);
-
+
// Clone the primal blocks from unzippedFwdDiffFunc
// to the reverse-mode function.
- // TODO: This is the spot where we can make a decision to split
- // the primal and differential into two different funcitons
- // instead of two blocks in the same function.
//
// Special care needs to be taken for the first block since it holds the parameters
@@ -547,6 +636,11 @@ struct BackwardDiffTranscriber
block->insertAtEnd(diffFunc);
}
+ // Extracts the primal computations into its own func, and replace the primal insts
+ // with the intermediate results computed from the extracted func.
+ IRInst* intermediateType = nullptr;
+ auto extractedPrimalFunc = diffUnzipPass->extractPrimalFunc(diffFunc, unzippedFwdDiffFunc, intermediateType);
+
// Transpose the first block (parameter block)
transcribeParameterBlock(builder, diffFunc);
@@ -563,6 +657,9 @@ struct BackwardDiffTranscriber
unzippedFwdDiffFunc->removeAndDeallocate();
fwdDiffFunc->removeAndDeallocate();
+ eliminateDeadCode(diffFunc);
+ cleanUpUnusedPrimalIntermediate(diffFunc, extractedPrimalFunc, intermediateType);
+
return InstPair(primalFunc, diffFunc);
}
diff --git a/source/slang/slang-ir-autodiff-unzip.cpp b/source/slang/slang-ir-autodiff-unzip.cpp
new file mode 100644
index 000000000..8dfedcb94
--- /dev/null
+++ b/source/slang/slang-ir-autodiff-unzip.cpp
@@ -0,0 +1,552 @@
+#include "slang-ir-autodiff-unzip.h"
+#include "slang-ir-ssa-simplification.h"
+#include "slang-ir-util.h"
+
+namespace Slang
+{
+struct ExtractPrimalFuncContext
+{
+ SharedIRBuilder* sharedBuilder;
+
+ void init(SharedIRBuilder* inSharedBuilder)
+ {
+ sharedBuilder = inSharedBuilder;
+ }
+
+ IRInst* cloneGenericHeader(IRBuilder& builder, IRCloneEnv& cloneEnv, IRGeneric* gen)
+ {
+ auto newGeneric = builder.emitGeneric();
+ newGeneric->setFullType(builder.getTypeKind());
+ for (auto decor : gen->getDecorations())
+ cloneDecoration(decor, newGeneric);
+ builder.emitBlock();
+ auto originalBlock = gen->getFirstBlock();
+ for (auto child = originalBlock->getFirstChild(); child != originalBlock->getLastParam();
+ child = child->getNextInst())
+ {
+ cloneInst(&cloneEnv, &builder, child);
+ }
+ return newGeneric;
+ }
+
+ IRInst* createGenericIntermediateType(IRGeneric* gen)
+ {
+ IRBuilder builder(sharedBuilder);
+ builder.setInsertBefore(gen);
+ IRCloneEnv intermediateTypeCloneEnv;
+ auto clonedGen = cloneGenericHeader(builder, intermediateTypeCloneEnv, gen);
+ auto structType = builder.createStructType();
+ builder.emitReturn(structType);
+ auto func = findGenericReturnVal(gen);
+ if (auto nameHint = func->findDecoration<IRNameHintDecoration>())
+ {
+ StringBuilder newName;
+ newName << nameHint->getName() << "_Intermediates";
+ builder.addNameHintDecoration(structType, UnownedStringSlice(newName.getBuffer()));
+ }
+ return clonedGen;
+ }
+
+ IRInst* createIntermediateType(IRGlobalValueWithCode* func)
+ {
+ if (func->getOp() == kIROp_Generic)
+ return createGenericIntermediateType(as<IRGeneric>(func));
+ IRBuilder builder(sharedBuilder);
+ builder.setInsertBefore(func);
+ auto intermediateType = builder.createStructType();
+ if (auto nameHint = func->findDecoration<IRNameHintDecoration>())
+ {
+ StringBuilder newName;
+ newName << nameHint->getName() << "_Intermediates";
+ builder.addNameHintDecoration(
+ intermediateType, UnownedStringSlice(newName.getBuffer()));
+ }
+ return intermediateType;
+ }
+
+ // Specialize `genericToSpecialize` with the generic parameters defined in `userGeneric`.
+ // For example:
+ // ```
+ // int f<T>(T a);
+ // ```
+ // will be extended into
+ // ```
+ // struct IntermediateFor_f<T> { T t0; }
+ // int f_primal<T>(T a, IntermediateFor_f<T> imm);
+ // ```
+ // Given a user generic `f_primal<T>` and a used value parameterized on the same set of generic parameters
+ // `IntermediateFor_f`, `genericToSpecialize` constructs `IntermediateFor_f<T>` (using the parameter list
+ // from user generic).
+ //
+ IRInst* specializeWithGeneric(
+ IRBuilder& builder, IRInst* genericToSpecialize, IRGeneric* userGeneric)
+ {
+ List<IRInst*> genArgs;
+ for (auto param : userGeneric->getFirstBlock()->getParams())
+ {
+ genArgs.add(param);
+ }
+ return builder.emitSpecializeInst(
+ builder.getTypeKind(),
+ genericToSpecialize,
+ (UInt)genArgs.getCount(),
+ genArgs.getBuffer());
+ }
+
+ IRInst* generatePrimalFuncType(
+ IRGlobalValueWithCode* destFunc, IRGlobalValueWithCode* fwdFunc, IRInst*& outIntermediateType)
+ {
+ IRBuilder builder(sharedBuilder);
+ builder.setInsertBefore(destFunc);
+ IRFuncType* originalFuncType = nullptr;
+ outIntermediateType = createIntermediateType(destFunc);
+
+ if (auto gen = as<IRGeneric>(destFunc))
+ {
+ auto func = findGenericReturnVal(gen);
+ builder.setInsertBefore(func);
+ outIntermediateType =
+ specializeWithGeneric(builder, outIntermediateType, gen);
+ SLANG_RELEASE_ASSERT(func);
+ originalFuncType = as<IRFuncType>(as<IRGeneric>(fwdFunc)->getDataType());
+ }
+ else
+ {
+ originalFuncType = as<IRFuncType>(fwdFunc->getDataType());
+ }
+
+ SLANG_RELEASE_ASSERT(originalFuncType);
+ List<IRType*> paramTypes;
+ for (UInt i = 0; i < originalFuncType->getParamCount(); i++)
+ paramTypes.add(originalFuncType->getParamType(i));
+ paramTypes.add(builder.getInOutType((IRType*)outIntermediateType));
+ auto newFuncType = builder.getFuncType(paramTypes, originalFuncType->getResultType());
+ return newFuncType;
+ }
+
+ bool isDiffInst(IRInst* inst)
+ {
+ if (inst->findDecoration<IRDifferentialInstDecoration>() ||
+ inst->findDecoration<IRMixedDifferentialInstDecoration>())
+ return true;
+ return false;
+ }
+
+ IRInst* insertIntoReturnBlock(IRBuilder& builder, IRInst* inst)
+ {
+ if (!isDiffInst(inst))
+ return inst;
+
+ switch (inst->getOp())
+ {
+ case kIROp_Return:
+ {
+ IRInst* val = builder.getVoidValue();
+ if (inst->getOperandCount() != 0)
+ {
+ val = insertIntoReturnBlock(builder, inst->getOperand(0));
+ }
+ return builder.emitReturn(val);
+ }
+ case kIROp_MakeDifferentialPair:
+ {
+ auto diff = builder.emitDefaultConstruct(inst->getOperand(1)->getDataType());
+ auto primal = insertIntoReturnBlock(builder, inst->getOperand(0));
+ return builder.emitMakeDifferentialPair(inst->getDataType(), primal, diff);
+ }
+ default:
+ SLANG_UNREACHABLE("unknown case of mixed inst.");
+ }
+ }
+
+ bool shouldStoreInst(IRInst* inst)
+ {
+ if (!inst->getDataType())
+ {
+ return false;
+ }
+
+ // Only store allowed types.
+ if (isScalarIntegerType(inst->getDataType()))
+ {
+ }
+ else if (as<IRResourceTypeBase>(inst->getDataType()))
+ {
+ }
+ else
+ {
+ switch (inst->getDataType()->getOp())
+ {
+ case kIROp_StructType:
+ case kIROp_OptionalType:
+ case kIROp_TupleType:
+ case kIROp_ArrayType:
+ case kIROp_DifferentialPairType:
+ case kIROp_InterfaceType:
+ case kIROp_AnyValueType:
+ case kIROp_ClassType:
+ case kIROp_FloatType:
+ case kIROp_HalfType:
+ case kIROp_DoubleType:
+ case kIROp_VectorType:
+ case kIROp_MatrixType:
+ case kIROp_Param:
+ case kIROp_Specialize:
+ case kIROp_LookupWitness:
+ break;
+ default:
+ return false;
+ }
+ }
+
+ // Never store certain opcodes.
+ switch (inst->getOp())
+ {
+ case kIROp_CastFloatToInt:
+ case kIROp_CastIntToFloat:
+ case kIROp_IntCast:
+ case kIROp_FloatCast:
+ case kIROp_MakeVectorFromScalar:
+ case kIROp_MakeMatrixFromScalar:
+ case kIROp_Reinterpret:
+ case kIROp_BitCast:
+ case kIROp_DefaultConstruct:
+ case kIROp_MakeStruct:
+ case kIROp_MakeTuple:
+ case kIROp_MakeArray:
+ case kIROp_MakeDifferentialPair:
+ case kIROp_MakeOptionalNone:
+ case kIROp_MakeOptionalValue:
+ case kIROp_DifferentialPairGetDifferential:
+ case kIROp_DifferentialPairGetPrimal:
+ return false;
+ case kIROp_GetElement:
+ case kIROp_FieldExtract:
+ case kIROp_swizzle:
+ case kIROp_OptionalHasValue:
+ case kIROp_GetOptionalValue:
+ case kIROp_MatrixReshape:
+ case kIROp_VectorReshape:
+ // If the operand is already stored, don't store the result of these insts.
+ if (inst->getOperand(0)->findDecoration<IRPrimalValueStructKeyDecoration>())
+ {
+ return false;
+ }
+ break;
+ default:
+ break;
+ }
+
+ // Only store if the inst has differential inst user.
+ bool hasDiffUser = false;
+ for (auto use = inst->firstUse; use; use = use->nextUse)
+ {
+ auto user = use->getUser();
+ if (isDiffInst(user))
+ {
+ // Ignore uses that is a return or MakeDiffPair
+ switch (user->getOp())
+ {
+ case kIROp_Return:
+ continue;
+ case kIROp_MakeDifferentialPair:
+ if (!user->hasMoreThanOneUse() && user->firstUse &&
+ user->firstUse->getUser()->getOp() == kIROp_Return)
+ continue;
+ break;
+ default:
+ break;
+ }
+ hasDiffUser = true;
+ break;
+ }
+ }
+ if (!hasDiffUser)
+ return false;
+
+ return true;
+ }
+
+ // Given a `genericA<Param1, Param1,...> { instX(Param1, Param2) }`,
+ // and a clone of it `genericB<ParamB_1, ParamB_2,...> { }`.
+ // `GenericChildrenMigrationContext(genericA, genericB)::getCorrespondingInst(instX)`
+ // returns a clone of `instX` in `genericB` that references the new generic params
+ // as `instX_clone` in `genericB<ParamB_1, ParamB_2,...> { instX_clone(ParamB_1, ParamB_2) }`.
+ struct GenericChildrenMigrationContext
+ {
+ IRCloneEnv cloneEnv;
+ IRGeneric* oldGeneric = nullptr;
+ IRGeneric* newGeneric = nullptr;
+ IRInst* newGenericRetVal = nullptr;
+
+ void init(IRGeneric* oldGen, IRGeneric* newGen)
+ {
+ oldGeneric = oldGen;
+ newGeneric = newGen;
+ newGenericRetVal = findGenericReturnVal(newGen);
+
+ IRInst* oldParam = oldGen->getFirstParam();
+ IRInst* newParam = newGen->getFirstParam();
+ while (oldParam)
+ {
+ oldParam = as<IRParam>(oldParam->getNextInst());
+ newParam = as<IRParam>(newParam->getNextInst());
+ if (!oldParam)
+ {
+ SLANG_RELEASE_ASSERT(!newParam);
+ break;
+ }
+ SLANG_RELEASE_ASSERT(newParam);
+ cloneEnv.mapOldValToNew[oldParam] = newParam;
+ }
+ }
+ IRInst* getCorrespondingInst(IRBuilder& builder, IRInst* oldChild)
+ {
+ if (!oldGeneric)
+ return oldChild;
+ auto parent = oldChild->getParent();
+ bool found = false;
+ while (parent)
+ {
+ if (parent == oldGeneric)
+ {
+ found = true;
+ break;
+ }
+ parent = parent->getParent();
+ }
+ if (!found)
+ return oldChild;
+ for (UInt i = 0; i < oldChild->getOperandCount(); i++)
+ {
+ auto operand = oldChild->getOperand(i);
+ if (cloneEnv.mapOldValToNew.ContainsKey(operand))
+ {}
+ else
+ {
+ getCorrespondingInst(builder, operand);
+ }
+ }
+ auto cloned = cloneInst(&cloneEnv, &builder, oldChild);
+ return cloned;
+ }
+ };
+
+ void storeInst(
+ IRBuilder& builder,
+ IRInst* inst,
+ GenericChildrenMigrationContext& genericContext,
+ IRInst* intermediateOutput)
+ {
+ IRBuilder genTypeBuilder(sharedBuilder);
+ auto ptrStructType = as<IRPtrTypeBase>(intermediateOutput->getDataType() );
+ SLANG_RELEASE_ASSERT(ptrStructType);
+ auto structType = as<IRStructType>(ptrStructType->getValueType());
+ genTypeBuilder.setInsertBefore(structType);
+ auto fieldType = genericContext.getCorrespondingInst(genTypeBuilder, inst->getDataType());
+ SLANG_RELEASE_ASSERT(structType);
+ auto structKey = genTypeBuilder.createStructKey();
+ if (auto nameHint = inst->findDecoration<IRNameHintDecoration>())
+ cloneDecoration(nameHint, structKey);
+ genTypeBuilder.setInsertInto(structType);
+ genTypeBuilder.createStructField(structType, structKey, (IRType*)fieldType);
+ builder.addPrimalValueStructKeyDecoration(inst, structKey);
+ builder.emitStore(
+ builder.emitFieldAddress(
+ builder.getPtrType(inst->getFullType()), intermediateOutput, structKey),
+ inst);
+ }
+
+ IRGlobalValueWithCode* turnUnzippedFuncIntoPrimalFunc(IRGlobalValueWithCode* unzippedFunc, IRGlobalValueWithCode* fwdFunc, IRInst*& outIntermediateType)
+ {
+ // Note: this transformation assumes the original func has only one return.
+
+ IRBuilder builder(sharedBuilder);
+
+ IRFunc* func = nullptr;
+ IRInst* intermediateType = nullptr;
+ auto newFuncType = generatePrimalFuncType(unzippedFunc, fwdFunc, intermediateType);
+ if (auto gen = as<IRGeneric>(unzippedFunc))
+ {
+ func = as<IRFunc>(findGenericReturnVal(gen));
+ SLANG_RELEASE_ASSERT(func);
+ builder.setInsertBefore(func);
+ auto spec = as<IRSpecialize>(intermediateType);
+ SLANG_RELEASE_ASSERT(spec);
+ outIntermediateType = spec->getBase();
+ }
+ else
+ {
+ func = as<IRFunc>(unzippedFunc);
+ SLANG_RELEASE_ASSERT(func);
+ outIntermediateType = intermediateType;
+ }
+ func->setFullType((IRType*)newFuncType);
+
+ // Go through all the insts and preserve the primal blocks.
+ // Create a return block to replace all branches into a non-primal block.
+ builder.setInsertInto(func);
+ auto returnBlock = builder.emitBlock();
+ for (auto block : func->getBlocks())
+ {
+ auto term = block->getTerminator();
+ if (auto ret = as<IRReturn>(term))
+ {
+ insertIntoReturnBlock(builder, ret);
+ break;
+ }
+ }
+
+ auto paramBlock = func->getFirstBlock();
+ builder.setInsertInto(paramBlock);
+ auto outIntermediary =
+ builder.emitParam(builder.getInOutType((IRType*)intermediateType));
+
+ auto firstBlock = *(paramBlock->getSuccessors().begin());
+
+ GenericChildrenMigrationContext genericMigrationContext;
+ if (auto gen = as<IRGeneric>(unzippedFunc))
+ {
+ auto spec = as<IRSpecialize>(intermediateType);
+ SLANG_RELEASE_ASSERT(spec);
+ genericMigrationContext.init(gen, as<IRGeneric>(spec->getBase()));
+ }
+
+ for (auto block : func->getBlocks())
+ {
+ if (block == paramBlock)
+ continue;
+ if (block->findDecoration<IRDifferentialInstDecoration>() ||
+ block->findDecoration<IRMixedDifferentialInstDecoration>())
+ {
+ if (block->getFirstParam() == nullptr)
+ {
+ // If the block does not have any PHI nodes, just remove it and
+ // replace all its uses with returnBlock.
+ block->replaceUsesWith(returnBlock);
+ block->removeAndDeallocate();
+ }
+ else
+ {
+ // If the block has Phi nodes, we can't directly replace it with
+ // `returnBlock`, but we can turn the block into a trivial branch
+ // into `returnBlock` to safely preserve the invariants of Phi nodes.
+ auto inst = block->getLastParam()->getNextInst();
+ for (; inst; inst = inst->getNextInst())
+ inst->removeAndDeallocate();
+ builder.setInsertInto(block);
+ builder.emitBranch(returnBlock);
+ }
+ }
+ else
+ {
+ // For primal insts, decide whether or not to store its result in
+ // output intermediary struct.
+ for (auto inst : block->getChildren())
+ {
+ if (shouldStoreInst(inst))
+ {
+ builder.setInsertAfter(inst);
+ storeInst(builder, inst, genericMigrationContext, outIntermediary);
+ }
+ }
+ }
+ }
+
+ builder.setInsertBefore(firstBlock->getFirstOrdinaryInst());
+ auto defVal = builder.emitDefaultConstructRaw((IRType*)intermediateType);
+ builder.emitStore(outIntermediary, defVal);
+ return unzippedFunc;
+ }
+};
+
+static void copyPrimalValueStructKeyDecorations(IRInst* inst, IRCloneEnv& cloneEnv)
+{
+ IRInst* newInst = nullptr;
+ if (cloneEnv.mapOldValToNew.TryGetValue(inst, newInst))
+ {
+ if (auto decor = newInst->findDecoration<IRPrimalValueStructKeyDecoration>())
+ {
+ cloneDecoration(decor, inst);
+ }
+ }
+
+ for (auto child : inst->getChildren())
+ {
+ copyPrimalValueStructKeyDecorations(child, cloneEnv);
+ }
+}
+
+IRGlobalValueWithCode* DiffUnzipPass::extractPrimalFunc(
+ IRGlobalValueWithCode* func, IRGlobalValueWithCode* fwdFunc, IRInst*& intermediateType)
+{
+ IRBuilder builder(this->autodiffContext->sharedBuilder);
+ builder.setInsertBefore(func);
+
+ IRCloneEnv subEnv;
+ subEnv.squashChildrenMapping = true;
+ subEnv.parent = &cloneEnv;
+ auto clonedFunc = as<IRGlobalValueWithCode>(cloneInst(&subEnv, &builder, func));
+
+ ExtractPrimalFuncContext context;
+ context.init(autodiffContext->sharedBuilder);
+
+ intermediateType = nullptr;
+ auto primalFunc = context.turnUnzippedFuncIntoPrimalFunc(clonedFunc, fwdFunc, intermediateType);
+ IRInst* specializedPrimalFunc = primalFunc;
+
+ // Copy PrimalValueStructKey decorations from primal func.
+ copyPrimalValueStructKeyDecorations(func, subEnv);
+
+ IRInst* specializedIntermediateType = intermediateType;
+ auto innerFunc = as<IRFunc>(func);
+
+ if (auto genFunc = as<IRGeneric>(func))
+ {
+ innerFunc = as<IRFunc>(findGenericReturnVal(genFunc));
+ builder.setInsertBefore(innerFunc);
+ specializedIntermediateType = context.specializeWithGeneric(builder, intermediateType, genFunc);
+ specializedPrimalFunc = context.specializeWithGeneric(builder, primalFunc, genFunc);
+ }
+ SLANG_RELEASE_ASSERT(innerFunc);
+
+ // Insert a call to primal func at start of the function.
+ auto paramBlock = innerFunc->getFirstBlock();
+ auto firstBlock = *(paramBlock->getSuccessors().begin());
+ builder.setInsertBefore(firstBlock->getFirstInst());
+ auto intermediateVar = builder.emitVar((IRType*)specializedIntermediateType);
+ List<IRInst*> args;
+ for (auto param : paramBlock->getParams())
+ {
+ args.add(param);
+ }
+ args.add(intermediateVar);
+ builder.emitCallInst(innerFunc->getResultType(), specializedPrimalFunc, args);
+
+ // Replace all insts that has intermediate results with a load of the intermediate.
+ List<IRInst*> instsToRemove;
+ for (auto block : innerFunc->getBlocks())
+ {
+ for (auto inst : block->getOrdinaryInsts())
+ {
+ if (auto structKeyDecor = inst->findDecoration<IRPrimalValueStructKeyDecoration>())
+ {
+ builder.setInsertBefore(inst);
+ auto addr = builder.emitFieldAddress(builder.getPtrType(inst->getDataType()), intermediateVar, structKeyDecor->getStructKey());
+ auto val = builder.emitLoad(addr);
+ inst->replaceUsesWith(val);
+ instsToRemove.add(inst);
+ }
+ }
+ }
+ for (auto inst : instsToRemove)
+ {
+ inst->removeAndDeallocate();
+ }
+
+ // Run simplification to DCE unnecessary insts.
+ eliminateDeadCode(innerFunc);
+
+ return primalFunc;
+}
+} // namespace Slang
diff --git a/source/slang/slang-ir-autodiff-unzip.h b/source/slang/slang-ir-autodiff-unzip.h
index 2bfe972ec..35aa55dd3 100644
--- a/source/slang/slang-ir-autodiff-unzip.h
+++ b/source/slang/slang-ir-autodiff-unzip.h
@@ -6,6 +6,7 @@
#include "slang-compiler.h"
#include "slang-ir-autodiff.h"
+#include "slang-ir-autodiff-fwd.h"
#include "slang-ir-autodiff-propagate.h"
namespace Slang
@@ -51,7 +52,9 @@ struct DiffUnzipPass
// Clone the entire function.
// TODO: Maybe don't clone? The reverse-mode process seems to clone several times.
// TODO: Looks like we get a copy of the decorations?
- IRFunc* unzippedFunc = as<IRFunc>(cloneInst(&cloneEnv, builder, func));
+ IRCloneEnv subEnv;
+ subEnv.parent = &cloneEnv;
+ IRFunc* unzippedFunc = as<IRFunc>(cloneInst(&subEnv, builder, func));
builder->setInsertInto(unzippedFunc);
@@ -86,6 +89,8 @@ struct DiffUnzipPass
return unzippedFunc;
}
+ IRGlobalValueWithCode* extractPrimalFunc(IRGlobalValueWithCode* func, IRGlobalValueWithCode* fwdFunc, IRInst*& intermediateType);
+
bool isRelevantDifferentialPair(IRType* type)
{
if (as<IRDifferentialPairType>(type))
diff --git a/source/slang/slang-ir-clone.cpp b/source/slang/slang-ir-clone.cpp
index 634aff75d..5b5ace64b 100644
--- a/source/slang/slang-ir-clone.cpp
+++ b/source/slang/slang-ir-clone.cpp
@@ -208,8 +208,9 @@ static void _cloneInstDecorationsAndChildren(
// The public version of `cloneInstDecorationsAndChildren` is then
// just a wrapper over the internal one that sets up a temporary
-// environment to use for the cloning process, so that we do
-// not leave any lasting changes in the user-provided `env`.
+// environment to use for the cloning process when `env->squashChildrenMapping` is false (default),
+// so that we do not leave any lasting changes in the user-provided `env` unless the caller
+// explicitly asks for it.
//
void cloneInstDecorationsAndChildren(
IRCloneEnv* env,
@@ -221,10 +222,17 @@ void cloneInstDecorationsAndChildren(
SLANG_ASSERT(oldInst);
SLANG_ASSERT(newInst);
+ IRCloneEnv* subEnv = nullptr;
IRCloneEnv subEnvStorage;
- auto subEnv = &subEnvStorage;
- subEnv->parent = env;
-
+ if (env->squashChildrenMapping)
+ {
+ subEnv = env;
+ }
+ else
+ {
+ subEnv = &subEnvStorage;
+ subEnv->parent = env;
+ }
_cloneInstDecorationsAndChildren(subEnv, sharedBuilder, oldInst, newInst);
}
diff --git a/source/slang/slang-ir-clone.h b/source/slang/slang-ir-clone.h
index b483b5fcb..824806d57 100644
--- a/source/slang/slang-ir-clone.h
+++ b/source/slang/slang-ir-clone.h
@@ -38,6 +38,9 @@ struct IRCloneEnv
/// A parent environment to fall back to if `mapOldValToNew` doesn't contain a key.
IRCloneEnv* parent = nullptr;
+
+ /// Should `mapOldValToNew` keep a copy of children's oldToNew mapping?
+ bool squashChildrenMapping = false;
};
/// Look up the replacement for `oldVal`, if any, registered in `env`.
diff --git a/source/slang/slang-ir-dce.cpp b/source/slang/slang-ir-dce.cpp
index 14edf21d7..ae04acbd1 100644
--- a/source/slang/slang-ir-dce.cpp
+++ b/source/slang/slang-ir-dce.cpp
@@ -103,18 +103,14 @@ struct DeadCodeEliminationContext
return undefInst;
}
- // Given the basic infrastructrure above, let's
- // dive into the task of actually finding all
- // the live code in a module.
- //
- bool processModule()
+ bool processInst(IRInst* root)
{
- // First of all, we know that the root module instruction
+ // First of all, we know that the root instruction
// should be considered as live, because otherwise
// we'd end up eliminating it, so that is a
// good place to start.
//
- markInstAsLive(module->getModuleInst());
+ markInstAsLive(root);
// Ensure there is a global undef inst that is always alive.
// This undef inst will be used to fill in weak-referencing uses
@@ -128,7 +124,7 @@ struct DeadCodeEliminationContext
// processing entries off of our work list
// until it goes dry.
//
- while( workList.getCount() )
+ while (workList.getCount())
{
auto inst = workList.getLast();
workList.removeLast();
@@ -151,7 +147,7 @@ struct DeadCodeEliminationContext
//
markInstAsLive(inst->getFullType());
UInt operandCount = inst->getOperandCount();
- for( UInt ii = 0; ii < operandCount; ++ii )
+ for (UInt ii = 0; ii < operandCount; ++ii)
{
// There are some type of operands that needs to be treated as
// "weak" references -- they can never hold things alive, and
@@ -182,9 +178,9 @@ struct DeadCodeEliminationContext
// decision of whether a child (or decoration)
// should be live when its parent is to a subroutine.
//
- for( auto child : inst->getDecorationsAndChildren() )
+ for (auto child : inst->getDecorationsAndChildren())
{
- if(shouldInstBeLiveIfParentIsLive(child))
+ if (shouldInstBeLiveIfParentIsLive(child))
{
// In this case, we know `inst` is live and
// its `child` should be live if its parent is,
@@ -203,7 +199,16 @@ struct DeadCodeEliminationContext
// recursively and eliminate those that are "dead" by
// virtue of not having been found live.
//
- return eliminateDeadInstsRec(module->getModuleInst());
+ return eliminateDeadInstsRec(root);
+ }
+
+ // Given the basic infrastructrure above, let's
+ // dive into the task of actually finding all
+ // the live code in a module.
+ //
+ bool processModule()
+ {
+ return processInst(module->getModuleInst());
}
bool eliminateDeadInstsRec(IRInst* inst)
@@ -421,4 +426,15 @@ bool eliminateDeadCode(
return context.processModule();
}
+bool eliminateDeadCode(
+ IRInst* root,
+ IRDeadCodeEliminationOptions const& options)
+{
+ DeadCodeEliminationContext context;
+ context.module = root->getModule();
+ context.options = options;
+
+ return context.processInst(root);
+}
+
}
diff --git a/source/slang/slang-ir-dce.h b/source/slang/slang-ir-dce.h
index 0abf17f9f..d8819e042 100644
--- a/source/slang/slang-ir-dce.h
+++ b/source/slang/slang-ir-dce.h
@@ -24,6 +24,10 @@ namespace Slang
IRModule* module,
IRDeadCodeEliminationOptions const& options = IRDeadCodeEliminationOptions());
+ bool eliminateDeadCode(
+ IRInst* root,
+ IRDeadCodeEliminationOptions const& options = IRDeadCodeEliminationOptions());
+
bool shouldInstBeLiveIfParentIsLive(IRInst* inst, IRDeadCodeEliminationOptions options);
bool isWeakReferenceOperand(IRInst* inst, UInt operandIndex);
diff --git a/source/slang/slang-ir-inst-defs.h b/source/slang/slang-ir-inst-defs.h
index c74388406..8440f4181 100644
--- a/source/slang/slang-ir-inst-defs.h
+++ b/source/slang/slang-ir-inst-defs.h
@@ -741,6 +741,10 @@ INST(HighLevelDeclDecoration, highLevelDecl, 1, 0)
/// BOTH a differential and a primal value.
INST(MixedDifferentialInstDecoration, mixedDiffInstDecoration, 1, 0)
+ /// Used by the auto-diff pass to mark insts whose result is stored
+ /// in an intermediary struct for reuse in backward propagation phase.
+ INST(PrimalValueStructKeyDecoration, primalValueKey, 1, 0)
+
/// Used by the auto-diff pass to hold a reference to a
/// differential member of a type in its associated differential type.
INST(DerivativeMemberDecoration, derivativeMemberDecoration, 1, 0)
diff --git a/source/slang/slang-ir-inst-pass-base.h b/source/slang/slang-ir-inst-pass-base.h
index b5a1f168a..86c2cb0fe 100644
--- a/source/slang/slang-ir-inst-pass-base.h
+++ b/source/slang/slang-ir-inst-pass-base.h
@@ -89,12 +89,12 @@ namespace Slang
}
template <typename Func>
- void processAllInsts(const Func& f)
+ void processChildInsts(IRInst* root, const Func& f)
{
workList.clear();
workListSet.Clear();
- addToWorkList(module->getModuleInst());
+ addToWorkList(root);
while (workList.getCount() != 0)
{
@@ -108,6 +108,13 @@ namespace Slang
}
}
}
+
+ template <typename Func>
+ void processAllInsts(const Func& f)
+ {
+ processChildInsts(module->getModuleInst(), f);
+ }
+
};
}
diff --git a/source/slang/slang-ir-insts.h b/source/slang/slang-ir-insts.h
index 67f17f5b2..6373334bf 100644
--- a/source/slang/slang-ir-insts.h
+++ b/source/slang/slang-ir-insts.h
@@ -610,6 +610,17 @@ struct IRDifferentialInstDecoration : IRDecoration
IRType* getPrimalType() { return as<IRType>(getOperand(0)); }
};
+struct IRPrimalValueStructKeyDecoration : IRDecoration
+{
+ enum
+ {
+ kOp = kIROp_PrimalValueStructKeyDecoration
+ };
+
+ IR_LEAF_ISA(PrimalValueStructKeyDecoration)
+
+ IRStructKey* getStructKey() { return as<IRStructKey>(getOperand(0)); }
+};
struct IRMixedDifferentialInstDecoration : IRDecoration
{
@@ -2657,6 +2668,8 @@ public:
IRInst* addDifferentiableTypeDictionaryDecoration(IRInst* target);
+ IRInst* addPrimalValueStructKeyDecoration(IRInst* target, IRStructKey* key);
+
// Add a differentiable type entry to the appropriate dictionary.
IRInst* addDifferentiableTypeEntry(IRInst* dictDecoration, IRInst* irType, IRInst* conformanceWitness);
@@ -2718,6 +2731,10 @@ public:
/// Otherwise, returns nullptr if we can't materialize the inst.
IRInst* emitDefaultConstruct(IRType* type, bool fallback = true);
+ /// Emits a raw `DefaultConstruct` opcode without attempting to fold/materialize
+ /// the inst.
+ IRInst* emitDefaultConstructRaw(IRType* type);
+
IRInst* emitCast(
IRType* type,
IRInst* value);
diff --git a/source/slang/slang-ir-peephole.cpp b/source/slang/slang-ir-peephole.cpp
index 66cde68de..21e17b546 100644
--- a/source/slang/slang-ir-peephole.cpp
+++ b/source/slang/slang-ir-peephole.cpp
@@ -241,16 +241,21 @@ struct PeepholeContext : InstPassBase
}
}
- bool processModule()
+ bool processFunc(IRInst* func)
{
SharedIRBuilder* sharedBuilder = &sharedBuilderStorage;
sharedBuilder->init(module);
sharedBuilderStorage.deduplicateAndRebuildGlobalNumberingMap();
changed = false;
- processAllInsts([this](IRInst* inst) { processInst(inst); });
+ processChildInsts(func, [this](IRInst* inst) { processInst(inst); });
return changed;
}
+
+ bool processModule()
+ {
+ return processFunc(module->getModuleInst());
+ }
};
bool peepholeOptimize(IRModule* module)
@@ -259,4 +264,10 @@ bool peepholeOptimize(IRModule* module)
return context.processModule();
}
+bool peepholeOptimize(IRInst* func)
+{
+ PeepholeContext context = PeepholeContext(func->getModule());
+ return context.processFunc(func);
+}
+
} // namespace Slang
diff --git a/source/slang/slang-ir-peephole.h b/source/slang/slang-ir-peephole.h
index e05c533eb..dc1b5527a 100644
--- a/source/slang/slang-ir-peephole.h
+++ b/source/slang/slang-ir-peephole.h
@@ -5,7 +5,9 @@ namespace Slang
{
struct IRModule;
struct IRCall;
+ struct IRInst;
/// Apply peephole optimizations.
bool peepholeOptimize(IRModule* module);
+ bool peepholeOptimize(IRInst* func);
}
diff --git a/source/slang/slang-ir-sccp.cpp b/source/slang/slang-ir-sccp.cpp
index fbc00848b..c03eee695 100644
--- a/source/slang/slang-ir-sccp.cpp
+++ b/source/slang/slang-ir-sccp.cpp
@@ -1678,5 +1678,21 @@ bool applySparseConditionalConstantPropagation(
return changed;
}
+
+bool applySparseConditionalConstantPropagation(IRInst* func)
+{
+ SharedSCCPContext shared;
+ shared.module = func->getModule();
+ shared.sharedBuilder.init(shared.module);
+ shared.sharedBuilder.deduplicateAndRebuildGlobalNumberingMap();
+
+ SCCPContext globalContext;
+ globalContext.shared = &shared;
+ globalContext.code = nullptr;
+
+ // Run recursive SCCP passes on each child code block.
+ return applySparseConditionalConstantPropagationRec(globalContext, func);
+}
+
}
diff --git a/source/slang/slang-ir-sccp.h b/source/slang/slang-ir-sccp.h
index 06c5769c8..23c903eeb 100644
--- a/source/slang/slang-ir-sccp.h
+++ b/source/slang/slang-ir-sccp.h
@@ -4,6 +4,7 @@
namespace Slang
{
struct IRModule;
+ struct IRInst;
/// Apply Sparse Conditional Constant Propagation (SCCP) to a module.
///
@@ -15,5 +16,7 @@ namespace Slang
/// Returns true if IR is changed.
bool applySparseConditionalConstantPropagation(
IRModule* module);
+
+ bool applySparseConditionalConstantPropagation(IRInst* func);
}
diff --git a/source/slang/slang-ir-ssa-simplification.cpp b/source/slang/slang-ir-ssa-simplification.cpp
index f723325c4..4b604e03a 100644
--- a/source/slang/slang-ir-ssa-simplification.cpp
+++ b/source/slang/slang-ir-ssa-simplification.cpp
@@ -10,8 +10,6 @@
namespace Slang
{
- struct IRModule;
-
// Run a combination of SSA, SCCP, SimplifyCFG, and DeadCodeElimination pass
// until no more changes are possible.
void simplifyIR(IRModule* module)
@@ -37,4 +35,27 @@ namespace Slang
iterationCounter++;
}
}
+
+ void simplifyFunc(IRGlobalValueWithCode* func)
+ {
+ bool changed = true;
+ const int kMaxIterations = 8;
+ int iterationCounter = 0;
+ while (changed && iterationCounter < kMaxIterations)
+ {
+ changed = false;
+ changed |= applySparseConditionalConstantPropagation(func);
+ changed |= peepholeOptimize(func);
+ changed |= simplifyCFG(func);
+
+ // Note: we disregard the `changed` state from dead code elimination pass since
+ // SCCP pass could be generating temporarily evaluated constant values and never actually use them.
+ // DCE will always remove those nearly generated consts and always returns true here.
+ eliminateDeadCode(func);
+
+ changed |= constructSSA(func);
+
+ iterationCounter++;
+ }
+ }
}
diff --git a/source/slang/slang-ir-ssa-simplification.h b/source/slang/slang-ir-ssa-simplification.h
index 19a39e8d4..ee8343003 100644
--- a/source/slang/slang-ir-ssa-simplification.h
+++ b/source/slang/slang-ir-ssa-simplification.h
@@ -4,8 +4,11 @@
namespace Slang
{
struct IRModule;
+ struct IRGlobalValueWithCode;
// Run a combination of SSA, SCCP, SimplifyCFG, and DeadCodeElimination pass
// until no more changes are possible.
void simplifyIR(IRModule* module);
+
+ void simplifyFunc(IRGlobalValueWithCode* func);
}
diff --git a/source/slang/slang-ir-ssa.cpp b/source/slang/slang-ir-ssa.cpp
index d84b48c3d..2dee189dc 100644
--- a/source/slang/slang-ir-ssa.cpp
+++ b/source/slang/slang-ir-ssa.cpp
@@ -1237,4 +1237,18 @@ bool constructSSA(IRModule* module)
return changed;
}
+bool constructSSA(IRInst* globalVal)
+{
+ switch (globalVal->getOp())
+ {
+ case kIROp_Func:
+ case kIROp_GlobalVar:
+ return constructSSA(globalVal->getModule(), (IRGlobalValueWithCode*)globalVal);
+
+ default:
+ break;
+ }
+ return false;
+}
+
}
diff --git a/source/slang/slang-ir-ssa.h b/source/slang/slang-ir-ssa.h
index b327802a1..d455439df 100644
--- a/source/slang/slang-ir-ssa.h
+++ b/source/slang/slang-ir-ssa.h
@@ -5,6 +5,8 @@ namespace Slang
{
struct IRModule;
struct IRGlobalValueWithCode;
+ struct IRInst;
bool constructSSA(IRModule* module, IRGlobalValueWithCode* globalVal);
bool constructSSA(IRModule* module);
+ bool constructSSA(IRInst* globalVal);
}
diff --git a/source/slang/slang-ir-util.h b/source/slang/slang-ir-util.h
index b09b2f4d0..385d05b28 100644
--- a/source/slang/slang-ir-util.h
+++ b/source/slang/slang-ir-util.h
@@ -33,6 +33,17 @@ inline bool isScalarIntegerType(IRType* type)
return getTypeStyle(type->getOp()) == kIROp_IntType;
}
+inline bool isChildInstOf(IRInst* inst, IRInst* parent)
+{
+ while (inst)
+ {
+ if (inst == parent)
+ return true;
+ inst = inst->getParent();
+ }
+ return false;
+}
+
}
#endif
diff --git a/source/slang/slang-ir.cpp b/source/slang/slang-ir.cpp
index b4528a452..33130cfb3 100644
--- a/source/slang/slang-ir.cpp
+++ b/source/slang/slang-ir.cpp
@@ -3301,6 +3301,11 @@ namespace Slang
return inst;
}
+ IRInst* IRBuilder::emitDefaultConstructRaw(IRType* type)
+ {
+ return emitIntrinsicInst(type, kIROp_DefaultConstruct, 0, nullptr);
+ }
+
IRInst* IRBuilder::emitDefaultConstruct(IRType* type, bool fallback)
{
IRType* actualType = type;
@@ -3809,6 +3814,11 @@ namespace Slang
return inst;
}
+ IRInst* IRBuilder::addPrimalValueStructKeyDecoration(IRInst* target, IRStructKey* key)
+ {
+ return addDecoration(target, kIROp_PrimalValueStructKeyDecoration, key);
+ }
+
RefPtr<IRModule> IRModule::create(Session* session)
{
RefPtr<IRModule> module = new IRModule(session);
diff --git a/source/slang/slang-lower-to-ir.cpp b/source/slang/slang-lower-to-ir.cpp
index 7fd0bbee7..a84cf9b8d 100644
--- a/source/slang/slang-lower-to-ir.cpp
+++ b/source/slang/slang-lower-to-ir.cpp
@@ -7861,25 +7861,6 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo>
return as<IRStringLit>(builder->getStringValue(stringLitExpr->value.getUnownedSlice()));
}
- IRInst* lowerFuncType(FunctionDeclBase* decl)
- {
- NestedContext nestedContextFuncType(this);
- auto funcTypeBuilder = nestedContextFuncType.getBuilder();
- auto funcTypeContext = nestedContextFuncType.getContext();
-
- auto outerGenerics = emitOuterGenerics(funcTypeContext, decl, decl);
-
- FuncDeclBaseTypeInfo info;
- _lowerFuncDeclBaseTypeInfo(
- funcTypeContext,
- createDefaultSpecializedDeclRef(funcTypeContext, nullptr, decl),
- info);
-
- auto irFuncType = info.type;
-
- return finishOuterGenerics(funcTypeBuilder, irFuncType, outerGenerics);
- }
-
bool isClassType(IRType* type)
{
if (auto specialize = as<IRSpecialize>(type))