summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorYong He <yonghe@outlook.com>2022-11-23 09:39:08 -0800
committerGitHub <noreply@github.com>2022-11-23 09:39:08 -0800
commit97cb4851eed7a43f10196971b08d3d311386ce9f (patch)
tree99ba81368068b3345fa23b749108265aa753ed2b
parent6178cb601368e977c4aa82e0ae25b8eb1e875d84 (diff)
Autodiff through simple dynamic dispatch. (#2527)
* Autodiff through simple dynamic dispatch. * Revert changes. * Fix. Co-authored-by: Yong He <yhe@nvidia.com>
-rw-r--r--source/slang/slang-ast-decl.h19
-rw-r--r--source/slang/slang-ast-modifier.h19
-rw-r--r--source/slang/slang-ast-val.cpp40
-rw-r--r--source/slang/slang-ast-val.h23
-rw-r--r--source/slang/slang-check-decl.cpp108
-rw-r--r--source/slang/slang-ir-autodiff-fwd.cpp204
-rw-r--r--source/slang/slang-ir-inst-defs.h8
-rw-r--r--source/slang/slang-ir-insts.h19
-rw-r--r--source/slang/slang-lower-to-ir.cpp80
-rw-r--r--source/slang/slang-mangle.cpp4
-rw-r--r--source/slang/slang-syntax.h4
-rw-r--r--tests/autodiff/dynamic-dispatch-autodiff-simple.slang48
-rw-r--r--tests/autodiff/dynamic-dispatch-autodiff-simple.slang.expected.txt6
13 files changed, 522 insertions, 60 deletions
diff --git a/source/slang/slang-ast-decl.h b/source/slang/slang-ast-decl.h
index 4da832d11..e7dc73a85 100644
--- a/source/slang/slang-ast-decl.h
+++ b/source/slang/slang-ast-decl.h
@@ -518,6 +518,25 @@ class AttributeDecl : public ContainerDecl
SyntaxClass<NodeBase> syntaxClass;
};
+// A synthesized decl used as a placeholder for a differentiable function requirement.
+// This allows us to form an interface requirement key for the derivative of an interface function.
+// The synthesized `DerivativeRequirementDecl` will be a child of the original function requirement
+// decl after an interface type is checked.
+class DerivativeRequirementDecl : public FunctionDeclBase
+{
+ SLANG_AST_CLASS(DerivativeRequirementDecl)
+};
+
+class ForwardDerivativeRequirementDecl : public DerivativeRequirementDecl
+{
+ SLANG_AST_CLASS(ForwardDerivativeRequirementDecl)
+};
+
+class BackwardDerivativeRequirementDecl : public DerivativeRequirementDecl
+{
+ SLANG_AST_CLASS(BackwardDerivativeRequirementDecl)
+};
+
bool isInterfaceRequirement(Decl* decl);
} // namespace Slang
diff --git a/source/slang/slang-ast-modifier.h b/source/slang/slang-ast-modifier.h
index 04af66b50..69f39efb6 100644
--- a/source/slang/slang-ast-modifier.h
+++ b/source/slang/slang-ast-modifier.h
@@ -1085,6 +1085,25 @@ class BackwardDifferentiableAttribute : public DifferentiableAttribute
SLANG_AST_CLASS(BackwardDifferentiableAttribute)
};
+ /// The `[BackwardDerivative(function)]` attribute specifies a custom function that should
+ /// be used as the backward-derivative for the decorated function.
+class BackwardDerivativeAttribute : public DifferentiableAttribute
+{
+ SLANG_AST_CLASS(BackwardDerivativeAttribute)
+ Expr* funcExpr;
+};
+
+ /// The `[BackwardDerivativeOf(primalFunction)]` attribute marks the decorated function as custom
+ /// backward-derivative implementation for `primalFunction`.
+class BackwardDerivativeOfAttribute : public DifferentiableAttribute
+{
+ SLANG_AST_CLASS(BackwardDerivativeOfAttribute)
+
+ Expr* funcExpr;
+
+ Expr* backDeclRef; // DeclRef to this derivative function when initiated from primalFunction.
+};
+
/// Indicates that the modified declaration is one of the "magic" declarations
/// that NVAPI uses to communicate extended operations. When NVAPI is being included
/// via the prelude for downstream compilation, declarations with this modifier
diff --git a/source/slang/slang-ast-val.cpp b/source/slang/slang-ast-val.cpp
index 87e89ef18..a0f0552c6 100644
--- a/source/slang/slang-ast-val.cpp
+++ b/source/slang/slang-ast-val.cpp
@@ -1516,4 +1516,44 @@ Val* WitnessLookupIntVal::tryFold(ASTBuilder* astBuilder, SubtypeWitness* witnes
return witnessResult;
}
+
+bool DifferentiateVal::_equalsValOverride(Val* val)
+{
+ if (auto other = as<DifferentiateVal>(val))
+ {
+ return other->astNodeType == astNodeType && other->func == func;
+ }
+ return false;
+}
+
+void DifferentiateVal::_toTextOverride(StringBuilder& out)
+{
+ out << "DifferentiateVal(";
+ out << func;
+ out << ")";
+}
+
+HashCode DifferentiateVal::_getHashCodeOverride()
+{
+ HashCode result = (HashCode)astNodeType;
+ result = combineHash(result, func.getHashCode());
+ return result;
+}
+
+Val* DifferentiateVal::_substituteImplOverride(ASTBuilder* astBuilder, SubstitutionSet subst, int* ioDiff)
+{
+ int diff = 0;
+ auto newFunc = func.substituteImpl(astBuilder, subst, &diff);
+ *ioDiff += diff;
+ if (diff)
+ {
+ auto result = as<DifferentiateVal>(astBuilder->createByNodeType(astNodeType));
+ result->func = newFunc;
+ return result;
+ }
+ // Nothing found: don't substitute.
+ return this;
+}
+
+
} // namespace Slang
diff --git a/source/slang/slang-ast-val.h b/source/slang/slang-ast-val.h
index b52984f8b..31b74a499 100644
--- a/source/slang/slang-ast-val.h
+++ b/source/slang/slang-ast-val.h
@@ -490,4 +490,27 @@ class SNormModifierVal : public ResourceFormatModifierVal
Val* _substituteImplOverride(ASTBuilder* astBuilder, SubstitutionSet subst, int* ioDiff);
};
+ /// Represents the result of differentiating a function.
+class DifferentiateVal : public Val
+{
+ SLANG_AST_CLASS(DifferentiateVal)
+
+ DeclRef<Decl> func;
+
+ bool _equalsValOverride(Val* val);
+ void _toTextOverride(StringBuilder& out);
+ HashCode _getHashCodeOverride();
+ Val* _substituteImplOverride(ASTBuilder* astBuilder, SubstitutionSet subst, int* ioDiff);
+};
+
+class ForwardDifferentiateVal : public DifferentiateVal
+{
+ SLANG_AST_CLASS(ForwardDifferentiateVal)
+};
+
+class BackwardDifferentiateVal : public DifferentiateVal
+{
+ SLANG_AST_CLASS(BackwardDifferentiateVal)
+};
+
} // namespace Slang
diff --git a/source/slang/slang-check-decl.cpp b/source/slang/slang-check-decl.cpp
index 5a1218abe..36a1061c9 100644
--- a/source/slang/slang-check-decl.cpp
+++ b/source/slang/slang-check-decl.cpp
@@ -107,6 +107,9 @@ namespace Slang
void visitAccessorDecl(AccessorDecl* decl);
void visitSetterDecl(SetterDecl* decl);
+
+ void cloneModifiers(Decl* dest, Decl* src);
+ void setFuncTypeIntoRequirementDecl(CallableDecl* decl, FuncType* funcType);
};
struct SemanticsDeclRedeclarationVisitor
@@ -1866,6 +1869,32 @@ namespace Slang
return false;
}
+ bool hasBackwardDerivative = false;
+ bool hasForwardDerivative = false;
+ if (requiredMemberDeclRef.getDecl()->hasModifier<BackwardDifferentiableAttribute>())
+ {
+ if (!satisfyingMemberDeclRef.getDecl()->hasModifier<BackwardDifferentiableAttribute>()
+ && !satisfyingMemberDeclRef.getDecl()->hasModifier<BackwardDerivativeAttribute>())
+ {
+ // A non-`BackwardDifferentiable` method can't satisfy a `BackwardDifferentiable` requirement and vice versa.
+ return false;
+ }
+ hasBackwardDerivative = true;
+ hasForwardDerivative = true;
+ }
+ else if (requiredMemberDeclRef.getDecl()->hasModifier<ForwardDifferentiableAttribute>())
+ {
+ if (!satisfyingMemberDeclRef.getDecl()->hasModifier<ForwardDifferentiableAttribute>()
+ && !satisfyingMemberDeclRef.getDecl()->hasModifier<ForwardDerivativeAttribute>()
+ && !satisfyingMemberDeclRef.getDecl()->hasModifier<BackwardDifferentiableAttribute>()
+ && !satisfyingMemberDeclRef.getDecl()->hasModifier<BackwardDerivativeAttribute>())
+ {
+ // A non-`ForwardDifferentiable` method can't satisfy a `ForwardDifferentiable` requirement and vice versa.
+ return false;
+ }
+ hasForwardDerivative = true;
+ }
+
// A signature matches the required one if it has the right number of parameters,
// and those parameters have the right types, and also the result/return type
// is the required one.
@@ -1896,6 +1925,24 @@ namespace Slang
witnessTable->add(
requiredMemberDeclRef.getDecl(),
RequirementWitness(satisfyingMemberDeclRef));
+
+ if (hasForwardDerivative)
+ {
+ auto reqDecl = requiredMemberDeclRef.getDecl()->getMembersOfType<ForwardDerivativeRequirementDecl>();
+ SLANG_RELEASE_ASSERT(reqDecl.isNonEmpty());
+ ForwardDifferentiateVal* val = m_astBuilder->create<ForwardDifferentiateVal>();
+ val->func = satisfyingMemberDeclRef;
+ witnessTable->add(reqDecl.getFirst(), RequirementWitness(val));
+ }
+
+ if (hasBackwardDerivative)
+ {
+ auto reqDecl = requiredMemberDeclRef.getDecl()->getMembersOfType<BackwardDerivativeRequirementDecl>();
+ SLANG_RELEASE_ASSERT(reqDecl.isNonEmpty());
+ BackwardDifferentiateVal* val = m_astBuilder->create<BackwardDifferentiateVal>();
+ val->func = satisfyingMemberDeclRef;
+ witnessTable->add(reqDecl.getFirst(), RequirementWitness(val));
+ }
return true;
}
@@ -5515,6 +5562,43 @@ namespace Slang
}
}
+ void SemanticsDeclHeaderVisitor::cloneModifiers(Decl* dest, Decl* src)
+ {
+ dest->modifiers = src->modifiers;
+ }
+ void SemanticsDeclHeaderVisitor::setFuncTypeIntoRequirementDecl(CallableDecl* decl, FuncType* funcType)
+ {
+ if (!funcType)
+ return;
+ decl->returnType.type = funcType->getResultType();
+ decl->errorType.type = funcType->getErrorType();
+ for (UInt i = 0; i < funcType->getParamCount(); i++)
+ {
+ auto paramType = funcType->getParamType(i);
+ if (auto dirType = as<ParamDirectionType>(paramType))
+ paramType = dirType->getValueType();
+ auto param = m_astBuilder->create<ParamDecl>();
+ param->type.type = paramType;
+ auto paramDir = funcType->getParamDirection(i);
+ switch (paramDir)
+ {
+ case ParameterDirection::kParameterDirection_InOut:
+ addModifier(param, m_astBuilder->create<InOutModifier>());
+ break;
+ case ParameterDirection::kParameterDirection_Out:
+ addModifier(param, m_astBuilder->create<OutModifier>());
+ break;
+ case ParameterDirection::kParameterDirection_Ref:
+ addModifier(param, m_astBuilder->create<RefModifier>());
+ break;
+ default:
+ break;
+ }
+ decl->members.add(param);
+ param->parentDecl = decl;
+ }
+ }
+
void SemanticsDeclHeaderVisitor::checkCallableDeclCommon(CallableDecl* decl)
{
for(auto paramDecl : decl->getParameters())
@@ -5532,6 +5616,30 @@ namespace Slang
errorType = TypeExp(m_astBuilder->getBottomType());
}
decl->errorType = errorType;
+
+ if (isInterfaceRequirement(decl))
+ {
+ if (decl->hasModifier<ForwardDifferentiableAttribute>())
+ {
+ auto reqDecl = m_astBuilder->create<ForwardDerivativeRequirementDecl>();
+ cloneModifiers(reqDecl, decl);
+ auto declRef = DeclRef<CallableDecl>(decl, createDefaultSubstitutions(m_astBuilder, this, decl));
+ auto diffFuncType = getForwardDiffFuncType(getFuncType(m_astBuilder, declRef));
+ setFuncTypeIntoRequirementDecl(reqDecl, as<FuncType>(diffFuncType));
+ decl->members.add(reqDecl);
+ reqDecl->parentDecl = decl;
+ }
+ if (decl->hasModifier<BackwardDifferentiableAttribute>())
+ {
+ auto reqDecl = m_astBuilder->create<BackwardDerivativeRequirementDecl>();
+ cloneModifiers(reqDecl, decl);
+ auto declRef = DeclRef<CallableDecl>(decl, createDefaultSubstitutions(m_astBuilder, this, decl));
+ auto diffFuncType = getBackwardDiffFuncType(getFuncType(m_astBuilder, declRef));
+ setFuncTypeIntoRequirementDecl(reqDecl, as<FuncType>(diffFuncType));
+ decl->members.add(reqDecl);
+ reqDecl->parentDecl = decl;
+ }
+ }
}
void SemanticsDeclHeaderVisitor::visitFuncDecl(FuncDecl* funcDecl)
diff --git a/source/slang/slang-ir-autodiff-fwd.cpp b/source/slang/slang-ir-autodiff-fwd.cpp
index 03e81c5b5..04a898ea9 100644
--- a/source/slang/slang-ir-autodiff-fwd.cpp
+++ b/source/slang/slang-ir-autodiff-fwd.cpp
@@ -100,8 +100,19 @@ struct JVPTranscriber
return instMapD.ContainsKey(origInst);
}
+ bool shouldUseOriginalAsPrimal(IRInst* origInst)
+ {
+ if (as<IRGlobalValueWithCode>(origInst))
+ return true;
+ if (origInst->parent && origInst->parent->getOp() == kIROp_Module)
+ return true;
+ return false;
+ }
+
IRInst* lookupPrimalInst(IRInst* origInst)
{
+ if (shouldUseOriginalAsPrimal(origInst))
+ return origInst;
return cloneEnv.mapOldValToNew[origInst];
}
@@ -112,9 +123,11 @@ struct JVPTranscriber
bool hasPrimalInst(IRInst* origInst)
{
+ if (shouldUseOriginalAsPrimal(origInst))
+ return true;
return cloneEnv.mapOldValToNew.ContainsKey(origInst);
}
-
+
IRInst* findOrTranscribeDiffInst(IRBuilder* builder, IRInst* origInst)
{
if (!hasDifferentialInst(origInst))
@@ -128,6 +141,9 @@ struct JVPTranscriber
IRInst* findOrTranscribePrimalInst(IRBuilder* builder, IRInst* origInst)
{
+ if (shouldUseOriginalAsPrimal(origInst))
+ return origInst;
+
if (!hasPrimalInst(origInst))
{
transcribe(builder, origInst);
@@ -687,7 +703,10 @@ struct JVPTranscriber
IRInst* diffCallee = nullptr;
- if (auto derivativeReferenceDecor = primalCallee->findDecoration<IRForwardDerivativeDecoration>())
+ if (instMapD.TryGetValue(origCallee, diffCallee))
+ {
+ }
+ else if (auto derivativeReferenceDecor = primalCallee->findDecoration<IRForwardDerivativeDecoration>())
{
// If the user has already provided an differentiated implementation, use that.
diffCallee = derivativeReferenceDecor->getForwardDerivativeFunc();
@@ -901,48 +920,111 @@ struct JVPTranscriber
return InstPair(nullptr, nullptr);
}
- InstPair transcribeSpecialize(IRBuilder*, IRSpecialize* origSpecialize)
+ IRInst* findInterfaceRequirement(IRInterfaceType* type, IRInst* key)
{
- // In general, we should not see any specialize insts at this stage.
- // The exceptions are target intrinsics.
- auto genericInnerVal = findInnerMostGenericReturnVal(as<IRGeneric>(origSpecialize->getBase()));
- if (genericInnerVal->findDecoration<IRTargetIntrinsicDecoration>())
+ for (UInt i = 0; i < type->getOperandCount(); i++)
{
- // Look for an IRForwardDerivativeDecoration on the specialize inst.
- // (Normally, this would be on the inner IRFunc, but in this case only the JVP func
- // can be specialized, so we put a decoration on the IRSpecialize)
- //
- if (auto jvpFuncDecoration = origSpecialize->findDecoration<IRForwardDerivativeDecoration>())
+ if (auto req = as<IRInterfaceRequirementEntry>(type->getOperand(i)))
{
- auto jvpFunc = jvpFuncDecoration->getForwardDerivativeFunc();
+ if (req->getRequirementKey() == key)
+ return req->getRequirementVal();
+ }
+ }
+ return nullptr;
+ }
- // Make sure this isn't itself a specialize .
- SLANG_RELEASE_ASSERT(!as<IRSpecialize>(jvpFunc));
+ InstPair transcribeSpecialize(IRBuilder* builder, IRSpecialize* origSpecialize)
+ {
+ auto primalBase = findOrTranscribePrimalInst(builder, origSpecialize->getBase());
+ List<IRInst*> primalArgs;
+ for (UInt i = 0; i < origSpecialize->getArgCount(); i++)
+ {
+ primalArgs.add(findOrTranscribePrimalInst(builder, origSpecialize->getArg(i)));
+ }
+ auto primalType = findOrTranscribePrimalInst(builder, origSpecialize->getFullType());
+ auto primalSpecialize = (IRSpecialize*)builder->emitSpecializeInst(
+ (IRType*)primalType, primalBase, primalArgs.getCount(), primalArgs.getBuffer());
- return InstPair(jvpFunc, jvpFunc);
+ IRInst* diffBase = nullptr;
+ if (instMapD.TryGetValue(origSpecialize->getBase(), diffBase))
+ {
+ List<IRInst*> args;
+ for (UInt i = 0; i < primalSpecialize->getArgCount(); i++)
+ {
+ args.add(primalSpecialize->getArg(i));
}
+ auto diffSpecialize = builder->emitSpecializeInst(
+ builder->getTypeKind(), diffBase, args.getCount(), args.getBuffer());
+ return InstPair(primalSpecialize, diffSpecialize);
+ }
+
+ auto genericInnerVal = findInnerMostGenericReturnVal(as<IRGeneric>(origSpecialize->getBase()));
+ // Look for an IRForwardDerivativeDecoration on the specialize inst.
+ // (Normally, this would be on the inner IRFunc, but in this case only the JVP func
+ // can be specialized, so we put a decoration on the IRSpecialize)
+ //
+ if (auto jvpFuncDecoration = origSpecialize->findDecoration<IRForwardDerivativeDecoration>())
+ {
+ auto jvpFunc = jvpFuncDecoration->getForwardDerivativeFunc();
+
+ // Make sure this isn't itself a specialize .
+ SLANG_RELEASE_ASSERT(!as<IRSpecialize>(jvpFunc));
+
+ return InstPair(primalSpecialize, jvpFunc);
+ }
+ else if (auto derivativeDecoration = genericInnerVal->findDecoration<IRForwardDerivativeDecoration>())
+ {
+ diffBase = derivativeDecoration->getForwardDerivativeFunc();
+ List<IRInst*> args;
+ for (UInt i = 0; i < primalSpecialize->getArgCount(); i++)
+ {
+ args.add(primalSpecialize->getArg(i));
+ }
+ auto diffSpecialize = builder->emitSpecializeInst(
+ builder->getTypeKind(), diffBase, args.getCount(), args.getBuffer());
+ return InstPair(primalSpecialize, diffSpecialize);
}
else
{
- getSink()->diagnose(origSpecialize->sourceLoc,
- Diagnostics::unexpected,
- "should not be attempting to differentiate anything specialized here.");
+ return InstPair(primalSpecialize, nullptr);
}
-
- return InstPair(nullptr, nullptr);
}
- InstPair transcibeLookupInterfaceMethod(IRBuilder* builder, IRLookupWitnessMethod* origLookup)
+ InstPair transcribeLookupInterfaceMethod(IRBuilder* builder, IRLookupWitnessMethod* lookupInst)
{
- // This is slightly counter-intuitive, but we don't perform any differentiation
- // logic here. We simple clone the original lookup which points to the original function,
- // or the cloned version in case we're inside a generic scope.
- // The differentiation logic is inserted later when this is used in an IRCall.
- // This decision is mostly to maintain a uniform convention of ForwardDifferentiate(Lookup(Table))
- // rather than have Lookup(ForwardDifferentiate(Table))
- //
- auto diffLookup = cloneInst(&cloneEnv, builder, origLookup);
- return InstPair(diffLookup, diffLookup);
+ auto primalWt = findOrTranscribePrimalInst(builder, lookupInst->getWitnessTable());
+ auto primalKey = findOrTranscribePrimalInst(builder, lookupInst->getRequirementKey());
+ auto primalType = findOrTranscribePrimalInst(builder, lookupInst->getFullType());
+ auto primal = (IRSpecialize*)builder->emitLookupInterfaceMethodInst((IRType*)primalType, primalWt, primalKey);
+
+ auto interfaceType = as<IRInterfaceType>(as<IRWitnessTableTypeBase>(lookupInst->getWitnessTable()->getDataType())->getConformanceType());
+ if (!interfaceType)
+ {
+ return InstPair(primal, nullptr);
+ }
+ auto dict = interfaceType->findDecoration<IRDifferentiableMethodRequirementDictionaryDecoration>();
+ if (!dict)
+ {
+ return InstPair(primal, nullptr);
+ }
+
+ for (auto child : dict->getChildren())
+ {
+ if (auto item = as<IRForwardDifferentiableMethodRequirementDictionaryItem>(child))
+ {
+ if (item->getOperand(0) == lookupInst->getRequirementKey())
+ {
+ auto diffKey = item->getOperand(1);
+ if (auto diffType = findInterfaceRequirement(interfaceType, diffKey))
+ {
+ auto diff = builder->emitLookupInterfaceMethodInst((IRType*)diffType, primalWt, diffKey);
+ return InstPair(primal, diff);
+ }
+ break;
+ }
+ }
+ }
+ return InstPair(primal, nullptr);
}
// In differential computation, the 'default' differential value is always zero.
@@ -1188,6 +1270,12 @@ struct JVPTranscriber
return InstPair(primalPair, diffPair);
}
+ InstPair trascribeNonDiffInst(IRBuilder* builder, IRInst* origInst)
+ {
+ auto primal = cloneInst(&cloneEnv, builder, origInst);
+ return InstPair(primal, nullptr);
+ }
+
InstPair transcribeDifferentialPairGetElement(IRBuilder* builder, IRInst* origInst)
{
SLANG_ASSERT(
@@ -1335,6 +1423,7 @@ struct JVPTranscriber
//
instsInProgress.Add(origInst);
InstPair pair = transcribeInst(builder, origInst);
+ instsInProgress.Remove(origInst);
if (auto primalInst = pair.primal)
{
@@ -1363,11 +1452,9 @@ struct JVPTranscriber
}
return pair.differential;
}
- instsInProgress.Remove(origInst);
-
getSink()->diagnose(origInst->sourceLoc,
- Diagnostics::internalCompilerError,
- "failed to transcibe instruction");
+ Diagnostics::internalCompilerError,
+ "failed to transcibe instruction");
return nullptr;
}
@@ -1407,7 +1494,10 @@ struct JVPTranscriber
case kIROp_Construct:
return transcribeConstruct(builder, origInst);
-
+
+ case kIROp_lookup_interface_method:
+ return transcribeLookupInterfaceMethod(builder, as<IRLookupWitnessMethod>(origInst));
+
case kIROp_Call:
return transcribeCall(builder, as<IRCall>(origInst));
@@ -1429,9 +1519,6 @@ struct JVPTranscriber
case kIROp_Specialize:
return transcribeSpecialize(builder, as<IRSpecialize>(origInst));
- case kIROp_lookup_interface_method:
- return transcibeLookupInterfaceMethod(builder, as<IRLookupWitnessMethod>(origInst));
-
case kIROp_FieldExtract:
case kIROp_FieldAddress:
return transcribeFieldExtract(builder, origInst);
@@ -1450,6 +1537,15 @@ struct JVPTranscriber
case kIROp_DifferentialPairGetPrimal:
case kIROp_DifferentialPairGetDifferential:
return transcribeDifferentialPairGetElement(builder, origInst);
+ case kIROp_ExtractExistentialWitnessTable:
+ case kIROp_ExtractExistentialType:
+ case kIROp_ExtractExistentialValue:
+ case kIROp_WrapExistential:
+ case kIROp_MakeExistential:
+ case kIROp_MakeExistentialWithRTTI:
+ return trascribeNonDiffInst(builder, origInst);
+ case kIROp_StructKey:
+ return InstPair(origInst, nullptr);
}
// If none of the cases have been hit, check if the instruction is a
@@ -1496,7 +1592,6 @@ struct JVPTranscriber
return transcribeGeneric(builder, as<IRGeneric>(origInst));
}
-
// If we reach this statement, the instruction type is likely unhandled.
getSink()->diagnose(origInst->sourceLoc,
Diagnostics::unimplemented,
@@ -1558,6 +1653,7 @@ struct ForwardDerivativePass : public InstPassBase
{
case kIROp_Func:
case kIROp_Specialize:
+ case kIROp_lookup_interface_method:
autoDiffWorkList.add(inst);
break;
default:
@@ -1587,29 +1683,15 @@ struct ForwardDerivativePass : public InstPassBase
differentiateInst->replaceUsesWith(existingDiffFunc);
differentiateInst->removeAndDeallocate();
}
- else if (isMarkedForForwardDifferentiation(baseInst))
- {
- if (as<IRFunc>(baseInst) || as<IRGeneric>(baseInst))
- {
- IRInst* diffFunc = transcriberStorage.transcribe(builder, baseInst);
- SLANG_ASSERT(diffFunc);
- differentiateInst->replaceUsesWith(diffFunc);
- differentiateInst->removeAndDeallocate();
- }
- else
- {
- getSink()->diagnose(differentiateInst->sourceLoc,
- Diagnostics::internalCompilerError,
- "Unexpected instruction. Expected func or generic");
- }
- }
else
{
- getSink()->diagnose(differentiateInst->sourceLoc,
- Diagnostics::internalCompilerError,
- "Requested differentiation on a function that isn't marked as differentiable.");
+ IRBuilder subBuilder(*builder);
+ subBuilder.setInsertBefore(differentiateInst);
+ IRInst* diffFunc = transcriberStorage.transcribe(&subBuilder, baseInst);
+ SLANG_ASSERT(diffFunc);
+ differentiateInst->replaceUsesWith(diffFunc);
+ differentiateInst->removeAndDeallocate();
}
-
}
}
// Actually synthesize the derivatives.
@@ -1638,6 +1720,8 @@ struct ForwardDerivativePass : public InstPassBase
//
bool isMarkedForForwardDifferentiation(IRInst* callable)
{
+ if (auto gen = as<IRGeneric>(callable))
+ callable = findGenericReturnVal(gen);
return callable->findDecoration<IRForwardDifferentiableDecoration>() != nullptr;
}
diff --git a/source/slang/slang-ir-inst-defs.h b/source/slang/slang-ir-inst-defs.h
index e11f98dcd..cc5261d14 100644
--- a/source/slang/slang-ir-inst-defs.h
+++ b/source/slang/slang-ir-inst-defs.h
@@ -746,6 +746,9 @@ INST(HighLevelDeclDecoration, highLevelDecl, 1, 0)
/* Differentiable Type Dictionary */
INST(DifferentiableTypeDictionaryDecoration, DifferentiableTypeDictionaryDecoration, 0, PARENT)
+ /// Decorates an interface type and stores the mapping from a normal function requirement key to its derivative requirement key.
+ INST(DifferentiableMethodRequirementDictionaryDecoration, DifferentiableMethodRequirementDictionaryDecoration, 0, PARENT)
+
/// Marks a struct type as being used as a structured buffer block.
/// Recognized by SPIRV-emit pass so we can emit a SPIRV `BufferBlock` decoration.
INST(SPIRVBufferBlockDecoration, spvBufferBlock, 0, 0)
@@ -846,6 +849,11 @@ INST(ExistentialTypeSpecializationDictionary, ExistentialTypeSpecializationDicti
/* Differentiable Type Dictionary */
INST(DifferentiableTypeDictionaryItem, DifferentiableTypeDictionaryItem, 0, 0)
+/* DifferentiableMethodRequirementDictionaryItem */
+ INST(ForwardDifferentiableMethodRequirementDictionaryItem, DifferentiableMethodRequirementDictionaryItem, 0, 0)
+ INST(BackwardDifferentiableMethodRequirementDictionaryItem, DifferentiableMethodRequirementDictionaryItem, 0, 0)
+INST_RANGE(DifferentiableMethodRequirementDictionaryItem, ForwardDifferentiableMethodRequirementDictionaryItem, BackwardDifferentiableMethodRequirementDictionaryItem)
+
#undef PARENT
#undef USE_OTHER
#undef INST_RANGE
diff --git a/source/slang/slang-ir-insts.h b/source/slang/slang-ir-insts.h
index 4434210c9..5c0401cc2 100644
--- a/source/slang/slang-ir-insts.h
+++ b/source/slang/slang-ir-insts.h
@@ -659,6 +659,25 @@ struct IRDifferentiableTypeDictionaryDecoration : IRDecoration
IR_LEAF_ISA(DifferentiableTypeDictionaryDecoration)
};
+struct IRDifferentiableMethodRequirementDictionaryDecoration : IRDecoration
+{
+ IR_LEAF_ISA(DifferentiableMethodRequirementDictionaryDecoration)
+};
+
+struct IRDifferentiableMethodRequirementDictionaryItem : IRInst
+{
+ IR_PARENT_ISA(DifferentiableMethodRequirementDictionaryItem)
+};
+
+struct IRForwardDifferentiableMethodRequirementDictionaryItem : IRDifferentiableMethodRequirementDictionaryItem
+{
+ IR_LEAF_ISA(ForwardDifferentiableMethodRequirementDictionaryItem)
+};
+
+struct IRBackwardDifferentiableMethodRequirementDictionaryItem : IRDifferentiableMethodRequirementDictionaryItem
+{
+ IR_LEAF_ISA(BackwardDifferentiableMethodRequirementDictionaryItem)
+};
// An instruction that specializes another IR value
// (representing a generic) to a particular set of generic arguments
diff --git a/source/slang/slang-lower-to-ir.cpp b/source/slang/slang-lower-to-ir.cpp
index 57267a9ea..a0becdafa 100644
--- a/source/slang/slang-lower-to-ir.cpp
+++ b/source/slang/slang-lower-to-ir.cpp
@@ -1398,6 +1398,27 @@ struct ValLoweringVisitor : ValVisitor<ValLoweringVisitor, LoweredValInfo, Lower
midToSup));
}
+ LoweredValInfo visitForwardDifferentiateVal(ForwardDifferentiateVal* val)
+ {
+ // TODO: properly fill in type info here.
+ // We should consider fold all cases of witness table entries to `Val`, and make the `DeclRef` case a `DeclRefVal`.
+ // So that we can hold the type in `DeclRefVal`.
+ auto funcVal = emitDeclRef(context, val->func, context->irBuilder->getTypeKind());
+ SLANG_RELEASE_ASSERT(funcVal.flavor == LoweredValInfo::Flavor::Simple);
+
+ auto diff = getBuilder()->emitForwardDifferentiateInst(getBuilder()->getTypeKind(), funcVal.val);
+ return LoweredValInfo::simple(diff);
+ }
+
+ LoweredValInfo visitBackwardDifferentiateVal(BackwardDifferentiateVal* val)
+ {
+ auto funcVal = emitDeclRef(context, val->func, context->irBuilder->getTypeKind());
+ SLANG_RELEASE_ASSERT(funcVal.flavor == LoweredValInfo::Flavor::Simple);
+
+ auto diff = getBuilder()->emitBackwardDifferentiateInst(getBuilder()->getTypeKind(), funcVal.val);
+ return LoweredValInfo::simple(diff);
+ }
+
LoweredValInfo visitDifferentialBottomSubtypeWitness(DifferentialBottomSubtypeWitness*)
{
return LoweredValInfo();
@@ -6786,6 +6807,24 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo>
return LoweredValInfo::simple(assocType);
}
+ void insertRequirementKeyAssociation(IRInterfaceType* interfaceType, Decl* requirementDecl, IRInst* originalKey, IRInst* associatedKey)
+ {
+ auto decor = interfaceType->findDecoration<IRDifferentiableMethodRequirementDictionaryDecoration>();
+ if (!decor)
+ {
+ decor =
+ (IRDifferentiableMethodRequirementDictionaryDecoration*)
+ context->irBuilder->addDecoration(
+ interfaceType, kIROp_DifferentiableMethodRequirementDictionaryDecoration);
+ }
+ auto op = as<ForwardDerivativeRequirementDecl>(requirementDecl)
+ ? kIROp_ForwardDifferentiableMethodRequirementDictionaryItem
+ : kIROp_BackwardDifferentiableMethodRequirementDictionaryItem;
+ IRInst* args[] = {originalKey, associatedKey};
+ auto assoc = context->irBuilder->emitIntrinsicInst(nullptr, op, 2, args);
+ assoc->insertAtEnd(decor);
+ }
+
LoweredValInfo visitInterfaceDecl(InterfaceDecl* decl)
{
// The members of an interface will turn into the keys that will
@@ -6824,6 +6863,14 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo>
{
operandCount += associatedTypeDecl->getMembersOfType<TypeConstraintDecl>().getCount();
}
+ else if (auto callableDecl = as<CallableDecl>(requirementDecl))
+ {
+ // Differentiable functions has additional requirements for the derivatives.
+ if (callableDecl->getMembersOfType<ForwardDerivativeRequirementDecl>().getCount())
+ operandCount++;
+ if (callableDecl->getMembersOfType<BackwardDerivativeRequirementDecl>().getCount())
+ operandCount++;
+ }
}
// Allocate an IRInterfaceType with the `operandCount` operands.
@@ -6907,6 +6954,38 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo>
}
else
{
+ if (auto callableDecl = as<CallableDecl>(requirementDecl))
+ {
+ // Differentiable functions has additional requirements for the derivatives.
+ for (auto diffDecl : callableDecl->getMembersOfType<DerivativeRequirementDecl>())
+ {
+ auto diffKey = getInterfaceRequirementKey(diffDecl);
+ IRInst* diffVal = ensureDecl(subContext, diffDecl).val;
+ auto diffEntry = subBuilder->createInterfaceRequirementEntry(diffKey, diffVal);
+ if (diffVal)
+ {
+ switch (diffVal->getOp())
+ {
+ case kIROp_Func:
+ case kIROp_Generic:
+ {
+ // Remove lowered `IRFunc`s since we only care about
+ // function types.
+ auto reqType = diffVal->getFullType();
+ diffEntry->setRequirementVal(reqType);
+ break;
+ }
+ default:
+ break;
+ }
+ }
+ irInterface->setOperand(entryIndex, diffEntry);
+ entryIndex++;
+
+ setValue(context, diffDecl, LoweredValInfo::simple(diffEntry));
+ insertRequirementKeyAssociation(irInterface, diffDecl, requirementKey, diffKey);
+ }
+ }
// Add lowered requirement entry to current decl mapping to prevent
// the function requirements from being lowered again when we get to
// `ensureAllDeclsRec`.
@@ -6914,6 +6993,7 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo>
}
}
+
addNameHint(context, irInterface, decl);
addLinkageDecoration(context, irInterface, decl);
if (auto anyValueSizeAttr = decl->findModifier<AnyValueSizeAttribute>())
diff --git a/source/slang/slang-mangle.cpp b/source/slang/slang-mangle.cpp
index bdb5465e9..6ea9ea01e 100644
--- a/source/slang/slang-mangle.cpp
+++ b/source/slang/slang-mangle.cpp
@@ -517,6 +517,10 @@ namespace Slang
emitQualifiedName(context, innerDecl);
return;
}
+ else if (as<ForwardDerivativeRequirementDecl>(decl))
+ emitRaw(context, "FwdReq_");
+ else if (as<BackwardDerivativeRequirementDecl>(decl))
+ emitRaw(context, "BwdReq_");
else
{
// TODO: handle other cases
diff --git a/source/slang/slang-syntax.h b/source/slang/slang-syntax.h
index 8f88ddb2d..2ceb7a9fd 100644
--- a/source/slang/slang-syntax.h
+++ b/source/slang/slang-syntax.h
@@ -309,6 +309,10 @@ namespace Slang
GenericSubstitution* findInnerMostGenericSubstitution(Substitutions* subst);
+ ThisTypeSubstitution* findThisTypeSubstitution(
+ const Substitutions* substs,
+ InterfaceDecl* interfaceDecl);
+
enum class UserDefinedAttributeTargets
{
None = 0,
diff --git a/tests/autodiff/dynamic-dispatch-autodiff-simple.slang b/tests/autodiff/dynamic-dispatch-autodiff-simple.slang
new file mode 100644
index 000000000..1247253f9
--- /dev/null
+++ b/tests/autodiff/dynamic-dispatch-autodiff-simple.slang
@@ -0,0 +1,48 @@
+// Test calling differentiable function through dynamic dispatch.
+
+//TEST(compute):COMPARE_COMPUTE_EX:-slang -compute -shaderobj -output-using-type
+//TEST(compute, vulkan):COMPARE_COMPUTE_EX:-vk -compute -shaderobj -output-using-type
+
+//TEST_INPUT:ubuffer(data=[0 0 0 0 0], stride=4):out,name=outputBuffer
+RWStructuredBuffer<float> outputBuffer;
+
+[anyValueSize(16)]
+interface IInterface
+{
+ [ForwardDifferentiable]
+ static float calc(float x);
+}
+
+struct A : IInterface
+{
+ [ForwardDifferentiable]
+ static float calc(float x) { return x * x * x; }
+};
+
+struct B : IInterface
+{
+ [ForwardDifferentiable]
+ static float calc(float x) { return x * x; }
+};
+
+[ForwardDifferentiable]
+float sqr(IInterface obj, float x)
+{
+ return obj.calc(x);
+}
+
+//TEST_INPUT: type_conformance A:IInterface = 0
+//TEST_INPUT: type_conformance B:IInterface = 1
+
+
+[numthreads(1, 1, 1)]
+void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID)
+{
+ var obj = createDynamicObject<IInterface>(dispatchThreadID.x, 0); // A
+ outputBuffer[0] = __fwd_diff(sqr)(obj, DifferentialPair<float>(2.0, 1.0)).d; // A.calc, expect 12
+
+ obj = createDynamicObject<IInterface>(dispatchThreadID.x + 1, 0); // B
+ outputBuffer[1] = __fwd_diff(sqr)(obj, DifferentialPair<float>(1.5, 1.0)).d; // B.calc, expect 3
+
+ outputBuffer[2] = __fwd_diff(obj.calc)(DifferentialPair<float>(1.5, 1.0)).d; // B.calc, expect 3
+}
diff --git a/tests/autodiff/dynamic-dispatch-autodiff-simple.slang.expected.txt b/tests/autodiff/dynamic-dispatch-autodiff-simple.slang.expected.txt
new file mode 100644
index 000000000..1b1844a5d
--- /dev/null
+++ b/tests/autodiff/dynamic-dispatch-autodiff-simple.slang.expected.txt
@@ -0,0 +1,6 @@
+type: float
+12.000000
+3.000000
+3.000000
+0.000000
+0.000000