diff options
| author | Yong He <yonghe@outlook.com> | 2023-03-21 21:29:13 -0700 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2023-03-21 21:29:13 -0700 |
| commit | d8a40abba5223fbcb56c52b04ccb88c02bbaf79f (patch) | |
| tree | 3207babbce41957fbd01c3c791fe9957c81f6a09 | |
| parent | 83876733d69582eec6bad26af64a651d40fa43aa (diff) | |
[TreatAsDifferentiable] functions. (#2720)
| -rw-r--r-- | docs/user-guide/07-autodiff.md | 34 | ||||
| -rw-r--r-- | source/slang/slang-ast-modifier.h | 16 | ||||
| -rw-r--r-- | source/slang/slang-check-decl.cpp | 20 | ||||
| -rw-r--r-- | source/slang/slang-ir-autodiff-fwd.cpp | 75 | ||||
| -rw-r--r-- | source/slang/slang-ir-autodiff-fwd.h | 2 | ||||
| -rw-r--r-- | source/slang/slang-ir-autodiff-rev.cpp | 3 | ||||
| -rw-r--r-- | source/slang/slang-lower-to-ir.cpp | 4 | ||||
| -rw-r--r-- | tests/autodiff/treat-as-differentiable.slang | 37 | ||||
| -rw-r--r-- | tests/autodiff/treat-as-differentiable.slang.expected.txt | 2 |
9 files changed, 177 insertions, 16 deletions
diff --git a/docs/user-guide/07-autodiff.md b/docs/user-guide/07-autodiff.md index 14b92774d..f2bf4dcc5 100644 --- a/docs/user-guide/07-autodiff.md +++ b/docs/user-guide/07-autodiff.md @@ -687,6 +687,40 @@ float f(float x) However, the `no_diff` keyword is not required in a call if a non-differentiable function does not take any differentiable parameters, or if the result of the differentiable function is not dependant on the derivative being propagated through the call. +### Treat Non-Differentiable Functions as Differentiable +Slang allows functions to be marked with a `[TreatAsDifferentiable]` attribute for them to be considered as differentiable functions by the type-system. When a function is marked as `[TreatAsDifferentiable]`, the compiler will not generate derivative propagation code from the original function body or perform any additional checking on the function definition. Instead, it will generate trivial forward and backward propagation functions that returns 0. + +This feature can be useful if the user marked an `interface` method as forward or backward differentiable, but only wish to provide non-trivial derivative propagation functions for a subset of types that implement the interface. For other types that does not actually need differentiation, the user can simply put `[TreatAsDifferentiable]` on the method implementations for them to satisfy the interface requirement. + +See the following code for an example of `[TreatAsDifferentiable]`: +```csharp +interface IFoo +{ + [BackwardDifferentiable] + float f(float v); +} + +struct B : IFoo +{ + [TreatAsDifferentiable] + float f(float v) + { + return v * v; + } +} + +[BackwardDifferentiable] +float use(IFoo o, float x) +{ + return o.f(x); +} + +// Test: +B obj; +float result = __fwd_diff(use)(obj, diffPair(2.0, 1.0)).d; +// result == 0.0, since `[TreatAsDifferentiable]` causes a trivial derivative implementation +// being generated regardless of the original code. +``` ## Higher Order Differentiation diff --git a/source/slang/slang-ast-modifier.h b/source/slang/slang-ast-modifier.h index 26303d6ad..0d2e27e5f 100644 --- a/source/slang/slang-ast-modifier.h +++ b/source/slang/slang-ast-modifier.h @@ -998,14 +998,6 @@ class ForceInlineAttribute : public Attribute }; -// A `[TreatAsDifferentiableAttribute]` attribute indicates that a function or an interface -// should be treated as differentiable in IR validation step. -// -class TreatAsDifferentiableAttribute : public Attribute -{ - SLANG_AST_CLASS(TreatAsDifferentiableAttribute) -}; - /// An attribute that marks a type declaration as either allowing or /// disallowing the type to be inherited from in other modules. class InheritanceControlAttribute : public Attribute { SLANG_AST_CLASS(InheritanceControlAttribute) }; @@ -1108,6 +1100,14 @@ class AlwaysFoldIntoUseSiteAttribute :public Attribute SLANG_AST_CLASS(AlwaysFoldIntoUseSiteAttribute) }; +// A `[TreatAsDifferentiableAttribute]` attribute indicates that a function or an interface +// should be treated as differentiable in IR validation step. +// +class TreatAsDifferentiableAttribute : public DifferentiableAttribute +{ + SLANG_AST_CLASS(TreatAsDifferentiableAttribute) +}; + /// The `[ForwardDifferentiable]` attribute indicates that a function can be forward-differentiated. class ForwardDifferentiableAttribute : public DifferentiableAttribute { diff --git a/source/slang/slang-check-decl.cpp b/source/slang/slang-check-decl.cpp index c0253fd2c..eaab43ef8 100644 --- a/source/slang/slang-check-decl.cpp +++ b/source/slang/slang-check-decl.cpp @@ -1955,8 +1955,11 @@ namespace Slang bool hasForwardDerivative = false; if (requiredMemberDeclRef.getDecl()->hasModifier<BackwardDifferentiableAttribute>()) { - if (!satisfyingMemberDeclRef.getDecl()->hasModifier<BackwardDifferentiableAttribute>() - && !satisfyingMemberDeclRef.getDecl()->hasModifier<BackwardDerivativeAttribute>()) + auto funcDecl = as<FunctionDeclBase>(satisfyingMemberDeclRef.getDecl()); + if (!funcDecl) + return false; + + if (getShared()->getFuncDifferentiableLevel(funcDecl) != FunctionDifferentiableLevel::Backward) { // A non-`BackwardDifferentiable` method can't satisfy a `BackwardDifferentiable` requirement and vice versa. return false; @@ -1966,12 +1969,12 @@ namespace Slang } else if (requiredMemberDeclRef.getDecl()->hasModifier<ForwardDifferentiableAttribute>()) { - if (!satisfyingMemberDeclRef.getDecl()->hasModifier<ForwardDifferentiableAttribute>() - && !satisfyingMemberDeclRef.getDecl()->hasModifier<ForwardDerivativeAttribute>() - && !satisfyingMemberDeclRef.getDecl()->hasModifier<BackwardDifferentiableAttribute>() - && !satisfyingMemberDeclRef.getDecl()->hasModifier<BackwardDerivativeAttribute>()) + auto funcDecl = as<FunctionDeclBase>(satisfyingMemberDeclRef.getDecl()); + if (!funcDecl) + return false; + if (getShared()->getFuncDifferentiableLevel(funcDecl) == FunctionDifferentiableLevel::None) { - // A non-`ForwardDifferentiable` method can't satisfy a `ForwardDifferentiable` requirement and vice versa. + // A non-`BackwardDifferentiable` method can't satisfy a `BackwardDifferentiable` requirement and vice versa. return false; } hasForwardDerivative = true; @@ -6674,6 +6677,9 @@ namespace Slang if (func->findModifier<BackwardDerivativeAttribute>()) return FunctionDifferentiableLevel::Backward; + if (func->findModifier<TreatAsDifferentiableAttribute>()) + return FunctionDifferentiableLevel::Backward; + FunctionDifferentiableLevel diffLevel = FunctionDifferentiableLevel::None; if (func->findModifier<DifferentiableAttribute>()) diffLevel = FunctionDifferentiableLevel::Forward; diff --git a/source/slang/slang-ir-autodiff-fwd.cpp b/source/slang/slang-ir-autodiff-fwd.cpp index ac4e3825a..bc7e03ad3 100644 --- a/source/slang/slang-ir-autodiff-fwd.cpp +++ b/source/slang/slang-ir-autodiff-fwd.cpp @@ -44,6 +44,74 @@ IRFuncType* ForwardDiffTranscriber::differentiateFunctionType(IRBuilder* builder return builder->getFuncType(newParameterTypes, diffReturnType); } +void ForwardDiffTranscriber::generateTrivialFwdDiffFunc(IRFunc* primalFunc, IRFunc* diffFunc) +{ + IRBuilder builder(diffFunc); + builder.setInsertInto(diffFunc); + auto block = builder.emitBlock(); + builder.markInstAsMixedDifferential(block); + + for (auto param : primalFunc->getParams()) + { + transcribeFuncParam(&builder, param, param->getFullType()).differential; + } + List<IRParam*> diffParams; + for (auto param : diffFunc->getParams()) + { + diffParams.add(param); + } + auto emitDiffPairVal = [&](IRDifferentialPairTypeBase* pairType) + { + auto primal = builder.emitDefaultConstruct(pairType->getValueType()); + builder.markInstAsPrimal(primal); + auto diff = getDifferentialZeroOfType(&builder, pairType->getValueType()); + builder.markInstAsDifferential(primal); + + auto val = builder.emitMakeDifferentialPair(pairType, primal, diff); + builder.markInstAsMixedDifferential(val); + + return val; + + }; + for (auto param : diffParams) + { + if (auto outType = as<IROutTypeBase>(param->getFullType())) + { + if (isRelevantDifferentialPair(outType)) + { + auto pairType = as<IRDifferentialPairTypeBase>(outType->getValueType()); + auto val = emitDiffPairVal(pairType); + auto store = builder.emitStore(param, val); + builder.markInstAsMixedDifferential(store); + } + else + { + auto val = builder.emitDefaultConstruct(outType->getValueType()); + builder.markInstAsPrimal(val); + + auto store = builder.emitStore(param, val); + builder.markInstAsPrimal(store); + + } + } + } + if (isRelevantDifferentialPair(diffFunc->getResultType())) + { + auto pairType = as<IRDifferentialPairTypeBase>(diffFunc->getResultType()); + auto val = emitDiffPairVal(pairType); + auto returnInst = builder.emitReturn(val); + builder.markInstAsMixedDifferential(val); + builder.markInstAsMixedDifferential(returnInst); + } + else + { + auto retVal = builder.emitDefaultConstruct(diffFunc->getResultType()); + auto returnInst = builder.emitReturn(retVal); + builder.markInstAsPrimal(retVal); + builder.markInstAsPrimal(returnInst); + } +} + // Returns "d<var-name>" to use as a name hint for variables and parameters. // If no primal name is available, returns a blank string. // @@ -1500,6 +1568,13 @@ SlangResult ForwardDiffTranscriber::prepareFuncForForwardDiff(IRFunc* func) // Transcribe a function definition. InstPair ForwardDiffTranscriber::transcribeFunc(IRBuilder* inBuilder, IRFunc* primalFunc, IRFunc* diffFunc) { + if (primalFunc->findDecoration<IRTreatAsDifferentiableDecoration>()) + { + // Generate a trivial implementation for [TreatAsDifferentiable] functions. + generateTrivialFwdDiffFunc(primalFunc, diffFunc); + return InstPair(primalFunc, diffFunc); + } + IRBuilder builder = *inBuilder; builder.setInsertBefore(primalFunc); diff --git a/source/slang/slang-ir-autodiff-fwd.h b/source/slang/slang-ir-autodiff-fwd.h index e9774be49..8fd271fd8 100644 --- a/source/slang/slang-ir-autodiff-fwd.h +++ b/source/slang/slang-ir-autodiff-fwd.h @@ -88,6 +88,8 @@ struct ForwardDiffTranscriber : AutoDiffTranscriberBase virtual IRFuncType* differentiateFunctionType(IRBuilder* builder, IRInst* func, IRFuncType* funcType) override; + void generateTrivialFwdDiffFunc(IRFunc* primalFunc, IRFunc* diffFunc); + // Transcribe a function definition. InstPair transcribeFunc(IRBuilder* inBuilder, IRFunc* primalFunc, IRFunc* diffFunc); diff --git a/source/slang/slang-ir-autodiff-rev.cpp b/source/slang/slang-ir-autodiff-rev.cpp index 7c11a1286..e01d65f4f 100644 --- a/source/slang/slang-ir-autodiff-rev.cpp +++ b/source/slang/slang-ir-autodiff-rev.cpp @@ -295,7 +295,8 @@ namespace Slang // Create an empty func to represent the transcribed func of `origFunc`. InstPair BackwardDiffTranscriberBase::transcribeFuncHeaderImpl(IRBuilder* inBuilder, IRFunc* origFunc) { - if (!isBackwardDifferentiableFunc(origFunc)) + if (!isBackwardDifferentiableFunc(origFunc) && + !origFunc->findDecoration<IRTreatAsDifferentiableDecoration>()) return InstPair(nullptr, nullptr); IRBuilder builder = *inBuilder; diff --git a/source/slang/slang-lower-to-ir.cpp b/source/slang/slang-lower-to-ir.cpp index 8164723b6..86df89702 100644 --- a/source/slang/slang-lower-to-ir.cpp +++ b/source/slang/slang-lower-to-ir.cpp @@ -8502,6 +8502,10 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo> { getBuilder()->addBackwardDifferentiableDecoration(irFunc); } + else if (as<TreatAsDifferentiableAttribute>(modifier)) + { + getBuilder()->addDecoration(irFunc, kIROp_TreatAsDifferentiableDecoration); + } } // For convenience, ensure that any additional global // values that were emitted while outputting the function diff --git a/tests/autodiff/treat-as-differentiable.slang b/tests/autodiff/treat-as-differentiable.slang new file mode 100644 index 000000000..95423d978 --- /dev/null +++ b/tests/autodiff/treat-as-differentiable.slang @@ -0,0 +1,37 @@ +// Tests automatic synthesis of Differential type and method requirements. + +//TEST(compute):COMPARE_COMPUTE_EX:-slang -compute -shaderobj -output-using-type +//TEST(compute, vulkan):COMPARE_COMPUTE_EX:-vk -compute -shaderobj -output-using-type + +//TEST_INPUT:ubuffer(data=[0], stride=4):out,name=outputBuffer +RWStructuredBuffer<float> outputBuffer; + +interface IFoo +{ + [BackwardDifferentiable] + float f(float v); +} + +struct B : IFoo +{ + [TreatAsDifferentiable] + float f(float v) + { + return v * v; + } +} + +[BackwardDifferentiable] +float use(IFoo o, float x) +{ + return o.f(x); +} + +[numthreads(1, 1, 1)] +void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID) +{ + B b; + var p = diffPair(1.0); + __bwd_diff(use)(b, p, 1.0); + outputBuffer[0] = p.d; +} diff --git a/tests/autodiff/treat-as-differentiable.slang.expected.txt b/tests/autodiff/treat-as-differentiable.slang.expected.txt new file mode 100644 index 000000000..9d11e5c94 --- /dev/null +++ b/tests/autodiff/treat-as-differentiable.slang.expected.txt @@ -0,0 +1,2 @@ +type: float +0.0
\ No newline at end of file |
