diff options
| author | Sai Praveen Bangaru <31557731+saipraveenb25@users.noreply.github.com> | 2023-04-23 17:12:14 -0400 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2023-04-23 14:12:14 -0700 |
| commit | 94d696801e8b313267e518cb16949d0ec122d46f (patch) | |
| tree | ad9f9628882792dc7f5f0fd4f987a5810ad0040e | |
| parent | e8673a535e91af8fd8d31d6845af1c792f554f05 (diff) | |
Add support for `kIROp_MakeExistential` (#2832)
| -rw-r--r-- | source/slang/slang-ir-autodiff-fwd.cpp | 42 | ||||
| -rw-r--r-- | source/slang/slang-ir-autodiff-fwd.h | 2 | ||||
| -rw-r--r-- | source/slang/slang-ir-autodiff-transcriber-base.cpp | 42 | ||||
| -rw-r--r-- | source/slang/slang-ir-autodiff-transcriber-base.h | 3 | ||||
| -rw-r--r-- | tests/autodiff/dynamic-object-bwd-diff.slang | 76 | ||||
| -rw-r--r-- | tests/autodiff/dynamic-object-bwd-diff.slang.expected.txt | 6 |
6 files changed, 157 insertions, 14 deletions
diff --git a/source/slang/slang-ir-autodiff-fwd.cpp b/source/slang/slang-ir-autodiff-fwd.cpp index 6025e1ccd..45857ca45 100644 --- a/source/slang/slang-ir-autodiff-fwd.cpp +++ b/source/slang/slang-ir-autodiff-fwd.cpp @@ -1341,6 +1341,44 @@ InstPair ForwardDiffTranscriber::transcribeSingleOperandInst(IRBuilder* builder, return InstPair(primalResult, diffResult); } +InstPair ForwardDiffTranscriber::transcribeMakeExistential(IRBuilder* builder, IRMakeExistential* origMakeExistential) +{ + auto origBase = origMakeExistential->getWrappedValue(); + auto origWitnessTable = origMakeExistential->getWitnessTable(); + + auto primalBase = findOrTranscribePrimalInst(builder, origBase); + auto primalWitnessTable = findOrTranscribePrimalInst(builder, origWitnessTable); + auto primalType = (IRType*)findOrTranscribePrimalInst(builder, origMakeExistential->getDataType()); + + IRInst* primalResult = builder->emitMakeExistential( + primalType, + primalBase, + primalWitnessTable); + + IRInst* diffResult = nullptr; + + auto primalInterfaceType = as<IRInterfaceType>(unwrapAttributedType(origMakeExistential->getDataType())); + SLANG_RELEASE_ASSERT(primalInterfaceType); + + // If the interface type of the existential is differentiable, we emit a make existential + // of IDifferentiable interface type and the witness table of the original type's conformance + // to IDifferentiable. + // + if (auto differentialWitnessTable = tryExtractConformanceFromInterfaceType( + builder, primalInterfaceType, (IRWitnessTable*)primalWitnessTable)) + { + if (auto diffBase = findOrTranscribeDiffInst(builder, origBase)) + { + diffResult = builder->emitMakeExistential( + autoDiffSharedContext->differentiableInterfaceType, + diffBase, + differentialWitnessTable); + } + } + + return InstPair(primalResult, diffResult); +} + InstPair ForwardDiffTranscriber::transcribeWrapExistential(IRBuilder* builder, IRInst* origInst) { auto primalType = (IRType*)findOrTranscribePrimalInst(builder, origInst->getDataType()); @@ -1753,8 +1791,10 @@ InstPair ForwardDiffTranscriber::transcribeInstImpl(IRBuilder* builder, IRInst* return transcribeDifferentialPairGetElement(builder, origInst); case kIROp_ExtractExistentialValue: - case kIROp_MakeExistential: return transcribeSingleOperandInst(builder, origInst); + + case kIROp_MakeExistential: + return transcribeMakeExistential(builder, as<IRMakeExistential>(origInst)); case kIROp_ExtractExistentialType: { diff --git a/source/slang/slang-ir-autodiff-fwd.h b/source/slang/slang-ir-autodiff-fwd.h index a9193acbe..91193edc1 100644 --- a/source/slang/slang-ir-autodiff-fwd.h +++ b/source/slang/slang-ir-autodiff-fwd.h @@ -78,6 +78,8 @@ struct ForwardDiffTranscriber : AutoDiffTranscriberBase InstPair transcribeMakeDifferentialPair(IRBuilder* builder, IRMakeDifferentialPairUserCode* origInst); + InstPair transcribeMakeExistential(IRBuilder* builder, IRMakeExistential* origMakeExistential); + InstPair transcribeDifferentialPairGetElement(IRBuilder* builder, IRInst* origInst); InstPair transcribeSingleOperandInst(IRBuilder* builder, IRInst* origInst); diff --git a/source/slang/slang-ir-autodiff-transcriber-base.cpp b/source/slang/slang-ir-autodiff-transcriber-base.cpp index 81c0ab235..98f3aebfa 100644 --- a/source/slang/slang-ir-autodiff-transcriber-base.cpp +++ b/source/slang/slang-ir-autodiff-transcriber-base.cpp @@ -524,10 +524,32 @@ List<IRInterfaceRequirementEntry*> AutoDiffTranscriberBase::findDifferentiableIn return currentPath; } -InstPair AutoDiffTranscriberBase::transcribeExtractExistentialWitnessTable(IRBuilder* builder, IRInst* origInst) +IRInst* AutoDiffTranscriberBase::tryExtractConformanceFromInterfaceType( + IRBuilder* builder, + IRInterfaceType* interfaceType, + IRWitnessTable* witnessTable) { - IRInst* witnessTable = nullptr; + SLANG_RELEASE_ASSERT(interfaceType); + + List<IRInterfaceRequirementEntry*> lookupKeyPath = findDifferentiableInterfaceLookupPath( + autoDiffSharedContext->differentiableInterfaceType, interfaceType); + + IRInst* differentialTypeWitness = witnessTable; + if (lookupKeyPath.getCount()) + { + // `interfaceType` does conform to `IDifferentiable`. + for (auto node : lookupKeyPath) + { + differentialTypeWitness = builder->emitLookupInterfaceMethodInst((IRType*)node->getRequirementVal(), differentialTypeWitness, node->getRequirementKey()); + } + return differentialTypeWitness; + } + return nullptr; +} + +InstPair AutoDiffTranscriberBase::transcribeExtractExistentialWitnessTable(IRBuilder* builder, IRInst* origInst) +{ IRInst* origBase = origInst->getOperand(0); auto primalBase = findOrTranscribePrimalInst(builder, origBase); auto primalType = (IRType*)findOrTranscribePrimalInst(builder, origInst->getDataType()); @@ -541,21 +563,17 @@ InstPair AutoDiffTranscriberBase::transcribeExtractExistentialWitnessTable(IRBui // Search for IDifferentiable conformance. auto interfaceType = as<IRInterfaceType>( unwrapAttributedType(cast<IRWitnessTableType>(origInst->getDataType())->getConformanceType())); + if (!interfaceType) return InstPair(primalResult, nullptr); - List<IRInterfaceRequirementEntry*> lookupKeyPath = findDifferentiableInterfaceLookupPath( - autoDiffSharedContext->differentiableInterfaceType, interfaceType); - - if (lookupKeyPath.getCount()) + + if (auto differentialWitnessTable = tryExtractConformanceFromInterfaceType( + builder, interfaceType, (IRWitnessTable*)primalResult)) { // `interfaceType` does conform to `IDifferentiable`. - witnessTable = primalResult; - for (auto node : lookupKeyPath) - { - witnessTable = builder->emitLookupInterfaceMethodInst((IRType*)node->getRequirementVal(), witnessTable, node->getRequirementKey()); - } - return InstPair(primalResult, witnessTable); + return InstPair(primalResult, differentialWitnessTable); } + return InstPair(primalResult, nullptr); } diff --git a/source/slang/slang-ir-autodiff-transcriber-base.h b/source/slang/slang-ir-autodiff-transcriber-base.h index 86af2fb8e..d5ad29610 100644 --- a/source/slang/slang-ir-autodiff-transcriber-base.h +++ b/source/slang/slang-ir-autodiff-transcriber-base.h @@ -79,7 +79,6 @@ struct AutoDiffTranscriberBase return lookupPrimalInst(builder->getInsertLoc().getParent(), origInst, defaultInst); } - bool hasPrimalInst(IRInst* currentParent, IRInst* origInst); IRInst* findOrTranscribeDiffInst(IRBuilder* builder, IRInst* origInst); @@ -93,6 +92,8 @@ struct AutoDiffTranscriberBase InstPair transcribeExtractExistentialWitnessTable(IRBuilder* builder, IRInst* origInst); + IRInst* tryExtractConformanceFromInterfaceType(IRBuilder* builder, IRInterfaceType* type, IRWitnessTable* WitnessTable); + void maybeMigrateDifferentiableDictionaryFromDerivativeFunc(IRBuilder* builder, IRInst* origFunc); // Get or construct `:IDifferentiable` conformance for a DifferentiablePair. diff --git a/tests/autodiff/dynamic-object-bwd-diff.slang b/tests/autodiff/dynamic-object-bwd-diff.slang new file mode 100644 index 000000000..a10c48f9b --- /dev/null +++ b/tests/autodiff/dynamic-object-bwd-diff.slang @@ -0,0 +1,76 @@ +// Test calling backward differentiable function through dynamic dispatch. + +//TEST(compute):COMPARE_COMPUTE_EX:-slang -compute -shaderobj -output-using-type +//TEST(compute, vulkan):COMPARE_COMPUTE_EX:-vk -compute -shaderobj -output-using-type + +//TEST_INPUT:ubuffer(data=[0 0 0 0 0], stride=4):out,name=outputBuffer +RWStructuredBuffer<float> outputBuffer; + +[anyValueSize(16)] +interface IInterface +{ + [BackwardDifferentiable] + float calc(IInterface2 i2, float x); +} + +interface IInterface2 +{ + float innerCalc(float x); +} + +struct C : IInterface2 +{ + float innerCalc(float x) { return 2 * x; } +} + +struct A : IInterface +{ + float a; + [BackwardDifferentiable] + float calc(IInterface2 i2, float x) + { + float b = no_diff(i2.innerCalc(x)); + return a*b*x; + } +}; + +struct B : IInterface +{ + float a; + [BackwardDifferentiable] + float calc(IInterface2 i2, float x) + { + float b = no_diff(i2.innerCalc(x)); + return a*b*x*x; + } +}; + +[BackwardDifferentiable] +float run(int id, float x, no_diff float y) +{ + IInterface obj = createDynamicObject<IInterface>(id, y); + C c = {}; + return obj.calc(c, x); +} + +//TEST_INPUT: type_conformance A:IInterface = 0 +//TEST_INPUT: type_conformance B:IInterface = 1 +//TEST_INPUT: type_conformance C:IInterface2 = 0 + +[numthreads(1, 1, 1)] +void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID) +{ + { + var p = diffPair(3.0); + + __bwd_diff(run)(0, p, 0.5, 1.0f); + outputBuffer[0] = p.d; // A.calc, expect 3 + } + + { + var p = diffPair(3.0); + + __bwd_diff(run)(1, p, 1.5, 1.0f); + outputBuffer[1] = p.d; // B.calc, expect 40.5 + } +} diff --git a/tests/autodiff/dynamic-object-bwd-diff.slang.expected.txt b/tests/autodiff/dynamic-object-bwd-diff.slang.expected.txt new file mode 100644 index 000000000..7c6952bfa --- /dev/null +++ b/tests/autodiff/dynamic-object-bwd-diff.slang.expected.txt @@ -0,0 +1,6 @@ +type: float +3.000000 +54.000000 +0.000000 +0.000000 +0.000000 |
