summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--source/slang/slang-ast-decl.h3
-rw-r--r--source/slang/slang-check-decl.cpp8
-rw-r--r--source/slang/slang-diagnostic-defs.h2
-rw-r--r--source/slang/slang-ir-autodiff-fwd.cpp2
-rw-r--r--source/slang/slang-ir-autodiff-rev.cpp2
-rw-r--r--source/slang/slang-ir-autodiff-transpose.h13
-rw-r--r--source/slang/slang-ir-autodiff-unzip.h15
-rw-r--r--source/slang/slang-ir-autodiff.cpp12
-rw-r--r--source/slang/slang-ir-autodiff.h4
-rw-r--r--source/slang/slang-ir-lower-witness-lookup.cpp1
-rw-r--r--source/slang/slang-ir.cpp3
-rw-r--r--source/slang/slang-lower-to-ir.cpp14
-rw-r--r--tests/autodiff/dynamic-dispatch-generic-member-2.slang51
-rw-r--r--tests/autodiff/dynamic-dispatch-generic-member-2.slang.expected.txt5
-rw-r--r--tests/autodiff/dynamic-object-bwd-diff-2.slang61
-rw-r--r--tests/autodiff/dynamic-object-bwd-diff-2.slang.expected.txt6
16 files changed, 174 insertions, 28 deletions
diff --git a/source/slang/slang-ast-decl.h b/source/slang/slang-ast-decl.h
index ccbac0286..e75660c7b 100644
--- a/source/slang/slang-ast-decl.h
+++ b/source/slang/slang-ast-decl.h
@@ -526,6 +526,9 @@ class AttributeDecl : public ContainerDecl
class DerivativeRequirementDecl : public FunctionDeclBase
{
SLANG_AST_CLASS(DerivativeRequirementDecl)
+
+ // The original requirement decl.
+ Decl* originalRequirementDecl = nullptr;
};
// A reference to a synthesized decl representing a differentiable function requirement, this decl will
diff --git a/source/slang/slang-check-decl.cpp b/source/slang/slang-check-decl.cpp
index 0901d2026..b3470e882 100644
--- a/source/slang/slang-check-decl.cpp
+++ b/source/slang/slang-check-decl.cpp
@@ -1427,6 +1427,7 @@ namespace Slang
varDecl->initExpr = CompleteOverloadCandidate(overloadContext, *overloadContext.bestCandidate);
}
}
+ maybeRegisterDifferentiableType(getASTBuilder(), varDecl->getType());
}
// Fill in default substitutions for the 'subtype' part of a type constraint decl
@@ -4738,7 +4739,6 @@ namespace Slang
void SemanticsDeclBodyVisitor::visitFunctionDeclBase(FunctionDeclBase* decl)
{
auto newContext = withParentFunc(decl);
-
if (newContext.getParentDifferentiableAttribute())
{
// Register additional types outside the function body first.
@@ -5638,11 +5638,8 @@ namespace Slang
bool isDiffFunc = false;
if (decl->hasModifier<ForwardDifferentiableAttribute>() || decl->hasModifier<BackwardDifferentiableAttribute>())
{
- if (GetOuterGeneric(decl))
- {
- getSink()->diagnose(decl, Diagnostics::differentiableGenericInterfaceMethodNotSupported);
- }
auto reqDecl = m_astBuilder->create<ForwardDerivativeRequirementDecl>();
+ reqDecl->originalRequirementDecl = decl;
cloneModifiers(reqDecl, decl);
auto declRef = DeclRef<CallableDecl>(decl, createDefaultSubstitutions(m_astBuilder, this, decl));
auto diffFuncType = getForwardDiffFuncType(getFuncType(m_astBuilder, declRef));
@@ -5664,6 +5661,7 @@ namespace Slang
auto diffFuncType = as<FuncType>(getBackwardDiffFuncType(originalFuncType));
{
auto reqDecl = m_astBuilder->create<BackwardDerivativeRequirementDecl>();
+ reqDecl->originalRequirementDecl = decl;
cloneModifiers(reqDecl, decl);
setFuncTypeIntoRequirementDecl(reqDecl, diffFuncType);
interfaceDecl->members.add(reqDecl);
diff --git a/source/slang/slang-diagnostic-defs.h b/source/slang/slang-diagnostic-defs.h
index ec8131824..cb441ade8 100644
--- a/source/slang/slang-diagnostic-defs.h
+++ b/source/slang/slang-diagnostic-defs.h
@@ -359,8 +359,6 @@ DIAGNOSTIC(31146, Error, declAlreadyHasAttribute, "'$0' already has attribute '[
DIAGNOSTIC(31147, Error, cannotResolveOriginalFunctionForDerivative, "cannot resolve the original function for the the custom derivative.")
DIAGNOSTIC(31148, Error, cannotResolveDerivativeFunction, "cannot resolve the custom derivative function")
-DIAGNOSTIC(31149, Error, differentiableGenericInterfaceMethodNotSupported, "`[ForwardDifferentiable] and [BackwardDifferentiable] are not supported on generic interface requirements.")
-
DIAGNOSTIC(31200, Warning, deprecatedUsage, "$0 has been deprecated: $1")
// Enums
diff --git a/source/slang/slang-ir-autodiff-fwd.cpp b/source/slang/slang-ir-autodiff-fwd.cpp
index e0b916090..819c6bc57 100644
--- a/source/slang/slang-ir-autodiff-fwd.cpp
+++ b/source/slang/slang-ir-autodiff-fwd.cpp
@@ -955,7 +955,7 @@ InstPair ForwardDiffTranscriber::transcribeSpecialize(IRBuilder* builder, IRSpec
builder->getTypeKind(), diffBaseSpecialize->getBase(), args.getCount(), args.getBuffer());
return InstPair(primalSpecialize, diffSpecialize);
}
- else if (_isDifferentiableFunc(genericInnerVal))
+ else if (_isDifferentiableFunc(genericInnerVal) || as<IRFuncType>(genericInnerVal))
{
List<IRInst*> args;
for (UInt i = 0; i < primalSpecialize->getArgCount(); i++)
diff --git a/source/slang/slang-ir-autodiff-rev.cpp b/source/slang/slang-ir-autodiff-rev.cpp
index 2994a8c31..e5735b831 100644
--- a/source/slang/slang-ir-autodiff-rev.cpp
+++ b/source/slang/slang-ir-autodiff-rev.cpp
@@ -1273,7 +1273,7 @@ namespace Slang
return InstPair(primalSpecialize, diffSpecialize);
}
- else if (isBackwardDifferentiableFunc(genericInnerVal))
+ else if (isBackwardDifferentiableFunc(genericInnerVal) || as<IRFuncType>(genericInnerVal))
{
List<IRInst*> args;
for (UInt i = 0; i < primalSpecialize->getArgCount(); i++)
diff --git a/source/slang/slang-ir-autodiff-transpose.h b/source/slang/slang-ir-autodiff-transpose.h
index 8a734446d..910c23708 100644
--- a/source/slang/slang-ir-autodiff-transpose.h
+++ b/source/slang/slang-ir-autodiff-transpose.h
@@ -655,13 +655,6 @@ struct DiffTransposePass
subBuilder.addBackwardDerivativePrimalReturnDecoration(branch, retVal);
}
- // TODO: Should move this to before all the transposition, but a lot of the
- // transposition logic seems to access the parent of blocks to find the func.
- // Replace those uses.
- //
- for (auto block : workList)
- block->removeFromParent();
-
// At this point, the only block left without terminator insts
// should be the last one. Add a void return to complete it.
//
@@ -1101,7 +1094,7 @@ struct DiffTransposePass
};
List<DiffValWriteBack> writebacks;
- auto baseFnType = as<IRFuncType>(baseFn->getDataType());
+ auto baseFnType = as<IRFuncType>(getResolvedInstForDecorations(baseFn->getDataType()));
SLANG_RELEASE_ASSERT(baseFnType);
SLANG_RELEASE_ASSERT(fwdCall->getArgCount() == baseFnType->getParamCount());
@@ -1151,8 +1144,8 @@ struct DiffTransposePass
auto pairType = as<IRDifferentialPairType>(arg->getDataType());
auto var = builder->emitVar(arg->getDataType());
- auto diffType = (IRType*)diffTypeContext.getDifferentialForType(builder, pairType->getValueType());
- auto zeroMethod = diffTypeContext.getZeroMethodForType(builder, pairType->getValueType());
+ auto diffType = (IRType*)diffTypeContext.getDiffTypeFromPairType(builder, pairType);
+ auto zeroMethod = diffTypeContext.getDiffZeroMethodFromPairType(builder, pairType);
SLANG_ASSERT(zeroMethod);
auto diffZero = builder->emitCallInst(
diffType,
diff --git a/source/slang/slang-ir-autodiff-unzip.h b/source/slang/slang-ir-autodiff-unzip.h
index 34f0f6c9b..63b46f779 100644
--- a/source/slang/slang-ir-autodiff-unzip.h
+++ b/source/slang/slang-ir-autodiff-unzip.h
@@ -210,8 +210,8 @@ struct DiffUnzipPass
auto baseFn = _getOriginalFunc(mixedCall);
SLANG_RELEASE_ASSERT(baseFn);
- auto primalFuncType = autodiffContext->transcriberSet.primalTranscriber->differentiateFunctionType(
- primalBuilder, baseFn, as<IRFuncType>(baseFn->getDataType()));
+ auto primalFuncType = autodiffContext->transcriberSet.primalTranscriber->transcribe(
+ primalBuilder, baseFn->getDataType());
IRInst* intermediateType = nullptr;
@@ -251,12 +251,12 @@ struct DiffUnzipPass
intermediateVar = primalBuilder->emitVar((IRType*)intermediateType);
primalBuilder->markInstAsPrimal(intermediateVar);
}
-
+
IRInst* primalFn = nullptr;
if (intermediateVar)
{
primalBuilder->addBackwardDerivativePrimalContextDecoration(intermediateVar, intermediateVar);
- primalFn = primalBuilder->emitBackwardDifferentiatePrimalInst(primalFuncType, baseFn);
+ primalFn = primalBuilder->emitBackwardDifferentiatePrimalInst((IRType*)primalFuncType, baseFn);
}
else
{
@@ -298,7 +298,10 @@ struct DiffUnzipPass
primalBuilder->addBackwardDerivativePrimalContextDecoration(primalVal, intermediateVar);
primalBuilder->markInstAsPrimal(primalVal);
- SLANG_RELEASE_ASSERT(mixedCall->getArgCount() <= primalFuncType->getParamCount());
+ auto resolvedPrimalFuncType = as<IRFuncType>(getResolvedInstForDecorations(primalFuncType));
+ SLANG_RELEASE_ASSERT(resolvedPrimalFuncType);
+
+ SLANG_RELEASE_ASSERT(mixedCall->getArgCount() <= resolvedPrimalFuncType->getParamCount());
List<IRInst*> diffArgs;
for (UIndex ii = 0; ii < mixedCall->getArgCount(); ii++)
@@ -316,7 +319,7 @@ struct DiffUnzipPass
// If arg is a mixed differential (pair), it should have already been split.
SLANG_ASSERT(primalArg);
SLANG_ASSERT(diffArg);
- auto primalParamType = primalFuncType->getParamType(ii);
+ auto primalParamType = resolvedPrimalFuncType->getParamType(ii);
if (auto outType = as<IROutType>(primalParamType))
{
diff --git a/source/slang/slang-ir-autodiff.cpp b/source/slang/slang-ir-autodiff.cpp
index 4188d2ec8..4e33a01ab 100644
--- a/source/slang/slang-ir-autodiff.cpp
+++ b/source/slang/slang-ir-autodiff.cpp
@@ -458,6 +458,18 @@ IRInst* DifferentiableTypeConformanceContext::getDiffTypeWitnessFromPairType(IRB
return _getDiffTypeWitnessFromPairType(sharedContext, builder, type);
}
+IRInst* DifferentiableTypeConformanceContext::getDiffZeroMethodFromPairType(IRBuilder* builder, IRDifferentialPairTypeBase* type)
+{
+ auto witnessTable = type->getWitness();
+ return _lookupWitness(builder, witnessTable, sharedContext->zeroMethodStructKey);
+}
+
+IRInst* DifferentiableTypeConformanceContext::getDiffAddMethodFromPairType(IRBuilder* builder, IRDifferentialPairTypeBase* type)
+{
+ auto witnessTable = type->getWitness();
+ return _lookupWitness(builder, witnessTable, sharedContext->addMethodStructKey);
+}
+
void DifferentiableTypeConformanceContext::buildGlobalWitnessDictionary()
{
for (auto globalInst : sharedContext->moduleInst->getChildren())
diff --git a/source/slang/slang-ir-autodiff.h b/source/slang/slang-ir-autodiff.h
index 52cf346b3..91b45c5be 100644
--- a/source/slang/slang-ir-autodiff.h
+++ b/source/slang/slang-ir-autodiff.h
@@ -177,6 +177,10 @@ struct DifferentiableTypeConformanceContext
IRInst* getDiffTypeWitnessFromPairType(IRBuilder* builder, IRDifferentialPairTypeBase* type);
+ IRInst* getDiffZeroMethodFromPairType(IRBuilder* builder, IRDifferentialPairTypeBase* type);
+
+ IRInst* getDiffAddMethodFromPairType(IRBuilder* builder, IRDifferentialPairTypeBase* type);
+
// Lookup and return the 'Differential' type declared in the concrete type
// in order to conform to the IDifferentiable interface.
// Note that inside a generic block, this will be a witness table lookup instruction
diff --git a/source/slang/slang-ir-lower-witness-lookup.cpp b/source/slang/slang-ir-lower-witness-lookup.cpp
index c1ee204b0..0e46987c7 100644
--- a/source/slang/slang-ir-lower-witness-lookup.cpp
+++ b/source/slang/slang-ir-lower-witness-lookup.cpp
@@ -350,6 +350,7 @@ struct WitnessLookupLoweringContext
{
if (auto specialize = as<IRSpecialize>(use->getUser()))
{
+ builder.setInsertBefore(use->getUser());
List<IRInst*> args;
for (UInt i = 0; i < specialize->getArgCount(); i++)
args.add(specialize->getArg(i));
diff --git a/source/slang/slang-ir.cpp b/source/slang/slang-ir.cpp
index e74a57424..eefcb9eea 100644
--- a/source/slang/slang-ir.cpp
+++ b/source/slang/slang-ir.cpp
@@ -7067,6 +7067,8 @@ namespace Slang
// and then destroy it (it had better have no uses!)
void IRInst::removeAndDeallocate()
{
+ removeAndDeallocateAllDecorationsAndChildren();
+
if (auto module = getModule())
{
if (getIROpInfo(getOp()).isHoistable())
@@ -7080,7 +7082,6 @@ namespace Slang
module->getDeduplicationContext()->getInstReplacementMap().remove(this);
}
removeArguments();
- removeAndDeallocateAllDecorationsAndChildren();
removeFromParent();
// Run destructor to be sure...
diff --git a/source/slang/slang-lower-to-ir.cpp b/source/slang/slang-lower-to-ir.cpp
index d644d01c7..c8a41c7c7 100644
--- a/source/slang/slang-lower-to-ir.cpp
+++ b/source/slang/slang-lower-to-ir.cpp
@@ -7429,7 +7429,12 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo>
}
else
{
- if (auto callableDecl = as<CallableDecl>(requirementDecl))
+ CallableDecl* callableDecl = nullptr;
+ if (auto genDecl = as<GenericDecl>(requirementDecl))
+ callableDecl = as<CallableDecl>(genDecl->inner);
+ else
+ callableDecl = as<CallableDecl>(requirementDecl);
+ if (callableDecl)
{
// Differentiable functions has additional requirements for the derivatives.
for (auto diffDecl : callableDecl->getMembersOfType<DerivativeRequirementReferenceDecl>())
@@ -8369,7 +8374,12 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo>
LoweredValInfo lowerFuncDeclInContext(IRGenContext* subContext, IRBuilder* subBuilder, FunctionDeclBase* decl, bool emitBody = true)
{
- auto outerGeneric = emitOuterGenerics(subContext, decl, decl);
+ IRGeneric* outerGeneric = nullptr;
+
+ if (auto derivativeRequirement = as<DerivativeRequirementDecl>(decl))
+ outerGeneric = emitOuterGenerics(subContext, derivativeRequirement->originalRequirementDecl, derivativeRequirement->originalRequirementDecl);
+ else
+ outerGeneric = emitOuterGenerics(subContext, decl, decl);
// need to create an IR function here
diff --git a/tests/autodiff/dynamic-dispatch-generic-member-2.slang b/tests/autodiff/dynamic-dispatch-generic-member-2.slang
new file mode 100644
index 000000000..ac6758c4d
--- /dev/null
+++ b/tests/autodiff/dynamic-dispatch-generic-member-2.slang
@@ -0,0 +1,51 @@
+// Test calling dynamic dispatched generic function from differentiable function.
+
+//TEST(compute):COMPARE_COMPUTE_EX:-slang -compute -shaderobj -output-using-type
+//TEST(compute, vulkan):COMPARE_COMPUTE_EX:-vk -compute -shaderobj -output-using-type
+
+//TEST_INPUT:ubuffer(data=[0 0 0 0], stride=4):out,name=outputBuffer
+RWStructuredBuffer<float> outputBuffer;
+
+interface IFoo
+{
+ float f();
+}
+
+interface IInterface
+{
+ [BackwardDifferentiable]
+ float calc<T:IFoo>(T t, float x);
+}
+
+struct A : IFoo
+{
+ float f() { return 5.0; }
+};
+
+struct B : IInterface
+{
+ [BackwardDifferentiable]
+ float calc<T : IFoo>(T t, float x)
+ {
+ return t.f() * x;
+ }
+};
+
+[BackwardDifferentiable]
+float test(IFoo foo, IInterface obj, float x)
+{
+ return obj.calc(foo, x) * x;
+}
+
+//TEST_INPUT: type_conformance A:IFoo = 0
+//TEST_INPUT: type_conformance B:IInterface = 1
+
+[numthreads(1, 1, 1)]
+void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID)
+{
+ var obj = createDynamicObject<IInterface>(dispatchThreadID.x, 1); // B
+ var foo = createDynamicObject<IFoo>(0, 0); // A
+ var p = diffPair(3.0);
+ __bwd_diff(test)(foo, obj, p, 1.0);
+ outputBuffer[0] = p.d;
+}
diff --git a/tests/autodiff/dynamic-dispatch-generic-member-2.slang.expected.txt b/tests/autodiff/dynamic-dispatch-generic-member-2.slang.expected.txt
new file mode 100644
index 000000000..1ce9558c3
--- /dev/null
+++ b/tests/autodiff/dynamic-dispatch-generic-member-2.slang.expected.txt
@@ -0,0 +1,5 @@
+type: float
+30.000000
+0.000000
+0.000000
+0.000000
diff --git a/tests/autodiff/dynamic-object-bwd-diff-2.slang b/tests/autodiff/dynamic-object-bwd-diff-2.slang
new file mode 100644
index 000000000..bb0a69b28
--- /dev/null
+++ b/tests/autodiff/dynamic-object-bwd-diff-2.slang
@@ -0,0 +1,61 @@
+// Test calling backward differentiable function through dynamic dispatch, where the interface
+// being dispatched inherits from IDifferentiable, so that `this` is differentiable.
+
+//DISABLED_TEST(compute):COMPARE_COMPUTE_EX:-slang -compute -shaderobj -output-using-type
+//DISABLED_TEST(compute, vulkan):COMPARE_COMPUTE_EX:-vk -compute -shaderobj -output-using-type
+
+//TEST_INPUT:ubuffer(data=[0 0 0 0 0], stride=4):out,name=outputBuffer
+RWStructuredBuffer<float> outputBuffer;
+
+[anyValueSize(16)]
+interface IInterface : IDifferentiable
+{
+ [BackwardDifferentiable]
+ float calc(float x);
+}
+
+struct C : IInterface
+{
+ [BackwardDifferentiable]
+ float calc(float x) { return 2 * x; }
+}
+
+struct A : IInterface
+{
+ float a;
+ [BackwardDifferentiable]
+ float calc(float x)
+ {
+ return a * x * x;
+ }
+};
+
+
+[BackwardDifferentiable]
+float run(int id, float x, no_diff float y)
+{
+ IInterface obj = createDynamicObject<IInterface>(id, y);
+ C c = {};
+ return obj.calc(x);
+}
+
+//TEST_INPUT: type_conformance A:IInterface = 0
+//TEST_INPUT: type_conformance C:IInterface = 1
+
+[numthreads(1, 1, 1)]
+void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID)
+{
+ {
+ var p = diffPair(3.0);
+
+ __bwd_diff(run)(0, p, 0.5, 1.0f);
+ outputBuffer[0] = p.d; // A.calc, expect 3
+ }
+
+ {
+ var p = diffPair(3.0);
+
+ __bwd_diff(run)(1, p, 1.5, 1.0f);
+ outputBuffer[1] = p.d; // c.calc, expect 2
+ }
+}
diff --git a/tests/autodiff/dynamic-object-bwd-diff-2.slang.expected.txt b/tests/autodiff/dynamic-object-bwd-diff-2.slang.expected.txt
new file mode 100644
index 000000000..3d273506f
--- /dev/null
+++ b/tests/autodiff/dynamic-object-bwd-diff-2.slang.expected.txt
@@ -0,0 +1,6 @@
+type: float
+3.000000
+2.000000
+0.000000
+0.000000
+0.000000