diff options
| author | Sai Praveen Bangaru <31557731+saipraveenb25@users.noreply.github.com> | 2023-04-23 17:12:14 -0400 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2023-04-23 14:12:14 -0700 |
| commit | 94d696801e8b313267e518cb16949d0ec122d46f (patch) | |
| tree | ad9f9628882792dc7f5f0fd4f987a5810ad0040e /source | |
| parent | e8673a535e91af8fd8d31d6845af1c792f554f05 (diff) | |
Add support for `kIROp_MakeExistential` (#2832)
Diffstat (limited to 'source')
| -rw-r--r-- | source/slang/slang-ir-autodiff-fwd.cpp | 42 | ||||
| -rw-r--r-- | source/slang/slang-ir-autodiff-fwd.h | 2 | ||||
| -rw-r--r-- | source/slang/slang-ir-autodiff-transcriber-base.cpp | 42 | ||||
| -rw-r--r-- | source/slang/slang-ir-autodiff-transcriber-base.h | 3 |
4 files changed, 75 insertions, 14 deletions
diff --git a/source/slang/slang-ir-autodiff-fwd.cpp b/source/slang/slang-ir-autodiff-fwd.cpp index 6025e1ccd..45857ca45 100644 --- a/source/slang/slang-ir-autodiff-fwd.cpp +++ b/source/slang/slang-ir-autodiff-fwd.cpp @@ -1341,6 +1341,44 @@ InstPair ForwardDiffTranscriber::transcribeSingleOperandInst(IRBuilder* builder, return InstPair(primalResult, diffResult); } +InstPair ForwardDiffTranscriber::transcribeMakeExistential(IRBuilder* builder, IRMakeExistential* origMakeExistential) +{ + auto origBase = origMakeExistential->getWrappedValue(); + auto origWitnessTable = origMakeExistential->getWitnessTable(); + + auto primalBase = findOrTranscribePrimalInst(builder, origBase); + auto primalWitnessTable = findOrTranscribePrimalInst(builder, origWitnessTable); + auto primalType = (IRType*)findOrTranscribePrimalInst(builder, origMakeExistential->getDataType()); + + IRInst* primalResult = builder->emitMakeExistential( + primalType, + primalBase, + primalWitnessTable); + + IRInst* diffResult = nullptr; + + auto primalInterfaceType = as<IRInterfaceType>(unwrapAttributedType(origMakeExistential->getDataType())); + SLANG_RELEASE_ASSERT(primalInterfaceType); + + // If the interface type of the existential is differentiable, we emit a make existential + // of IDifferentiable interface type and the witness table of the original type's conformance + // to IDifferentiable. + // + if (auto differentialWitnessTable = tryExtractConformanceFromInterfaceType( + builder, primalInterfaceType, (IRWitnessTable*)primalWitnessTable)) + { + if (auto diffBase = findOrTranscribeDiffInst(builder, origBase)) + { + diffResult = builder->emitMakeExistential( + autoDiffSharedContext->differentiableInterfaceType, + diffBase, + differentialWitnessTable); + } + } + + return InstPair(primalResult, diffResult); +} + InstPair ForwardDiffTranscriber::transcribeWrapExistential(IRBuilder* builder, IRInst* origInst) { auto primalType = (IRType*)findOrTranscribePrimalInst(builder, origInst->getDataType()); @@ -1753,8 +1791,10 @@ InstPair ForwardDiffTranscriber::transcribeInstImpl(IRBuilder* builder, IRInst* return transcribeDifferentialPairGetElement(builder, origInst); case kIROp_ExtractExistentialValue: - case kIROp_MakeExistential: return transcribeSingleOperandInst(builder, origInst); + + case kIROp_MakeExistential: + return transcribeMakeExistential(builder, as<IRMakeExistential>(origInst)); case kIROp_ExtractExistentialType: { diff --git a/source/slang/slang-ir-autodiff-fwd.h b/source/slang/slang-ir-autodiff-fwd.h index a9193acbe..91193edc1 100644 --- a/source/slang/slang-ir-autodiff-fwd.h +++ b/source/slang/slang-ir-autodiff-fwd.h @@ -78,6 +78,8 @@ struct ForwardDiffTranscriber : AutoDiffTranscriberBase InstPair transcribeMakeDifferentialPair(IRBuilder* builder, IRMakeDifferentialPairUserCode* origInst); + InstPair transcribeMakeExistential(IRBuilder* builder, IRMakeExistential* origMakeExistential); + InstPair transcribeDifferentialPairGetElement(IRBuilder* builder, IRInst* origInst); InstPair transcribeSingleOperandInst(IRBuilder* builder, IRInst* origInst); diff --git a/source/slang/slang-ir-autodiff-transcriber-base.cpp b/source/slang/slang-ir-autodiff-transcriber-base.cpp index 81c0ab235..98f3aebfa 100644 --- a/source/slang/slang-ir-autodiff-transcriber-base.cpp +++ b/source/slang/slang-ir-autodiff-transcriber-base.cpp @@ -524,10 +524,32 @@ List<IRInterfaceRequirementEntry*> AutoDiffTranscriberBase::findDifferentiableIn return currentPath; } -InstPair AutoDiffTranscriberBase::transcribeExtractExistentialWitnessTable(IRBuilder* builder, IRInst* origInst) +IRInst* AutoDiffTranscriberBase::tryExtractConformanceFromInterfaceType( + IRBuilder* builder, + IRInterfaceType* interfaceType, + IRWitnessTable* witnessTable) { - IRInst* witnessTable = nullptr; + SLANG_RELEASE_ASSERT(interfaceType); + + List<IRInterfaceRequirementEntry*> lookupKeyPath = findDifferentiableInterfaceLookupPath( + autoDiffSharedContext->differentiableInterfaceType, interfaceType); + + IRInst* differentialTypeWitness = witnessTable; + if (lookupKeyPath.getCount()) + { + // `interfaceType` does conform to `IDifferentiable`. + for (auto node : lookupKeyPath) + { + differentialTypeWitness = builder->emitLookupInterfaceMethodInst((IRType*)node->getRequirementVal(), differentialTypeWitness, node->getRequirementKey()); + } + return differentialTypeWitness; + } + return nullptr; +} + +InstPair AutoDiffTranscriberBase::transcribeExtractExistentialWitnessTable(IRBuilder* builder, IRInst* origInst) +{ IRInst* origBase = origInst->getOperand(0); auto primalBase = findOrTranscribePrimalInst(builder, origBase); auto primalType = (IRType*)findOrTranscribePrimalInst(builder, origInst->getDataType()); @@ -541,21 +563,17 @@ InstPair AutoDiffTranscriberBase::transcribeExtractExistentialWitnessTable(IRBui // Search for IDifferentiable conformance. auto interfaceType = as<IRInterfaceType>( unwrapAttributedType(cast<IRWitnessTableType>(origInst->getDataType())->getConformanceType())); + if (!interfaceType) return InstPair(primalResult, nullptr); - List<IRInterfaceRequirementEntry*> lookupKeyPath = findDifferentiableInterfaceLookupPath( - autoDiffSharedContext->differentiableInterfaceType, interfaceType); - - if (lookupKeyPath.getCount()) + + if (auto differentialWitnessTable = tryExtractConformanceFromInterfaceType( + builder, interfaceType, (IRWitnessTable*)primalResult)) { // `interfaceType` does conform to `IDifferentiable`. - witnessTable = primalResult; - for (auto node : lookupKeyPath) - { - witnessTable = builder->emitLookupInterfaceMethodInst((IRType*)node->getRequirementVal(), witnessTable, node->getRequirementKey()); - } - return InstPair(primalResult, witnessTable); + return InstPair(primalResult, differentialWitnessTable); } + return InstPair(primalResult, nullptr); } diff --git a/source/slang/slang-ir-autodiff-transcriber-base.h b/source/slang/slang-ir-autodiff-transcriber-base.h index 86af2fb8e..d5ad29610 100644 --- a/source/slang/slang-ir-autodiff-transcriber-base.h +++ b/source/slang/slang-ir-autodiff-transcriber-base.h @@ -79,7 +79,6 @@ struct AutoDiffTranscriberBase return lookupPrimalInst(builder->getInsertLoc().getParent(), origInst, defaultInst); } - bool hasPrimalInst(IRInst* currentParent, IRInst* origInst); IRInst* findOrTranscribeDiffInst(IRBuilder* builder, IRInst* origInst); @@ -93,6 +92,8 @@ struct AutoDiffTranscriberBase InstPair transcribeExtractExistentialWitnessTable(IRBuilder* builder, IRInst* origInst); + IRInst* tryExtractConformanceFromInterfaceType(IRBuilder* builder, IRInterfaceType* type, IRWitnessTable* WitnessTable); + void maybeMigrateDifferentiableDictionaryFromDerivativeFunc(IRBuilder* builder, IRInst* origFunc); // Get or construct `:IDifferentiable` conformance for a DifferentiablePair. |
