From 94d696801e8b313267e518cb16949d0ec122d46f Mon Sep 17 00:00:00 2001 From: Sai Praveen Bangaru <31557731+saipraveenb25@users.noreply.github.com> Date: Sun, 23 Apr 2023 17:12:14 -0400 Subject: Add support for `kIROp_MakeExistential` (#2832) --- source/slang/slang-ir-autodiff-fwd.cpp | 42 +++++++++++++++++++++- source/slang/slang-ir-autodiff-fwd.h | 2 ++ .../slang/slang-ir-autodiff-transcriber-base.cpp | 42 +++++++++++++++------- source/slang/slang-ir-autodiff-transcriber-base.h | 3 +- 4 files changed, 75 insertions(+), 14 deletions(-) (limited to 'source') diff --git a/source/slang/slang-ir-autodiff-fwd.cpp b/source/slang/slang-ir-autodiff-fwd.cpp index 6025e1ccd..45857ca45 100644 --- a/source/slang/slang-ir-autodiff-fwd.cpp +++ b/source/slang/slang-ir-autodiff-fwd.cpp @@ -1341,6 +1341,44 @@ InstPair ForwardDiffTranscriber::transcribeSingleOperandInst(IRBuilder* builder, return InstPair(primalResult, diffResult); } +InstPair ForwardDiffTranscriber::transcribeMakeExistential(IRBuilder* builder, IRMakeExistential* origMakeExistential) +{ + auto origBase = origMakeExistential->getWrappedValue(); + auto origWitnessTable = origMakeExistential->getWitnessTable(); + + auto primalBase = findOrTranscribePrimalInst(builder, origBase); + auto primalWitnessTable = findOrTranscribePrimalInst(builder, origWitnessTable); + auto primalType = (IRType*)findOrTranscribePrimalInst(builder, origMakeExistential->getDataType()); + + IRInst* primalResult = builder->emitMakeExistential( + primalType, + primalBase, + primalWitnessTable); + + IRInst* diffResult = nullptr; + + auto primalInterfaceType = as(unwrapAttributedType(origMakeExistential->getDataType())); + SLANG_RELEASE_ASSERT(primalInterfaceType); + + // If the interface type of the existential is differentiable, we emit a make existential + // of IDifferentiable interface type and the witness table of the original type's conformance + // to IDifferentiable. + // + if (auto differentialWitnessTable = tryExtractConformanceFromInterfaceType( + builder, primalInterfaceType, (IRWitnessTable*)primalWitnessTable)) + { + if (auto diffBase = findOrTranscribeDiffInst(builder, origBase)) + { + diffResult = builder->emitMakeExistential( + autoDiffSharedContext->differentiableInterfaceType, + diffBase, + differentialWitnessTable); + } + } + + return InstPair(primalResult, diffResult); +} + InstPair ForwardDiffTranscriber::transcribeWrapExistential(IRBuilder* builder, IRInst* origInst) { auto primalType = (IRType*)findOrTranscribePrimalInst(builder, origInst->getDataType()); @@ -1753,8 +1791,10 @@ InstPair ForwardDiffTranscriber::transcribeInstImpl(IRBuilder* builder, IRInst* return transcribeDifferentialPairGetElement(builder, origInst); case kIROp_ExtractExistentialValue: - case kIROp_MakeExistential: return transcribeSingleOperandInst(builder, origInst); + + case kIROp_MakeExistential: + return transcribeMakeExistential(builder, as(origInst)); case kIROp_ExtractExistentialType: { diff --git a/source/slang/slang-ir-autodiff-fwd.h b/source/slang/slang-ir-autodiff-fwd.h index a9193acbe..91193edc1 100644 --- a/source/slang/slang-ir-autodiff-fwd.h +++ b/source/slang/slang-ir-autodiff-fwd.h @@ -78,6 +78,8 @@ struct ForwardDiffTranscriber : AutoDiffTranscriberBase InstPair transcribeMakeDifferentialPair(IRBuilder* builder, IRMakeDifferentialPairUserCode* origInst); + InstPair transcribeMakeExistential(IRBuilder* builder, IRMakeExistential* origMakeExistential); + InstPair transcribeDifferentialPairGetElement(IRBuilder* builder, IRInst* origInst); InstPair transcribeSingleOperandInst(IRBuilder* builder, IRInst* origInst); diff --git a/source/slang/slang-ir-autodiff-transcriber-base.cpp b/source/slang/slang-ir-autodiff-transcriber-base.cpp index 81c0ab235..98f3aebfa 100644 --- a/source/slang/slang-ir-autodiff-transcriber-base.cpp +++ b/source/slang/slang-ir-autodiff-transcriber-base.cpp @@ -524,10 +524,32 @@ List AutoDiffTranscriberBase::findDifferentiableIn return currentPath; } -InstPair AutoDiffTranscriberBase::transcribeExtractExistentialWitnessTable(IRBuilder* builder, IRInst* origInst) +IRInst* AutoDiffTranscriberBase::tryExtractConformanceFromInterfaceType( + IRBuilder* builder, + IRInterfaceType* interfaceType, + IRWitnessTable* witnessTable) { - IRInst* witnessTable = nullptr; + SLANG_RELEASE_ASSERT(interfaceType); + + List lookupKeyPath = findDifferentiableInterfaceLookupPath( + autoDiffSharedContext->differentiableInterfaceType, interfaceType); + + IRInst* differentialTypeWitness = witnessTable; + if (lookupKeyPath.getCount()) + { + // `interfaceType` does conform to `IDifferentiable`. + for (auto node : lookupKeyPath) + { + differentialTypeWitness = builder->emitLookupInterfaceMethodInst((IRType*)node->getRequirementVal(), differentialTypeWitness, node->getRequirementKey()); + } + return differentialTypeWitness; + } + return nullptr; +} + +InstPair AutoDiffTranscriberBase::transcribeExtractExistentialWitnessTable(IRBuilder* builder, IRInst* origInst) +{ IRInst* origBase = origInst->getOperand(0); auto primalBase = findOrTranscribePrimalInst(builder, origBase); auto primalType = (IRType*)findOrTranscribePrimalInst(builder, origInst->getDataType()); @@ -541,21 +563,17 @@ InstPair AutoDiffTranscriberBase::transcribeExtractExistentialWitnessTable(IRBui // Search for IDifferentiable conformance. auto interfaceType = as( unwrapAttributedType(cast(origInst->getDataType())->getConformanceType())); + if (!interfaceType) return InstPair(primalResult, nullptr); - List lookupKeyPath = findDifferentiableInterfaceLookupPath( - autoDiffSharedContext->differentiableInterfaceType, interfaceType); - - if (lookupKeyPath.getCount()) + + if (auto differentialWitnessTable = tryExtractConformanceFromInterfaceType( + builder, interfaceType, (IRWitnessTable*)primalResult)) { // `interfaceType` does conform to `IDifferentiable`. - witnessTable = primalResult; - for (auto node : lookupKeyPath) - { - witnessTable = builder->emitLookupInterfaceMethodInst((IRType*)node->getRequirementVal(), witnessTable, node->getRequirementKey()); - } - return InstPair(primalResult, witnessTable); + return InstPair(primalResult, differentialWitnessTable); } + return InstPair(primalResult, nullptr); } diff --git a/source/slang/slang-ir-autodiff-transcriber-base.h b/source/slang/slang-ir-autodiff-transcriber-base.h index 86af2fb8e..d5ad29610 100644 --- a/source/slang/slang-ir-autodiff-transcriber-base.h +++ b/source/slang/slang-ir-autodiff-transcriber-base.h @@ -79,7 +79,6 @@ struct AutoDiffTranscriberBase return lookupPrimalInst(builder->getInsertLoc().getParent(), origInst, defaultInst); } - bool hasPrimalInst(IRInst* currentParent, IRInst* origInst); IRInst* findOrTranscribeDiffInst(IRBuilder* builder, IRInst* origInst); @@ -93,6 +92,8 @@ struct AutoDiffTranscriberBase InstPair transcribeExtractExistentialWitnessTable(IRBuilder* builder, IRInst* origInst); + IRInst* tryExtractConformanceFromInterfaceType(IRBuilder* builder, IRInterfaceType* type, IRWitnessTable* WitnessTable); + void maybeMigrateDifferentiableDictionaryFromDerivativeFunc(IRBuilder* builder, IRInst* origFunc); // Get or construct `:IDifferentiable` conformance for a DifferentiablePair. -- cgit v1.2.3