diff options
| author | Yong He <yonghe@outlook.com> | 2022-11-23 17:50:02 -0800 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2022-11-23 17:50:02 -0800 |
| commit | 1b40fe56725eeefe9c601461278376b697d4d35a (patch) | |
| tree | 2bdd321eed24e6e313839fe45aa84b23daa643fe | |
| parent | d4787e92253cf963f590d62522e82ce8285fc751 (diff) | |
Make differentiable data-flow pass recognize interface methods. (#2530)
* Make differentiable data-flow pass recognize interface methods.
* Make existing test to work with `[TreatAsDifferentiable]`.
Co-authored-by: Yong He <yhe@nvidia.com>
| -rw-r--r-- | source/slang/core.meta.slang | 4 | ||||
| -rw-r--r-- | source/slang/slang-ast-modifier.h | 9 | ||||
| -rw-r--r-- | source/slang/slang-ir-check-differentiability.cpp | 35 | ||||
| -rw-r--r-- | source/slang/slang-ir-inst-defs.h | 4 | ||||
| -rw-r--r-- | source/slang/slang-ir-insts.h | 6 | ||||
| -rw-r--r-- | source/slang/slang-lower-to-ir.cpp | 14 | ||||
| -rw-r--r-- | tests/autodiff/generic-autodiff-1.slang | 2 | ||||
| -rw-r--r-- | tests/autodiff/generic-impl-jvp.slang | 1 | ||||
| -rw-r--r-- | tests/autodiff/generic-jvp.slang | 1 |
9 files changed, 67 insertions, 9 deletions
diff --git a/source/slang/core.meta.slang b/source/slang/core.meta.slang index d1328b72a..8cfa85983 100644 --- a/source/slang/core.meta.slang +++ b/source/slang/core.meta.slang @@ -2539,6 +2539,10 @@ interface IComparable bool lessThanOrEquals(This other); } +__attributeTarget(DeclBase) +attribute_syntax [TreatAsDifferentiable] : TreatAsDifferentiableAttribute; + +[TreatAsDifferentiable] interface IArithmetic : IComparable { This add(This other); diff --git a/source/slang/slang-ast-modifier.h b/source/slang/slang-ast-modifier.h index 69f39efb6..f9a3fc393 100644 --- a/source/slang/slang-ast-modifier.h +++ b/source/slang/slang-ast-modifier.h @@ -974,6 +974,15 @@ class ForceInlineAttribute : public Attribute SLANG_AST_CLASS(ForceInlineAttribute) }; + +// A `[TreatAsDifferentiableAttribute]` attribute indicates that a function or an interface +// should be treated as differentiable in IR validation step. +// +class TreatAsDifferentiableAttribute : public Attribute +{ + SLANG_AST_CLASS(TreatAsDifferentiableAttribute) +}; + /// An attribute that marks a type declaration as either allowing or /// disallowing the type to be inherited from in other modules. class InheritanceControlAttribute : public Attribute { SLANG_AST_CLASS(InheritanceControlAttribute) }; diff --git a/source/slang/slang-ir-check-differentiability.cpp b/source/slang/slang-ir-check-differentiability.cpp index f4f61d7e9..83351d07b 100644 --- a/source/slang/slang-ir-check-differentiability.cpp +++ b/source/slang/slang-ir-check-differentiability.cpp @@ -97,6 +97,39 @@ public: if (differentiableFunctions.Contains(func)) return true; + if (func->findDecoration<IRTreatAsDifferentiableDecoration>()) + return true; + + if (auto lookupInterfaceMethod = as<IRLookupWitnessMethod>(func)) + { + auto wit = lookupInterfaceMethod->getWitnessTable(); + if (!wit) + return false; + auto witType = as<IRWitnessTableTypeBase>(wit->getDataType()); + if (!witType) + return false; + auto interfaceType = witType->getConformanceType(); + if (!interfaceType) + return false; + if (interfaceType->findDecoration<IRTreatAsDifferentiableDecoration>()) + return true; + if (sharedContext.differentiableInterfaceType && interfaceType == sharedContext.differentiableInterfaceType) + return true; + auto dictDecor = interfaceType->findDecoration<IRDifferentiableMethodRequirementDictionaryDecoration>(); + if (!dictDecor) + return false; + for (auto child : dictDecor->getChildren()) + { + if (auto entry = as<IRDifferentiableMethodRequirementDictionaryItem>(child)) + { + if (entry->getOperand(0) == lookupInterfaceMethod->getRequirementKey()) + { + return true; + } + } + } + } + for (; func; func = func->parent) { if (as<IRGeneric>(func)) @@ -222,7 +255,7 @@ public: case kIROp_FloatLit: return true; case kIROp_Call: - return inst->findDecoration<IRTreatAsDifferentiableCallDecoration>() || isDifferentiableFunc(as<IRCall>(inst)->getCallee()); + return inst->findDecoration<IRTreatAsDifferentiableDecoration>() || isDifferentiableFunc(as<IRCall>(inst)->getCallee()); case kIROp_Load: // We don't have more knowledge on whether diff is available at the destination address. // Just assume it is producing diff. diff --git a/source/slang/slang-ir-inst-defs.h b/source/slang/slang-ir-inst-defs.h index cc5261d14..9233972ad 100644 --- a/source/slang/slang-ir-inst-defs.h +++ b/source/slang/slang-ir-inst-defs.h @@ -736,8 +736,8 @@ INST(HighLevelDeclDecoration, highLevelDecl, 1, 0) /// differential member of a type in its associated differential type. INST(DerivativeMemberDecoration, derivativeMemberDecoration, 1, 0) - /// Treat the IRCall as a call to a differentiable function. - INST(TreatAsDifferentiableCallDecoration, treatAsDifferentiableCallDecoration, 0, 0) + /// Treat a function as differentiable function, or an IRCall as a call to a differentiable function. + INST(TreatAsDifferentiableDecoration, treatAsDifferentiableDecoration, 0, 0) /// Marks a class type as a COM interface implementation, which enables /// the witness table to be easily picked up by emit. diff --git a/source/slang/slang-ir-insts.h b/source/slang/slang-ir-insts.h index 5c0401cc2..250088f96 100644 --- a/source/slang/slang-ir-insts.h +++ b/source/slang/slang-ir-insts.h @@ -593,13 +593,13 @@ struct IRBackwardDifferentiableDecoration : IRDecoration IR_LEAF_ISA(BackwardDifferentiableDecoration) }; -struct IRTreatAsDifferentiableCallDecoration : IRDecoration +struct IRTreatAsDifferentiableDecoration : IRDecoration { enum { - kOp = kIROp_TreatAsDifferentiableCallDecoration + kOp = kIROp_TreatAsDifferentiableDecoration }; - IR_LEAF_ISA(TreatAsDifferentiableCallDecoration) + IR_LEAF_ISA(TreatAsDifferentiableDecoration) }; struct IRDerivativeMemberDecoration : IRDecoration diff --git a/source/slang/slang-lower-to-ir.cpp b/source/slang/slang-lower-to-ir.cpp index 09dacc20d..4db9a479b 100644 --- a/source/slang/slang-lower-to-ir.cpp +++ b/source/slang/slang-lower-to-ir.cpp @@ -3139,7 +3139,7 @@ struct ExprLoweringVisitorBase : ExprVisitor<Derived, LoweredValInfo> { auto baseVal = lowerSubExpr(expr->innerExpr); SLANG_ASSERT(baseVal.flavor == LoweredValInfo::Flavor::Simple); - getBuilder()->addDecoration(baseVal.val, kIROp_TreatAsDifferentiableCallDecoration); + getBuilder()->addDecoration(baseVal.val, kIROp_TreatAsDifferentiableDecoration); return baseVal; } @@ -6867,7 +6867,7 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo> // Allocate an IRInterfaceType with the `operandCount` operands. IRInterfaceType* irInterface = subBuilder->createInterfaceType(operandCount, nullptr); - + // Add `irInterface` to decl mapping now to prevent cyclic lowering. setValue(context, decl, LoweredValInfo::simple(irInterface)); @@ -6981,6 +6981,11 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo> { subBuilder->addBuiltinDecoration(irInterface); } + if (decl->hasModifier<TreatAsDifferentiableAttribute>()) + { + subBuilder->addDecoration(irInterface, kIROp_TreatAsDifferentiableDecoration); + } + subBuilder->setInsertInto(irInterface); // TODO: are there any interface members that should be // nested inside the interface type itself? @@ -8307,6 +8312,11 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo> getBuilder()->addDecoration(irFunc, kIROp_ForceInlineDecoration); } + if (decl->findModifier<TreatAsDifferentiableAttribute>()) + { + getBuilder()->addDecoration(irFunc, kIROp_TreatAsDifferentiableDecoration); + } + // Register the value now, to avoid any possible infinite recursion when lowering ForwardDerivativeAttribute setGlobalValue(context, decl, LoweredValInfo::simple(findOuterMostGeneric(irFunc))); diff --git a/tests/autodiff/generic-autodiff-1.slang b/tests/autodiff/generic-autodiff-1.slang index 43a6d3b10..9ab0d5fef 100644 --- a/tests/autodiff/generic-autodiff-1.slang +++ b/tests/autodiff/generic-autodiff-1.slang @@ -23,7 +23,7 @@ struct A : IInterface [ForwardDifferentiable] float sqr<T:IInterface>(inout T obj, float x) { - return obj.sample() + x*x; + return (no_diff obj.sample()) + x*x; } [numthreads(1, 1, 1)] diff --git a/tests/autodiff/generic-impl-jvp.slang b/tests/autodiff/generic-impl-jvp.slang index a1bc18252..332833fff 100644 --- a/tests/autodiff/generic-impl-jvp.slang +++ b/tests/autodiff/generic-impl-jvp.slang @@ -225,6 +225,7 @@ typedef lineardvector<4> mydfloat4; typedef DifferentialPair<Real> dpfloat; +[TreatAsDifferentiable] interface MyLinearArithmeticType { static This ladd(This a, This b); diff --git a/tests/autodiff/generic-jvp.slang b/tests/autodiff/generic-jvp.slang index 61ec077f4..2be0045d4 100644 --- a/tests/autodiff/generic-jvp.slang +++ b/tests/autodiff/generic-jvp.slang @@ -85,6 +85,7 @@ typedef myvector<4> myfloat4; typedef DifferentialPair<Real> dpfloat; +[TreatAsDifferentiable] interface MyLinearArithmeticType { static This ladd(This a, This b); |
