summaryrefslogtreecommitdiffstats
path: root/source/slang
diff options
context:
space:
mode:
authorYong He <yonghe@outlook.com>2023-03-21 21:29:13 -0700
committerGitHub <noreply@github.com>2023-03-21 21:29:13 -0700
commitd8a40abba5223fbcb56c52b04ccb88c02bbaf79f (patch)
tree3207babbce41957fbd01c3c791fe9957c81f6a09 /source/slang
parent83876733d69582eec6bad26af64a651d40fa43aa (diff)
[TreatAsDifferentiable] functions. (#2720)
Diffstat (limited to 'source/slang')
-rw-r--r--source/slang/slang-ast-modifier.h16
-rw-r--r--source/slang/slang-check-decl.cpp20
-rw-r--r--source/slang/slang-ir-autodiff-fwd.cpp75
-rw-r--r--source/slang/slang-ir-autodiff-fwd.h2
-rw-r--r--source/slang/slang-ir-autodiff-rev.cpp3
-rw-r--r--source/slang/slang-lower-to-ir.cpp4
6 files changed, 104 insertions, 16 deletions
diff --git a/source/slang/slang-ast-modifier.h b/source/slang/slang-ast-modifier.h
index 26303d6ad..0d2e27e5f 100644
--- a/source/slang/slang-ast-modifier.h
+++ b/source/slang/slang-ast-modifier.h
@@ -998,14 +998,6 @@ class ForceInlineAttribute : public Attribute
};
-// A `[TreatAsDifferentiableAttribute]` attribute indicates that a function or an interface
-// should be treated as differentiable in IR validation step.
-//
-class TreatAsDifferentiableAttribute : public Attribute
-{
- SLANG_AST_CLASS(TreatAsDifferentiableAttribute)
-};
-
/// An attribute that marks a type declaration as either allowing or
/// disallowing the type to be inherited from in other modules.
class InheritanceControlAttribute : public Attribute { SLANG_AST_CLASS(InheritanceControlAttribute) };
@@ -1108,6 +1100,14 @@ class AlwaysFoldIntoUseSiteAttribute :public Attribute
SLANG_AST_CLASS(AlwaysFoldIntoUseSiteAttribute)
};
+// A `[TreatAsDifferentiableAttribute]` attribute indicates that a function or an interface
+// should be treated as differentiable in IR validation step.
+//
+class TreatAsDifferentiableAttribute : public DifferentiableAttribute
+{
+ SLANG_AST_CLASS(TreatAsDifferentiableAttribute)
+};
+
/// The `[ForwardDifferentiable]` attribute indicates that a function can be forward-differentiated.
class ForwardDifferentiableAttribute : public DifferentiableAttribute
{
diff --git a/source/slang/slang-check-decl.cpp b/source/slang/slang-check-decl.cpp
index c0253fd2c..eaab43ef8 100644
--- a/source/slang/slang-check-decl.cpp
+++ b/source/slang/slang-check-decl.cpp
@@ -1955,8 +1955,11 @@ namespace Slang
bool hasForwardDerivative = false;
if (requiredMemberDeclRef.getDecl()->hasModifier<BackwardDifferentiableAttribute>())
{
- if (!satisfyingMemberDeclRef.getDecl()->hasModifier<BackwardDifferentiableAttribute>()
- && !satisfyingMemberDeclRef.getDecl()->hasModifier<BackwardDerivativeAttribute>())
+ auto funcDecl = as<FunctionDeclBase>(satisfyingMemberDeclRef.getDecl());
+ if (!funcDecl)
+ return false;
+
+ if (getShared()->getFuncDifferentiableLevel(funcDecl) != FunctionDifferentiableLevel::Backward)
{
// A non-`BackwardDifferentiable` method can't satisfy a `BackwardDifferentiable` requirement and vice versa.
return false;
@@ -1966,12 +1969,12 @@ namespace Slang
}
else if (requiredMemberDeclRef.getDecl()->hasModifier<ForwardDifferentiableAttribute>())
{
- if (!satisfyingMemberDeclRef.getDecl()->hasModifier<ForwardDifferentiableAttribute>()
- && !satisfyingMemberDeclRef.getDecl()->hasModifier<ForwardDerivativeAttribute>()
- && !satisfyingMemberDeclRef.getDecl()->hasModifier<BackwardDifferentiableAttribute>()
- && !satisfyingMemberDeclRef.getDecl()->hasModifier<BackwardDerivativeAttribute>())
+ auto funcDecl = as<FunctionDeclBase>(satisfyingMemberDeclRef.getDecl());
+ if (!funcDecl)
+ return false;
+ if (getShared()->getFuncDifferentiableLevel(funcDecl) == FunctionDifferentiableLevel::None)
{
- // A non-`ForwardDifferentiable` method can't satisfy a `ForwardDifferentiable` requirement and vice versa.
+ // A non-`BackwardDifferentiable` method can't satisfy a `BackwardDifferentiable` requirement and vice versa.
return false;
}
hasForwardDerivative = true;
@@ -6674,6 +6677,9 @@ namespace Slang
if (func->findModifier<BackwardDerivativeAttribute>())
return FunctionDifferentiableLevel::Backward;
+ if (func->findModifier<TreatAsDifferentiableAttribute>())
+ return FunctionDifferentiableLevel::Backward;
+
FunctionDifferentiableLevel diffLevel = FunctionDifferentiableLevel::None;
if (func->findModifier<DifferentiableAttribute>())
diffLevel = FunctionDifferentiableLevel::Forward;
diff --git a/source/slang/slang-ir-autodiff-fwd.cpp b/source/slang/slang-ir-autodiff-fwd.cpp
index ac4e3825a..bc7e03ad3 100644
--- a/source/slang/slang-ir-autodiff-fwd.cpp
+++ b/source/slang/slang-ir-autodiff-fwd.cpp
@@ -44,6 +44,74 @@ IRFuncType* ForwardDiffTranscriber::differentiateFunctionType(IRBuilder* builder
return builder->getFuncType(newParameterTypes, diffReturnType);
}
+void ForwardDiffTranscriber::generateTrivialFwdDiffFunc(IRFunc* primalFunc, IRFunc* diffFunc)
+{
+ IRBuilder builder(diffFunc);
+ builder.setInsertInto(diffFunc);
+ auto block = builder.emitBlock();
+ builder.markInstAsMixedDifferential(block);
+
+ for (auto param : primalFunc->getParams())
+ {
+ transcribeFuncParam(&builder, param, param->getFullType()).differential;
+ }
+ List<IRParam*> diffParams;
+ for (auto param : diffFunc->getParams())
+ {
+ diffParams.add(param);
+ }
+ auto emitDiffPairVal = [&](IRDifferentialPairTypeBase* pairType)
+ {
+ auto primal = builder.emitDefaultConstruct(pairType->getValueType());
+ builder.markInstAsPrimal(primal);
+ auto diff = getDifferentialZeroOfType(&builder, pairType->getValueType());
+ builder.markInstAsDifferential(primal);
+
+ auto val = builder.emitMakeDifferentialPair(pairType, primal, diff);
+ builder.markInstAsMixedDifferential(val);
+
+ return val;
+
+ };
+ for (auto param : diffParams)
+ {
+ if (auto outType = as<IROutTypeBase>(param->getFullType()))
+ {
+ if (isRelevantDifferentialPair(outType))
+ {
+ auto pairType = as<IRDifferentialPairTypeBase>(outType->getValueType());
+ auto val = emitDiffPairVal(pairType);
+ auto store = builder.emitStore(param, val);
+ builder.markInstAsMixedDifferential(store);
+ }
+ else
+ {
+ auto val = builder.emitDefaultConstruct(outType->getValueType());
+ builder.markInstAsPrimal(val);
+
+ auto store = builder.emitStore(param, val);
+ builder.markInstAsPrimal(store);
+
+ }
+ }
+ }
+ if (isRelevantDifferentialPair(diffFunc->getResultType()))
+ {
+ auto pairType = as<IRDifferentialPairTypeBase>(diffFunc->getResultType());
+ auto val = emitDiffPairVal(pairType);
+ auto returnInst = builder.emitReturn(val);
+ builder.markInstAsMixedDifferential(val);
+ builder.markInstAsMixedDifferential(returnInst);
+ }
+ else
+ {
+ auto retVal = builder.emitDefaultConstruct(diffFunc->getResultType());
+ auto returnInst = builder.emitReturn(retVal);
+ builder.markInstAsPrimal(retVal);
+ builder.markInstAsPrimal(returnInst);
+ }
+}
+
// Returns "d<var-name>" to use as a name hint for variables and parameters.
// If no primal name is available, returns a blank string.
//
@@ -1500,6 +1568,13 @@ SlangResult ForwardDiffTranscriber::prepareFuncForForwardDiff(IRFunc* func)
// Transcribe a function definition.
InstPair ForwardDiffTranscriber::transcribeFunc(IRBuilder* inBuilder, IRFunc* primalFunc, IRFunc* diffFunc)
{
+ if (primalFunc->findDecoration<IRTreatAsDifferentiableDecoration>())
+ {
+ // Generate a trivial implementation for [TreatAsDifferentiable] functions.
+ generateTrivialFwdDiffFunc(primalFunc, diffFunc);
+ return InstPair(primalFunc, diffFunc);
+ }
+
IRBuilder builder = *inBuilder;
builder.setInsertBefore(primalFunc);
diff --git a/source/slang/slang-ir-autodiff-fwd.h b/source/slang/slang-ir-autodiff-fwd.h
index e9774be49..8fd271fd8 100644
--- a/source/slang/slang-ir-autodiff-fwd.h
+++ b/source/slang/slang-ir-autodiff-fwd.h
@@ -88,6 +88,8 @@ struct ForwardDiffTranscriber : AutoDiffTranscriberBase
virtual IRFuncType* differentiateFunctionType(IRBuilder* builder, IRInst* func, IRFuncType* funcType) override;
+ void generateTrivialFwdDiffFunc(IRFunc* primalFunc, IRFunc* diffFunc);
+
// Transcribe a function definition.
InstPair transcribeFunc(IRBuilder* inBuilder, IRFunc* primalFunc, IRFunc* diffFunc);
diff --git a/source/slang/slang-ir-autodiff-rev.cpp b/source/slang/slang-ir-autodiff-rev.cpp
index 7c11a1286..e01d65f4f 100644
--- a/source/slang/slang-ir-autodiff-rev.cpp
+++ b/source/slang/slang-ir-autodiff-rev.cpp
@@ -295,7 +295,8 @@ namespace Slang
// Create an empty func to represent the transcribed func of `origFunc`.
InstPair BackwardDiffTranscriberBase::transcribeFuncHeaderImpl(IRBuilder* inBuilder, IRFunc* origFunc)
{
- if (!isBackwardDifferentiableFunc(origFunc))
+ if (!isBackwardDifferentiableFunc(origFunc) &&
+ !origFunc->findDecoration<IRTreatAsDifferentiableDecoration>())
return InstPair(nullptr, nullptr);
IRBuilder builder = *inBuilder;
diff --git a/source/slang/slang-lower-to-ir.cpp b/source/slang/slang-lower-to-ir.cpp
index 8164723b6..86df89702 100644
--- a/source/slang/slang-lower-to-ir.cpp
+++ b/source/slang/slang-lower-to-ir.cpp
@@ -8502,6 +8502,10 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo>
{
getBuilder()->addBackwardDifferentiableDecoration(irFunc);
}
+ else if (as<TreatAsDifferentiableAttribute>(modifier))
+ {
+ getBuilder()->addDecoration(irFunc, kIROp_TreatAsDifferentiableDecoration);
+ }
}
// For convenience, ensure that any additional global
// values that were emitted while outputting the function