summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--source/slang/slang-ast-decl.h15
-rw-r--r--source/slang/slang-check-decl.cpp69
-rw-r--r--source/slang/slang-ir-autodiff-transpose.h56
-rw-r--r--source/slang/slang-ir-autodiff-unzip.h70
-rw-r--r--source/slang/slang-lower-to-ir.cpp14
-rw-r--r--source/slang/slang-mangle.cpp6
-rw-r--r--tests/autodiff/dynamic-dispatch-bwd-diff.slang52
-rw-r--r--tests/autodiff/dynamic-dispatch-bwd-diff.slang.expected.txt6
8 files changed, 151 insertions, 137 deletions
diff --git a/source/slang/slang-ast-decl.h b/source/slang/slang-ast-decl.h
index 81a6e3f7d..ccbac0286 100644
--- a/source/slang/slang-ast-decl.h
+++ b/source/slang/slang-ast-decl.h
@@ -546,21 +546,6 @@ class BackwardDerivativeRequirementDecl : public DerivativeRequirementDecl
SLANG_AST_CLASS(BackwardDerivativeRequirementDecl)
};
-class BackwardDerivativePrimalRequirementDecl : public DerivativeRequirementDecl
-{
- SLANG_AST_CLASS(BackwardDerivativePrimalRequirementDecl)
-};
-
-class BackwardDerivativePropagateRequirementDecl : public DerivativeRequirementDecl
-{
- SLANG_AST_CLASS(BackwardDerivativePropagateRequirementDecl)
-};
-
-class BackwardDerivativeIntermediateTypeRequirementDecl : public DerivativeRequirementDecl
-{
- SLANG_AST_CLASS(BackwardDerivativeIntermediateTypeRequirementDecl)
-};
-
bool isInterfaceRequirement(Decl* decl);
InterfaceDecl* findParentInterfaceDecl(Decl* decl);
diff --git a/source/slang/slang-check-decl.cpp b/source/slang/slang-check-decl.cpp
index 142842e12..a1d5acfb0 100644
--- a/source/slang/slang-check-decl.cpp
+++ b/source/slang/slang-check-decl.cpp
@@ -2677,24 +2677,6 @@ namespace Slang
val->func = satisfyingMemberDeclRef;
witnessTable->add(bwdReq, RequirementWitness(val));
}
- else if (auto primalReq = as<BackwardDerivativePrimalRequirementDecl>(reqRefDecl->referencedDecl))
- {
- DifferentiateVal* val = m_astBuilder->create<BackwardDifferentiatePrimalVal>();
- val->func = satisfyingMemberDeclRef;
- witnessTable->add(primalReq, RequirementWitness(val));
- }
- else if (auto propReq = as<BackwardDerivativePropagateRequirementDecl>(reqRefDecl->referencedDecl))
- {
- DifferentiateVal* val = m_astBuilder->create<BackwardDifferentiatePropagateVal>();
- val->func = satisfyingMemberDeclRef;
- witnessTable->add(propReq, RequirementWitness(val));
- }
- else if (auto itypeReq = as<BackwardDerivativeIntermediateTypeRequirementDecl>(reqRefDecl->referencedDecl))
- {
- DifferentiateVal* val = m_astBuilder->create<BackwardDifferentiateIntermediateTypeVal>();
- val->func = satisfyingMemberDeclRef;
- witnessTable->add(itypeReq, RequirementWitness(val));
- }
}
witnessTable->add(requiredMemberDeclRef, RequirementWitness(satisfyingMemberDeclRef));
}
@@ -5920,7 +5902,7 @@ namespace Slang
if (auto interfaceDecl = findParentInterfaceDecl(decl))
{
bool isDiffFunc = false;
- if (decl->hasModifier<ForwardDifferentiableAttribute>())
+ if (decl->hasModifier<ForwardDifferentiableAttribute>() || decl->hasModifier<BackwardDifferentiableAttribute>())
{
auto reqDecl = m_astBuilder->create<ForwardDerivativeRequirementDecl>();
cloneModifiers(reqDecl, decl);
@@ -5954,55 +5936,6 @@ namespace Slang
reqRef->parentDecl = decl;
decl->members.add(reqRef);
}
- // Requirement for backward derivative intermediate type.
- auto intermediateTypeReqDecl = m_astBuilder->create<BackwardDerivativeIntermediateTypeRequirementDecl>();
- auto intermediateType = m_astBuilder->getOrCreateDeclRefType(
- intermediateTypeReqDecl, createDefaultSubstitutions(m_astBuilder, this, decl));
- {
- cloneModifiers(intermediateTypeReqDecl, decl);
- interfaceDecl->members.add(intermediateTypeReqDecl);
- intermediateTypeReqDecl->parentDecl = interfaceDecl;
-
- auto reqRef = m_astBuilder->create<DerivativeRequirementReferenceDecl>();
- reqRef->referencedDecl = intermediateTypeReqDecl;
- reqRef->parentDecl = decl;
- decl->members.add(reqRef);
- }
- // Requirement for backward derivative primal func.
- {
- auto reqDecl = m_astBuilder->create<BackwardDerivativePrimalRequirementDecl>();
- cloneModifiers(reqDecl, decl);
- FuncType* primalFuncType = m_astBuilder->create<FuncType>();
- primalFuncType->resultType = originalFuncType->resultType;
- primalFuncType->paramTypes.addRange(originalFuncType->paramTypes);
- auto outType = m_astBuilder->getOutType(intermediateType);
- primalFuncType->paramTypes.add(outType);
- setFuncTypeIntoRequirementDecl(reqDecl, primalFuncType);
- interfaceDecl->members.add(reqDecl);
- reqDecl->parentDecl = interfaceDecl;
-
- auto reqRef = m_astBuilder->create<DerivativeRequirementReferenceDecl>();
- reqRef->referencedDecl = reqDecl;
- reqRef->parentDecl = decl;
- decl->members.add(reqRef);
- }
- // Requirement for backward derivative propagate func.
- {
- auto reqDecl = m_astBuilder->create<BackwardDerivativePropagateRequirementDecl>();
- cloneModifiers(reqDecl, decl);
- interfaceDecl->members.add(reqDecl);
- reqDecl->parentDecl = interfaceDecl;
- FuncType* propagateFuncType = m_astBuilder->create<FuncType>();
- propagateFuncType->resultType = diffFuncType->resultType;
- propagateFuncType->paramTypes.addRange(diffFuncType->paramTypes);
- propagateFuncType->paramTypes.add(intermediateType);
- setFuncTypeIntoRequirementDecl(reqDecl, propagateFuncType);
- auto reqRef = m_astBuilder->create<DerivativeRequirementReferenceDecl>();
- reqRef->referencedDecl = reqDecl;
- reqRef->parentDecl = decl;
- decl->members.add(reqRef);
- }
-
isDiffFunc = true;
}
if (isDiffFunc)
diff --git a/source/slang/slang-ir-autodiff-transpose.h b/source/slang/slang-ir-autodiff-transpose.h
index a4c79d09a..e1832b9eb 100644
--- a/source/slang/slang-ir-autodiff-transpose.h
+++ b/source/slang/slang-ir-autodiff-transpose.h
@@ -851,6 +851,8 @@ struct DiffTransposePass
if (as<IRDecoration>(child) || as<IRParam>(child))
continue;
+ if (as<IRType>(child))
+ continue;
if (isDifferentialInst(child))
transposeInst(&builder, child);
@@ -1332,10 +1334,6 @@ struct DiffTransposePass
}
}
- // The call must have been decorated with the continuation context after splitting.
- auto primalContextDecor = fwdCall->findDecoration<IRBackwardDerivativePrimalContextDecoration>();
- SLANG_RELEASE_ASSERT(primalContextDecor);
-
auto baseFn = fwdDiffCallee->getBaseFn();
List<IRInst*> args;
@@ -1453,20 +1451,52 @@ struct DiffTransposePass
argRequiresLoad.add(false);
}
- // Ensure availability of the primal context var
- auto primalContextVar = hoistPrimalInst(builder, primalContextDecor->getBackwardDerivativePrimalContextVar());
- SLANG_RELEASE_ASSERT(primalContextVar);
+ // If the callee provides a primal implementation that produces continuation context for propagation phase
+ // we grab it and pass it as argument to the propagation function.
+ if (auto primalContextDecor = fwdCall->findDecoration<IRBackwardDerivativePrimalContextDecoration>())
+ {
+ // Ensure availability of the primal context var
+ auto primalContextVar = hoistPrimalInst(builder, primalContextDecor->getBackwardDerivativePrimalContextVar());
+ SLANG_RELEASE_ASSERT(primalContextVar);
- args.add(builder->emitLoad(primalContextVar));
- argTypes.add(as<IRPtrTypeBase>(
+ args.add(builder->emitLoad(primalContextVar));
+ argTypes.add(as<IRPtrTypeBase>(
primalContextVar->getDataType())
->getValueType());
- argRequiresLoad.add(false);
+ argRequiresLoad.add(false);
+ }
auto revFnType = builder->getFuncType(argTypes, builder->getVoidType());
- auto revCallee = builder->emitBackwardDifferentiatePropagateInst(
- revFnType,
- baseFn);
+ IRInst* revCallee = nullptr;
+ if (getResolvedInstForDecorations(baseFn)->getOp() == kIROp_LookupWitness)
+ {
+ // This is an interface method call, we can simply transcribe it here.
+ auto specialize = as<IRSpecialize>(baseFn);
+ auto innerFn = baseFn;
+ if (specialize)
+ innerFn = specialize->getBase();
+ auto lookupWitness = as<IRLookupWitnessMethod>(innerFn);
+ SLANG_RELEASE_ASSERT(lookupWitness);
+ auto diffDecor = lookupWitness->getRequirementKey()->findDecoration<IRBackwardDerivativeDecoration>();
+ SLANG_RELEASE_ASSERT(diffDecor);
+ auto diffKey = diffDecor->getBackwardDerivativeFunc();
+ revCallee = builder->emitLookupInterfaceMethodInst(builder->getTypeKind(), lookupWitness->getWitnessTable(), diffKey);
+ if (specialize)
+ {
+ List<IRInst*> specArgs;
+ for (UInt i = 0; i < specialize->getArgCount(); i++)
+ specArgs.add(specialize->getArg(i));
+ revCallee = builder->emitSpecializeInst(builder->getTypeKind(), revCallee, specArgs.getCount(), specArgs.getBuffer());
+ }
+ revCallee->setFullType(revFnType);
+ }
+ else
+ {
+ // All other calls, we insert a `backwardDifferentiate` inst so we will process it in a follow-up iteration.
+ revCallee = builder->emitBackwardDifferentiatePropagateInst(
+ revFnType,
+ baseFn);
+ }
List<IRInst*> callArgs;
for (auto arg : args)
diff --git a/source/slang/slang-ir-autodiff-unzip.h b/source/slang/slang-ir-autodiff-unzip.h
index 2ebc330f0..a30826370 100644
--- a/source/slang/slang-ir-autodiff-unzip.h
+++ b/source/slang/slang-ir-autodiff-unzip.h
@@ -830,28 +830,51 @@ struct DiffUnzipPass
{
auto func = findSpecializeReturnVal(specialize);
auto outerGen = findOuterGeneric(func);
- intermediateType = primalBuilder->getBackwardDiffIntermediateContextType(outerGen);
- List<IRInst*> args;
- for (UInt i = 0; i < specialize->getArgCount(); i++)
- args.add(specialize->getArg(i));
- intermediateType = primalBuilder->emitSpecializeInst(
- primalBuilder->getTypeKind(),
- intermediateType,
- args.getCount(),
- args.getBuffer());
+ if (func->getOp() == kIROp_LookupWitness)
+ {
+ // An interface method won't have intermediate type.
+ intermediateType = primalBuilder->getVoidType();
+ }
+ else
+ {
+ intermediateType = primalBuilder->getBackwardDiffIntermediateContextType(outerGen);
+ List<IRInst*> args;
+ for (UInt i = 0; i < specialize->getArgCount(); i++)
+ args.add(specialize->getArg(i));
+ intermediateType = primalBuilder->emitSpecializeInst(
+ primalBuilder->getTypeKind(),
+ intermediateType,
+ args.getCount(),
+ args.getBuffer());
+ }
}
else
{
- intermediateType = primalBuilder->getBackwardDiffIntermediateContextType(baseFn);
+ if (baseFn->getOp() == kIROp_LookupWitness)
+ intermediateType = primalBuilder->getVoidType();
+ else
+ intermediateType = primalBuilder->getBackwardDiffIntermediateContextType(baseFn);
}
- auto intermediateVar = primalBuilder->emitVar((IRType*)intermediateType);
- primalBuilder->markInstAsPrimal(intermediateVar);
+ IRVar* intermediateVar = nullptr;
+ if (!as<IRVoidType>(intermediateType))
+ {
+ intermediateVar = primalBuilder->emitVar((IRType*)intermediateType);
+ primalBuilder->markInstAsPrimal(intermediateVar);
+ }
- primalBuilder->addBackwardDerivativePrimalContextDecoration(intermediateVar, intermediateVar);
-
- auto primalFn = primalBuilder->emitBackwardDifferentiatePrimalInst(primalFuncType, baseFn);
-
+ IRInst* primalFn = nullptr;
+ if (intermediateVar)
+ {
+ primalBuilder->addBackwardDerivativePrimalContextDecoration(intermediateVar, intermediateVar);
+ primalFn = primalBuilder->emitBackwardDifferentiatePrimalInst(primalFuncType, baseFn);
+ }
+ else
+ {
+ // If we decided not to use diff-primal func that stores an reuse context,
+ // we can just call the original function instead.
+ primalFn = baseFn;
+ }
List<IRInst*> primalArgs;
for (UIndex ii = 0; ii < mixedCall->getArgCount(); ii++)
{
@@ -865,7 +888,8 @@ struct DiffUnzipPass
primalArgs.add(arg);
}
}
- primalArgs.add(intermediateVar);
+ if (intermediateType->getOp() != kIROp_VoidType)
+ primalArgs.add(intermediateVar);
auto mixedDecoration = mixedCall->findDecoration<IRMixedDifferentialInstDecoration>();
SLANG_ASSERT(mixedDecoration);
@@ -881,7 +905,8 @@ struct DiffUnzipPass
}
auto primalVal = primalBuilder->emitCallInst(primalType, primalFn, primalArgs);
- primalBuilder->addBackwardDerivativePrimalContextDecoration(primalVal, intermediateVar);
+ if (intermediateVar)
+ primalBuilder->addBackwardDerivativePrimalContextDecoration(primalVal, intermediateVar);
primalBuilder->markInstAsPrimal(primalVal);
SLANG_RELEASE_ASSERT(mixedCall->getArgCount() <= primalFuncType->getParamCount());
@@ -960,9 +985,12 @@ struct DiffUnzipPass
diffArgs);
diffBuilder->markInstAsDifferential(callInst, primalType);
- disableIRValidationAtInsert();
- diffBuilder->addBackwardDerivativePrimalContextDecoration(callInst, intermediateVar);
- enableIRValidationAtInsert();
+ if (intermediateVar)
+ {
+ disableIRValidationAtInsert();
+ diffBuilder->addBackwardDerivativePrimalContextDecoration(callInst, intermediateVar);
+ enableIRValidationAtInsert();
+ }
IRInst* diffVal = nullptr;
if (as<IRDifferentialPairType>(callInst->getDataType()))
diff --git a/source/slang/slang-lower-to-ir.cpp b/source/slang/slang-lower-to-ir.cpp
index d09c35eea..261e08168 100644
--- a/source/slang/slang-lower-to-ir.cpp
+++ b/source/slang/slang-lower-to-ir.cpp
@@ -6899,14 +6899,6 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo>
{
op = kIROp_BackwardDerivativeDecoration;
}
- else if (as<BackwardDerivativePropagateRequirementDecl>(requirementDecl))
- {
- op = kIROp_BackwardDerivativePropagateDecoration;
- }
- else if (as<BackwardDerivativePrimalRequirementDecl>(requirementDecl))
- {
- op = kIROp_BackwardDerivativePrimalDecoration;
- }
else if (as<ForwardDerivativeRequirementDecl>(requirementDecl))
{
op = kIROp_ForwardDerivativeDecoration;
@@ -8534,12 +8526,6 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo>
UNREACHABLE_RETURN(LoweredValInfo());
}
- LoweredValInfo visitBackwardDerivativeIntermediateTypeRequirementDecl(BackwardDerivativeIntermediateTypeRequirementDecl* decl)
- {
- SLANG_UNUSED(decl);
- return LoweredValInfo(getBuilder()->getTypeKind());
- }
-
LoweredValInfo visitFunctionDeclBase(FunctionDeclBase* decl)
{
// A function declaration may have multiple, target-specific
diff --git a/source/slang/slang-mangle.cpp b/source/slang/slang-mangle.cpp
index da5099934..a7d047a0c 100644
--- a/source/slang/slang-mangle.cpp
+++ b/source/slang/slang-mangle.cpp
@@ -521,12 +521,6 @@ namespace Slang
emitRaw(context, "FwdReq_");
else if (as<BackwardDerivativeRequirementDecl>(decl))
emitRaw(context, "BwdReq_");
- else if (as<BackwardDerivativePropagateRequirementDecl>(decl))
- emitRaw(context, "BwdReq_Prop_");
- else if (as<BackwardDerivativePrimalRequirementDecl>(decl))
- emitRaw(context, "BwdReq_Primal_");
- else if (as<BackwardDerivativeIntermediateTypeRequirementDecl>(decl))
- emitRaw(context, "BwdReq_CtxType_");
else
{
// TODO: handle other cases
diff --git a/tests/autodiff/dynamic-dispatch-bwd-diff.slang b/tests/autodiff/dynamic-dispatch-bwd-diff.slang
new file mode 100644
index 000000000..5945c22cd
--- /dev/null
+++ b/tests/autodiff/dynamic-dispatch-bwd-diff.slang
@@ -0,0 +1,52 @@
+// Test calling backward differentiable function through dynamic dispatch.
+
+//TEST(compute):COMPARE_COMPUTE_EX:-slang -compute -shaderobj -output-using-type
+//TEST(compute, vulkan):COMPARE_COMPUTE_EX:-vk -compute -shaderobj -output-using-type
+
+//TEST_INPUT:ubuffer(data=[0 0 0 0 0], stride=4):out,name=outputBuffer
+RWStructuredBuffer<float> outputBuffer;
+
+[anyValueSize(16)]
+interface IInterface
+{
+ [BackwardDifferentiable]
+ float calc(float x);
+}
+
+struct A : IInterface
+{
+ float a;
+ [BackwardDifferentiable]
+ float calc(float x) { return a*x*x; }
+};
+
+struct B : IInterface
+{
+ float a;
+ [BackwardDifferentiable]
+ float calc(float x) { return a*x*x*x; }
+};
+
+[BackwardDifferentiable]
+float run(IInterface obj, float x)
+{
+ return obj.calc(x);
+}
+
+//TEST_INPUT: type_conformance A:IInterface = 0
+//TEST_INPUT: type_conformance B:IInterface = 1
+
+[numthreads(1, 1, 1)]
+void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID)
+{
+ var obj = createDynamicObject<IInterface>(dispatchThreadID.x, 0.5f); // A
+ var p = diffPair(3.0);
+
+ __bwd_diff(run)(obj, p, 1.0f);
+ outputBuffer[0] = p.d; // A.calc, expect 3
+
+ obj = createDynamicObject<IInterface>(dispatchThreadID.x + 1, 1.5f); // B
+ p = diffPair(3.0);
+ __bwd_diff(run)(obj, p, 1.0f);
+ outputBuffer[1] = p.d; // B.calc, expect 40.5
+}
diff --git a/tests/autodiff/dynamic-dispatch-bwd-diff.slang.expected.txt b/tests/autodiff/dynamic-dispatch-bwd-diff.slang.expected.txt
new file mode 100644
index 000000000..57bb1ee65
--- /dev/null
+++ b/tests/autodiff/dynamic-dispatch-bwd-diff.slang.expected.txt
@@ -0,0 +1,6 @@
+type: float
+3.000000
+40.500000
+0.000000
+0.000000
+0.000000