diff options
| author | Yong He <yonghe@outlook.com> | 2023-01-17 22:19:10 -0800 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2023-01-17 22:19:10 -0800 |
| commit | 86ddb9c452c4f33d09b4f7d4f90a9abad4984071 (patch) | |
| tree | 833f8bb0fd5df8f0328e20a4568a19081593cedb /source | |
| parent | a0994a8da142e54362e9ec1fdb5e5abc708ec3d2 (diff) | |
First custom backward-derivative test case working. (#2598)
Diffstat (limited to 'source')
| -rw-r--r-- | source/slang/diff.meta.slang | 18 | ||||
| -rw-r--r-- | source/slang/slang-ir-autodiff-fwd.cpp | 10 | ||||
| -rw-r--r-- | source/slang/slang-ir-autodiff-pairs.cpp | 11 | ||||
| -rw-r--r-- | source/slang/slang-ir-autodiff-rev.cpp | 30 | ||||
| -rw-r--r-- | source/slang/slang-ir-autodiff-transcriber-base.cpp | 176 | ||||
| -rw-r--r-- | source/slang/slang-ir-autodiff-transcriber-base.h | 16 | ||||
| -rw-r--r-- | source/slang/slang-ir-autodiff-unzip.cpp | 2 | ||||
| -rw-r--r-- | source/slang/slang-ir-autodiff-unzip.h | 43 | ||||
| -rw-r--r-- | source/slang/slang-ir-autodiff.cpp | 4 | ||||
| -rw-r--r-- | source/slang/slang-ir-dce.cpp | 4 | ||||
| -rw-r--r-- | source/slang/slang-ir-deduplicate-generic-children.cpp | 43 | ||||
| -rw-r--r-- | source/slang/slang-ir-deduplicate-generic-children.h | 12 | ||||
| -rw-r--r-- | source/slang/slang-ir-link.cpp | 1 | ||||
| -rw-r--r-- | source/slang/slang-ir-specialize.cpp | 96 | ||||
| -rw-r--r-- | source/slang/slang-ir-ssa-simplification.cpp | 2 | ||||
| -rw-r--r-- | source/slang/slang-ir-util.cpp | 116 | ||||
| -rw-r--r-- | source/slang/slang-ir-util.h | 23 |
17 files changed, 451 insertions, 156 deletions
diff --git a/source/slang/diff.meta.slang b/source/slang/diff.meta.slang index e19923c80..c732d1a5e 100644 --- a/source/slang/diff.meta.slang +++ b/source/slang/diff.meta.slang @@ -223,6 +223,24 @@ DifferentialPair<vector<T, N>> __d_exp_vector(DifferentialPair<vector<T, N>> dpx VECTOR_MAP_D_UNARY(T, N, __d_exp, dpx); } +__generic<T : __BuiltinFloatingPointType> +[BackwardDerivativeOf(exp)] +void __d_exp(inout DifferentialPair<T> dpx, T.Differential dOut) +{ + dpx = diffPair( + dpx.p, + T.dmul(exp(dpx.p), dOut)); +} + +__generic<T : __BuiltinFloatingPointType, let N : int> +[BackwardDerivativeOf(exp)] +void __d_exp_vector(inout DifferentialPair<vector<T, N>> dpx, vector<T, N>.Differential dOut) +{ + dpx = diffPair( + dpx.p, + vector<T, N>.dmul(exp(dpx.p), dOut)); +} + // Absolute value __generic<T : __BuiltinFloatingPointType> diff --git a/source/slang/slang-ir-autodiff-fwd.cpp b/source/slang/slang-ir-autodiff-fwd.cpp index e8fe2beac..ee159b80b 100644 --- a/source/slang/slang-ir-autodiff-fwd.cpp +++ b/source/slang/slang-ir-autodiff-fwd.cpp @@ -880,10 +880,12 @@ InstPair ForwardDiffTranscriber::transcribeMakeDifferentialPair(IRBuilder* build auto diffDiffVal = findOrTranscribeDiffInst(builder, origInst->getDifferentialValue()); SLANG_ASSERT(diffDiffVal); + auto primalPairType = findOrTranscribePrimalInst(builder, origInst->getFullType()); + auto diffPairType = findOrTranscribeDiffInst(builder, origInst->getFullType()); auto primalPair = builder->emitMakeDifferentialPair( - tryGetDiffPairType(builder, primalVal->getDataType()), primalVal, diffPrimalVal); + (IRType*)primalPairType, primalVal, diffPrimalVal); auto diffPair = builder->emitMakeDifferentialPair( - tryGetDiffPairType(builder, differentiateType(builder, origInst->getPrimalValue()->getDataType())), + (IRType*)diffPairType, primalDiffVal, diffDiffVal); return InstPair(primalPair, diffPair); @@ -901,7 +903,9 @@ InstPair ForwardDiffTranscriber::transcribeDifferentialPairGetElement(IRBuilder* auto diffVal = findOrTranscribeDiffInst(builder, origInst->getOperand(0)); SLANG_ASSERT(diffVal); - auto primalResult = builder->emitIntrinsicInst(origInst->getFullType(), origInst->getOp(), 1, &primalVal); + auto primalType = findOrTranscribePrimalInst(builder, origInst->getFullType()); + + auto primalResult = builder->emitIntrinsicInst((IRType*)primalType, origInst->getOp(), 1, &primalVal); auto diffValPairType = as<IRDifferentialPairType>(diffVal->getDataType()); IRInst* diffResultType = nullptr; diff --git a/source/slang/slang-ir-autodiff-pairs.cpp b/source/slang/slang-ir-autodiff-pairs.cpp index dc72ed44a..b3665f27c 100644 --- a/source/slang/slang-ir-autodiff-pairs.cpp +++ b/source/slang/slang-ir-autodiff-pairs.cpp @@ -124,15 +124,22 @@ struct DiffPairLoweringPass : InstPassBase } }); + OrderedDictionary<IRInst*, IRInst*> pendingReplacements; processInstsOfType<IRDifferentialPairType>(kIROp_DifferentialPairType, [&](IRDifferentialPairType* inst) { if (auto loweredType = lowerPairType(builder, inst)) { - inst->replaceUsesWith(loweredType); - inst->removeAndDeallocate(); + pendingReplacements.Add(inst, loweredType); modified = true; } }); + for (auto replacement : pendingReplacements) + { + replacement.Key->replaceUsesWith(replacement.Value); + replacement.Key->removeAndDeallocate(); + } + autodiffContext->sharedBuilder->deduplicateAndRebuildGlobalNumberingMap(); + return modified; } diff --git a/source/slang/slang-ir-autodiff-rev.cpp b/source/slang/slang-ir-autodiff-rev.cpp index de4fbe182..d3a6137c1 100644 --- a/source/slang/slang-ir-autodiff-rev.cpp +++ b/source/slang/slang-ir-autodiff-rev.cpp @@ -36,7 +36,7 @@ namespace Slang } else { - if (auto diffPairType = tryGetDiffPairType(builder, primalType)) + if (auto diffPairType = tryGetDiffPairType(builder, origType)) { auto inoutDiffPairType = builder->getPtrType(kIROp_InOutType, diffPairType); newParameterTypes.add(inoutDiffPairType); @@ -311,6 +311,7 @@ namespace Slang IRFunc* primalFunc = origFunc; + maybeMigrateDifferentiableDictionaryFromDerivativeFunc(inBuilder, origFunc); differentiableTypeConformanceContext.setFunc(origFunc); auto diffFunc = builder.createFunc(); @@ -337,7 +338,8 @@ namespace Slang // Find and clone `DifferentiableTypeDictionaryDecoration` to the new diffFunc. if (auto dictDecor = origFunc->findDecoration<IRDifferentiableTypeDictionaryDecoration>()) { - cloneDecoration(dictDecor, diffFunc); + builder.setInsertBefore(diffFunc->getFirstDecorationOrChild()); + cloneInst(&cloneEnv, &builder, dictDecor); } return InstPair(primalFunc, diffFunc); @@ -526,7 +528,7 @@ namespace Slang auto diffOuterGeneric = as<IRGeneric>(findOuterGeneric(diffPropagateFunc)); SLANG_RELEASE_ASSERT(diffOuterGeneric); - migrationContext.init(fwdParentGeneric, diffOuterGeneric); + migrationContext.init(fwdParentGeneric, diffOuterGeneric, diffPropagateFunc); auto inst = fwdParentGeneric->getFirstBlock()->getFirstOrdinaryInst(); builder->setInsertBefore(diffPropagateFunc); while (inst) @@ -580,24 +582,12 @@ namespace Slang // IRFunc* unzippedFwdDiffFunc = diffUnzipPass->unzipDiffInsts(fwdDiffFunc); - // Clone the primal blocks from unzippedFwdDiffFunc - // to the reverse-mode function. - // - // Special care needs to be taken for the first block since it holds the parameters - - // Clone all blocks into a temporary diff func. - // We're using a temporary sice we don't want to clone decorations, - // only blocks, and right now there's no provision in slang-ir-clone.h - // for that. - // + + // Move blocks from `unzippedFwdDiffFunc` to the `diffPropagateFunc` shell. builder->setInsertInto(diffPropagateFunc->getParent()); - IRCloneEnv subCloneEnv; - auto tempDiffFunc = as<IRFunc>(cloneInst(&subCloneEnv, builder, unzippedFwdDiffFunc)); - - // Move blocks to the diffFunc shell. { List<IRBlock*> workList; - for (auto block = tempDiffFunc->getFirstBlock(); block; block = block->getNextBlock()) + for (auto block = unzippedFwdDiffFunc->getFirstBlock(); block; block = block->getNextBlock()) workList.add(block); for (auto block : workList) @@ -623,9 +613,7 @@ namespace Slang auto extractedPrimalFunc = diffUnzipPass->extractPrimalFunc( diffPropagateFunc, primalFunc, intermediateType); - // Clean up by deallocating intermediate versions. - tempDiffFunc->removeAndDeallocate(); - unzippedFwdDiffFunc->removeAndDeallocate(); + // Clean up by deallocating the tempoarary forward derivative func. fwdDiffFunc->removeAndDeallocate(); // If primal function is nested in a generic, we want to create separate generics for all the associated things diff --git a/source/slang/slang-ir-autodiff-transcriber-base.cpp b/source/slang/slang-ir-autodiff-transcriber-base.cpp index 275b40b46..cfbc9638a 100644 --- a/source/slang/slang-ir-autodiff-transcriber-base.cpp +++ b/source/slang/slang-ir-autodiff-transcriber-base.cpp @@ -176,61 +176,56 @@ IRInst* AutoDiffTranscriberBase::maybeCloneForPrimalInst(IRBuilder* builder, IRI } // Get or construct `:IDifferentiable` conformance for a DifferentiablePair. -IRWitnessTable* AutoDiffTranscriberBase::getDifferentialPairWitness(IRInst* inDiffPairType) +IRWitnessTable* AutoDiffTranscriberBase::getDifferentialPairWitness(IRBuilder* builder, IRInst* inOriginalDiffPairType, IRInst* inPrimalDiffPairType) { - IRBuilder builder(sharedBuilder); - builder.setInsertInto(inDiffPairType->parent); - auto diffPairType = as<IRDifferentialPairType>(inDiffPairType); - SLANG_ASSERT(diffPairType); - - auto table = builder.createWitnessTable(autoDiffSharedContext->differentiableInterfaceType, diffPairType); - // Differentiate the pair type to get it's differential (which is itself a pair) - auto diffDiffPairType = differentiateType(&builder, diffPairType); + auto diffDiffPairType = differentiateType(builder, (IRType*)inOriginalDiffPairType); + + auto table = builder->createWitnessTable(autoDiffSharedContext->differentiableInterfaceType, (IRType*)inPrimalDiffPairType); // And place it in the synthesized witness table. - builder.createWitnessTableEntry(table, autoDiffSharedContext->differentialAssocTypeStructKey, diffDiffPairType); + builder->createWitnessTableEntry(table, autoDiffSharedContext->differentialAssocTypeStructKey, diffDiffPairType); + // Omit the method synthesis here, since we can just intercept those directly at `getXXMethodForType`. // Record this in the context for future lookups - differentiableTypeConformanceContext.differentiableWitnessDictionary[diffPairType] = table; + differentiableTypeConformanceContext.differentiableWitnessDictionary[(IRType*)inOriginalDiffPairType] = table; return table; } -IRType* AutoDiffTranscriberBase::getOrCreateDiffPairType(IRInst* primalType, IRInst* witness) +IRType* AutoDiffTranscriberBase::getOrCreateDiffPairType(IRBuilder* builder, IRInst* primalType, IRInst* witness) { - IRBuilder builder(sharedBuilder); - builder.setInsertInto(primalType->parent); - return builder.getDifferentialPairType( + return builder->getDifferentialPairType( (IRType*)primalType, witness); } -IRType* AutoDiffTranscriberBase::getOrCreateDiffPairType(IRInst* primalType) +IRType* AutoDiffTranscriberBase::getOrCreateDiffPairType(IRBuilder* builder, IRInst* originalType) { - IRBuilder builder(sharedBuilder); - if (!primalType->next) - builder.setInsertInto(primalType->parent); - else - builder.setInsertBefore(primalType->next); - - IRInst* witness = as<IRWitnessTable>( - differentiableTypeConformanceContext.lookUpConformanceForType((IRType*)primalType)); + auto primalType = lookupPrimalInst(builder, originalType, nullptr); + SLANG_RELEASE_ASSERT(primalType); + IRInst* witness = + differentiableTypeConformanceContext.lookUpConformanceForType((IRType*)originalType); + if (witness) + { + witness = lookupPrimalInst(builder, witness, nullptr); + SLANG_RELEASE_ASSERT(witness); + } if (!witness) { if (auto primalPairType = as<IRDifferentialPairType>(primalType)) { - witness = getDifferentialPairWitness(primalPairType); + witness = getDifferentialPairWitness(builder, originalType, primalPairType); } - else if (auto extractExistential = as<IRExtractExistentialType>(primalType)) + else if (auto extractExistential = as<IRExtractExistentialType>(originalType)) { - differentiateExtractExistentialType(&builder, extractExistential, witness); + differentiateExtractExistentialType(builder, extractExistential, witness); } } - return builder.getDifferentialPairType( + return builder->getDifferentialPairType( (IRType*)primalType, witness); } @@ -242,7 +237,8 @@ IRType* AutoDiffTranscriberBase::differentiateType(IRBuilder* builder, IRType* o primalType->getParent() && primalType->getParent()->getParent() && primalType->getParent()->getParent()->getOp() == kIROp_Generic) { - return (IRType*)differentiableTypeConformanceContext.getDifferentialForType(builder, origType); + auto diffType = (IRType*)differentiableTypeConformanceContext.getDifferentialForType(builder, origType); + return (IRType*)findOrTranscribePrimalInst(builder, diffType); } return (IRType*)transcribe(builder, origType); } @@ -254,10 +250,7 @@ IRType* AutoDiffTranscriberBase::_differentiateTypeImpl(IRBuilder* builder, IRTy origType->getOp(), differentiateType(builder, ptrType->getValueType())); - // If there is an explicit primal version of this type in the local scope, load that - // otherwise use the original type. - // - IRInst* primalType = lookupPrimalInst(builder, origType, origType); + auto primalType = maybeCloneForPrimalInst(builder, origType); // Special case certain compound types (PtrType, FuncType, etc..) // otherwise try to lookup a differential definition for the given type. @@ -288,6 +281,7 @@ IRType* AutoDiffTranscriberBase::_differentiateTypeImpl(IRBuilder* builder, IRTy { auto primalPairType = as<IRDifferentialPairType>(primalType); return getOrCreateDiffPairType( + builder, pairBuilder->getDiffTypeFromPairType(builder, primalPairType), pairBuilder->getDiffTypeWitnessFromPairType(builder, primalPairType)); } @@ -409,6 +403,44 @@ InstPair AutoDiffTranscriberBase::transcribeExtractExistentialWitnessTable(IRBui return InstPair(primalResult, nullptr); } +void AutoDiffTranscriberBase::maybeMigrateDifferentiableDictionaryFromDerivativeFunc(IRBuilder* builder, IRInst* origFunc) +{ + auto decor = origFunc->findDecoration<IRDifferentiableTypeDictionaryDecoration>(); + if (decor) + return; + // A differentiable func must have `IRDifferentiableTypeDictionaryDecoration`, except it has a + // `IRUserDefinedBackwardDerivativeDecoration`. + auto udfDecor = origFunc->findDecoration<IRUserDefinedBackwardDerivativeDecoration>(); + SLANG_RELEASE_ASSERT(udfDecor); + // We need to migrate the dictionary from the backward derivative func so we can properly + // differentiate the function header. + IRBuilder subBuilder = *builder; + subBuilder.setInsertBefore(origFunc); + + auto derivative = udfDecor->getBackwardDerivativeFunc(); + if (auto specialize = as<IRSpecialize>(derivative)) + { + auto derivativeGeneric = cast<IRGeneric>(specialize->getBase()); + GenericChildrenMigrationContext migrationContext; + migrationContext.init(derivativeGeneric, cast<IRGeneric>(findOuterGeneric(origFunc)), origFunc); + auto derivativeFunc = findGenericReturnVal(derivativeGeneric); + auto derivativeBlock = cast<IRBlock>(derivativeFunc->getParent()); + for (auto dInst = derivativeBlock->getFirstOrdinaryInst(); dInst != derivativeFunc; + dInst = dInst->getNextInst()) + { + migrationContext.cloneInst(&subBuilder, dInst); + } + auto udfDictDecor = derivativeFunc->findDecoration<IRDifferentiableTypeDictionaryDecoration>(); + SLANG_RELEASE_ASSERT(udfDictDecor); + subBuilder.setInsertBefore(origFunc->getFirstDecorationOrChild()); + migrationContext.cloneInst(&subBuilder, udfDictDecor); + eliminateDeadCode(origFunc->getParent()); + } + else + { + cloneDecoration(udfDecor, origFunc); + } +} IRType* AutoDiffTranscriberBase::differentiateExtractExistentialType(IRBuilder* builder, IRExtractExistentialType* origType, IRInst*& outWitnessTable) { @@ -435,21 +467,21 @@ IRType* AutoDiffTranscriberBase::differentiateExtractExistentialType(IRBuilder* return nullptr; } -IRType* AutoDiffTranscriberBase::tryGetDiffPairType(IRBuilder* builder, IRType* primalType) +IRType* AutoDiffTranscriberBase::tryGetDiffPairType(IRBuilder* builder, IRType* originalType) { // If this is a PtrType (out, inout, etc..), then create diff pair from // value type and re-apply the appropropriate PtrType wrapper. // - if (auto origPtrType = as<IRPtrTypeBase>(primalType)) + if (auto origPtrType = as<IRPtrTypeBase>(originalType)) { if (auto diffPairValueType = tryGetDiffPairType(builder, origPtrType->getValueType())) - return builder->getPtrType(primalType->getOp(), diffPairValueType); + return builder->getPtrType(originalType->getOp(), diffPairValueType); else return nullptr; } - auto diffType = differentiateType(builder, primalType); + auto diffType = differentiateType(builder, originalType); if (diffType) - return (IRType*)getOrCreateDiffPairType(primalType); + return (IRType*)getOrCreateDiffPairType(builder, originalType); return nullptr; } @@ -628,7 +660,7 @@ IRInst* AutoDiffTranscriberBase::getDifferentialZeroOfType(IRBuilder* builder, I } } -InstPair AutoDiffTranscriberBase::transcribeBlock(IRBuilder* builder, IRBlock* origBlock) +InstPair AutoDiffTranscriberBase::transcribeBlockImpl(IRBuilder* builder, IRBlock* origBlock, HashSet<IRInst*>& instsToSkip) { IRBuilder subBuilder(builder->getSharedBuilder()); subBuilder.setInsertLoc(builder->getInsertLoc()); @@ -653,7 +685,14 @@ InstPair AutoDiffTranscriberBase::transcribeBlock(IRBuilder* builder, IRBlock* o // derivative code. // for (auto child = origBlock->getFirstOrdinaryInst(); child; child = child->getNextInst()) + { + if (instsToSkip.Contains(child)) + { + continue; + } + this->transcribe(&subBuilder, child); + } return InstPair(diffBlock, diffBlock); } @@ -713,12 +752,63 @@ InstPair AutoDiffTranscriberBase::transcribeReturn(IRBuilder* builder, IRReturn* } } +static void _markGenericChildrenWithoutRelaventUse(IRGeneric* origGeneric, HashSet<IRInst*>& outInstsToSkip) +{ + for (;;) + { + bool changed = false; + for (auto inst = origGeneric->getFirstBlock()->getFirstOrdinaryInst(); inst; + inst = inst->getNextInst()) + { + // If an inst is only referenced by a UserDefinedDerivativeDecoration, we need to skip + // its transcription. + switch (inst->getOp()) + { + case kIROp_Return: + continue; + default: + break; + } + + bool hasRelaventUse = false; + for (auto use = inst->firstUse; use; use = use->nextUse) + { + switch (use->getUser()->getOp()) + { + case kIROp_UserDefinedBackwardDerivativeDecoration: + case kIROp_ForwardDerivativeDecoration: + case kIROp_BackwardDerivativeDecoration: + case kIROp_BackwardDerivativePrimalDecoration: + case kIROp_BackwardDerivativePropagateDecoration: + break; + default: + if (!outInstsToSkip.Contains(use->getUser())) + { + hasRelaventUse = true; + } + break; + } + } + if (!hasRelaventUse) + { + if (outInstsToSkip.Add(inst)) + { + changed = true; + } + } + } + if (!changed) + break; + } +} + // Transcribe a generic definition InstPair AutoDiffTranscriberBase::transcribeGeneric(IRBuilder* inBuilder, IRGeneric* origGeneric) { auto innerVal = findInnerMostGenericReturnVal(origGeneric); if (auto innerFunc = as<IRFunc>(innerVal)) { + maybeMigrateDifferentiableDictionaryFromDerivativeFunc(inBuilder, innerFunc); differentiableTypeConformanceContext.setFunc(innerFunc); } else if (auto funcType = as<IRFuncType>(innerVal)) @@ -752,10 +842,14 @@ InstPair AutoDiffTranscriberBase::transcribeGeneric(IRBuilder* inBuilder, IRGene diffGeneric->setFullType(diffType); + HashSet<IRInst*> instsToSkip; + _markGenericChildrenWithoutRelaventUse(origGeneric, instsToSkip); + // Transcribe children from origFunc into diffFunc. builder.setInsertInto(diffGeneric); - for (auto block = origGeneric->getFirstBlock(); block; block = block->getNextBlock()) - this->transcribe(&builder, block); + auto transcribedBlock = transcribeBlockImpl(&builder, origGeneric->getFirstBlock(), instsToSkip); + mapPrimalInst(origGeneric->getFirstBlock(), transcribedBlock.primal); + mapDifferentialInst(origGeneric->getFirstBlock(), transcribedBlock.differential); return InstPair(primalGeneric, diffGeneric); } diff --git a/source/slang/slang-ir-autodiff-transcriber-base.h b/source/slang/slang-ir-autodiff-transcriber-base.h index 2d980145e..b0397069b 100644 --- a/source/slang/slang-ir-autodiff-transcriber-base.h +++ b/source/slang/slang-ir-autodiff-transcriber-base.h @@ -96,12 +96,14 @@ struct AutoDiffTranscriberBase InstPair transcribeExtractExistentialWitnessTable(IRBuilder* builder, IRInst* origInst); + void maybeMigrateDifferentiableDictionaryFromDerivativeFunc(IRBuilder* builder, IRInst* origFunc); + // Get or construct `:IDifferentiable` conformance for a DifferentiablePair. - IRWitnessTable* getDifferentialPairWitness(IRInst* inDiffPairType); + IRWitnessTable* getDifferentialPairWitness(IRBuilder* builder, IRInst* inOriginalDiffPairType, IRInst* inPrimalDiffPairType); - IRType* getOrCreateDiffPairType(IRInst* primalType, IRInst* witness); + IRType* getOrCreateDiffPairType(IRBuilder* builder, IRInst* primalType, IRInst* witness); - IRType* getOrCreateDiffPairType(IRInst* primalType); + IRType* getOrCreateDiffPairType(IRBuilder* builder, IRInst* originalType); IRType* differentiateType(IRBuilder* builder, IRType* origType); @@ -121,7 +123,13 @@ struct AutoDiffTranscriberBase InstPair transcribeLookupInterfaceMethod(IRBuilder* builder, IRLookupWitnessMethod* lookupInst); - InstPair transcribeBlock(IRBuilder* builder, IRBlock* origBlock); + InstPair transcribeBlockImpl(IRBuilder* builder, IRBlock* origBlock, HashSet<IRInst*>& instsToSkip); + + InstPair transcribeBlock(IRBuilder* builder, IRBlock* origBlock) + { + HashSet<IRInst*> emptySet; + return transcribeBlockImpl(builder, origBlock, emptySet); + } // Transcribe a generic definition InstPair transcribeGeneric(IRBuilder* inBuilder, IRGeneric* origGeneric); diff --git a/source/slang/slang-ir-autodiff-unzip.cpp b/source/slang/slang-ir-autodiff-unzip.cpp index a95fd7b9b..640041ecf 100644 --- a/source/slang/slang-ir-autodiff-unzip.cpp +++ b/source/slang/slang-ir-autodiff-unzip.cpp @@ -76,7 +76,7 @@ struct ExtractPrimalFuncContext outIntermediateType = createIntermediateType(destFunc); GenericChildrenMigrationContext migrationContext; - migrationContext.init(as<IRGeneric>(findOuterGeneric(originalFunc)), as<IRGeneric>(findOuterGeneric(destFunc))); + migrationContext.init(as<IRGeneric>(findOuterGeneric(originalFunc)), as<IRGeneric>(findOuterGeneric(destFunc)), destFunc); originalFuncType = as<IRFuncType>(originalFunc->getDataType()); diff --git a/source/slang/slang-ir-autodiff-unzip.h b/source/slang/slang-ir-autodiff-unzip.h index e616578c1..c525191a3 100644 --- a/source/slang/slang-ir-autodiff-unzip.h +++ b/source/slang/slang-ir-autodiff-unzip.h @@ -14,39 +14,6 @@ namespace Slang { -struct GenericChildrenMigrationContext -{ - IRCloneEnv cloneEnv; - IRGeneric* srcGeneric; - void init(IRGeneric* genericSrc, IRGeneric* genericDst) - { - srcGeneric = genericSrc; - if (!genericSrc) - return; - auto srcParam = genericSrc->getFirstBlock()->getFirstParam(); - auto dstParam = genericDst->getFirstBlock()->getFirstParam(); - while (srcParam && dstParam) - { - cloneEnv.mapOldValToNew[srcParam] = dstParam; - srcParam = srcParam->getNextParam(); - dstParam = dstParam->getNextParam(); - } - cloneEnv.mapOldValToNew[genericSrc] = genericDst; - cloneEnv.mapOldValToNew[genericSrc->getFirstBlock()] = genericDst->getFirstBlock(); - } - - IRInst* cloneInst(IRBuilder* builder, IRInst* src) - { - if (!srcGeneric) - return src; - if (findOuterGeneric(src) == srcGeneric) - { - return Slang::cloneInst(&cloneEnv, builder, src); - } - return src; - } -}; - struct DiffUnzipPass { AutoDiffSharedContext* autodiffContext; @@ -210,10 +177,14 @@ struct DiffUnzipPass auto func = findSpecializeReturnVal(specialize); auto outerGen = findOuterGeneric(func); intermediateType = primalBuilder->getBackwardDiffIntermediateContextType(outerGen); - intermediateType = specializeWithGeneric( - *primalBuilder, + List<IRInst*> args; + for (UInt i = 0; i < specialize->getArgCount(); i++) + args.add(specialize->getArg(i)); + intermediateType = primalBuilder->emitSpecializeInst( + primalBuilder->getTypeKind(), intermediateType, - as<IRGeneric>(findOuterGeneric(primalBuilder->getInsertLoc().getParent()))); + args.getCount(), + args.getBuffer()); } else { diff --git a/source/slang/slang-ir-autodiff.cpp b/source/slang/slang-ir-autodiff.cpp index 7182375de..363006f58 100644 --- a/source/slang/slang-ir-autodiff.cpp +++ b/source/slang/slang-ir-autodiff.cpp @@ -31,11 +31,11 @@ static IRInst* _lookupWitness(IRBuilder* builder, IRInst* witness, IRInst* requi return entry->getSatisfyingVal(); } } - else if (auto witnessTableParam = as<IRParam>(witness)) + else { return builder->emitLookupInterfaceMethodInst( builder->getTypeKind(), - witnessTableParam, + witness, requirementKey); } return nullptr; diff --git a/source/slang/slang-ir-dce.cpp b/source/slang/slang-ir-dce.cpp index ae04acbd1..33e5b3cb4 100644 --- a/source/slang/slang-ir-dce.cpp +++ b/source/slang/slang-ir-dce.cpp @@ -314,11 +314,15 @@ bool shouldInstBeLiveIfParentIsLive(IRInst* inst, IRDeadCodeEliminationOptions o { if (inst->findDecoration<IRForwardDerivativeDecoration>()) return true; + if (inst->findDecoration<IRUserDefinedBackwardDerivativeDecoration>()) + return true; if (auto genInst = as<IRGeneric>(inst)) { auto inner = findInnerMostGenericReturnVal(genInst); if (inner->findDecoration<IRForwardDerivativeDecoration>()) return true; + if (inner->findDecoration<IRUserDefinedBackwardDerivativeDecoration>()) + return true; } } } diff --git a/source/slang/slang-ir-deduplicate-generic-children.cpp b/source/slang/slang-ir-deduplicate-generic-children.cpp new file mode 100644 index 000000000..f933f77cc --- /dev/null +++ b/source/slang/slang-ir-deduplicate-generic-children.cpp @@ -0,0 +1,43 @@ +#include "slang-ir-deduplicate-generic-children.h" +#include "slang-ir.h" +#include "slang-ir-clone.h" +#include "slang-ir-util.h" + +namespace Slang +{ + +bool deduplicateGenericChildren(IRGeneric* genericInst) +{ + bool changed = false; + GenericChildrenMigrationContext ctx; + ctx.init(genericInst, genericInst, nullptr); + List<IRInst*> instsToRemove; + for (auto inst = genericInst->getFirstBlock()->getFirstInst(); inst; inst = inst->getNextInst()) + { + auto deduped = ctx.deduplicate(inst); + if (deduped != inst) + { + inst->replaceUsesWith(deduped); + instsToRemove.add(inst); + changed = true; + } + } + for (auto inst : instsToRemove) + inst->removeAndDeallocate(); + return changed; +} + +bool deduplicateGenericChildren(IRModule* module) +{ + bool changed = false; + for (auto inst : module->getGlobalInsts()) + { + if (auto gen = as<IRGeneric>(inst)) + { + changed |= deduplicateGenericChildren(gen); + } + } + return changed; +} + +} diff --git a/source/slang/slang-ir-deduplicate-generic-children.h b/source/slang/slang-ir-deduplicate-generic-children.h new file mode 100644 index 000000000..71cfafd4e --- /dev/null +++ b/source/slang/slang-ir-deduplicate-generic-children.h @@ -0,0 +1,12 @@ +// slang-ir-deduplicate-generic-children.h +#pragma once + +namespace Slang +{ + struct IRModule; + struct IRGeneric; + + // Deduplicate insts inside a generic. + bool deduplicateGenericChildren(IRModule* module); + bool deduplicateGenericChildren(IRGeneric* genericInst); +} diff --git a/source/slang/slang-ir-link.cpp b/source/slang/slang-ir-link.cpp index ea4bfb7b9..80f974536 100644 --- a/source/slang/slang-ir-link.cpp +++ b/source/slang/slang-ir-link.cpp @@ -432,6 +432,7 @@ static void cloneExtraDecorationsFromInst( case kIROp_PublicDecoration: case kIROp_SequentialIDDecoration: case kIROp_ForwardDerivativeDecoration: + case kIROp_UserDefinedBackwardDerivativeDecoration: case kIROp_IntrinsicOpDecoration: if (!clonedInst->findDecorationImpl(decoration->getOp())) { diff --git a/source/slang/slang-ir-specialize.cpp b/source/slang/slang-ir-specialize.cpp index cbb5ccf09..747d0ccdd 100644 --- a/source/slang/slang-ir-specialize.cpp +++ b/source/slang/slang-ir-specialize.cpp @@ -390,57 +390,61 @@ struct SpecializationContext auto genericReturnVal = findInnerMostGenericReturnVal(genericVal); if (genericReturnVal->findDecoration<IRTargetIntrinsicDecoration>()) { - if (auto customDiffRef = genericReturnVal->findDecoration<IRForwardDerivativeDecoration>()) + for (auto decor : genericReturnVal->getDecorations()) { - // If we already have a diff func on this specialize, skip. - if (auto specDiffRef = specInst->findDecoration<IRForwardDerivativeDecoration>()) + if (decor->getOp() == kIROp_ForwardDerivativeDecoration || + decor->getOp() == kIROp_UserDefinedBackwardDerivativeDecoration) { - return false; - } - - auto specDiffFunc = as<IRSpecialize>(customDiffRef->getForwardDerivativeFunc()); + // If we already have a diff func on this specialize, skip. + if (auto specDiffRef = specInst->findDecorationImpl(decor->getOp())) + { + return false; + } - // If the base is specialized, the JVP version must be also be a specialized - // generic. - // - SLANG_RELEASE_ASSERT(specDiffFunc); + auto specDiffFunc = as<IRSpecialize>(decor->getOperand(0)); - // Build specialization arguments from specInst. - // Note that if we've reached this point, we can safely assume - // that our args are fully specialized/concrete. - // - UCount argCount = specInst->getArgCount(); - List<IRInst*> args; - for (UIndex ii = 0; ii < argCount; ii++) - args.add(specInst->getArg(ii)); - - IRBuilder builder(&sharedBuilderStorage); - - // Specialize the custom JVP function type with the original arguments. - builder.setInsertInto(module); - auto newDiffFuncType = builder.emitSpecializeInst( - builder.getTypeKind(), - specDiffFunc->getBase()->getDataType(), - argCount, - args.getBuffer()); - - // Specialize the custom JVP function with the original arguments. - builder.setInsertBefore(specInst); - auto newDiffFunc = builder.emitSpecializeInst( - (IRType*) newDiffFuncType, - specDiffFunc->getBase(), - argCount, - args.getBuffer()); - - // Add the new spec insts to the list so they get specialized with - // the usual logic. - // - addToWorkList(newDiffFuncType); - addToWorkList(newDiffFunc); - - builder.addForwardDerivativeDecoration(specInst, newDiffFunc); + // If the base is specialized, the JVP version must be also be a specialized + // generic. + // + SLANG_RELEASE_ASSERT(specDiffFunc); - return true; + // Build specialization arguments from specInst. + // Note that if we've reached this point, we can safely assume + // that our args are fully specialized/concrete. + // + UCount argCount = specInst->getArgCount(); + List<IRInst*> args; + for (UIndex ii = 0; ii < argCount; ii++) + args.add(specInst->getArg(ii)); + + IRBuilder builder(&sharedBuilderStorage); + + // Specialize the custom derivative function type with the original arguments. + builder.setInsertInto(module); + auto newDiffFuncType = builder.emitSpecializeInst( + builder.getTypeKind(), + specDiffFunc->getBase()->getDataType(), + argCount, + args.getBuffer()); + + // Specialize the custom derivative function with the original arguments. + builder.setInsertBefore(specInst); + auto newDiffFunc = builder.emitSpecializeInst( + (IRType*)newDiffFuncType, + specDiffFunc->getBase(), + argCount, + args.getBuffer()); + + // Add the new spec insts to the list so they get specialized with + // the usual logic. + // + addToWorkList(newDiffFuncType); + addToWorkList(newDiffFunc); + + builder.addDecoration(specInst, decor->getOp(), newDiffFunc); + + return true; + } } } return false; diff --git a/source/slang/slang-ir-ssa-simplification.cpp b/source/slang/slang-ir-ssa-simplification.cpp index 938094551..fd5f41f49 100644 --- a/source/slang/slang-ir-ssa-simplification.cpp +++ b/source/slang/slang-ir-ssa-simplification.cpp @@ -7,6 +7,7 @@ #include "slang-ir-simplify-cfg.h" #include "slang-ir-peephole.h" #include "slang-ir-hoist-constants.h" +#include "slang-ir-deduplicate-generic-children.h" #include "slang-ir-remove-unused-generic-param.h" namespace Slang @@ -22,6 +23,7 @@ namespace Slang { changed = false; changed |= hoistConstants(module); + changed |= deduplicateGenericChildren(module); changed |= applySparseConditionalConstantPropagation(module); changed |= peepholeOptimize(module); changed |= simplifyCFG(module); diff --git a/source/slang/slang-ir-util.cpp b/source/slang/slang-ir-util.cpp index fb465f638..881f041c0 100644 --- a/source/slang/slang-ir-util.cpp +++ b/source/slang/slang-ir-util.cpp @@ -219,4 +219,120 @@ void moveInstChildren(IRInst* dest, IRInst* src) } } +struct GenericChildrenMigrationContextImpl +{ + IRCloneEnv cloneEnv; + IRGeneric* srcGeneric; + IRGeneric* dstGeneric; + Dictionary<IRInstKey, IRInst*> deduplicateMap; + + void init(IRGeneric* genericSrc, IRGeneric* genericDst, IRInst* insertBefore) + { + srcGeneric = genericSrc; + dstGeneric = genericDst; + + if (!genericSrc) + return; + auto srcParam = genericSrc->getFirstBlock()->getFirstParam(); + auto dstParam = genericDst->getFirstBlock()->getFirstParam(); + while (srcParam && dstParam) + { + cloneEnv.mapOldValToNew[srcParam] = dstParam; + srcParam = srcParam->getNextParam(); + dstParam = dstParam->getNextParam(); + } + cloneEnv.mapOldValToNew[genericSrc] = genericDst; + cloneEnv.mapOldValToNew[genericSrc->getFirstBlock()] = genericDst->getFirstBlock(); + + if (insertBefore) + { + for (auto inst = genericDst->getFirstBlock()->getFirstOrdinaryInst(); + inst && inst != insertBefore; + inst = inst->getNextInst()) + { + IRInstKey key = { inst }; + deduplicateMap.AddIfNotExists(key, inst); + } + } + } + + IRInst* deduplicate(IRInst* value) + { + if (!value) return nullptr; + if (value->getParent() != dstGeneric->getFirstBlock()) + return value; + switch (value->getOp()) + { + case kIROp_Param: + case kIROp_StructType: + case kIROp_StructKey: + case kIROp_InterfaceType: + case kIROp_ClassType: + case kIROp_Func: + case kIROp_Generic: + return value; + default: + break; + } + if (as<IRConstant>(value)) + return value; + + for (UInt i = 0; i < value->getOperandCount(); i++) + { + value->setOperand(i, deduplicate(value->getOperand(i))); + } + value->setFullType((IRType*)deduplicate(value->getFullType())); + IRInstKey key = { value }; + if (auto newValue = deduplicateMap.TryGetValue(key)) + return *newValue; + deduplicateMap[key] = value; + return value; + } + + IRInst* cloneInst(IRBuilder* builder, IRInst* src) + { + if (!srcGeneric) + return src; + if (findOuterGeneric(src) == srcGeneric) + { + auto cloned = Slang::cloneInst(&cloneEnv, builder, src); + auto deduplicated = deduplicate(cloned); + if (deduplicated != cloned) + cloneEnv.mapOldValToNew[src] = deduplicated; + return deduplicated; + } + return src; + } +}; + +GenericChildrenMigrationContext::GenericChildrenMigrationContext() +{ + impl = new GenericChildrenMigrationContextImpl(); +} + +GenericChildrenMigrationContext::~GenericChildrenMigrationContext() +{ + delete impl; +} + +IRCloneEnv* GenericChildrenMigrationContext::getCloneEnv() +{ + return &impl->cloneEnv; +} + +void GenericChildrenMigrationContext::init(IRGeneric* genericSrc, IRGeneric* genericDst, IRInst* insertBefore) +{ + impl->init(genericSrc, genericDst, insertBefore); +} + +IRInst* GenericChildrenMigrationContext::deduplicate(IRInst* value) +{ + return impl->deduplicate(value); +} + +IRInst* GenericChildrenMigrationContext::cloneInst(IRBuilder* builder, IRInst* src) +{ + return impl->cloneInst(builder, src); +} + } diff --git a/source/slang/slang-ir-util.h b/source/slang/slang-ir-util.h index 4885dcd96..92446138f 100644 --- a/source/slang/slang-ir-util.h +++ b/source/slang/slang-ir-util.h @@ -8,6 +8,29 @@ namespace Slang { +struct GenericChildrenMigrationContextImpl; +struct IRCloneEnv; + +// A helper class to clone children insts to a different generic parent that has equivalent set of +// generic parameters. The clone will take care of substitution of equivalent generic parameters and +// intermediate values between the two generic parents. +struct GenericChildrenMigrationContext : public RefObject +{ +private: + GenericChildrenMigrationContextImpl* impl; + +public: + IRCloneEnv* getCloneEnv(); + + GenericChildrenMigrationContext(); + ~GenericChildrenMigrationContext(); + + void init(IRGeneric* genericSrc, IRGeneric* genericDst, IRInst* insertBefore); + + IRInst* deduplicate(IRInst* value); + + IRInst* cloneInst(IRBuilder* builder, IRInst* src); +}; bool isPtrToClassType(IRInst* type); |
