diff options
| author | Yong He <yonghe@outlook.com> | 2023-03-21 21:29:13 -0700 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2023-03-21 21:29:13 -0700 |
| commit | d8a40abba5223fbcb56c52b04ccb88c02bbaf79f (patch) | |
| tree | 3207babbce41957fbd01c3c791fe9957c81f6a09 /source/slang | |
| parent | 83876733d69582eec6bad26af64a651d40fa43aa (diff) | |
[TreatAsDifferentiable] functions. (#2720)
Diffstat (limited to 'source/slang')
| -rw-r--r-- | source/slang/slang-ast-modifier.h | 16 | ||||
| -rw-r--r-- | source/slang/slang-check-decl.cpp | 20 | ||||
| -rw-r--r-- | source/slang/slang-ir-autodiff-fwd.cpp | 75 | ||||
| -rw-r--r-- | source/slang/slang-ir-autodiff-fwd.h | 2 | ||||
| -rw-r--r-- | source/slang/slang-ir-autodiff-rev.cpp | 3 | ||||
| -rw-r--r-- | source/slang/slang-lower-to-ir.cpp | 4 |
6 files changed, 104 insertions, 16 deletions
diff --git a/source/slang/slang-ast-modifier.h b/source/slang/slang-ast-modifier.h index 26303d6ad..0d2e27e5f 100644 --- a/source/slang/slang-ast-modifier.h +++ b/source/slang/slang-ast-modifier.h @@ -998,14 +998,6 @@ class ForceInlineAttribute : public Attribute }; -// A `[TreatAsDifferentiableAttribute]` attribute indicates that a function or an interface -// should be treated as differentiable in IR validation step. -// -class TreatAsDifferentiableAttribute : public Attribute -{ - SLANG_AST_CLASS(TreatAsDifferentiableAttribute) -}; - /// An attribute that marks a type declaration as either allowing or /// disallowing the type to be inherited from in other modules. class InheritanceControlAttribute : public Attribute { SLANG_AST_CLASS(InheritanceControlAttribute) }; @@ -1108,6 +1100,14 @@ class AlwaysFoldIntoUseSiteAttribute :public Attribute SLANG_AST_CLASS(AlwaysFoldIntoUseSiteAttribute) }; +// A `[TreatAsDifferentiableAttribute]` attribute indicates that a function or an interface +// should be treated as differentiable in IR validation step. +// +class TreatAsDifferentiableAttribute : public DifferentiableAttribute +{ + SLANG_AST_CLASS(TreatAsDifferentiableAttribute) +}; + /// The `[ForwardDifferentiable]` attribute indicates that a function can be forward-differentiated. class ForwardDifferentiableAttribute : public DifferentiableAttribute { diff --git a/source/slang/slang-check-decl.cpp b/source/slang/slang-check-decl.cpp index c0253fd2c..eaab43ef8 100644 --- a/source/slang/slang-check-decl.cpp +++ b/source/slang/slang-check-decl.cpp @@ -1955,8 +1955,11 @@ namespace Slang bool hasForwardDerivative = false; if (requiredMemberDeclRef.getDecl()->hasModifier<BackwardDifferentiableAttribute>()) { - if (!satisfyingMemberDeclRef.getDecl()->hasModifier<BackwardDifferentiableAttribute>() - && !satisfyingMemberDeclRef.getDecl()->hasModifier<BackwardDerivativeAttribute>()) + auto funcDecl = as<FunctionDeclBase>(satisfyingMemberDeclRef.getDecl()); + if (!funcDecl) + return false; + + if (getShared()->getFuncDifferentiableLevel(funcDecl) != FunctionDifferentiableLevel::Backward) { // A non-`BackwardDifferentiable` method can't satisfy a `BackwardDifferentiable` requirement and vice versa. return false; @@ -1966,12 +1969,12 @@ namespace Slang } else if (requiredMemberDeclRef.getDecl()->hasModifier<ForwardDifferentiableAttribute>()) { - if (!satisfyingMemberDeclRef.getDecl()->hasModifier<ForwardDifferentiableAttribute>() - && !satisfyingMemberDeclRef.getDecl()->hasModifier<ForwardDerivativeAttribute>() - && !satisfyingMemberDeclRef.getDecl()->hasModifier<BackwardDifferentiableAttribute>() - && !satisfyingMemberDeclRef.getDecl()->hasModifier<BackwardDerivativeAttribute>()) + auto funcDecl = as<FunctionDeclBase>(satisfyingMemberDeclRef.getDecl()); + if (!funcDecl) + return false; + if (getShared()->getFuncDifferentiableLevel(funcDecl) == FunctionDifferentiableLevel::None) { - // A non-`ForwardDifferentiable` method can't satisfy a `ForwardDifferentiable` requirement and vice versa. + // A non-`BackwardDifferentiable` method can't satisfy a `BackwardDifferentiable` requirement and vice versa. return false; } hasForwardDerivative = true; @@ -6674,6 +6677,9 @@ namespace Slang if (func->findModifier<BackwardDerivativeAttribute>()) return FunctionDifferentiableLevel::Backward; + if (func->findModifier<TreatAsDifferentiableAttribute>()) + return FunctionDifferentiableLevel::Backward; + FunctionDifferentiableLevel diffLevel = FunctionDifferentiableLevel::None; if (func->findModifier<DifferentiableAttribute>()) diffLevel = FunctionDifferentiableLevel::Forward; diff --git a/source/slang/slang-ir-autodiff-fwd.cpp b/source/slang/slang-ir-autodiff-fwd.cpp index ac4e3825a..bc7e03ad3 100644 --- a/source/slang/slang-ir-autodiff-fwd.cpp +++ b/source/slang/slang-ir-autodiff-fwd.cpp @@ -44,6 +44,74 @@ IRFuncType* ForwardDiffTranscriber::differentiateFunctionType(IRBuilder* builder return builder->getFuncType(newParameterTypes, diffReturnType); } +void ForwardDiffTranscriber::generateTrivialFwdDiffFunc(IRFunc* primalFunc, IRFunc* diffFunc) +{ + IRBuilder builder(diffFunc); + builder.setInsertInto(diffFunc); + auto block = builder.emitBlock(); + builder.markInstAsMixedDifferential(block); + + for (auto param : primalFunc->getParams()) + { + transcribeFuncParam(&builder, param, param->getFullType()).differential; + } + List<IRParam*> diffParams; + for (auto param : diffFunc->getParams()) + { + diffParams.add(param); + } + auto emitDiffPairVal = [&](IRDifferentialPairTypeBase* pairType) + { + auto primal = builder.emitDefaultConstruct(pairType->getValueType()); + builder.markInstAsPrimal(primal); + auto diff = getDifferentialZeroOfType(&builder, pairType->getValueType()); + builder.markInstAsDifferential(primal); + + auto val = builder.emitMakeDifferentialPair(pairType, primal, diff); + builder.markInstAsMixedDifferential(val); + + return val; + + }; + for (auto param : diffParams) + { + if (auto outType = as<IROutTypeBase>(param->getFullType())) + { + if (isRelevantDifferentialPair(outType)) + { + auto pairType = as<IRDifferentialPairTypeBase>(outType->getValueType()); + auto val = emitDiffPairVal(pairType); + auto store = builder.emitStore(param, val); + builder.markInstAsMixedDifferential(store); + } + else + { + auto val = builder.emitDefaultConstruct(outType->getValueType()); + builder.markInstAsPrimal(val); + + auto store = builder.emitStore(param, val); + builder.markInstAsPrimal(store); + + } + } + } + if (isRelevantDifferentialPair(diffFunc->getResultType())) + { + auto pairType = as<IRDifferentialPairTypeBase>(diffFunc->getResultType()); + auto val = emitDiffPairVal(pairType); + auto returnInst = builder.emitReturn(val); + builder.markInstAsMixedDifferential(val); + builder.markInstAsMixedDifferential(returnInst); + } + else + { + auto retVal = builder.emitDefaultConstruct(diffFunc->getResultType()); + auto returnInst = builder.emitReturn(retVal); + builder.markInstAsPrimal(retVal); + builder.markInstAsPrimal(returnInst); + } +} + // Returns "d<var-name>" to use as a name hint for variables and parameters. // If no primal name is available, returns a blank string. // @@ -1500,6 +1568,13 @@ SlangResult ForwardDiffTranscriber::prepareFuncForForwardDiff(IRFunc* func) // Transcribe a function definition. InstPair ForwardDiffTranscriber::transcribeFunc(IRBuilder* inBuilder, IRFunc* primalFunc, IRFunc* diffFunc) { + if (primalFunc->findDecoration<IRTreatAsDifferentiableDecoration>()) + { + // Generate a trivial implementation for [TreatAsDifferentiable] functions. + generateTrivialFwdDiffFunc(primalFunc, diffFunc); + return InstPair(primalFunc, diffFunc); + } + IRBuilder builder = *inBuilder; builder.setInsertBefore(primalFunc); diff --git a/source/slang/slang-ir-autodiff-fwd.h b/source/slang/slang-ir-autodiff-fwd.h index e9774be49..8fd271fd8 100644 --- a/source/slang/slang-ir-autodiff-fwd.h +++ b/source/slang/slang-ir-autodiff-fwd.h @@ -88,6 +88,8 @@ struct ForwardDiffTranscriber : AutoDiffTranscriberBase virtual IRFuncType* differentiateFunctionType(IRBuilder* builder, IRInst* func, IRFuncType* funcType) override; + void generateTrivialFwdDiffFunc(IRFunc* primalFunc, IRFunc* diffFunc); + // Transcribe a function definition. InstPair transcribeFunc(IRBuilder* inBuilder, IRFunc* primalFunc, IRFunc* diffFunc); diff --git a/source/slang/slang-ir-autodiff-rev.cpp b/source/slang/slang-ir-autodiff-rev.cpp index 7c11a1286..e01d65f4f 100644 --- a/source/slang/slang-ir-autodiff-rev.cpp +++ b/source/slang/slang-ir-autodiff-rev.cpp @@ -295,7 +295,8 @@ namespace Slang // Create an empty func to represent the transcribed func of `origFunc`. InstPair BackwardDiffTranscriberBase::transcribeFuncHeaderImpl(IRBuilder* inBuilder, IRFunc* origFunc) { - if (!isBackwardDifferentiableFunc(origFunc)) + if (!isBackwardDifferentiableFunc(origFunc) && + !origFunc->findDecoration<IRTreatAsDifferentiableDecoration>()) return InstPair(nullptr, nullptr); IRBuilder builder = *inBuilder; diff --git a/source/slang/slang-lower-to-ir.cpp b/source/slang/slang-lower-to-ir.cpp index 8164723b6..86df89702 100644 --- a/source/slang/slang-lower-to-ir.cpp +++ b/source/slang/slang-lower-to-ir.cpp @@ -8502,6 +8502,10 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo> { getBuilder()->addBackwardDifferentiableDecoration(irFunc); } + else if (as<TreatAsDifferentiableAttribute>(modifier)) + { + getBuilder()->addDecoration(irFunc, kIROp_TreatAsDifferentiableDecoration); + } } // For convenience, ensure that any additional global // values that were emitted while outputting the function |
