diff options
| -rw-r--r-- | source/slang/slang-check-decl.cpp | 4 | ||||
| -rw-r--r-- | source/slang/slang-diagnostic-defs.h | 2 | ||||
| -rw-r--r-- | source/slang/slang-ir-addr-inst-elimination.cpp | 21 | ||||
| -rw-r--r-- | source/slang/slang-ir-autodiff-fwd.cpp | 40 | ||||
| -rw-r--r-- | source/slang/slang-ir-autodiff-transcriber-base.cpp | 8 | ||||
| -rw-r--r-- | source/slang/slang-ir-autodiff.cpp | 1 | ||||
| -rw-r--r-- | source/slang/slang-lower-to-ir.cpp | 11 | ||||
| -rw-r--r-- | tests/autodiff/dynamic-dispatch-generic-member.slang | 49 | ||||
| -rw-r--r-- | tests/autodiff/dynamic-dispatch-generic-member.slang.expected.txt | 5 | ||||
| -rw-r--r-- | tests/autodiff/member-func-custom-derivative-2.slang | 49 | ||||
| -rw-r--r-- | tests/autodiff/member-func-custom-derivative-2.slang.expected.txt | 2 | ||||
| -rw-r--r-- | tests/autodiff/member-func-custom-derivative.slang | 36 | ||||
| -rw-r--r-- | tests/autodiff/member-func-custom-derivative.slang.expected.txt | 2 |
13 files changed, 197 insertions, 33 deletions
diff --git a/source/slang/slang-check-decl.cpp b/source/slang/slang-check-decl.cpp index 6083ce9c0..6a32f59d3 100644 --- a/source/slang/slang-check-decl.cpp +++ b/source/slang/slang-check-decl.cpp @@ -5638,6 +5638,10 @@ namespace Slang bool isDiffFunc = false; if (decl->hasModifier<ForwardDifferentiableAttribute>() || decl->hasModifier<BackwardDifferentiableAttribute>()) { + if (GetOuterGeneric(decl)) + { + getSink()->diagnose(decl, Diagnostics::differentiableGenericInterfaceMethodNotSupported); + } auto reqDecl = m_astBuilder->create<ForwardDerivativeRequirementDecl>(); cloneModifiers(reqDecl, decl); auto declRef = DeclRef<CallableDecl>(decl, createDefaultSubstitutions(m_astBuilder, this, decl)); diff --git a/source/slang/slang-diagnostic-defs.h b/source/slang/slang-diagnostic-defs.h index 2401b6e58..e3e9cfc44 100644 --- a/source/slang/slang-diagnostic-defs.h +++ b/source/slang/slang-diagnostic-defs.h @@ -352,6 +352,8 @@ DIAGNOSTIC(31145, Error, invalidCustomDerivative, "invalid custom derivative att DIAGNOSTIC(31146, Error, declAlreadyHasAttribute, "'$0' already has attribute '[$1]'.") DIAGNOSTIC(31147, Error, cannotResolveOriginalFunctionForDerivative, "cannot resolve the original function for the the custom derivative.") +DIAGNOSTIC(31148, Error, differentiableGenericInterfaceMethodNotSupported, "`[ForwardDifferentiable] and [BackwardDifferentiable] are not supported on generic interface requirements.") + // Enums DIAGNOSTIC(32000, Error, invalidEnumTagType, "invalid tag type for 'enum': '$0'") diff --git a/source/slang/slang-ir-addr-inst-elimination.cpp b/source/slang/slang-ir-addr-inst-elimination.cpp index 6715f2c6a..16bd67f66 100644 --- a/source/slang/slang-ir-addr-inst-elimination.cpp +++ b/source/slang/slang-ir-addr-inst-elimination.cpp @@ -99,22 +99,11 @@ struct AddressInstEliminationContext IRBuilder builder(module); builder.setInsertBefore(call); auto tempVar = builder.emitVar(cast<IRPtrTypeBase>(addr->getFullType())->getValueType()); - auto callee = getResolvedInstForDecorations(call->getCallee()); - auto funcType = as<IRFuncType>(callee->getFullType()); - SLANG_RELEASE_ASSERT(funcType); - UInt paramIndex = (UInt)(use - call->getOperands() - 1); - SLANG_RELEASE_ASSERT(call->getArg(paramIndex) == addr); - if (!as<IROutType>(funcType->getParamType(paramIndex))) - { - builder.emitStore(tempVar, getValue(builder, addr)); - } - else - { - builder.emitStore( - tempVar, - builder.emitDefaultConstruct( - as<IRPtrTypeBase>(tempVar->getDataType())->getValueType())); - } + + // Store the initial value of the mutable argument into temp var. + // If this is an `out` var, the initial value will be undefined, + // which will get cleaned up later into a `defaultConstruct`. + builder.emitStore(tempVar, getValue(builder, addr)); builder.setInsertAfter(call); storeValue(builder, addr, builder.emitLoad(tempVar)); use->set(tempVar); diff --git a/source/slang/slang-ir-autodiff-fwd.cpp b/source/slang/slang-ir-autodiff-fwd.cpp index 3f31f1463..869f8920c 100644 --- a/source/slang/slang-ir-autodiff-fwd.cpp +++ b/source/slang/slang-ir-autodiff-fwd.cpp @@ -510,7 +510,7 @@ IRInst* tryFindPrimalSubstitute(IRBuilder* builder, IRInst* callee) { auto innerGen = as<IRGeneric>(specialize->getBase()); if (!innerGen) - return nullptr; + return callee; auto innerFunc = findGenericReturnVal(innerGen); if (auto decor = innerFunc->findDecoration<IRPrimalSubstituteDecoration>()) { @@ -553,7 +553,7 @@ InstPair ForwardDiffTranscriber::transcribeCall(IRBuilder* builder, IRCall* orig return InstPair(nullptr, nullptr); } - auto primalCallee = lookupPrimalInst(builder, origCallee, origCallee); + auto primalCallee = findOrTranscribePrimalInst(builder, origCallee); auto substPrimalCallee = tryFindPrimalSubstitute(builder, primalCallee); IRInst* diffCallee = nullptr; @@ -563,7 +563,7 @@ InstPair ForwardDiffTranscriber::transcribeCall(IRBuilder* builder, IRCall* orig } else { - instMapD.TryGetValue(substPrimalCallee, diffCallee); + diffCallee = findOrTranscribeDiffInst(builder, origCallee); primalCallee = substPrimalCallee; } @@ -904,17 +904,32 @@ InstPair ForwardDiffTranscriber::transcribeSpecialize(IRBuilder* builder, IRSpec IRInst* diffBase = nullptr; if (instMapD.TryGetValue(origSpecialize->getBase(), diffBase)) { - List<IRInst*> args; - for (UInt i = 0; i < primalSpecialize->getArgCount(); i++) + if (diffBase) { - args.add(primalSpecialize->getArg(i)); + List<IRInst*> args; + for (UInt i = 0; i < primalSpecialize->getArgCount(); i++) + { + args.add(primalSpecialize->getArg(i)); + } + auto diffSpecialize = builder->emitSpecializeInst( + builder->getTypeKind(), diffBase, args.getCount(), args.getBuffer()); + return InstPair(primalSpecialize, diffSpecialize); + } + else + { + return InstPair(primalSpecialize, nullptr); } - auto diffSpecialize = builder->emitSpecializeInst( - builder->getTypeKind(), diffBase, args.getCount(), args.getBuffer()); - return InstPair(primalSpecialize, diffSpecialize); } auto genericInnerVal = findInnerMostGenericReturnVal(as<IRGeneric>(origSpecialize->getBase())); + + // Right now we don't support transcribing a differentiable callee that is a specialize of a interface lookup + // (calling differentiable generic interface method). To support it, we need to recursively transcribe the + // specialization base here. + + if (!genericInnerVal) + return InstPair(primalSpecialize, nullptr); + // Look for an IRForwardDerivativeDecoration on the specialize inst. // (Normally, this would be on the inner IRFunc, but in this case only the JVP func // can be specialized, so we put a decoration on the IRSpecialize) @@ -963,10 +978,7 @@ InstPair ForwardDiffTranscriber::transcribeSpecialize(IRBuilder* builder, IRSpec builder->getTypeKind(), diffBase, args.getCount(), args.getBuffer()); return InstPair(primalSpecialize, diffSpecialize); } - else - { - return InstPair(primalSpecialize, nullptr); - } + return InstPair(primalSpecialize, nullptr); } InstPair ForwardDiffTranscriber::transcribeFieldExtract(IRBuilder* builder, IRInst* originalInst) @@ -1433,6 +1445,8 @@ IRFunc* ForwardDiffTranscriber::transcribeFuncHeaderImpl(IRBuilder* inBuilder, I IRFunc* primalFunc = origFunc; + maybeMigrateDifferentiableDictionaryFromDerivativeFunc(inBuilder, origFunc); + differentiableTypeConformanceContext.setFunc(origFunc); primalFunc = origFunc; diff --git a/source/slang/slang-ir-autodiff-transcriber-base.cpp b/source/slang/slang-ir-autodiff-transcriber-base.cpp index 552ac762c..9cbea7873 100644 --- a/source/slang/slang-ir-autodiff-transcriber-base.cpp +++ b/source/slang/slang-ir-autodiff-transcriber-base.cpp @@ -594,7 +594,11 @@ void AutoDiffTranscriberBase::maybeMigrateDifferentiableDictionaryFromDerivative } else { - cloneDecoration(udfDecor, origFunc); + auto udfDictDecor = derivative->findDecoration< IRDifferentiableTypeDictionaryDecoration>(); + if (udfDictDecor) + { + cloneDecoration(udfDictDecor, origFunc); + } } } @@ -977,6 +981,8 @@ InstPair AutoDiffTranscriberBase::transcribeGeneric(IRBuilder* inBuilder, IRGene if (auto innerFunc = as<IRFunc>(innerVal)) { maybeMigrateDifferentiableDictionaryFromDerivativeFunc(inBuilder, innerFunc); + if (!innerFunc->findDecoration<IRDifferentiableTypeDictionaryDecoration>()) + return InstPair(origGeneric, nullptr); differentiableTypeConformanceContext.setFunc(innerFunc); } else if (auto funcType = as<IRFuncType>(innerVal)) diff --git a/source/slang/slang-ir-autodiff.cpp b/source/slang/slang-ir-autodiff.cpp index f173aaa8b..1909f860c 100644 --- a/source/slang/slang-ir-autodiff.cpp +++ b/source/slang/slang-ir-autodiff.cpp @@ -368,6 +368,7 @@ void DifferentiableTypeConformanceContext::setFunc(IRGlobalValueWithCode* func) { parentFunc = func; + auto decor = func->findDecoration<IRDifferentiableTypeDictionaryDecoration>(); SLANG_RELEASE_ASSERT(decor); diff --git a/source/slang/slang-lower-to-ir.cpp b/source/slang/slang-lower-to-ir.cpp index f84f17886..9c27beb58 100644 --- a/source/slang/slang-lower-to-ir.cpp +++ b/source/slang/slang-lower-to-ir.cpp @@ -8484,10 +8484,15 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo> funcExpr = udAttr->funcExpr; else if (auto primalAttr = as<PrimalSubstituteAttribute>(modifier)) funcExpr = primalAttr->funcExpr; + DeclRefExpr* declRefExpr = as<DeclRefExpr>(funcExpr); + auto funcType = lowerType(subContext, funcExpr->type); + auto loweredVal = emitDeclRef( + subContext, + declRefExpr->declRef, + funcType); + + SLANG_RELEASE_ASSERT(loweredVal.flavor == LoweredValInfo::Flavor::Simple); - auto loweredVal = lowerRValueExpr(subContext, funcExpr); - - SLANG_ASSERT(loweredVal.flavor == LoweredValInfo::Flavor::Simple); IRInst* derivativeFunc = loweredVal.val; if (as<ForwardDerivativeAttribute>(modifier)) diff --git a/tests/autodiff/dynamic-dispatch-generic-member.slang b/tests/autodiff/dynamic-dispatch-generic-member.slang new file mode 100644 index 000000000..83c3aee7c --- /dev/null +++ b/tests/autodiff/dynamic-dispatch-generic-member.slang @@ -0,0 +1,49 @@ +// Test calling dynamic dispatched generic function from differentiable function. + +//TEST(compute):COMPARE_COMPUTE_EX:-slang -compute -shaderobj -output-using-type +//TEST(compute, vulkan):COMPARE_COMPUTE_EX:-vk -compute -shaderobj -output-using-type + +//TEST_INPUT:ubuffer(data=[0 0 0 0], stride=4):out,name=outputBuffer +RWStructuredBuffer<float> outputBuffer; + +interface IFoo +{ + float f(); +} + +interface IInterface +{ + float calc<T:IFoo>(T t, float x); +} + +struct A : IFoo +{ + float f() { return 1.0; } +}; + +struct B : IInterface +{ + float calc<T : IFoo>(T t, float x) + { + return t.f() * x; + } +}; + +[BackwardDifferentiable] +float test(IInterface obj, float x) +{ + A objA; + return no_diff(obj.calc(objA, x)) * x; +} + +//TEST_INPUT: type_conformance A:IFoo = 0 +//TEST_INPUT: type_conformance B:IInterface = 1 + +[numthreads(1, 1, 1)] +void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID) +{ + var obj = createDynamicObject<IInterface>(dispatchThreadID.x, 1); // B + var p = diffPair(3.0); + __bwd_diff(test)(obj, p, 1.0); + outputBuffer[0] = p.d; +} diff --git a/tests/autodiff/dynamic-dispatch-generic-member.slang.expected.txt b/tests/autodiff/dynamic-dispatch-generic-member.slang.expected.txt new file mode 100644 index 000000000..857cebc03 --- /dev/null +++ b/tests/autodiff/dynamic-dispatch-generic-member.slang.expected.txt @@ -0,0 +1,5 @@ +type: float +3.000000 +0.000000 +0.000000 +0.000000 diff --git a/tests/autodiff/member-func-custom-derivative-2.slang b/tests/autodiff/member-func-custom-derivative-2.slang new file mode 100644 index 000000000..329f3ade8 --- /dev/null +++ b/tests/autodiff/member-func-custom-derivative-2.slang @@ -0,0 +1,49 @@ +//TEST(compute):COMPARE_COMPUTE_EX:-slang -compute -shaderobj -output-using-type +//TEST(compute, vulkan):COMPARE_COMPUTE_EX:-vk -compute -shaderobj -output-using-type + +//TEST_INPUT:ubuffer(data=[0], stride=4):out,name=outputBuffer +RWStructuredBuffer<float> outputBuffer; + +interface IFoo +{ + [BackwardDifferentiable] + float3 test(float v, uint offset); +} +struct A : IFoo +{ + float x; + + float3 f(float v, uint offset) + { + return v * v; + } + + // Provide a backward diff, but leave out forward diff. + [BackwardDerivativeOf(f)] + [TreatAsDifferentiable] + void diff_f(inout DifferentialPair<float> v, uint offset, float3 dOut) + { + v = diffPair(v.p, 2 * v.p * dOut.x); + } + + [BackwardDifferentiable] + float3 test(float v, uint offset) + { + return f(v, 0); + } +} + +[BackwardDifferentiable] +float3 test(IFoo obj, float v) +{ + return obj.test(v, 0); +} + +[numthreads(1, 1, 1)] +void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID) +{ + A a = {0.0}; + var p = diffPair(3.0, 0.0); + let rs = __bwd_diff(test)(a, p, 1.0); + outputBuffer[0] = p.d; +} diff --git a/tests/autodiff/member-func-custom-derivative-2.slang.expected.txt b/tests/autodiff/member-func-custom-derivative-2.slang.expected.txt new file mode 100644 index 000000000..253df0793 --- /dev/null +++ b/tests/autodiff/member-func-custom-derivative-2.slang.expected.txt @@ -0,0 +1,2 @@ +type: float +6.0 diff --git a/tests/autodiff/member-func-custom-derivative.slang b/tests/autodiff/member-func-custom-derivative.slang new file mode 100644 index 000000000..3ec44e690 --- /dev/null +++ b/tests/autodiff/member-func-custom-derivative.slang @@ -0,0 +1,36 @@ +//TEST(compute):COMPARE_COMPUTE_EX:-slang -compute -shaderobj -output-using-type +//TEST(compute, vulkan):COMPARE_COMPUTE_EX:-vk -compute -shaderobj -output-using-type + +//TEST_INPUT:ubuffer(data=[0], stride=4):out,name=outputBuffer +RWStructuredBuffer<float> outputBuffer; + +struct A +{ + float x; + + [ForwardDerivative(diff_f)] + float f(float v) + { + return v * v; + } + + DifferentialPair<float> diff_f(DifferentialPair<float> v) + { + return diffPair(v.p * v.p, v.p * v.d * 2.0); + } +} + +[ForwardDifferentiable] +float test(A obj, float v) +{ + return obj.f(v); +} + +[numthreads(1, 1, 1)] +void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID) +{ + A a = {0.0}; + var p = diffPair(3.0, 1.0); + let rs = __fwd_diff(test)(a, p); + outputBuffer[0] = rs.d; +} diff --git a/tests/autodiff/member-func-custom-derivative.slang.expected.txt b/tests/autodiff/member-func-custom-derivative.slang.expected.txt new file mode 100644 index 000000000..253df0793 --- /dev/null +++ b/tests/autodiff/member-func-custom-derivative.slang.expected.txt @@ -0,0 +1,2 @@ +type: float +6.0 |
