summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--source/slang/slang-check-decl.cpp4
-rw-r--r--source/slang/slang-diagnostic-defs.h2
-rw-r--r--source/slang/slang-ir-addr-inst-elimination.cpp21
-rw-r--r--source/slang/slang-ir-autodiff-fwd.cpp40
-rw-r--r--source/slang/slang-ir-autodiff-transcriber-base.cpp8
-rw-r--r--source/slang/slang-ir-autodiff.cpp1
-rw-r--r--source/slang/slang-lower-to-ir.cpp11
-rw-r--r--tests/autodiff/dynamic-dispatch-generic-member.slang49
-rw-r--r--tests/autodiff/dynamic-dispatch-generic-member.slang.expected.txt5
-rw-r--r--tests/autodiff/member-func-custom-derivative-2.slang49
-rw-r--r--tests/autodiff/member-func-custom-derivative-2.slang.expected.txt2
-rw-r--r--tests/autodiff/member-func-custom-derivative.slang36
-rw-r--r--tests/autodiff/member-func-custom-derivative.slang.expected.txt2
13 files changed, 197 insertions, 33 deletions
diff --git a/source/slang/slang-check-decl.cpp b/source/slang/slang-check-decl.cpp
index 6083ce9c0..6a32f59d3 100644
--- a/source/slang/slang-check-decl.cpp
+++ b/source/slang/slang-check-decl.cpp
@@ -5638,6 +5638,10 @@ namespace Slang
bool isDiffFunc = false;
if (decl->hasModifier<ForwardDifferentiableAttribute>() || decl->hasModifier<BackwardDifferentiableAttribute>())
{
+ if (GetOuterGeneric(decl))
+ {
+ getSink()->diagnose(decl, Diagnostics::differentiableGenericInterfaceMethodNotSupported);
+ }
auto reqDecl = m_astBuilder->create<ForwardDerivativeRequirementDecl>();
cloneModifiers(reqDecl, decl);
auto declRef = DeclRef<CallableDecl>(decl, createDefaultSubstitutions(m_astBuilder, this, decl));
diff --git a/source/slang/slang-diagnostic-defs.h b/source/slang/slang-diagnostic-defs.h
index 2401b6e58..e3e9cfc44 100644
--- a/source/slang/slang-diagnostic-defs.h
+++ b/source/slang/slang-diagnostic-defs.h
@@ -352,6 +352,8 @@ DIAGNOSTIC(31145, Error, invalidCustomDerivative, "invalid custom derivative att
DIAGNOSTIC(31146, Error, declAlreadyHasAttribute, "'$0' already has attribute '[$1]'.")
DIAGNOSTIC(31147, Error, cannotResolveOriginalFunctionForDerivative, "cannot resolve the original function for the the custom derivative.")
+DIAGNOSTIC(31148, Error, differentiableGenericInterfaceMethodNotSupported, "`[ForwardDifferentiable] and [BackwardDifferentiable] are not supported on generic interface requirements.")
+
// Enums
DIAGNOSTIC(32000, Error, invalidEnumTagType, "invalid tag type for 'enum': '$0'")
diff --git a/source/slang/slang-ir-addr-inst-elimination.cpp b/source/slang/slang-ir-addr-inst-elimination.cpp
index 6715f2c6a..16bd67f66 100644
--- a/source/slang/slang-ir-addr-inst-elimination.cpp
+++ b/source/slang/slang-ir-addr-inst-elimination.cpp
@@ -99,22 +99,11 @@ struct AddressInstEliminationContext
IRBuilder builder(module);
builder.setInsertBefore(call);
auto tempVar = builder.emitVar(cast<IRPtrTypeBase>(addr->getFullType())->getValueType());
- auto callee = getResolvedInstForDecorations(call->getCallee());
- auto funcType = as<IRFuncType>(callee->getFullType());
- SLANG_RELEASE_ASSERT(funcType);
- UInt paramIndex = (UInt)(use - call->getOperands() - 1);
- SLANG_RELEASE_ASSERT(call->getArg(paramIndex) == addr);
- if (!as<IROutType>(funcType->getParamType(paramIndex)))
- {
- builder.emitStore(tempVar, getValue(builder, addr));
- }
- else
- {
- builder.emitStore(
- tempVar,
- builder.emitDefaultConstruct(
- as<IRPtrTypeBase>(tempVar->getDataType())->getValueType()));
- }
+
+ // Store the initial value of the mutable argument into temp var.
+ // If this is an `out` var, the initial value will be undefined,
+ // which will get cleaned up later into a `defaultConstruct`.
+ builder.emitStore(tempVar, getValue(builder, addr));
builder.setInsertAfter(call);
storeValue(builder, addr, builder.emitLoad(tempVar));
use->set(tempVar);
diff --git a/source/slang/slang-ir-autodiff-fwd.cpp b/source/slang/slang-ir-autodiff-fwd.cpp
index 3f31f1463..869f8920c 100644
--- a/source/slang/slang-ir-autodiff-fwd.cpp
+++ b/source/slang/slang-ir-autodiff-fwd.cpp
@@ -510,7 +510,7 @@ IRInst* tryFindPrimalSubstitute(IRBuilder* builder, IRInst* callee)
{
auto innerGen = as<IRGeneric>(specialize->getBase());
if (!innerGen)
- return nullptr;
+ return callee;
auto innerFunc = findGenericReturnVal(innerGen);
if (auto decor = innerFunc->findDecoration<IRPrimalSubstituteDecoration>())
{
@@ -553,7 +553,7 @@ InstPair ForwardDiffTranscriber::transcribeCall(IRBuilder* builder, IRCall* orig
return InstPair(nullptr, nullptr);
}
- auto primalCallee = lookupPrimalInst(builder, origCallee, origCallee);
+ auto primalCallee = findOrTranscribePrimalInst(builder, origCallee);
auto substPrimalCallee = tryFindPrimalSubstitute(builder, primalCallee);
IRInst* diffCallee = nullptr;
@@ -563,7 +563,7 @@ InstPair ForwardDiffTranscriber::transcribeCall(IRBuilder* builder, IRCall* orig
}
else
{
- instMapD.TryGetValue(substPrimalCallee, diffCallee);
+ diffCallee = findOrTranscribeDiffInst(builder, origCallee);
primalCallee = substPrimalCallee;
}
@@ -904,17 +904,32 @@ InstPair ForwardDiffTranscriber::transcribeSpecialize(IRBuilder* builder, IRSpec
IRInst* diffBase = nullptr;
if (instMapD.TryGetValue(origSpecialize->getBase(), diffBase))
{
- List<IRInst*> args;
- for (UInt i = 0; i < primalSpecialize->getArgCount(); i++)
+ if (diffBase)
{
- args.add(primalSpecialize->getArg(i));
+ List<IRInst*> args;
+ for (UInt i = 0; i < primalSpecialize->getArgCount(); i++)
+ {
+ args.add(primalSpecialize->getArg(i));
+ }
+ auto diffSpecialize = builder->emitSpecializeInst(
+ builder->getTypeKind(), diffBase, args.getCount(), args.getBuffer());
+ return InstPair(primalSpecialize, diffSpecialize);
+ }
+ else
+ {
+ return InstPair(primalSpecialize, nullptr);
}
- auto diffSpecialize = builder->emitSpecializeInst(
- builder->getTypeKind(), diffBase, args.getCount(), args.getBuffer());
- return InstPair(primalSpecialize, diffSpecialize);
}
auto genericInnerVal = findInnerMostGenericReturnVal(as<IRGeneric>(origSpecialize->getBase()));
+
+ // Right now we don't support transcribing a differentiable callee that is a specialize of a interface lookup
+ // (calling differentiable generic interface method). To support it, we need to recursively transcribe the
+ // specialization base here.
+
+ if (!genericInnerVal)
+ return InstPair(primalSpecialize, nullptr);
+
// Look for an IRForwardDerivativeDecoration on the specialize inst.
// (Normally, this would be on the inner IRFunc, but in this case only the JVP func
// can be specialized, so we put a decoration on the IRSpecialize)
@@ -963,10 +978,7 @@ InstPair ForwardDiffTranscriber::transcribeSpecialize(IRBuilder* builder, IRSpec
builder->getTypeKind(), diffBase, args.getCount(), args.getBuffer());
return InstPair(primalSpecialize, diffSpecialize);
}
- else
- {
- return InstPair(primalSpecialize, nullptr);
- }
+ return InstPair(primalSpecialize, nullptr);
}
InstPair ForwardDiffTranscriber::transcribeFieldExtract(IRBuilder* builder, IRInst* originalInst)
@@ -1433,6 +1445,8 @@ IRFunc* ForwardDiffTranscriber::transcribeFuncHeaderImpl(IRBuilder* inBuilder, I
IRFunc* primalFunc = origFunc;
+ maybeMigrateDifferentiableDictionaryFromDerivativeFunc(inBuilder, origFunc);
+
differentiableTypeConformanceContext.setFunc(origFunc);
primalFunc = origFunc;
diff --git a/source/slang/slang-ir-autodiff-transcriber-base.cpp b/source/slang/slang-ir-autodiff-transcriber-base.cpp
index 552ac762c..9cbea7873 100644
--- a/source/slang/slang-ir-autodiff-transcriber-base.cpp
+++ b/source/slang/slang-ir-autodiff-transcriber-base.cpp
@@ -594,7 +594,11 @@ void AutoDiffTranscriberBase::maybeMigrateDifferentiableDictionaryFromDerivative
}
else
{
- cloneDecoration(udfDecor, origFunc);
+ auto udfDictDecor = derivative->findDecoration< IRDifferentiableTypeDictionaryDecoration>();
+ if (udfDictDecor)
+ {
+ cloneDecoration(udfDictDecor, origFunc);
+ }
}
}
@@ -977,6 +981,8 @@ InstPair AutoDiffTranscriberBase::transcribeGeneric(IRBuilder* inBuilder, IRGene
if (auto innerFunc = as<IRFunc>(innerVal))
{
maybeMigrateDifferentiableDictionaryFromDerivativeFunc(inBuilder, innerFunc);
+ if (!innerFunc->findDecoration<IRDifferentiableTypeDictionaryDecoration>())
+ return InstPair(origGeneric, nullptr);
differentiableTypeConformanceContext.setFunc(innerFunc);
}
else if (auto funcType = as<IRFuncType>(innerVal))
diff --git a/source/slang/slang-ir-autodiff.cpp b/source/slang/slang-ir-autodiff.cpp
index f173aaa8b..1909f860c 100644
--- a/source/slang/slang-ir-autodiff.cpp
+++ b/source/slang/slang-ir-autodiff.cpp
@@ -368,6 +368,7 @@ void DifferentiableTypeConformanceContext::setFunc(IRGlobalValueWithCode* func)
{
parentFunc = func;
+
auto decor = func->findDecoration<IRDifferentiableTypeDictionaryDecoration>();
SLANG_RELEASE_ASSERT(decor);
diff --git a/source/slang/slang-lower-to-ir.cpp b/source/slang/slang-lower-to-ir.cpp
index f84f17886..9c27beb58 100644
--- a/source/slang/slang-lower-to-ir.cpp
+++ b/source/slang/slang-lower-to-ir.cpp
@@ -8484,10 +8484,15 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo>
funcExpr = udAttr->funcExpr;
else if (auto primalAttr = as<PrimalSubstituteAttribute>(modifier))
funcExpr = primalAttr->funcExpr;
+ DeclRefExpr* declRefExpr = as<DeclRefExpr>(funcExpr);
+ auto funcType = lowerType(subContext, funcExpr->type);
+ auto loweredVal = emitDeclRef(
+ subContext,
+ declRefExpr->declRef,
+ funcType);
+
+ SLANG_RELEASE_ASSERT(loweredVal.flavor == LoweredValInfo::Flavor::Simple);
- auto loweredVal = lowerRValueExpr(subContext, funcExpr);
-
- SLANG_ASSERT(loweredVal.flavor == LoweredValInfo::Flavor::Simple);
IRInst* derivativeFunc = loweredVal.val;
if (as<ForwardDerivativeAttribute>(modifier))
diff --git a/tests/autodiff/dynamic-dispatch-generic-member.slang b/tests/autodiff/dynamic-dispatch-generic-member.slang
new file mode 100644
index 000000000..83c3aee7c
--- /dev/null
+++ b/tests/autodiff/dynamic-dispatch-generic-member.slang
@@ -0,0 +1,49 @@
+// Test calling dynamic dispatched generic function from differentiable function.
+
+//TEST(compute):COMPARE_COMPUTE_EX:-slang -compute -shaderobj -output-using-type
+//TEST(compute, vulkan):COMPARE_COMPUTE_EX:-vk -compute -shaderobj -output-using-type
+
+//TEST_INPUT:ubuffer(data=[0 0 0 0], stride=4):out,name=outputBuffer
+RWStructuredBuffer<float> outputBuffer;
+
+interface IFoo
+{
+ float f();
+}
+
+interface IInterface
+{
+ float calc<T:IFoo>(T t, float x);
+}
+
+struct A : IFoo
+{
+ float f() { return 1.0; }
+};
+
+struct B : IInterface
+{
+ float calc<T : IFoo>(T t, float x)
+ {
+ return t.f() * x;
+ }
+};
+
+[BackwardDifferentiable]
+float test(IInterface obj, float x)
+{
+ A objA;
+ return no_diff(obj.calc(objA, x)) * x;
+}
+
+//TEST_INPUT: type_conformance A:IFoo = 0
+//TEST_INPUT: type_conformance B:IInterface = 1
+
+[numthreads(1, 1, 1)]
+void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID)
+{
+ var obj = createDynamicObject<IInterface>(dispatchThreadID.x, 1); // B
+ var p = diffPair(3.0);
+ __bwd_diff(test)(obj, p, 1.0);
+ outputBuffer[0] = p.d;
+}
diff --git a/tests/autodiff/dynamic-dispatch-generic-member.slang.expected.txt b/tests/autodiff/dynamic-dispatch-generic-member.slang.expected.txt
new file mode 100644
index 000000000..857cebc03
--- /dev/null
+++ b/tests/autodiff/dynamic-dispatch-generic-member.slang.expected.txt
@@ -0,0 +1,5 @@
+type: float
+3.000000
+0.000000
+0.000000
+0.000000
diff --git a/tests/autodiff/member-func-custom-derivative-2.slang b/tests/autodiff/member-func-custom-derivative-2.slang
new file mode 100644
index 000000000..329f3ade8
--- /dev/null
+++ b/tests/autodiff/member-func-custom-derivative-2.slang
@@ -0,0 +1,49 @@
+//TEST(compute):COMPARE_COMPUTE_EX:-slang -compute -shaderobj -output-using-type
+//TEST(compute, vulkan):COMPARE_COMPUTE_EX:-vk -compute -shaderobj -output-using-type
+
+//TEST_INPUT:ubuffer(data=[0], stride=4):out,name=outputBuffer
+RWStructuredBuffer<float> outputBuffer;
+
+interface IFoo
+{
+ [BackwardDifferentiable]
+ float3 test(float v, uint offset);
+}
+struct A : IFoo
+{
+ float x;
+
+ float3 f(float v, uint offset)
+ {
+ return v * v;
+ }
+
+ // Provide a backward diff, but leave out forward diff.
+ [BackwardDerivativeOf(f)]
+ [TreatAsDifferentiable]
+ void diff_f(inout DifferentialPair<float> v, uint offset, float3 dOut)
+ {
+ v = diffPair(v.p, 2 * v.p * dOut.x);
+ }
+
+ [BackwardDifferentiable]
+ float3 test(float v, uint offset)
+ {
+ return f(v, 0);
+ }
+}
+
+[BackwardDifferentiable]
+float3 test(IFoo obj, float v)
+{
+ return obj.test(v, 0);
+}
+
+[numthreads(1, 1, 1)]
+void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID)
+{
+ A a = {0.0};
+ var p = diffPair(3.0, 0.0);
+ let rs = __bwd_diff(test)(a, p, 1.0);
+ outputBuffer[0] = p.d;
+}
diff --git a/tests/autodiff/member-func-custom-derivative-2.slang.expected.txt b/tests/autodiff/member-func-custom-derivative-2.slang.expected.txt
new file mode 100644
index 000000000..253df0793
--- /dev/null
+++ b/tests/autodiff/member-func-custom-derivative-2.slang.expected.txt
@@ -0,0 +1,2 @@
+type: float
+6.0
diff --git a/tests/autodiff/member-func-custom-derivative.slang b/tests/autodiff/member-func-custom-derivative.slang
new file mode 100644
index 000000000..3ec44e690
--- /dev/null
+++ b/tests/autodiff/member-func-custom-derivative.slang
@@ -0,0 +1,36 @@
+//TEST(compute):COMPARE_COMPUTE_EX:-slang -compute -shaderobj -output-using-type
+//TEST(compute, vulkan):COMPARE_COMPUTE_EX:-vk -compute -shaderobj -output-using-type
+
+//TEST_INPUT:ubuffer(data=[0], stride=4):out,name=outputBuffer
+RWStructuredBuffer<float> outputBuffer;
+
+struct A
+{
+ float x;
+
+ [ForwardDerivative(diff_f)]
+ float f(float v)
+ {
+ return v * v;
+ }
+
+ DifferentialPair<float> diff_f(DifferentialPair<float> v)
+ {
+ return diffPair(v.p * v.p, v.p * v.d * 2.0);
+ }
+}
+
+[ForwardDifferentiable]
+float test(A obj, float v)
+{
+ return obj.f(v);
+}
+
+[numthreads(1, 1, 1)]
+void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID)
+{
+ A a = {0.0};
+ var p = diffPair(3.0, 1.0);
+ let rs = __fwd_diff(test)(a, p);
+ outputBuffer[0] = rs.d;
+}
diff --git a/tests/autodiff/member-func-custom-derivative.slang.expected.txt b/tests/autodiff/member-func-custom-derivative.slang.expected.txt
new file mode 100644
index 000000000..253df0793
--- /dev/null
+++ b/tests/autodiff/member-func-custom-derivative.slang.expected.txt
@@ -0,0 +1,2 @@
+type: float
+6.0