summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorYong He <yonghe@outlook.com>2022-11-23 17:50:02 -0800
committerGitHub <noreply@github.com>2022-11-23 17:50:02 -0800
commit1b40fe56725eeefe9c601461278376b697d4d35a (patch)
tree2bdd321eed24e6e313839fe45aa84b23daa643fe
parentd4787e92253cf963f590d62522e82ce8285fc751 (diff)
Make differentiable data-flow pass recognize interface methods. (#2530)
* Make differentiable data-flow pass recognize interface methods. * Make existing test to work with `[TreatAsDifferentiable]`. Co-authored-by: Yong He <yhe@nvidia.com>
-rw-r--r--source/slang/core.meta.slang4
-rw-r--r--source/slang/slang-ast-modifier.h9
-rw-r--r--source/slang/slang-ir-check-differentiability.cpp35
-rw-r--r--source/slang/slang-ir-inst-defs.h4
-rw-r--r--source/slang/slang-ir-insts.h6
-rw-r--r--source/slang/slang-lower-to-ir.cpp14
-rw-r--r--tests/autodiff/generic-autodiff-1.slang2
-rw-r--r--tests/autodiff/generic-impl-jvp.slang1
-rw-r--r--tests/autodiff/generic-jvp.slang1
9 files changed, 67 insertions, 9 deletions
diff --git a/source/slang/core.meta.slang b/source/slang/core.meta.slang
index d1328b72a..8cfa85983 100644
--- a/source/slang/core.meta.slang
+++ b/source/slang/core.meta.slang
@@ -2539,6 +2539,10 @@ interface IComparable
bool lessThanOrEquals(This other);
}
+__attributeTarget(DeclBase)
+attribute_syntax [TreatAsDifferentiable] : TreatAsDifferentiableAttribute;
+
+[TreatAsDifferentiable]
interface IArithmetic : IComparable
{
This add(This other);
diff --git a/source/slang/slang-ast-modifier.h b/source/slang/slang-ast-modifier.h
index 69f39efb6..f9a3fc393 100644
--- a/source/slang/slang-ast-modifier.h
+++ b/source/slang/slang-ast-modifier.h
@@ -974,6 +974,15 @@ class ForceInlineAttribute : public Attribute
SLANG_AST_CLASS(ForceInlineAttribute)
};
+
+// A `[TreatAsDifferentiableAttribute]` attribute indicates that a function or an interface
+// should be treated as differentiable in IR validation step.
+//
+class TreatAsDifferentiableAttribute : public Attribute
+{
+ SLANG_AST_CLASS(TreatAsDifferentiableAttribute)
+};
+
/// An attribute that marks a type declaration as either allowing or
/// disallowing the type to be inherited from in other modules.
class InheritanceControlAttribute : public Attribute { SLANG_AST_CLASS(InheritanceControlAttribute) };
diff --git a/source/slang/slang-ir-check-differentiability.cpp b/source/slang/slang-ir-check-differentiability.cpp
index f4f61d7e9..83351d07b 100644
--- a/source/slang/slang-ir-check-differentiability.cpp
+++ b/source/slang/slang-ir-check-differentiability.cpp
@@ -97,6 +97,39 @@ public:
if (differentiableFunctions.Contains(func))
return true;
+ if (func->findDecoration<IRTreatAsDifferentiableDecoration>())
+ return true;
+
+ if (auto lookupInterfaceMethod = as<IRLookupWitnessMethod>(func))
+ {
+ auto wit = lookupInterfaceMethod->getWitnessTable();
+ if (!wit)
+ return false;
+ auto witType = as<IRWitnessTableTypeBase>(wit->getDataType());
+ if (!witType)
+ return false;
+ auto interfaceType = witType->getConformanceType();
+ if (!interfaceType)
+ return false;
+ if (interfaceType->findDecoration<IRTreatAsDifferentiableDecoration>())
+ return true;
+ if (sharedContext.differentiableInterfaceType && interfaceType == sharedContext.differentiableInterfaceType)
+ return true;
+ auto dictDecor = interfaceType->findDecoration<IRDifferentiableMethodRequirementDictionaryDecoration>();
+ if (!dictDecor)
+ return false;
+ for (auto child : dictDecor->getChildren())
+ {
+ if (auto entry = as<IRDifferentiableMethodRequirementDictionaryItem>(child))
+ {
+ if (entry->getOperand(0) == lookupInterfaceMethod->getRequirementKey())
+ {
+ return true;
+ }
+ }
+ }
+ }
+
for (; func; func = func->parent)
{
if (as<IRGeneric>(func))
@@ -222,7 +255,7 @@ public:
case kIROp_FloatLit:
return true;
case kIROp_Call:
- return inst->findDecoration<IRTreatAsDifferentiableCallDecoration>() || isDifferentiableFunc(as<IRCall>(inst)->getCallee());
+ return inst->findDecoration<IRTreatAsDifferentiableDecoration>() || isDifferentiableFunc(as<IRCall>(inst)->getCallee());
case kIROp_Load:
// We don't have more knowledge on whether diff is available at the destination address.
// Just assume it is producing diff.
diff --git a/source/slang/slang-ir-inst-defs.h b/source/slang/slang-ir-inst-defs.h
index cc5261d14..9233972ad 100644
--- a/source/slang/slang-ir-inst-defs.h
+++ b/source/slang/slang-ir-inst-defs.h
@@ -736,8 +736,8 @@ INST(HighLevelDeclDecoration, highLevelDecl, 1, 0)
/// differential member of a type in its associated differential type.
INST(DerivativeMemberDecoration, derivativeMemberDecoration, 1, 0)
- /// Treat the IRCall as a call to a differentiable function.
- INST(TreatAsDifferentiableCallDecoration, treatAsDifferentiableCallDecoration, 0, 0)
+ /// Treat a function as differentiable function, or an IRCall as a call to a differentiable function.
+ INST(TreatAsDifferentiableDecoration, treatAsDifferentiableDecoration, 0, 0)
/// Marks a class type as a COM interface implementation, which enables
/// the witness table to be easily picked up by emit.
diff --git a/source/slang/slang-ir-insts.h b/source/slang/slang-ir-insts.h
index 5c0401cc2..250088f96 100644
--- a/source/slang/slang-ir-insts.h
+++ b/source/slang/slang-ir-insts.h
@@ -593,13 +593,13 @@ struct IRBackwardDifferentiableDecoration : IRDecoration
IR_LEAF_ISA(BackwardDifferentiableDecoration)
};
-struct IRTreatAsDifferentiableCallDecoration : IRDecoration
+struct IRTreatAsDifferentiableDecoration : IRDecoration
{
enum
{
- kOp = kIROp_TreatAsDifferentiableCallDecoration
+ kOp = kIROp_TreatAsDifferentiableDecoration
};
- IR_LEAF_ISA(TreatAsDifferentiableCallDecoration)
+ IR_LEAF_ISA(TreatAsDifferentiableDecoration)
};
struct IRDerivativeMemberDecoration : IRDecoration
diff --git a/source/slang/slang-lower-to-ir.cpp b/source/slang/slang-lower-to-ir.cpp
index 09dacc20d..4db9a479b 100644
--- a/source/slang/slang-lower-to-ir.cpp
+++ b/source/slang/slang-lower-to-ir.cpp
@@ -3139,7 +3139,7 @@ struct ExprLoweringVisitorBase : ExprVisitor<Derived, LoweredValInfo>
{
auto baseVal = lowerSubExpr(expr->innerExpr);
SLANG_ASSERT(baseVal.flavor == LoweredValInfo::Flavor::Simple);
- getBuilder()->addDecoration(baseVal.val, kIROp_TreatAsDifferentiableCallDecoration);
+ getBuilder()->addDecoration(baseVal.val, kIROp_TreatAsDifferentiableDecoration);
return baseVal;
}
@@ -6867,7 +6867,7 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo>
// Allocate an IRInterfaceType with the `operandCount` operands.
IRInterfaceType* irInterface = subBuilder->createInterfaceType(operandCount, nullptr);
-
+
// Add `irInterface` to decl mapping now to prevent cyclic lowering.
setValue(context, decl, LoweredValInfo::simple(irInterface));
@@ -6981,6 +6981,11 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo>
{
subBuilder->addBuiltinDecoration(irInterface);
}
+ if (decl->hasModifier<TreatAsDifferentiableAttribute>())
+ {
+ subBuilder->addDecoration(irInterface, kIROp_TreatAsDifferentiableDecoration);
+ }
+
subBuilder->setInsertInto(irInterface);
// TODO: are there any interface members that should be
// nested inside the interface type itself?
@@ -8307,6 +8312,11 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo>
getBuilder()->addDecoration(irFunc, kIROp_ForceInlineDecoration);
}
+ if (decl->findModifier<TreatAsDifferentiableAttribute>())
+ {
+ getBuilder()->addDecoration(irFunc, kIROp_TreatAsDifferentiableDecoration);
+ }
+
// Register the value now, to avoid any possible infinite recursion when lowering ForwardDerivativeAttribute
setGlobalValue(context, decl, LoweredValInfo::simple(findOuterMostGeneric(irFunc)));
diff --git a/tests/autodiff/generic-autodiff-1.slang b/tests/autodiff/generic-autodiff-1.slang
index 43a6d3b10..9ab0d5fef 100644
--- a/tests/autodiff/generic-autodiff-1.slang
+++ b/tests/autodiff/generic-autodiff-1.slang
@@ -23,7 +23,7 @@ struct A : IInterface
[ForwardDifferentiable]
float sqr<T:IInterface>(inout T obj, float x)
{
- return obj.sample() + x*x;
+ return (no_diff obj.sample()) + x*x;
}
[numthreads(1, 1, 1)]
diff --git a/tests/autodiff/generic-impl-jvp.slang b/tests/autodiff/generic-impl-jvp.slang
index a1bc18252..332833fff 100644
--- a/tests/autodiff/generic-impl-jvp.slang
+++ b/tests/autodiff/generic-impl-jvp.slang
@@ -225,6 +225,7 @@ typedef lineardvector<4> mydfloat4;
typedef DifferentialPair<Real> dpfloat;
+[TreatAsDifferentiable]
interface MyLinearArithmeticType
{
static This ladd(This a, This b);
diff --git a/tests/autodiff/generic-jvp.slang b/tests/autodiff/generic-jvp.slang
index 61ec077f4..2be0045d4 100644
--- a/tests/autodiff/generic-jvp.slang
+++ b/tests/autodiff/generic-jvp.slang
@@ -85,6 +85,7 @@ typedef myvector<4> myfloat4;
typedef DifferentialPair<Real> dpfloat;
+[TreatAsDifferentiable]
interface MyLinearArithmeticType
{
static This ladd(This a, This b);