diff options
| author | Yong He <yonghe@outlook.com> | 2022-12-19 11:47:19 -0800 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2022-12-19 11:47:19 -0800 |
| commit | 216dfba0af66210a46ef0df18beb73d975fdf727 (patch) | |
| tree | f397ea5bf8d47d7a5d90dc95edfb472f2e49d762 /source | |
| parent | 36220da1e29c891972fef32c8575c15f868b9959 (diff) | |
Separate primal computations from unzipped function into an explicit function. (#2569)
Co-authored-by: Yong He <yhe@nvidia.com>
Diffstat (limited to 'source')
22 files changed, 837 insertions, 50 deletions
diff --git a/source/slang/slang-ir-autodiff-propagate.h b/source/slang/slang-ir-autodiff-propagate.h index 0d5686899..4edf20142 100644 --- a/source/slang/slang-ir-autodiff-propagate.h +++ b/source/slang/slang-ir-autodiff-propagate.h @@ -10,12 +10,12 @@ namespace Slang { -bool isDifferentialInst(IRInst* inst) +inline bool isDifferentialInst(IRInst* inst) { return inst->findDecoration<IRDifferentialInstDecoration>(); } -bool isMixedDifferentialInst(IRInst* inst) +inline bool isMixedDifferentialInst(IRInst* inst) { return inst->findDecoration<IRMixedDifferentialInstDecoration>(); } @@ -104,4 +104,4 @@ struct DiffPropagationPass : InstPassBase } }; -}
\ No newline at end of file +} diff --git a/source/slang/slang-ir-autodiff-rev.cpp b/source/slang/slang-ir-autodiff-rev.cpp index c7fbc415a..56002231a 100644 --- a/source/slang/slang-ir-autodiff-rev.cpp +++ b/source/slang/slang-ir-autodiff-rev.cpp @@ -491,6 +491,98 @@ struct BackwardDiffTranscriber builder.emitBranch(firstBlock); } + void cleanUpUnusedPrimalIntermediate(IRInst* func, IRInst* primalFunc, IRInst* intermediateType) + { + IRStructType* structType = as<IRStructType>(intermediateType); + if (!structType) + { + auto genType = as<IRGeneric>(intermediateType); + structType = as<IRStructType>(findGenericReturnVal(genType)); + SLANG_RELEASE_ASSERT(structType); + } + + // Collect fields that are never fetched by reverse func. + OrderedHashSet<IRStructKey*> fieldsToCleanup; + for (auto children : structType->getChildren()) + { + if (auto field = as<IRStructField>(children)) + { + auto structKey = field->getKey(); + bool usedByRevFunc = false; + for (auto use = structKey->firstUse; use; use = use->nextUse) + { + if (isChildInstOf(use->getUser(), func)) + { + usedByRevFunc = true; + break; + } + } + if (!usedByRevFunc) + { + List<IRInst*> users; + for (auto use = structKey->firstUse; use; use = use->nextUse) + { + users.add(use->getUser()); + } + for (auto user : users) + { + if (!isChildInstOf(user, primalFunc)) + continue; + if (auto addr = as<IRFieldAddress>(user)) + { + if (addr->hasMoreThanOneUse()) + continue; + if (addr->firstUse) + { + if (addr->firstUse->getUser()->getOp() == kIROp_Store) + { + addr->firstUse->getUser()->removeAndDeallocate(); + } + addr->removeAndDeallocate(); + } + } + } + + bool hasNonTrivialUse = false; + for (auto use = structKey->firstUse; use; use = use->nextUse) + { + switch (use->getUser()->getOp()) + { + case kIROp_PrimalValueStructKeyDecoration: + case kIROp_StructField: + continue; + default: + hasNonTrivialUse = true; + break; + } + } + if (!hasNonTrivialUse) + { + fieldsToCleanup.Add(structKey); + } + } + } + } + + // Actually remove fields from struct. + for (auto children : structType->getChildren()) + { + if (auto field = as<IRStructField>(children)) + { + if (fieldsToCleanup.Contains(field->getKey())) + { + auto key = field->getKey(); + List<IRInst*> keyUsers; + for (auto use = key->firstUse; use; use = use->nextUse) + keyUsers.add(use->getUser()); + for (auto keyUser : keyUsers) + keyUser->removeAndDeallocate(); + key->removeAndDeallocate(); + } + } + } + } + // Transcribe a function definition. InstPair transcribeFunc(IRBuilder* builder, IRFunc* primalFunc, IRFunc* diffFunc) { @@ -520,12 +612,9 @@ struct BackwardDiffTranscriber // second block of the unzipped function. // IRFunc* unzippedFwdDiffFunc = diffUnzipPass->unzipDiffInsts(fwdDiffFunc); - + // Clone the primal blocks from unzippedFwdDiffFunc // to the reverse-mode function. - // TODO: This is the spot where we can make a decision to split - // the primal and differential into two different funcitons - // instead of two blocks in the same function. // // Special care needs to be taken for the first block since it holds the parameters @@ -547,6 +636,11 @@ struct BackwardDiffTranscriber block->insertAtEnd(diffFunc); } + // Extracts the primal computations into its own func, and replace the primal insts + // with the intermediate results computed from the extracted func. + IRInst* intermediateType = nullptr; + auto extractedPrimalFunc = diffUnzipPass->extractPrimalFunc(diffFunc, unzippedFwdDiffFunc, intermediateType); + // Transpose the first block (parameter block) transcribeParameterBlock(builder, diffFunc); @@ -563,6 +657,9 @@ struct BackwardDiffTranscriber unzippedFwdDiffFunc->removeAndDeallocate(); fwdDiffFunc->removeAndDeallocate(); + eliminateDeadCode(diffFunc); + cleanUpUnusedPrimalIntermediate(diffFunc, extractedPrimalFunc, intermediateType); + return InstPair(primalFunc, diffFunc); } diff --git a/source/slang/slang-ir-autodiff-unzip.cpp b/source/slang/slang-ir-autodiff-unzip.cpp new file mode 100644 index 000000000..8dfedcb94 --- /dev/null +++ b/source/slang/slang-ir-autodiff-unzip.cpp @@ -0,0 +1,552 @@ +#include "slang-ir-autodiff-unzip.h" +#include "slang-ir-ssa-simplification.h" +#include "slang-ir-util.h" + +namespace Slang +{ +struct ExtractPrimalFuncContext +{ + SharedIRBuilder* sharedBuilder; + + void init(SharedIRBuilder* inSharedBuilder) + { + sharedBuilder = inSharedBuilder; + } + + IRInst* cloneGenericHeader(IRBuilder& builder, IRCloneEnv& cloneEnv, IRGeneric* gen) + { + auto newGeneric = builder.emitGeneric(); + newGeneric->setFullType(builder.getTypeKind()); + for (auto decor : gen->getDecorations()) + cloneDecoration(decor, newGeneric); + builder.emitBlock(); + auto originalBlock = gen->getFirstBlock(); + for (auto child = originalBlock->getFirstChild(); child != originalBlock->getLastParam(); + child = child->getNextInst()) + { + cloneInst(&cloneEnv, &builder, child); + } + return newGeneric; + } + + IRInst* createGenericIntermediateType(IRGeneric* gen) + { + IRBuilder builder(sharedBuilder); + builder.setInsertBefore(gen); + IRCloneEnv intermediateTypeCloneEnv; + auto clonedGen = cloneGenericHeader(builder, intermediateTypeCloneEnv, gen); + auto structType = builder.createStructType(); + builder.emitReturn(structType); + auto func = findGenericReturnVal(gen); + if (auto nameHint = func->findDecoration<IRNameHintDecoration>()) + { + StringBuilder newName; + newName << nameHint->getName() << "_Intermediates"; + builder.addNameHintDecoration(structType, UnownedStringSlice(newName.getBuffer())); + } + return clonedGen; + } + + IRInst* createIntermediateType(IRGlobalValueWithCode* func) + { + if (func->getOp() == kIROp_Generic) + return createGenericIntermediateType(as<IRGeneric>(func)); + IRBuilder builder(sharedBuilder); + builder.setInsertBefore(func); + auto intermediateType = builder.createStructType(); + if (auto nameHint = func->findDecoration<IRNameHintDecoration>()) + { + StringBuilder newName; + newName << nameHint->getName() << "_Intermediates"; + builder.addNameHintDecoration( + intermediateType, UnownedStringSlice(newName.getBuffer())); + } + return intermediateType; + } + + // Specialize `genericToSpecialize` with the generic parameters defined in `userGeneric`. + // For example: + // ``` + // int f<T>(T a); + // ``` + // will be extended into + // ``` + // struct IntermediateFor_f<T> { T t0; } + // int f_primal<T>(T a, IntermediateFor_f<T> imm); + // ``` + // Given a user generic `f_primal<T>` and a used value parameterized on the same set of generic parameters + // `IntermediateFor_f`, `genericToSpecialize` constructs `IntermediateFor_f<T>` (using the parameter list + // from user generic). + // + IRInst* specializeWithGeneric( + IRBuilder& builder, IRInst* genericToSpecialize, IRGeneric* userGeneric) + { + List<IRInst*> genArgs; + for (auto param : userGeneric->getFirstBlock()->getParams()) + { + genArgs.add(param); + } + return builder.emitSpecializeInst( + builder.getTypeKind(), + genericToSpecialize, + (UInt)genArgs.getCount(), + genArgs.getBuffer()); + } + + IRInst* generatePrimalFuncType( + IRGlobalValueWithCode* destFunc, IRGlobalValueWithCode* fwdFunc, IRInst*& outIntermediateType) + { + IRBuilder builder(sharedBuilder); + builder.setInsertBefore(destFunc); + IRFuncType* originalFuncType = nullptr; + outIntermediateType = createIntermediateType(destFunc); + + if (auto gen = as<IRGeneric>(destFunc)) + { + auto func = findGenericReturnVal(gen); + builder.setInsertBefore(func); + outIntermediateType = + specializeWithGeneric(builder, outIntermediateType, gen); + SLANG_RELEASE_ASSERT(func); + originalFuncType = as<IRFuncType>(as<IRGeneric>(fwdFunc)->getDataType()); + } + else + { + originalFuncType = as<IRFuncType>(fwdFunc->getDataType()); + } + + SLANG_RELEASE_ASSERT(originalFuncType); + List<IRType*> paramTypes; + for (UInt i = 0; i < originalFuncType->getParamCount(); i++) + paramTypes.add(originalFuncType->getParamType(i)); + paramTypes.add(builder.getInOutType((IRType*)outIntermediateType)); + auto newFuncType = builder.getFuncType(paramTypes, originalFuncType->getResultType()); + return newFuncType; + } + + bool isDiffInst(IRInst* inst) + { + if (inst->findDecoration<IRDifferentialInstDecoration>() || + inst->findDecoration<IRMixedDifferentialInstDecoration>()) + return true; + return false; + } + + IRInst* insertIntoReturnBlock(IRBuilder& builder, IRInst* inst) + { + if (!isDiffInst(inst)) + return inst; + + switch (inst->getOp()) + { + case kIROp_Return: + { + IRInst* val = builder.getVoidValue(); + if (inst->getOperandCount() != 0) + { + val = insertIntoReturnBlock(builder, inst->getOperand(0)); + } + return builder.emitReturn(val); + } + case kIROp_MakeDifferentialPair: + { + auto diff = builder.emitDefaultConstruct(inst->getOperand(1)->getDataType()); + auto primal = insertIntoReturnBlock(builder, inst->getOperand(0)); + return builder.emitMakeDifferentialPair(inst->getDataType(), primal, diff); + } + default: + SLANG_UNREACHABLE("unknown case of mixed inst."); + } + } + + bool shouldStoreInst(IRInst* inst) + { + if (!inst->getDataType()) + { + return false; + } + + // Only store allowed types. + if (isScalarIntegerType(inst->getDataType())) + { + } + else if (as<IRResourceTypeBase>(inst->getDataType())) + { + } + else + { + switch (inst->getDataType()->getOp()) + { + case kIROp_StructType: + case kIROp_OptionalType: + case kIROp_TupleType: + case kIROp_ArrayType: + case kIROp_DifferentialPairType: + case kIROp_InterfaceType: + case kIROp_AnyValueType: + case kIROp_ClassType: + case kIROp_FloatType: + case kIROp_HalfType: + case kIROp_DoubleType: + case kIROp_VectorType: + case kIROp_MatrixType: + case kIROp_Param: + case kIROp_Specialize: + case kIROp_LookupWitness: + break; + default: + return false; + } + } + + // Never store certain opcodes. + switch (inst->getOp()) + { + case kIROp_CastFloatToInt: + case kIROp_CastIntToFloat: + case kIROp_IntCast: + case kIROp_FloatCast: + case kIROp_MakeVectorFromScalar: + case kIROp_MakeMatrixFromScalar: + case kIROp_Reinterpret: + case kIROp_BitCast: + case kIROp_DefaultConstruct: + case kIROp_MakeStruct: + case kIROp_MakeTuple: + case kIROp_MakeArray: + case kIROp_MakeDifferentialPair: + case kIROp_MakeOptionalNone: + case kIROp_MakeOptionalValue: + case kIROp_DifferentialPairGetDifferential: + case kIROp_DifferentialPairGetPrimal: + return false; + case kIROp_GetElement: + case kIROp_FieldExtract: + case kIROp_swizzle: + case kIROp_OptionalHasValue: + case kIROp_GetOptionalValue: + case kIROp_MatrixReshape: + case kIROp_VectorReshape: + // If the operand is already stored, don't store the result of these insts. + if (inst->getOperand(0)->findDecoration<IRPrimalValueStructKeyDecoration>()) + { + return false; + } + break; + default: + break; + } + + // Only store if the inst has differential inst user. + bool hasDiffUser = false; + for (auto use = inst->firstUse; use; use = use->nextUse) + { + auto user = use->getUser(); + if (isDiffInst(user)) + { + // Ignore uses that is a return or MakeDiffPair + switch (user->getOp()) + { + case kIROp_Return: + continue; + case kIROp_MakeDifferentialPair: + if (!user->hasMoreThanOneUse() && user->firstUse && + user->firstUse->getUser()->getOp() == kIROp_Return) + continue; + break; + default: + break; + } + hasDiffUser = true; + break; + } + } + if (!hasDiffUser) + return false; + + return true; + } + + // Given a `genericA<Param1, Param1,...> { instX(Param1, Param2) }`, + // and a clone of it `genericB<ParamB_1, ParamB_2,...> { }`. + // `GenericChildrenMigrationContext(genericA, genericB)::getCorrespondingInst(instX)` + // returns a clone of `instX` in `genericB` that references the new generic params + // as `instX_clone` in `genericB<ParamB_1, ParamB_2,...> { instX_clone(ParamB_1, ParamB_2) }`. + struct GenericChildrenMigrationContext + { + IRCloneEnv cloneEnv; + IRGeneric* oldGeneric = nullptr; + IRGeneric* newGeneric = nullptr; + IRInst* newGenericRetVal = nullptr; + + void init(IRGeneric* oldGen, IRGeneric* newGen) + { + oldGeneric = oldGen; + newGeneric = newGen; + newGenericRetVal = findGenericReturnVal(newGen); + + IRInst* oldParam = oldGen->getFirstParam(); + IRInst* newParam = newGen->getFirstParam(); + while (oldParam) + { + oldParam = as<IRParam>(oldParam->getNextInst()); + newParam = as<IRParam>(newParam->getNextInst()); + if (!oldParam) + { + SLANG_RELEASE_ASSERT(!newParam); + break; + } + SLANG_RELEASE_ASSERT(newParam); + cloneEnv.mapOldValToNew[oldParam] = newParam; + } + } + IRInst* getCorrespondingInst(IRBuilder& builder, IRInst* oldChild) + { + if (!oldGeneric) + return oldChild; + auto parent = oldChild->getParent(); + bool found = false; + while (parent) + { + if (parent == oldGeneric) + { + found = true; + break; + } + parent = parent->getParent(); + } + if (!found) + return oldChild; + for (UInt i = 0; i < oldChild->getOperandCount(); i++) + { + auto operand = oldChild->getOperand(i); + if (cloneEnv.mapOldValToNew.ContainsKey(operand)) + {} + else + { + getCorrespondingInst(builder, operand); + } + } + auto cloned = cloneInst(&cloneEnv, &builder, oldChild); + return cloned; + } + }; + + void storeInst( + IRBuilder& builder, + IRInst* inst, + GenericChildrenMigrationContext& genericContext, + IRInst* intermediateOutput) + { + IRBuilder genTypeBuilder(sharedBuilder); + auto ptrStructType = as<IRPtrTypeBase>(intermediateOutput->getDataType() ); + SLANG_RELEASE_ASSERT(ptrStructType); + auto structType = as<IRStructType>(ptrStructType->getValueType()); + genTypeBuilder.setInsertBefore(structType); + auto fieldType = genericContext.getCorrespondingInst(genTypeBuilder, inst->getDataType()); + SLANG_RELEASE_ASSERT(structType); + auto structKey = genTypeBuilder.createStructKey(); + if (auto nameHint = inst->findDecoration<IRNameHintDecoration>()) + cloneDecoration(nameHint, structKey); + genTypeBuilder.setInsertInto(structType); + genTypeBuilder.createStructField(structType, structKey, (IRType*)fieldType); + builder.addPrimalValueStructKeyDecoration(inst, structKey); + builder.emitStore( + builder.emitFieldAddress( + builder.getPtrType(inst->getFullType()), intermediateOutput, structKey), + inst); + } + + IRGlobalValueWithCode* turnUnzippedFuncIntoPrimalFunc(IRGlobalValueWithCode* unzippedFunc, IRGlobalValueWithCode* fwdFunc, IRInst*& outIntermediateType) + { + // Note: this transformation assumes the original func has only one return. + + IRBuilder builder(sharedBuilder); + + IRFunc* func = nullptr; + IRInst* intermediateType = nullptr; + auto newFuncType = generatePrimalFuncType(unzippedFunc, fwdFunc, intermediateType); + if (auto gen = as<IRGeneric>(unzippedFunc)) + { + func = as<IRFunc>(findGenericReturnVal(gen)); + SLANG_RELEASE_ASSERT(func); + builder.setInsertBefore(func); + auto spec = as<IRSpecialize>(intermediateType); + SLANG_RELEASE_ASSERT(spec); + outIntermediateType = spec->getBase(); + } + else + { + func = as<IRFunc>(unzippedFunc); + SLANG_RELEASE_ASSERT(func); + outIntermediateType = intermediateType; + } + func->setFullType((IRType*)newFuncType); + + // Go through all the insts and preserve the primal blocks. + // Create a return block to replace all branches into a non-primal block. + builder.setInsertInto(func); + auto returnBlock = builder.emitBlock(); + for (auto block : func->getBlocks()) + { + auto term = block->getTerminator(); + if (auto ret = as<IRReturn>(term)) + { + insertIntoReturnBlock(builder, ret); + break; + } + } + + auto paramBlock = func->getFirstBlock(); + builder.setInsertInto(paramBlock); + auto outIntermediary = + builder.emitParam(builder.getInOutType((IRType*)intermediateType)); + + auto firstBlock = *(paramBlock->getSuccessors().begin()); + + GenericChildrenMigrationContext genericMigrationContext; + if (auto gen = as<IRGeneric>(unzippedFunc)) + { + auto spec = as<IRSpecialize>(intermediateType); + SLANG_RELEASE_ASSERT(spec); + genericMigrationContext.init(gen, as<IRGeneric>(spec->getBase())); + } + + for (auto block : func->getBlocks()) + { + if (block == paramBlock) + continue; + if (block->findDecoration<IRDifferentialInstDecoration>() || + block->findDecoration<IRMixedDifferentialInstDecoration>()) + { + if (block->getFirstParam() == nullptr) + { + // If the block does not have any PHI nodes, just remove it and + // replace all its uses with returnBlock. + block->replaceUsesWith(returnBlock); + block->removeAndDeallocate(); + } + else + { + // If the block has Phi nodes, we can't directly replace it with + // `returnBlock`, but we can turn the block into a trivial branch + // into `returnBlock` to safely preserve the invariants of Phi nodes. + auto inst = block->getLastParam()->getNextInst(); + for (; inst; inst = inst->getNextInst()) + inst->removeAndDeallocate(); + builder.setInsertInto(block); + builder.emitBranch(returnBlock); + } + } + else + { + // For primal insts, decide whether or not to store its result in + // output intermediary struct. + for (auto inst : block->getChildren()) + { + if (shouldStoreInst(inst)) + { + builder.setInsertAfter(inst); + storeInst(builder, inst, genericMigrationContext, outIntermediary); + } + } + } + } + + builder.setInsertBefore(firstBlock->getFirstOrdinaryInst()); + auto defVal = builder.emitDefaultConstructRaw((IRType*)intermediateType); + builder.emitStore(outIntermediary, defVal); + return unzippedFunc; + } +}; + +static void copyPrimalValueStructKeyDecorations(IRInst* inst, IRCloneEnv& cloneEnv) +{ + IRInst* newInst = nullptr; + if (cloneEnv.mapOldValToNew.TryGetValue(inst, newInst)) + { + if (auto decor = newInst->findDecoration<IRPrimalValueStructKeyDecoration>()) + { + cloneDecoration(decor, inst); + } + } + + for (auto child : inst->getChildren()) + { + copyPrimalValueStructKeyDecorations(child, cloneEnv); + } +} + +IRGlobalValueWithCode* DiffUnzipPass::extractPrimalFunc( + IRGlobalValueWithCode* func, IRGlobalValueWithCode* fwdFunc, IRInst*& intermediateType) +{ + IRBuilder builder(this->autodiffContext->sharedBuilder); + builder.setInsertBefore(func); + + IRCloneEnv subEnv; + subEnv.squashChildrenMapping = true; + subEnv.parent = &cloneEnv; + auto clonedFunc = as<IRGlobalValueWithCode>(cloneInst(&subEnv, &builder, func)); + + ExtractPrimalFuncContext context; + context.init(autodiffContext->sharedBuilder); + + intermediateType = nullptr; + auto primalFunc = context.turnUnzippedFuncIntoPrimalFunc(clonedFunc, fwdFunc, intermediateType); + IRInst* specializedPrimalFunc = primalFunc; + + // Copy PrimalValueStructKey decorations from primal func. + copyPrimalValueStructKeyDecorations(func, subEnv); + + IRInst* specializedIntermediateType = intermediateType; + auto innerFunc = as<IRFunc>(func); + + if (auto genFunc = as<IRGeneric>(func)) + { + innerFunc = as<IRFunc>(findGenericReturnVal(genFunc)); + builder.setInsertBefore(innerFunc); + specializedIntermediateType = context.specializeWithGeneric(builder, intermediateType, genFunc); + specializedPrimalFunc = context.specializeWithGeneric(builder, primalFunc, genFunc); + } + SLANG_RELEASE_ASSERT(innerFunc); + + // Insert a call to primal func at start of the function. + auto paramBlock = innerFunc->getFirstBlock(); + auto firstBlock = *(paramBlock->getSuccessors().begin()); + builder.setInsertBefore(firstBlock->getFirstInst()); + auto intermediateVar = builder.emitVar((IRType*)specializedIntermediateType); + List<IRInst*> args; + for (auto param : paramBlock->getParams()) + { + args.add(param); + } + args.add(intermediateVar); + builder.emitCallInst(innerFunc->getResultType(), specializedPrimalFunc, args); + + // Replace all insts that has intermediate results with a load of the intermediate. + List<IRInst*> instsToRemove; + for (auto block : innerFunc->getBlocks()) + { + for (auto inst : block->getOrdinaryInsts()) + { + if (auto structKeyDecor = inst->findDecoration<IRPrimalValueStructKeyDecoration>()) + { + builder.setInsertBefore(inst); + auto addr = builder.emitFieldAddress(builder.getPtrType(inst->getDataType()), intermediateVar, structKeyDecor->getStructKey()); + auto val = builder.emitLoad(addr); + inst->replaceUsesWith(val); + instsToRemove.add(inst); + } + } + } + for (auto inst : instsToRemove) + { + inst->removeAndDeallocate(); + } + + // Run simplification to DCE unnecessary insts. + eliminateDeadCode(innerFunc); + + return primalFunc; +} +} // namespace Slang diff --git a/source/slang/slang-ir-autodiff-unzip.h b/source/slang/slang-ir-autodiff-unzip.h index 2bfe972ec..35aa55dd3 100644 --- a/source/slang/slang-ir-autodiff-unzip.h +++ b/source/slang/slang-ir-autodiff-unzip.h @@ -6,6 +6,7 @@ #include "slang-compiler.h" #include "slang-ir-autodiff.h" +#include "slang-ir-autodiff-fwd.h" #include "slang-ir-autodiff-propagate.h" namespace Slang @@ -51,7 +52,9 @@ struct DiffUnzipPass // Clone the entire function. // TODO: Maybe don't clone? The reverse-mode process seems to clone several times. // TODO: Looks like we get a copy of the decorations? - IRFunc* unzippedFunc = as<IRFunc>(cloneInst(&cloneEnv, builder, func)); + IRCloneEnv subEnv; + subEnv.parent = &cloneEnv; + IRFunc* unzippedFunc = as<IRFunc>(cloneInst(&subEnv, builder, func)); builder->setInsertInto(unzippedFunc); @@ -86,6 +89,8 @@ struct DiffUnzipPass return unzippedFunc; } + IRGlobalValueWithCode* extractPrimalFunc(IRGlobalValueWithCode* func, IRGlobalValueWithCode* fwdFunc, IRInst*& intermediateType); + bool isRelevantDifferentialPair(IRType* type) { if (as<IRDifferentialPairType>(type)) diff --git a/source/slang/slang-ir-clone.cpp b/source/slang/slang-ir-clone.cpp index 634aff75d..5b5ace64b 100644 --- a/source/slang/slang-ir-clone.cpp +++ b/source/slang/slang-ir-clone.cpp @@ -208,8 +208,9 @@ static void _cloneInstDecorationsAndChildren( // The public version of `cloneInstDecorationsAndChildren` is then // just a wrapper over the internal one that sets up a temporary -// environment to use for the cloning process, so that we do -// not leave any lasting changes in the user-provided `env`. +// environment to use for the cloning process when `env->squashChildrenMapping` is false (default), +// so that we do not leave any lasting changes in the user-provided `env` unless the caller +// explicitly asks for it. // void cloneInstDecorationsAndChildren( IRCloneEnv* env, @@ -221,10 +222,17 @@ void cloneInstDecorationsAndChildren( SLANG_ASSERT(oldInst); SLANG_ASSERT(newInst); + IRCloneEnv* subEnv = nullptr; IRCloneEnv subEnvStorage; - auto subEnv = &subEnvStorage; - subEnv->parent = env; - + if (env->squashChildrenMapping) + { + subEnv = env; + } + else + { + subEnv = &subEnvStorage; + subEnv->parent = env; + } _cloneInstDecorationsAndChildren(subEnv, sharedBuilder, oldInst, newInst); } diff --git a/source/slang/slang-ir-clone.h b/source/slang/slang-ir-clone.h index b483b5fcb..824806d57 100644 --- a/source/slang/slang-ir-clone.h +++ b/source/slang/slang-ir-clone.h @@ -38,6 +38,9 @@ struct IRCloneEnv /// A parent environment to fall back to if `mapOldValToNew` doesn't contain a key. IRCloneEnv* parent = nullptr; + + /// Should `mapOldValToNew` keep a copy of children's oldToNew mapping? + bool squashChildrenMapping = false; }; /// Look up the replacement for `oldVal`, if any, registered in `env`. diff --git a/source/slang/slang-ir-dce.cpp b/source/slang/slang-ir-dce.cpp index 14edf21d7..ae04acbd1 100644 --- a/source/slang/slang-ir-dce.cpp +++ b/source/slang/slang-ir-dce.cpp @@ -103,18 +103,14 @@ struct DeadCodeEliminationContext return undefInst; } - // Given the basic infrastructrure above, let's - // dive into the task of actually finding all - // the live code in a module. - // - bool processModule() + bool processInst(IRInst* root) { - // First of all, we know that the root module instruction + // First of all, we know that the root instruction // should be considered as live, because otherwise // we'd end up eliminating it, so that is a // good place to start. // - markInstAsLive(module->getModuleInst()); + markInstAsLive(root); // Ensure there is a global undef inst that is always alive. // This undef inst will be used to fill in weak-referencing uses @@ -128,7 +124,7 @@ struct DeadCodeEliminationContext // processing entries off of our work list // until it goes dry. // - while( workList.getCount() ) + while (workList.getCount()) { auto inst = workList.getLast(); workList.removeLast(); @@ -151,7 +147,7 @@ struct DeadCodeEliminationContext // markInstAsLive(inst->getFullType()); UInt operandCount = inst->getOperandCount(); - for( UInt ii = 0; ii < operandCount; ++ii ) + for (UInt ii = 0; ii < operandCount; ++ii) { // There are some type of operands that needs to be treated as // "weak" references -- they can never hold things alive, and @@ -182,9 +178,9 @@ struct DeadCodeEliminationContext // decision of whether a child (or decoration) // should be live when its parent is to a subroutine. // - for( auto child : inst->getDecorationsAndChildren() ) + for (auto child : inst->getDecorationsAndChildren()) { - if(shouldInstBeLiveIfParentIsLive(child)) + if (shouldInstBeLiveIfParentIsLive(child)) { // In this case, we know `inst` is live and // its `child` should be live if its parent is, @@ -203,7 +199,16 @@ struct DeadCodeEliminationContext // recursively and eliminate those that are "dead" by // virtue of not having been found live. // - return eliminateDeadInstsRec(module->getModuleInst()); + return eliminateDeadInstsRec(root); + } + + // Given the basic infrastructrure above, let's + // dive into the task of actually finding all + // the live code in a module. + // + bool processModule() + { + return processInst(module->getModuleInst()); } bool eliminateDeadInstsRec(IRInst* inst) @@ -421,4 +426,15 @@ bool eliminateDeadCode( return context.processModule(); } +bool eliminateDeadCode( + IRInst* root, + IRDeadCodeEliminationOptions const& options) +{ + DeadCodeEliminationContext context; + context.module = root->getModule(); + context.options = options; + + return context.processInst(root); +} + } diff --git a/source/slang/slang-ir-dce.h b/source/slang/slang-ir-dce.h index 0abf17f9f..d8819e042 100644 --- a/source/slang/slang-ir-dce.h +++ b/source/slang/slang-ir-dce.h @@ -24,6 +24,10 @@ namespace Slang IRModule* module, IRDeadCodeEliminationOptions const& options = IRDeadCodeEliminationOptions()); + bool eliminateDeadCode( + IRInst* root, + IRDeadCodeEliminationOptions const& options = IRDeadCodeEliminationOptions()); + bool shouldInstBeLiveIfParentIsLive(IRInst* inst, IRDeadCodeEliminationOptions options); bool isWeakReferenceOperand(IRInst* inst, UInt operandIndex); diff --git a/source/slang/slang-ir-inst-defs.h b/source/slang/slang-ir-inst-defs.h index c74388406..8440f4181 100644 --- a/source/slang/slang-ir-inst-defs.h +++ b/source/slang/slang-ir-inst-defs.h @@ -741,6 +741,10 @@ INST(HighLevelDeclDecoration, highLevelDecl, 1, 0) /// BOTH a differential and a primal value. INST(MixedDifferentialInstDecoration, mixedDiffInstDecoration, 1, 0) + /// Used by the auto-diff pass to mark insts whose result is stored + /// in an intermediary struct for reuse in backward propagation phase. + INST(PrimalValueStructKeyDecoration, primalValueKey, 1, 0) + /// Used by the auto-diff pass to hold a reference to a /// differential member of a type in its associated differential type. INST(DerivativeMemberDecoration, derivativeMemberDecoration, 1, 0) diff --git a/source/slang/slang-ir-inst-pass-base.h b/source/slang/slang-ir-inst-pass-base.h index b5a1f168a..86c2cb0fe 100644 --- a/source/slang/slang-ir-inst-pass-base.h +++ b/source/slang/slang-ir-inst-pass-base.h @@ -89,12 +89,12 @@ namespace Slang } template <typename Func> - void processAllInsts(const Func& f) + void processChildInsts(IRInst* root, const Func& f) { workList.clear(); workListSet.Clear(); - addToWorkList(module->getModuleInst()); + addToWorkList(root); while (workList.getCount() != 0) { @@ -108,6 +108,13 @@ namespace Slang } } } + + template <typename Func> + void processAllInsts(const Func& f) + { + processChildInsts(module->getModuleInst(), f); + } + }; } diff --git a/source/slang/slang-ir-insts.h b/source/slang/slang-ir-insts.h index 67f17f5b2..6373334bf 100644 --- a/source/slang/slang-ir-insts.h +++ b/source/slang/slang-ir-insts.h @@ -610,6 +610,17 @@ struct IRDifferentialInstDecoration : IRDecoration IRType* getPrimalType() { return as<IRType>(getOperand(0)); } }; +struct IRPrimalValueStructKeyDecoration : IRDecoration +{ + enum + { + kOp = kIROp_PrimalValueStructKeyDecoration + }; + + IR_LEAF_ISA(PrimalValueStructKeyDecoration) + + IRStructKey* getStructKey() { return as<IRStructKey>(getOperand(0)); } +}; struct IRMixedDifferentialInstDecoration : IRDecoration { @@ -2657,6 +2668,8 @@ public: IRInst* addDifferentiableTypeDictionaryDecoration(IRInst* target); + IRInst* addPrimalValueStructKeyDecoration(IRInst* target, IRStructKey* key); + // Add a differentiable type entry to the appropriate dictionary. IRInst* addDifferentiableTypeEntry(IRInst* dictDecoration, IRInst* irType, IRInst* conformanceWitness); @@ -2718,6 +2731,10 @@ public: /// Otherwise, returns nullptr if we can't materialize the inst. IRInst* emitDefaultConstruct(IRType* type, bool fallback = true); + /// Emits a raw `DefaultConstruct` opcode without attempting to fold/materialize + /// the inst. + IRInst* emitDefaultConstructRaw(IRType* type); + IRInst* emitCast( IRType* type, IRInst* value); diff --git a/source/slang/slang-ir-peephole.cpp b/source/slang/slang-ir-peephole.cpp index 66cde68de..21e17b546 100644 --- a/source/slang/slang-ir-peephole.cpp +++ b/source/slang/slang-ir-peephole.cpp @@ -241,16 +241,21 @@ struct PeepholeContext : InstPassBase } } - bool processModule() + bool processFunc(IRInst* func) { SharedIRBuilder* sharedBuilder = &sharedBuilderStorage; sharedBuilder->init(module); sharedBuilderStorage.deduplicateAndRebuildGlobalNumberingMap(); changed = false; - processAllInsts([this](IRInst* inst) { processInst(inst); }); + processChildInsts(func, [this](IRInst* inst) { processInst(inst); }); return changed; } + + bool processModule() + { + return processFunc(module->getModuleInst()); + } }; bool peepholeOptimize(IRModule* module) @@ -259,4 +264,10 @@ bool peepholeOptimize(IRModule* module) return context.processModule(); } +bool peepholeOptimize(IRInst* func) +{ + PeepholeContext context = PeepholeContext(func->getModule()); + return context.processFunc(func); +} + } // namespace Slang diff --git a/source/slang/slang-ir-peephole.h b/source/slang/slang-ir-peephole.h index e05c533eb..dc1b5527a 100644 --- a/source/slang/slang-ir-peephole.h +++ b/source/slang/slang-ir-peephole.h @@ -5,7 +5,9 @@ namespace Slang { struct IRModule; struct IRCall; + struct IRInst; /// Apply peephole optimizations. bool peepholeOptimize(IRModule* module); + bool peepholeOptimize(IRInst* func); } diff --git a/source/slang/slang-ir-sccp.cpp b/source/slang/slang-ir-sccp.cpp index fbc00848b..c03eee695 100644 --- a/source/slang/slang-ir-sccp.cpp +++ b/source/slang/slang-ir-sccp.cpp @@ -1678,5 +1678,21 @@ bool applySparseConditionalConstantPropagation( return changed; } + +bool applySparseConditionalConstantPropagation(IRInst* func) +{ + SharedSCCPContext shared; + shared.module = func->getModule(); + shared.sharedBuilder.init(shared.module); + shared.sharedBuilder.deduplicateAndRebuildGlobalNumberingMap(); + + SCCPContext globalContext; + globalContext.shared = &shared; + globalContext.code = nullptr; + + // Run recursive SCCP passes on each child code block. + return applySparseConditionalConstantPropagationRec(globalContext, func); +} + } diff --git a/source/slang/slang-ir-sccp.h b/source/slang/slang-ir-sccp.h index 06c5769c8..23c903eeb 100644 --- a/source/slang/slang-ir-sccp.h +++ b/source/slang/slang-ir-sccp.h @@ -4,6 +4,7 @@ namespace Slang { struct IRModule; + struct IRInst; /// Apply Sparse Conditional Constant Propagation (SCCP) to a module. /// @@ -15,5 +16,7 @@ namespace Slang /// Returns true if IR is changed. bool applySparseConditionalConstantPropagation( IRModule* module); + + bool applySparseConditionalConstantPropagation(IRInst* func); } diff --git a/source/slang/slang-ir-ssa-simplification.cpp b/source/slang/slang-ir-ssa-simplification.cpp index f723325c4..4b604e03a 100644 --- a/source/slang/slang-ir-ssa-simplification.cpp +++ b/source/slang/slang-ir-ssa-simplification.cpp @@ -10,8 +10,6 @@ namespace Slang { - struct IRModule; - // Run a combination of SSA, SCCP, SimplifyCFG, and DeadCodeElimination pass // until no more changes are possible. void simplifyIR(IRModule* module) @@ -37,4 +35,27 @@ namespace Slang iterationCounter++; } } + + void simplifyFunc(IRGlobalValueWithCode* func) + { + bool changed = true; + const int kMaxIterations = 8; + int iterationCounter = 0; + while (changed && iterationCounter < kMaxIterations) + { + changed = false; + changed |= applySparseConditionalConstantPropagation(func); + changed |= peepholeOptimize(func); + changed |= simplifyCFG(func); + + // Note: we disregard the `changed` state from dead code elimination pass since + // SCCP pass could be generating temporarily evaluated constant values and never actually use them. + // DCE will always remove those nearly generated consts and always returns true here. + eliminateDeadCode(func); + + changed |= constructSSA(func); + + iterationCounter++; + } + } } diff --git a/source/slang/slang-ir-ssa-simplification.h b/source/slang/slang-ir-ssa-simplification.h index 19a39e8d4..ee8343003 100644 --- a/source/slang/slang-ir-ssa-simplification.h +++ b/source/slang/slang-ir-ssa-simplification.h @@ -4,8 +4,11 @@ namespace Slang { struct IRModule; + struct IRGlobalValueWithCode; // Run a combination of SSA, SCCP, SimplifyCFG, and DeadCodeElimination pass // until no more changes are possible. void simplifyIR(IRModule* module); + + void simplifyFunc(IRGlobalValueWithCode* func); } diff --git a/source/slang/slang-ir-ssa.cpp b/source/slang/slang-ir-ssa.cpp index d84b48c3d..2dee189dc 100644 --- a/source/slang/slang-ir-ssa.cpp +++ b/source/slang/slang-ir-ssa.cpp @@ -1237,4 +1237,18 @@ bool constructSSA(IRModule* module) return changed; } +bool constructSSA(IRInst* globalVal) +{ + switch (globalVal->getOp()) + { + case kIROp_Func: + case kIROp_GlobalVar: + return constructSSA(globalVal->getModule(), (IRGlobalValueWithCode*)globalVal); + + default: + break; + } + return false; +} + } diff --git a/source/slang/slang-ir-ssa.h b/source/slang/slang-ir-ssa.h index b327802a1..d455439df 100644 --- a/source/slang/slang-ir-ssa.h +++ b/source/slang/slang-ir-ssa.h @@ -5,6 +5,8 @@ namespace Slang { struct IRModule; struct IRGlobalValueWithCode; + struct IRInst; bool constructSSA(IRModule* module, IRGlobalValueWithCode* globalVal); bool constructSSA(IRModule* module); + bool constructSSA(IRInst* globalVal); } diff --git a/source/slang/slang-ir-util.h b/source/slang/slang-ir-util.h index b09b2f4d0..385d05b28 100644 --- a/source/slang/slang-ir-util.h +++ b/source/slang/slang-ir-util.h @@ -33,6 +33,17 @@ inline bool isScalarIntegerType(IRType* type) return getTypeStyle(type->getOp()) == kIROp_IntType; } +inline bool isChildInstOf(IRInst* inst, IRInst* parent) +{ + while (inst) + { + if (inst == parent) + return true; + inst = inst->getParent(); + } + return false; +} + } #endif diff --git a/source/slang/slang-ir.cpp b/source/slang/slang-ir.cpp index b4528a452..33130cfb3 100644 --- a/source/slang/slang-ir.cpp +++ b/source/slang/slang-ir.cpp @@ -3301,6 +3301,11 @@ namespace Slang return inst; } + IRInst* IRBuilder::emitDefaultConstructRaw(IRType* type) + { + return emitIntrinsicInst(type, kIROp_DefaultConstruct, 0, nullptr); + } + IRInst* IRBuilder::emitDefaultConstruct(IRType* type, bool fallback) { IRType* actualType = type; @@ -3809,6 +3814,11 @@ namespace Slang return inst; } + IRInst* IRBuilder::addPrimalValueStructKeyDecoration(IRInst* target, IRStructKey* key) + { + return addDecoration(target, kIROp_PrimalValueStructKeyDecoration, key); + } + RefPtr<IRModule> IRModule::create(Session* session) { RefPtr<IRModule> module = new IRModule(session); diff --git a/source/slang/slang-lower-to-ir.cpp b/source/slang/slang-lower-to-ir.cpp index 7fd0bbee7..a84cf9b8d 100644 --- a/source/slang/slang-lower-to-ir.cpp +++ b/source/slang/slang-lower-to-ir.cpp @@ -7861,25 +7861,6 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo> return as<IRStringLit>(builder->getStringValue(stringLitExpr->value.getUnownedSlice())); } - IRInst* lowerFuncType(FunctionDeclBase* decl) - { - NestedContext nestedContextFuncType(this); - auto funcTypeBuilder = nestedContextFuncType.getBuilder(); - auto funcTypeContext = nestedContextFuncType.getContext(); - - auto outerGenerics = emitOuterGenerics(funcTypeContext, decl, decl); - - FuncDeclBaseTypeInfo info; - _lowerFuncDeclBaseTypeInfo( - funcTypeContext, - createDefaultSpecializedDeclRef(funcTypeContext, nullptr, decl), - info); - - auto irFuncType = info.type; - - return finishOuterGenerics(funcTypeBuilder, irFuncType, outerGenerics); - } - bool isClassType(IRType* type) { if (auto specialize = as<IRSpecialize>(type)) |
