From d8a40abba5223fbcb56c52b04ccb88c02bbaf79f Mon Sep 17 00:00:00 2001 From: Yong He Date: Tue, 21 Mar 2023 21:29:13 -0700 Subject: [TreatAsDifferentiable] functions. (#2720) --- source/slang/slang-ast-modifier.h | 16 ++++---- source/slang/slang-check-decl.cpp | 20 +++++---- source/slang/slang-ir-autodiff-fwd.cpp | 75 ++++++++++++++++++++++++++++++++++ source/slang/slang-ir-autodiff-fwd.h | 2 + source/slang/slang-ir-autodiff-rev.cpp | 3 +- source/slang/slang-lower-to-ir.cpp | 4 ++ 6 files changed, 104 insertions(+), 16 deletions(-) (limited to 'source') diff --git a/source/slang/slang-ast-modifier.h b/source/slang/slang-ast-modifier.h index 26303d6ad..0d2e27e5f 100644 --- a/source/slang/slang-ast-modifier.h +++ b/source/slang/slang-ast-modifier.h @@ -998,14 +998,6 @@ class ForceInlineAttribute : public Attribute }; -// A `[TreatAsDifferentiableAttribute]` attribute indicates that a function or an interface -// should be treated as differentiable in IR validation step. -// -class TreatAsDifferentiableAttribute : public Attribute -{ - SLANG_AST_CLASS(TreatAsDifferentiableAttribute) -}; - /// An attribute that marks a type declaration as either allowing or /// disallowing the type to be inherited from in other modules. class InheritanceControlAttribute : public Attribute { SLANG_AST_CLASS(InheritanceControlAttribute) }; @@ -1108,6 +1100,14 @@ class AlwaysFoldIntoUseSiteAttribute :public Attribute SLANG_AST_CLASS(AlwaysFoldIntoUseSiteAttribute) }; +// A `[TreatAsDifferentiableAttribute]` attribute indicates that a function or an interface +// should be treated as differentiable in IR validation step. +// +class TreatAsDifferentiableAttribute : public DifferentiableAttribute +{ + SLANG_AST_CLASS(TreatAsDifferentiableAttribute) +}; + /// The `[ForwardDifferentiable]` attribute indicates that a function can be forward-differentiated. class ForwardDifferentiableAttribute : public DifferentiableAttribute { diff --git a/source/slang/slang-check-decl.cpp b/source/slang/slang-check-decl.cpp index c0253fd2c..eaab43ef8 100644 --- a/source/slang/slang-check-decl.cpp +++ b/source/slang/slang-check-decl.cpp @@ -1955,8 +1955,11 @@ namespace Slang bool hasForwardDerivative = false; if (requiredMemberDeclRef.getDecl()->hasModifier()) { - if (!satisfyingMemberDeclRef.getDecl()->hasModifier() - && !satisfyingMemberDeclRef.getDecl()->hasModifier()) + auto funcDecl = as(satisfyingMemberDeclRef.getDecl()); + if (!funcDecl) + return false; + + if (getShared()->getFuncDifferentiableLevel(funcDecl) != FunctionDifferentiableLevel::Backward) { // A non-`BackwardDifferentiable` method can't satisfy a `BackwardDifferentiable` requirement and vice versa. return false; @@ -1966,12 +1969,12 @@ namespace Slang } else if (requiredMemberDeclRef.getDecl()->hasModifier()) { - if (!satisfyingMemberDeclRef.getDecl()->hasModifier() - && !satisfyingMemberDeclRef.getDecl()->hasModifier() - && !satisfyingMemberDeclRef.getDecl()->hasModifier() - && !satisfyingMemberDeclRef.getDecl()->hasModifier()) + auto funcDecl = as(satisfyingMemberDeclRef.getDecl()); + if (!funcDecl) + return false; + if (getShared()->getFuncDifferentiableLevel(funcDecl) == FunctionDifferentiableLevel::None) { - // A non-`ForwardDifferentiable` method can't satisfy a `ForwardDifferentiable` requirement and vice versa. + // A non-`BackwardDifferentiable` method can't satisfy a `BackwardDifferentiable` requirement and vice versa. return false; } hasForwardDerivative = true; @@ -6674,6 +6677,9 @@ namespace Slang if (func->findModifier()) return FunctionDifferentiableLevel::Backward; + if (func->findModifier()) + return FunctionDifferentiableLevel::Backward; + FunctionDifferentiableLevel diffLevel = FunctionDifferentiableLevel::None; if (func->findModifier()) diffLevel = FunctionDifferentiableLevel::Forward; diff --git a/source/slang/slang-ir-autodiff-fwd.cpp b/source/slang/slang-ir-autodiff-fwd.cpp index ac4e3825a..bc7e03ad3 100644 --- a/source/slang/slang-ir-autodiff-fwd.cpp +++ b/source/slang/slang-ir-autodiff-fwd.cpp @@ -44,6 +44,74 @@ IRFuncType* ForwardDiffTranscriber::differentiateFunctionType(IRBuilder* builder return builder->getFuncType(newParameterTypes, diffReturnType); } +void ForwardDiffTranscriber::generateTrivialFwdDiffFunc(IRFunc* primalFunc, IRFunc* diffFunc) +{ + IRBuilder builder(diffFunc); + builder.setInsertInto(diffFunc); + auto block = builder.emitBlock(); + builder.markInstAsMixedDifferential(block); + + for (auto param : primalFunc->getParams()) + { + transcribeFuncParam(&builder, param, param->getFullType()).differential; + } + List diffParams; + for (auto param : diffFunc->getParams()) + { + diffParams.add(param); + } + auto emitDiffPairVal = [&](IRDifferentialPairTypeBase* pairType) + { + auto primal = builder.emitDefaultConstruct(pairType->getValueType()); + builder.markInstAsPrimal(primal); + auto diff = getDifferentialZeroOfType(&builder, pairType->getValueType()); + builder.markInstAsDifferential(primal); + + auto val = builder.emitMakeDifferentialPair(pairType, primal, diff); + builder.markInstAsMixedDifferential(val); + + return val; + + }; + for (auto param : diffParams) + { + if (auto outType = as(param->getFullType())) + { + if (isRelevantDifferentialPair(outType)) + { + auto pairType = as(outType->getValueType()); + auto val = emitDiffPairVal(pairType); + auto store = builder.emitStore(param, val); + builder.markInstAsMixedDifferential(store); + } + else + { + auto val = builder.emitDefaultConstruct(outType->getValueType()); + builder.markInstAsPrimal(val); + + auto store = builder.emitStore(param, val); + builder.markInstAsPrimal(store); + + } + } + } + if (isRelevantDifferentialPair(diffFunc->getResultType())) + { + auto pairType = as(diffFunc->getResultType()); + auto val = emitDiffPairVal(pairType); + auto returnInst = builder.emitReturn(val); + builder.markInstAsMixedDifferential(val); + builder.markInstAsMixedDifferential(returnInst); + } + else + { + auto retVal = builder.emitDefaultConstruct(diffFunc->getResultType()); + auto returnInst = builder.emitReturn(retVal); + builder.markInstAsPrimal(retVal); + builder.markInstAsPrimal(returnInst); + } +} + // Returns "d" to use as a name hint for variables and parameters. // If no primal name is available, returns a blank string. // @@ -1500,6 +1568,13 @@ SlangResult ForwardDiffTranscriber::prepareFuncForForwardDiff(IRFunc* func) // Transcribe a function definition. InstPair ForwardDiffTranscriber::transcribeFunc(IRBuilder* inBuilder, IRFunc* primalFunc, IRFunc* diffFunc) { + if (primalFunc->findDecoration()) + { + // Generate a trivial implementation for [TreatAsDifferentiable] functions. + generateTrivialFwdDiffFunc(primalFunc, diffFunc); + return InstPair(primalFunc, diffFunc); + } + IRBuilder builder = *inBuilder; builder.setInsertBefore(primalFunc); diff --git a/source/slang/slang-ir-autodiff-fwd.h b/source/slang/slang-ir-autodiff-fwd.h index e9774be49..8fd271fd8 100644 --- a/source/slang/slang-ir-autodiff-fwd.h +++ b/source/slang/slang-ir-autodiff-fwd.h @@ -88,6 +88,8 @@ struct ForwardDiffTranscriber : AutoDiffTranscriberBase virtual IRFuncType* differentiateFunctionType(IRBuilder* builder, IRInst* func, IRFuncType* funcType) override; + void generateTrivialFwdDiffFunc(IRFunc* primalFunc, IRFunc* diffFunc); + // Transcribe a function definition. InstPair transcribeFunc(IRBuilder* inBuilder, IRFunc* primalFunc, IRFunc* diffFunc); diff --git a/source/slang/slang-ir-autodiff-rev.cpp b/source/slang/slang-ir-autodiff-rev.cpp index 7c11a1286..e01d65f4f 100644 --- a/source/slang/slang-ir-autodiff-rev.cpp +++ b/source/slang/slang-ir-autodiff-rev.cpp @@ -295,7 +295,8 @@ namespace Slang // Create an empty func to represent the transcribed func of `origFunc`. InstPair BackwardDiffTranscriberBase::transcribeFuncHeaderImpl(IRBuilder* inBuilder, IRFunc* origFunc) { - if (!isBackwardDifferentiableFunc(origFunc)) + if (!isBackwardDifferentiableFunc(origFunc) && + !origFunc->findDecoration()) return InstPair(nullptr, nullptr); IRBuilder builder = *inBuilder; diff --git a/source/slang/slang-lower-to-ir.cpp b/source/slang/slang-lower-to-ir.cpp index 8164723b6..86df89702 100644 --- a/source/slang/slang-lower-to-ir.cpp +++ b/source/slang/slang-lower-to-ir.cpp @@ -8502,6 +8502,10 @@ struct DeclLoweringVisitor : DeclVisitor { getBuilder()->addBackwardDifferentiableDecoration(irFunc); } + else if (as(modifier)) + { + getBuilder()->addDecoration(irFunc, kIROp_TreatAsDifferentiableDecoration); + } } // For convenience, ensure that any additional global // values that were emitted while outputting the function -- cgit v1.2.3