diff options
| author | Yong He <yonghe@outlook.com> | 2022-11-23 17:50:02 -0800 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2022-11-23 17:50:02 -0800 |
| commit | 1b40fe56725eeefe9c601461278376b697d4d35a (patch) | |
| tree | 2bdd321eed24e6e313839fe45aa84b23daa643fe /source | |
| parent | d4787e92253cf963f590d62522e82ce8285fc751 (diff) | |
Make differentiable data-flow pass recognize interface methods. (#2530)
* Make differentiable data-flow pass recognize interface methods.
* Make existing test to work with `[TreatAsDifferentiable]`.
Co-authored-by: Yong He <yhe@nvidia.com>
Diffstat (limited to 'source')
| -rw-r--r-- | source/slang/core.meta.slang | 4 | ||||
| -rw-r--r-- | source/slang/slang-ast-modifier.h | 9 | ||||
| -rw-r--r-- | source/slang/slang-ir-check-differentiability.cpp | 35 | ||||
| -rw-r--r-- | source/slang/slang-ir-inst-defs.h | 4 | ||||
| -rw-r--r-- | source/slang/slang-ir-insts.h | 6 | ||||
| -rw-r--r-- | source/slang/slang-lower-to-ir.cpp | 14 |
6 files changed, 64 insertions, 8 deletions
diff --git a/source/slang/core.meta.slang b/source/slang/core.meta.slang index d1328b72a..8cfa85983 100644 --- a/source/slang/core.meta.slang +++ b/source/slang/core.meta.slang @@ -2539,6 +2539,10 @@ interface IComparable bool lessThanOrEquals(This other); } +__attributeTarget(DeclBase) +attribute_syntax [TreatAsDifferentiable] : TreatAsDifferentiableAttribute; + +[TreatAsDifferentiable] interface IArithmetic : IComparable { This add(This other); diff --git a/source/slang/slang-ast-modifier.h b/source/slang/slang-ast-modifier.h index 69f39efb6..f9a3fc393 100644 --- a/source/slang/slang-ast-modifier.h +++ b/source/slang/slang-ast-modifier.h @@ -974,6 +974,15 @@ class ForceInlineAttribute : public Attribute SLANG_AST_CLASS(ForceInlineAttribute) }; + +// A `[TreatAsDifferentiableAttribute]` attribute indicates that a function or an interface +// should be treated as differentiable in IR validation step. +// +class TreatAsDifferentiableAttribute : public Attribute +{ + SLANG_AST_CLASS(TreatAsDifferentiableAttribute) +}; + /// An attribute that marks a type declaration as either allowing or /// disallowing the type to be inherited from in other modules. class InheritanceControlAttribute : public Attribute { SLANG_AST_CLASS(InheritanceControlAttribute) }; diff --git a/source/slang/slang-ir-check-differentiability.cpp b/source/slang/slang-ir-check-differentiability.cpp index f4f61d7e9..83351d07b 100644 --- a/source/slang/slang-ir-check-differentiability.cpp +++ b/source/slang/slang-ir-check-differentiability.cpp @@ -97,6 +97,39 @@ public: if (differentiableFunctions.Contains(func)) return true; + if (func->findDecoration<IRTreatAsDifferentiableDecoration>()) + return true; + + if (auto lookupInterfaceMethod = as<IRLookupWitnessMethod>(func)) + { + auto wit = lookupInterfaceMethod->getWitnessTable(); + if (!wit) + return false; + auto witType = as<IRWitnessTableTypeBase>(wit->getDataType()); + if (!witType) + return false; + auto interfaceType = witType->getConformanceType(); + if (!interfaceType) + return false; + if (interfaceType->findDecoration<IRTreatAsDifferentiableDecoration>()) + return true; + if (sharedContext.differentiableInterfaceType && interfaceType == sharedContext.differentiableInterfaceType) + return true; + auto dictDecor = interfaceType->findDecoration<IRDifferentiableMethodRequirementDictionaryDecoration>(); + if (!dictDecor) + return false; + for (auto child : dictDecor->getChildren()) + { + if (auto entry = as<IRDifferentiableMethodRequirementDictionaryItem>(child)) + { + if (entry->getOperand(0) == lookupInterfaceMethod->getRequirementKey()) + { + return true; + } + } + } + } + for (; func; func = func->parent) { if (as<IRGeneric>(func)) @@ -222,7 +255,7 @@ public: case kIROp_FloatLit: return true; case kIROp_Call: - return inst->findDecoration<IRTreatAsDifferentiableCallDecoration>() || isDifferentiableFunc(as<IRCall>(inst)->getCallee()); + return inst->findDecoration<IRTreatAsDifferentiableDecoration>() || isDifferentiableFunc(as<IRCall>(inst)->getCallee()); case kIROp_Load: // We don't have more knowledge on whether diff is available at the destination address. // Just assume it is producing diff. diff --git a/source/slang/slang-ir-inst-defs.h b/source/slang/slang-ir-inst-defs.h index cc5261d14..9233972ad 100644 --- a/source/slang/slang-ir-inst-defs.h +++ b/source/slang/slang-ir-inst-defs.h @@ -736,8 +736,8 @@ INST(HighLevelDeclDecoration, highLevelDecl, 1, 0) /// differential member of a type in its associated differential type. INST(DerivativeMemberDecoration, derivativeMemberDecoration, 1, 0) - /// Treat the IRCall as a call to a differentiable function. - INST(TreatAsDifferentiableCallDecoration, treatAsDifferentiableCallDecoration, 0, 0) + /// Treat a function as differentiable function, or an IRCall as a call to a differentiable function. + INST(TreatAsDifferentiableDecoration, treatAsDifferentiableDecoration, 0, 0) /// Marks a class type as a COM interface implementation, which enables /// the witness table to be easily picked up by emit. diff --git a/source/slang/slang-ir-insts.h b/source/slang/slang-ir-insts.h index 5c0401cc2..250088f96 100644 --- a/source/slang/slang-ir-insts.h +++ b/source/slang/slang-ir-insts.h @@ -593,13 +593,13 @@ struct IRBackwardDifferentiableDecoration : IRDecoration IR_LEAF_ISA(BackwardDifferentiableDecoration) }; -struct IRTreatAsDifferentiableCallDecoration : IRDecoration +struct IRTreatAsDifferentiableDecoration : IRDecoration { enum { - kOp = kIROp_TreatAsDifferentiableCallDecoration + kOp = kIROp_TreatAsDifferentiableDecoration }; - IR_LEAF_ISA(TreatAsDifferentiableCallDecoration) + IR_LEAF_ISA(TreatAsDifferentiableDecoration) }; struct IRDerivativeMemberDecoration : IRDecoration diff --git a/source/slang/slang-lower-to-ir.cpp b/source/slang/slang-lower-to-ir.cpp index 09dacc20d..4db9a479b 100644 --- a/source/slang/slang-lower-to-ir.cpp +++ b/source/slang/slang-lower-to-ir.cpp @@ -3139,7 +3139,7 @@ struct ExprLoweringVisitorBase : ExprVisitor<Derived, LoweredValInfo> { auto baseVal = lowerSubExpr(expr->innerExpr); SLANG_ASSERT(baseVal.flavor == LoweredValInfo::Flavor::Simple); - getBuilder()->addDecoration(baseVal.val, kIROp_TreatAsDifferentiableCallDecoration); + getBuilder()->addDecoration(baseVal.val, kIROp_TreatAsDifferentiableDecoration); return baseVal; } @@ -6867,7 +6867,7 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo> // Allocate an IRInterfaceType with the `operandCount` operands. IRInterfaceType* irInterface = subBuilder->createInterfaceType(operandCount, nullptr); - + // Add `irInterface` to decl mapping now to prevent cyclic lowering. setValue(context, decl, LoweredValInfo::simple(irInterface)); @@ -6981,6 +6981,11 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo> { subBuilder->addBuiltinDecoration(irInterface); } + if (decl->hasModifier<TreatAsDifferentiableAttribute>()) + { + subBuilder->addDecoration(irInterface, kIROp_TreatAsDifferentiableDecoration); + } + subBuilder->setInsertInto(irInterface); // TODO: are there any interface members that should be // nested inside the interface type itself? @@ -8307,6 +8312,11 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo> getBuilder()->addDecoration(irFunc, kIROp_ForceInlineDecoration); } + if (decl->findModifier<TreatAsDifferentiableAttribute>()) + { + getBuilder()->addDecoration(irFunc, kIROp_TreatAsDifferentiableDecoration); + } + // Register the value now, to avoid any possible infinite recursion when lowering ForwardDerivativeAttribute setGlobalValue(context, decl, LoweredValInfo::simple(findOuterMostGeneric(irFunc))); |
