summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--docs/user-guide/07-autodiff.md7
-rw-r--r--source/slang/diff.meta.slang195
-rw-r--r--source/slang/hlsl.meta.slang2
-rw-r--r--source/slang/slang-ast-expr.h2
-rw-r--r--source/slang/slang-ast-support-types.cpp7
-rw-r--r--source/slang/slang-ast-support-types.h15
-rw-r--r--source/slang/slang-check-decl.cpp192
-rw-r--r--source/slang/slang-check-expr.cpp25
-rw-r--r--source/slang/slang-check-impl.h6
-rw-r--r--source/slang/slang-check-modifier.cpp11
-rw-r--r--source/slang/slang-check-overload.cpp2
-rw-r--r--tests/autodiff-dstdlib/dstdlib-elementary-bwd.slang12
-rw-r--r--tests/autodiff-dstdlib/dstdlib-elementary-bwd.slang.expected.txt5
-rw-r--r--tests/autodiff-dstdlib/dstdlib-sqrt.slang11
-rw-r--r--tests/autodiff-dstdlib/dstdlib-sqrt.slang.expected.txt6
15 files changed, 306 insertions, 192 deletions
diff --git a/docs/user-guide/07-autodiff.md b/docs/user-guide/07-autodiff.md
index 244ebb47b..6b21ee28c 100644
--- a/docs/user-guide/07-autodiff.md
+++ b/docs/user-guide/07-autodiff.md
@@ -483,10 +483,9 @@ The following builtin functions are backward differentiable and both their forwa
- Hyperbolic functions: `sinh`, `cosh`, `tanh`
- Exponential and logarithmic functions: `exp`, `exp2`, `pow`, `log`, `log2`, `log10`
- Vector functions: `dot`, `cross`, `length`, `distance`, `normalize`, `reflect`, `refract`
-- Matrix transform: `mul(matrix, vector)`, `mul(vector, matrix)`, `mul(matrix, matrix)`, `transpose`
-
-Derivatives for the following legacy HLSL intrinsic functions are not implemented:
-- `dst`, `lit`,
+- Matrix transforms: `mul(matrix, vector)`, `mul(vector, matrix)`, `mul(matrix, matrix)`
+- Matrix operations: `transpose`, `determinant`
+- Legacy blending and lighting intrinsics: `dst`, `lit`
## Excluding Parameters From Differentiation
diff --git a/source/slang/diff.meta.slang b/source/slang/diff.meta.slang
index c303b39d9..54f927816 100644
--- a/source/slang/diff.meta.slang
+++ b/source/slang/diff.meta.slang
@@ -380,6 +380,24 @@ void __d_cross(inout DifferentialPair<vector<T, 3>> a, inout DifferentialPair<ve
} \
return DifferentialPair<vector<T, N>>(result, d_result); \
} \
+ __generic<T : __BuiltinFloatingPointType, let M : int, let N : int> \
+ [ForwardDerivativeOf(NAME)] \
+ DifferentialPair<matrix<T, M, N>> __d_##NAME##_matrix( \
+ DifferentialPair<matrix<T, M, N>> dpx, DifferentialPair<matrix<T, M, N>> dpy) \
+ { \
+ matrix<T, M, N> result; \
+ matrix<T, M, N>.Differential d_result; \
+ [ForceUnroll] for (int i = 0; i < M; ++i) \
+ [ForceUnroll] for (int j = 0; j < N; ++j) \
+ { \
+ DifferentialPair<T> dp_elem = __d_##NAME( \
+ DifferentialPair<T>(dpx.p[i][j], __slang_noop_cast<T.Differential>(dpx.d[i][j])), \
+ DifferentialPair<T>(dpy.p[i][j], __slang_noop_cast<T.Differential>(dpy.d[i][j]))); \
+ result[i][j] = dp_elem.p; \
+ d_result[i][j] = __slang_noop_cast<T>(dp_elem.d); \
+ } \
+ return DifferentialPair<matrix<T, M, N>>(result, d_result); \
+ } \
__generic<T : __BuiltinFloatingPointType, let N : int> \
[BackwardDerivativeOf(NAME)] \
void __d_##NAME##_vector( \
@@ -398,6 +416,26 @@ void __d_cross(inout DifferentialPair<vector<T, 3>> a, inout DifferentialPair<ve
} \
dpx = diffPair(dpx.p, left_d_result); \
dpy = diffPair(dpy.p, right_d_result); \
+ } \
+ __generic<T : __BuiltinFloatingPointType, let M : int, let N : int> \
+ [BackwardDerivativeOf(NAME)] \
+ void __d_##NAME##_matrix( \
+ inout DifferentialPair<matrix<T, M, N>> dpx, \
+ inout DifferentialPair<matrix<T, M, N>> dpy, \
+ matrix<T, M, N>.Differential dOut) \
+ { \
+ matrix<T, M, N>.Differential left_d_result, right_d_result; \
+ [ForceUnroll] for (int i = 0; i < M; ++i) \
+ [ForceUnroll] for (int j = 0; j < N; ++j) \
+ { \
+ DifferentialPair<T> left_dp = diffPair(dpx.p[i][j], T.dzero()); \
+ DifferentialPair<T> right_dp = diffPair(dpy.p[i][j], T.dzero()); \
+ __d_##NAME(left_dp, right_dp, __slang_noop_cast<T.Differential>(dOut[i][j])); \
+ left_d_result[i][j] = __slang_noop_cast<T>(left_dp.d); \
+ right_d_result[i][j] = __slang_noop_cast<T>(right_dp.d); \
+ } \
+ dpx = diffPair(dpx.p, left_d_result); \
+ dpy = diffPair(dpy.p, right_d_result); \
}
#define VECTOR_MATRIX_TERNARY_DIFF_IMPL(NAME) \
@@ -407,7 +445,7 @@ void __d_cross(inout DifferentialPair<vector<T, 3>> a, inout DifferentialPair<ve
DifferentialPair<vector<T, N>> dpx, \
DifferentialPair<vector<T, N>> dpy, \
DifferentialPair<vector<T, N>> dpz) \
-{ \
+ { \
vector<T, N> result; \
vector<T, N>.Differential d_result; \
[ForceUnroll] for (int i = 0; i < N; ++i) \
@@ -421,6 +459,27 @@ void __d_cross(inout DifferentialPair<vector<T, 3>> a, inout DifferentialPair<ve
} \
return DifferentialPair<vector<T, N>>(result, d_result); \
} \
+ __generic<T : __BuiltinFloatingPointType, let M : int, let N : int> \
+ [ForwardDerivativeOf(NAME)] \
+ DifferentialPair<matrix<T, M, N>> __d_##NAME##_matrix( \
+ DifferentialPair<matrix<T, M, N>> dpx, \
+ DifferentialPair<matrix<T, M, N>> dpy, \
+ DifferentialPair<matrix<T, M, N>> dpz) \
+ { \
+ matrix<T, M, N> result; \
+ matrix<T, M, N>.Differential d_result; \
+ [ForceUnroll] for (int i = 0; i < M; ++i) \
+ [ForceUnroll] for (int j = 0; j < N; ++j) \
+ { \
+ DifferentialPair<T> dp_elem = __d_##NAME( \
+ DifferentialPair<T>(dpx.p[i][j], __slang_noop_cast<T.Differential>(dpx.d[i][j])), \
+ DifferentialPair<T>(dpy.p[i][j], __slang_noop_cast<T.Differential>(dpy.d[i][j])), \
+ DifferentialPair<T>(dpz.p[i][j], __slang_noop_cast<T.Differential>(dpz.d[i][j]))); \
+ result[i][j] = dp_elem.p; \
+ d_result[i][j] = __slang_noop_cast<T>(dp_elem.d); \
+ } \
+ return DifferentialPair<matrix<T, M, N>>(result, d_result); \
+ } \
__generic<T : __BuiltinFloatingPointType, let N : int> \
[BackwardDerivativeOf(NAME)] \
void __d_##NAME##_vector( \
@@ -444,6 +503,31 @@ void __d_cross(inout DifferentialPair<vector<T, 3>> a, inout DifferentialPair<ve
dpx = diffPair(dpx.p, left_d_result); \
dpy = diffPair(dpy.p, middle_d_result); \
dpz = diffPair(dpz.p, right_d_result); \
+ } \
+ __generic<T : __BuiltinFloatingPointType, let M : int, let N : int> \
+ [BackwardDerivativeOf(NAME)] \
+ void __d_##NAME##_matrix( \
+ inout DifferentialPair<matrix<T, M, N>> dpx, \
+ inout DifferentialPair<matrix<T, M, N>> dpy, \
+ inout DifferentialPair<matrix<T, M, N>> dpz, \
+ matrix<T, M, N>.Differential dOut) \
+ { \
+ matrix<T, M, N>.Differential left_d_result, middle_d_result, right_d_result; \
+ [ForceUnroll] for (int i = 0; i < M; ++i) \
+ [ForceUnroll] for (int j = 0; j < N; ++j) \
+ { \
+ DifferentialPair<T> left_dp = diffPair(dpx.p[i][j], T.dzero()); \
+ DifferentialPair<T> middle_dp = diffPair(dpy.p[i][j], T.dzero()); \
+ DifferentialPair<T> right_dp = diffPair(dpz.p[i][j], T.dzero()); \
+ __d_##NAME(left_dp, middle_dp, right_dp, \
+ __slang_noop_cast<T.Differential>(dOut[i][j])); \
+ left_d_result[i][j] = __slang_noop_cast<T>(left_dp.d); \
+ middle_d_result[i][j] = __slang_noop_cast<T>(middle_dp.d); \
+ right_d_result[i][j] = __slang_noop_cast<T>(right_dp.d); \
+ } \
+ dpx = diffPair(dpx.p, left_d_result); \
+ dpy = diffPair(dpy.p, middle_d_result); \
+ dpz = diffPair(dpz.p, right_d_result); \
}
#define UNARY_DERIVATIVE_IMPL(NAME, FWD_DIFF_FUNC, BWD_DIFF_FUNC) \
@@ -999,24 +1083,19 @@ void __d_sincos(DifferentialPair<matrix<T, N, M>> x, out DifferentialPair<matrix
__fwd_diff(__sincos_impl)(x, s, c);
}
-#if 0
-// TODO: this is not working right now since our type system can't resolve
-// the overload to `sincos` in `[BackwardDerivativeOf]` attribute. We need to implement
-// a proper overload resolver for custom backward derivatives.
-
__generic<T: __BuiltinFloatingPointType>
[BackwardDerivativeOf(sincos)]
[ForceInline]
void __d_sincos(inout DifferentialPair<T> x, T.Differential dS, T.Differential dC)
{
- __bwd_diff(__sincos_impl)(x, s, c);
+ __bwd_diff(__sincos_impl)(x, dS, dC);
}
__generic<T: __BuiltinFloatingPointType, let N : int>
[BackwardDerivativeOf(sincos)]
[ForceInline]
void __d_sincos(inout DifferentialPair<vector<T, N>> x, vector<T, N>.Differential dS, vector<T, N>.Differential dC)
{
- __bwd_diff(__sincos_impl)(x, s, c);
+ __bwd_diff(__sincos_impl)(x, dS, dC);
}
__generic<T : __BuiltinFloatingPointType, let N : int, let M : int>
@@ -1024,7 +1103,103 @@ __generic<T : __BuiltinFloatingPointType, let N : int, let M : int>
[ForceInline]
void __d_sincos(inout DifferentialPair<matrix<T, N, M>> x, matrix<T, N, M>.Differential dS, matrix<T, N, M>.Differential dC)
{
- __bwd_diff(__sincos_impl)(x, s, c);
+ __bwd_diff(__sincos_impl)(x, dS, dC);
+}
+
+// dst (obsolete)
+__generic<T : __BuiltinFloatingPointType>
+[BackwardDifferentiable]
+vector<T, 4> __dst_impl(vector<T, 4> src0, vector<T, 4> src1)
+{
+ vector<T, 4> dest;
+ dest.x = T(1.0);
+ dest.y = src0.y * src1.y;
+ dest.z = src0.z;
+ dest.w = src1.w; ;
+ return dest;
+}
+__generic<T : __BuiltinFloatingPointType>
+[ForwardDerivativeOf(dst)]
+[ForceInline]
+DifferentialPair<vector<T, 4>> __d_dst(DifferentialPair<vector<T, 4>> src0, DifferentialPair<vector<T, 4>> src1)
+{
+ return __fwd_diff(__dst_impl)(src0, src1);
+}
+__generic<T : __BuiltinFloatingPointType>
+[BackwardDerivativeOf(dst)]
+[ForceInline]
+void __d_dst(inout DifferentialPair<vector<T, 4>> src0, inout DifferentialPair<vector<T, 4>> src1, vector<T, 4>.Differential dOut)
+{
+ __bwd_diff(__dst_impl)(src0, src1, dOut);
+}
+
+// Legacy lighting function (obsolete)
+__target_intrinsic(hlsl)
+[__readNone]
+[BackwardDifferentiable]
+float4 __lit_impl(float n_dot_l, float n_dot_h, float m)
+{
+ let ambient = 1.0f;
+ let diffuse = max(n_dot_l, 0.0f);
+ let specular = ((n_dot_l < 0.0f || n_dot_h < 0.0) ? 0.0 : pow(n_dot_h, m));
+ return float4(ambient, diffuse, specular, 1.0f);
+}
+[ForwardDerivativeOf(lit)]
+[ForceInline]
+DifferentialPair<float4> __d_lit(DifferentialPair<float> n_dot_l, DifferentialPair<float> n_dot_h, DifferentialPair<float> m)
+{
+ return __fwd_diff(__lit_impl)(n_dot_l, n_dot_h, m);
+}
+[BackwardDerivativeOf(lit)]
+[ForceInline]
+void __d_lit(inout DifferentialPair<float> n_dot_l, inout DifferentialPair<float> n_dot_h, inout DifferentialPair<float> m, float4 dOut)
+{
+ __bwd_diff(__lit_impl)(n_dot_l, n_dot_h, m, dOut);
}
-#endif \ No newline at end of file
+// Matrix determinant
+__generic<T : __BuiltinFloatingPointType, let N : int>
+[BackwardDifferentiable]
+[__readNone]
+T __determinant_impl(matrix<T,N,N> m)
+{
+ if (N == 1)
+ return m[0][0];
+ else if (N == 2)
+ return m[0][0] * m[1][1] - m[0][1] * m[1][0];
+ else if (N == 3)
+ {
+ return m[0][0] * (m[1][1] * m[2][2] - m[2][1] * m[1][2])
+ - m[1][0] * (m[0][1] * m[2][2] - m[2][1] * m[0][2])
+ + m[2][0] * (m[0][1] * m[1][2] - m[1][1] * m[0][2]);
+ }
+ else if (N == 4)
+ {
+ T s00 = m[2][2] * m[3][3] - m[3][2] * m[2][3];
+ T s01 = m[2][1] * m[3][3] - m[3][1] * m[2][3];
+ T s02 = m[2][1] * m[3][2] - m[3][1] * m[2][2];
+ T s03 = m[2][0] * m[3][3] - m[3][0] * m[2][3];
+ T s04 = m[2][0] * m[3][2] - m[3][0] * m[2][2];
+ T s05 = m[2][0] * m[3][1] - m[3][0] * m[2][1];
+
+ return m[0][0] * (m[1][1] * s00 - m[1][2] * s01 + m[1][3] * s02)
+ - m[0][1] * (m[1][0] * s00 - m[1][2] * s03 + m[1][3] * s04)
+ + m[0][2] * (m[1][0] * s01 - m[1][1] * s03 + m[1][3] * s05)
+ - m[0][3] * (m[1][0] * s02 - m[1][1] * s04 + m[1][2] * s05);
+ }
+ return T(0.0);
+}
+__generic<T : __BuiltinFloatingPointType, let N : int>
+[ForwardDerivativeOf(determinant)]
+[ForceInline]
+DifferentialPair<T> __determinant_impl(DifferentialPair<matrix<T,N,N>> m)
+{
+ return __fwd_diff(__determinant_impl)(m);
+}
+__generic<T : __BuiltinFloatingPointType, let N : int>
+[BackwardDerivativeOf(determinant)]
+[ForceInline]
+void __d_determinant(inout DifferentialPair<matrix<T,N,N>> m, T.Differential dOut)
+{
+ __bwd_diff(__determinant_impl)(m, dOut);
+}
diff --git a/source/slang/hlsl.meta.slang b/source/slang/hlsl.meta.slang
index 5a01bc132..633c3d87e 100644
--- a/source/slang/hlsl.meta.slang
+++ b/source/slang/hlsl.meta.slang
@@ -2619,7 +2619,7 @@ float4 lit(float n_dot_l, float n_dot_h, float m)
{
let ambient = 1.0f;
let diffuse = max(n_dot_l, 0.0f);
- let specular = step(0.0f, n_dot_l) * max(n_dot_h * m, 0.0f);
+ let specular = step(0.0f, n_dot_l) * max(pow(n_dot_h, m), 0.0f);
return float4(ambient, diffuse, specular, 1.0f);
}
diff --git a/source/slang/slang-ast-expr.h b/source/slang/slang-ast-expr.h
index d49db89a2..ba0b4ce7a 100644
--- a/source/slang/slang-ast-expr.h
+++ b/source/slang/slang-ast-expr.h
@@ -448,7 +448,7 @@ class OpenRefExpr : public Expr
class HigherOrderInvokeExpr : public Expr
{
SLANG_ABSTRACT_AST_CLASS(HigherOrderInvokeExpr)
- Expr* baseFunction;
+ Expr* baseFunction;
List<Name*> newParameterNames;
};
diff --git a/source/slang/slang-ast-support-types.cpp b/source/slang/slang-ast-support-types.cpp
index aa0513569..3feed7541 100644
--- a/source/slang/slang-ast-support-types.cpp
+++ b/source/slang/slang-ast-support-types.cpp
@@ -36,11 +36,16 @@ void removeModifier(ModifiableSyntaxNode* syntax, Modifier* toRemove)
}
}
-Expr* getInnerMostExprFromHigherOrderExpr(Expr* expr)
+Expr* getInnerMostExprFromHigherOrderExpr(Expr* expr, FunctionDifferentiableLevel& outLevel)
{
HashSet<Expr*> workListSet;
+ outLevel = FunctionDifferentiableLevel::None;
while (auto higherOrder = as<HigherOrderInvokeExpr>(expr))
{
+ if (as<BackwardDifferentiateExpr>(expr))
+ outLevel = FunctionDifferentiableLevel::Backward;
+ else if (as<ForwardDifferentiateExpr>(expr) && outLevel == FunctionDifferentiableLevel::None)
+ outLevel = FunctionDifferentiableLevel::Forward;
if (workListSet.Add(higherOrder))
{
expr = higherOrder->baseFunction;
diff --git a/source/slang/slang-ast-support-types.h b/source/slang/slang-ast-support-types.h
index 5fd9df400..07b3a5eac 100644
--- a/source/slang/slang-ast-support-types.h
+++ b/source/slang/slang-ast-support-types.h
@@ -1537,9 +1537,22 @@ namespace Slang
DMulFunc, ///< The `IDifferentiable.dmul` function requirement
};
+ enum class FunctionDifferentiableLevel
+ {
+ None,
+ Forward,
+ Backward
+ };
+
/// Get the inner most expr from an higher order expr chain, e.g. `__fwd_diff(__fwd_diff(f))`'s
/// inner most expr is `f`.
- Expr* getInnerMostExprFromHigherOrderExpr(Expr* expr);
+ Expr* getInnerMostExprFromHigherOrderExpr(Expr* expr, FunctionDifferentiableLevel& outDiffLevel);
+ inline Expr* getInnerMostExprFromHigherOrderExpr(Expr* expr)
+ {
+ FunctionDifferentiableLevel level;
+ return getInnerMostExprFromHigherOrderExpr(expr, level);
+ }
+
/// Get the operator name from the higher order invoke expr.
UnownedStringSlice getHigherOrderOperatorName(HigherOrderInvokeExpr* expr);
diff --git a/source/slang/slang-check-decl.cpp b/source/slang/slang-check-decl.cpp
index a1d5acfb0..7c42c1892 100644
--- a/source/slang/slang-check-decl.cpp
+++ b/source/slang/slang-check-decl.cpp
@@ -4663,7 +4663,8 @@ namespace Slang
TDerivativeAttr* attr,
const List<Expr*>& imaginaryArguments)
{
- auto invokeExpr = visitor->constructUncheckedInvokeExpr(attr->funcExpr, imaginaryArguments);
+ auto checkedFuncExpr = visitor->dispatchExpr(attr->funcExpr, *visitor);
+ auto invokeExpr = visitor->constructUncheckedInvokeExpr(checkedFuncExpr, imaginaryArguments);
auto resolved = visitor->ResolveInvoke(invokeExpr);
if (auto resolvedInvoke = as<InvokeExpr>(resolved))
{
@@ -4690,38 +4691,34 @@ namespace Slang
return "BackwardDerivative";
}
- List<Expr*> getImaginaryArgsToForwardDerivative(SemanticsVisitor* visitor, FunctionDeclBase* originalFuncDecl, SourceLoc loc)
+ List<Expr*> getImaginaryArgsToFunc(ASTBuilder* astBuilder, FunctionDeclBase* func, SourceLoc loc)
{
List<Expr*> imaginaryArguments;
- for (auto param : originalFuncDecl->getParameters())
+ for (auto param : func->getParameters())
{
- auto arg = visitor->getASTBuilder()->create<VarExpr>();
+ auto arg = astBuilder->create<VarExpr>();
arg->declRef.decl = param;
arg->type.isLeftValue = param->findModifier<OutModifier>() ? true : false;
arg->type.type = param->getType();
arg->loc = loc;
- if (auto pairType = visitor->getDifferentialPairType(param->getType()))
- {
- arg->type.type = pairType;
- }
imaginaryArguments.add(arg);
}
return imaginaryArguments;
}
- List<Expr*> getImaginaryArgsToOriginalFuncFromForwardDerivativeFunc(ASTBuilder* astBuilder, FunctionDeclBase* fwdDiffFunc, SourceLoc loc)
+ List<Expr*> getImaginaryArgsToForwardDerivative(SemanticsVisitor* visitor, FunctionDeclBase* originalFuncDecl, SourceLoc loc)
{
List<Expr*> imaginaryArguments;
- for (auto param : fwdDiffFunc->getParameters())
+ for (auto param : originalFuncDecl->getParameters())
{
- auto arg = astBuilder->create<VarExpr>();
+ auto arg = visitor->getASTBuilder()->create<VarExpr>();
arg->declRef.decl = param;
arg->type.isLeftValue = param->findModifier<OutModifier>() ? true : false;
arg->type.type = param->getType();
arg->loc = loc;
- if (auto pairType = as<DifferentialPairType>(param->getType()))
+ if (auto pairType = visitor->getDifferentialPairType(param->getType()))
{
- arg->type.type = pairType->getPrimalType();
+ arg->type.type = pairType;
}
imaginaryArguments.add(arg);
}
@@ -4731,6 +4728,11 @@ namespace Slang
List<Expr*> getImaginaryArgsToBackwardDerivative(SemanticsVisitor* visitor, FunctionDeclBase* originalFuncDecl, SourceLoc loc)
{
List<Expr*> imaginaryArguments;
+ auto isOutParam = [&](ParamDecl* param)
+ {
+ return param->findModifier<OutModifier>() != nullptr && param->findModifier<InModifier>() == nullptr;
+ };
+
for (auto param : originalFuncDecl->getParameters())
{
auto arg = visitor->getASTBuilder()->create<VarExpr>();
@@ -4738,16 +4740,23 @@ namespace Slang
arg->type.isLeftValue = param->findModifier<OutModifier>() ? true : false;
arg->type.type = param->getType();
arg->loc = loc;
- if (auto pairType = visitor->getDifferentialPairType(param->getType()))
+ if (auto pairType = as<DifferentialPairType>(visitor->getDifferentialPairType(param->getType())))
{
arg->type.type = pairType;
- if (auto diffPairType = as<DifferentialPairType>(pairType))
+ if (isOutParam(param))
{
- if (param->findModifier<OutModifier>() != nullptr && param->findModifier<InModifier>() == nullptr)
- {
- arg->type.isLeftValue = false;
- arg->type.type = diffPairType->getPrimalType();
- }
+ // out T -> in T.Differential
+ arg->type.isLeftValue = false;
+ arg->type.type = visitor->tryGetDifferentialType(
+ visitor->getASTBuilder(), pairType->getPrimalType());
+ }
+ }
+ else
+ {
+ if (isOutParam(param))
+ {
+ // Skip non-differentiable out params.
+ continue;
}
}
imaginaryArguments.add(arg);
@@ -4763,38 +4772,6 @@ namespace Slang
return imaginaryArguments;
}
- List<Expr*> getImaginaryArgsToOriginalFuncFromBackwardDerivativeFunc(ASTBuilder* astBuilder, FunctionDeclBase* bwdDiffFunc, SourceLoc loc)
- {
- // Note: it isn't always possible to construct original arguments from
- // backward propagation arguments because backward propagation function
- // may drop certain parameters.
- List<Expr*> imaginaryArguments;
- for (auto param : bwdDiffFunc->getParameters())
- {
- auto arg = astBuilder->create<VarExpr>();
- arg->declRef.decl = param;
- arg->type.isLeftValue = param->findModifier<OutModifier>() ? true : false;
- arg->type.type = param->getType();
- arg->loc = loc;
- if (auto pairType = as<DifferentialPairType>(param->getType()))
- {
- if (param->findModifier<OutModifier>() != nullptr && param->findModifier<InModifier>() == nullptr)
- {
- arg->type.isLeftValue = false;
- }
- arg->type.type = pairType->getPrimalType();
- }
- imaginaryArguments.add(arg);
- }
- // Assume the last parameter is `dOut`.
- // This is not true if the function returns a non-differentiable value.
- // However in that uncommon case we just fail the overload resolution
- // and require the user to provide disambiguate themselves.
- if (imaginaryArguments.getCount())
- imaginaryArguments.fastRemoveAt(imaginaryArguments.getCount() - 1);
- return imaginaryArguments;
- }
-
// This helper function is needed to workaround a gcc bug.
// Remove when we upgrade to a newer version of gcc.
template <typename T>
@@ -4803,76 +4780,41 @@ namespace Slang
return decl->findModifier<T>();
}
- template <typename TDerivativeAttr, typename TDerivativeOfAttr>
+ template <typename TDerivativeAttr, typename TDifferentiateExpr, typename TDerivativeOfAttr>
void checkDerivativeOfAttributeImpl(
SemanticsVisitor* visitor,
FunctionDeclBase* funcDecl,
TDerivativeOfAttr* derivativeOfAttr,
- DeclAssociationKind assocKind,
- const List<Expr*>& imaginaryArgsToOriginal)
+ DeclAssociationKind assocKind)
{
DeclRef<Decl> calleeDeclRef;
- auto calleeDeclRefExpr = as<DeclRefExpr>(derivativeOfAttr->funcExpr);
- if (!calleeDeclRefExpr)
+ DeclRefExpr* calleeDeclRefExpr = nullptr;
+ DifferentiateExpr* diffFuncExpr = visitor->getASTBuilder()->create<TDifferentiateExpr>();
+ diffFuncExpr->baseFunction = derivativeOfAttr->funcExpr;
+ diffFuncExpr->loc = derivativeOfAttr->loc;
+ Expr* checkedDiffFuncExpr = visitor->dispatchExpr(diffFuncExpr, *visitor);
+ if (!checkedDiffFuncExpr)
{
- auto invokeExpr = visitor->constructUncheckedInvokeExpr(derivativeOfAttr->funcExpr, imaginaryArgsToOriginal);
- auto resolved = visitor->ResolveInvoke(invokeExpr);
- if (auto resolvedInvoke = as<InvokeExpr>(resolved))
- {
- calleeDeclRefExpr = as<DeclRefExpr>(resolvedInvoke->functionExpr);
- }
+ visitor->getSink()->diagnose(derivativeOfAttr, Diagnostics::cannotResolveOriginalFunctionForDerivative);
+ return;
+ }
+ List<Expr*> imaginaryArgs = getImaginaryArgsToFunc(visitor->getASTBuilder(), funcDecl, derivativeOfAttr->loc);
+ auto invokeExpr = visitor->constructUncheckedInvokeExpr(checkedDiffFuncExpr, imaginaryArgs);
+ auto resolved = visitor->ResolveInvoke(invokeExpr);
+ if (auto resolvedInvoke = as<InvokeExpr>(resolved))
+ {
+ auto resolvedDiffFuncExpr = as<DifferentiateExpr>(resolvedInvoke->functionExpr);
+ if (resolvedDiffFuncExpr)
+ calleeDeclRefExpr = as<DeclRefExpr>(resolvedDiffFuncExpr->baseFunction);
}
+
if (!calleeDeclRefExpr)
{
visitor->getSink()->diagnose(derivativeOfAttr, Diagnostics::cannotResolveOriginalFunctionForDerivative);
return;
}
calleeDeclRef = calleeDeclRefExpr->declRef;
- if (auto calleeGenDecl = as<GenericDecl>(calleeDeclRef.getDecl()))
- {
- auto parentGenericDecl = as<GenericDecl>(funcDecl->parentDecl);
- if (!parentGenericDecl)
- {
- visitor->getSink()->diagnose(derivativeOfAttr, Diagnostics::cannotResolveOriginalFunctionForDerivative);
- return;
- }
- FunctionDeclBase* funcReturnVal = nullptr;
- List<Val*> args;
- for (auto mm : parentGenericDecl->members)
- {
- if (auto genericTypeParamDecl = as<GenericTypeParamDecl>(mm))
- {
- args.add(DeclRefType::create(visitor->getASTBuilder(), DeclRef<Decl>(genericTypeParamDecl, nullptr)));
- }
- else if (auto genericValueParamDecl = as<GenericValueParamDecl>(mm))
- {
- args.add(visitor->getASTBuilder()->getOrCreate<GenericParamIntVal>(
- genericValueParamDecl->getType(),
- genericValueParamDecl, nullptr));
- }
- }
- auto funcs = calleeGenDecl->getMembersOfType<FunctionDeclBase>();
- if (funcs.isEmpty())
- {
- visitor->getSink()->diagnose(derivativeOfAttr, Diagnostics::cannotResolveOriginalFunctionForDerivative);
- return;
- }
- funcReturnVal = funcs.getFirst();
- if (funcReturnVal)
- {
- auto subst = visitor->getASTBuilder()->getOrCreateGenericSubstitution(calleeGenDecl, args, nullptr);
- calleeDeclRef.decl = funcReturnVal;
- calleeDeclRef.substitutions = subst;
- calleeDeclRefExpr = as<DeclRefExpr>(visitor->ConstructDeclRefExpr(
- calleeDeclRef, nullptr, derivativeOfAttr->loc, nullptr));
- }
- else
- {
- calleeDeclRef = DeclRef<Decl>();
- calleeDeclRefExpr = nullptr;
- }
- }
-
+
auto calleeFunc = as<FunctionDeclBase>(calleeDeclRef.getDecl());
if (!calleeFunc)
{
@@ -4953,9 +4895,8 @@ namespace Slang
if (!attr)
return;
- List<Expr*> imaginaryArgsToOriginal = getImaginaryArgsToOriginalFuncFromForwardDerivativeFunc(m_astBuilder, funcDecl, attr->loc);
- checkDerivativeOfAttributeImpl<ForwardDerivativeAttribute>(
- this, funcDecl, attr, DeclAssociationKind::ForwardDerivativeFunc, imaginaryArgsToOriginal);
+ checkDerivativeOfAttributeImpl<ForwardDerivativeAttribute, ForwardDifferentiateExpr>(
+ this, funcDecl, attr, DeclAssociationKind::ForwardDerivativeFunc);
}
void SemanticsDeclBodyVisitor::checkBackwardDerivativeOfAttribute(FunctionDeclBase* funcDecl)
@@ -4964,33 +4905,8 @@ namespace Slang
if (!attr)
return;
- List<Expr*> imaginaryArguments = getImaginaryArgsToOriginalFuncFromBackwardDerivativeFunc(m_astBuilder, funcDecl, attr->loc);
-
- // The tricky part here is that we can't easily derive the arguments to original func just
- // from the definition of a backward derivative function, because we don't know if the last
- // parameter is just a normal parameter of the original func, or if it is the additional
- // derivative of the return value. The solution here is to try to resolve the original
- // function with or without the last argument. However if the type of the last argument
- // isn't differentiable, we know that it can't possibly be the result derivative.
-
- if (imaginaryArguments.getCount() == 0 ||
- !tryGetDifferentialType(m_astBuilder, imaginaryArguments.getLast()->type.type))
- {
- checkDerivativeOfAttributeImpl<BackwardDerivativeAttribute>(
- this, funcDecl, attr, DeclAssociationKind::BackwardDerivativeFunc, imaginaryArguments);
- return;
- }
-
- // Otherwise, try resolve with all the arguments, if failed, resolve without the last
- // argument.
- if (tryCheckDerivativeOfAttributeImpl<BackwardDerivativeAttribute>(this, funcDecl, attr, DeclAssociationKind::BackwardDerivativeFunc, imaginaryArguments))
- {
- return;
- }
-
- imaginaryArguments.removeLast();
- checkDerivativeOfAttributeImpl<BackwardDerivativeAttribute>(
- this, funcDecl, attr, DeclAssociationKind::BackwardDerivativeFunc, imaginaryArguments);
+ checkDerivativeOfAttributeImpl<BackwardDerivativeAttribute, BackwardDifferentiateExpr>(
+ this, funcDecl, attr, DeclAssociationKind::BackwardDerivativeFunc);
}
void SemanticsDeclBodyVisitor::visitFunctionDeclBase(FunctionDeclBase* decl)
diff --git a/source/slang/slang-check-expr.cpp b/source/slang/slang-check-expr.cpp
index 2803b5959..bebaa63a2 100644
--- a/source/slang/slang-check-expr.cpp
+++ b/source/slang/slang-check-expr.cpp
@@ -1971,31 +1971,26 @@ namespace Slang
if (auto higherOrderInvoke = as<DifferentiateExpr>(invoke->functionExpr))
{
- if (auto funcDeclExpr = as<DeclRefExpr>(getInnerMostExprFromHigherOrderExpr(higherOrderInvoke)))
+ FunctionDifferentiableLevel requiredLevel;
+ if (auto funcDeclExpr = as<DeclRefExpr>(
+ getInnerMostExprFromHigherOrderExpr(higherOrderInvoke, requiredLevel)))
{
auto funcDecl = as<FunctionDeclBase>(funcDeclExpr->declRef.getDecl());
if (funcDecl)
{
- DifferentiateExpr* forwardDiff = nullptr;
- DifferentiateExpr* backwardDiff = nullptr;
- for (auto node = as<DifferentiateExpr>(invoke->functionExpr); node; node = as<DifferentiateExpr>(node->baseFunction))
+ if (requiredLevel == FunctionDifferentiableLevel::Forward &&
+ !getShared()->isDifferentiableFunc(funcDecl))
{
- if (auto fwd = as<ForwardDifferentiateExpr>(node))
- forwardDiff = fwd;
- if (auto bwd = as<BackwardDifferentiateExpr>(node))
- backwardDiff = bwd;
+ getSink()->diagnose(funcDeclExpr, Diagnostics::functionNotMarkedAsDifferentiable, funcDecl, "forward");
}
- if (forwardDiff && !getShared()->isDifferentiableFunc(funcDecl))
+ if (requiredLevel == FunctionDifferentiableLevel::Backward &&
+ !getShared()->isBackwardDifferentiableFunc(funcDecl))
{
- getSink()->diagnose(forwardDiff, Diagnostics::functionNotMarkedAsDifferentiable, funcDecl, "forward");
- }
- if (backwardDiff && !getShared()->isBackwardDifferentiableFunc(funcDecl))
- {
- getSink()->diagnose(forwardDiff, Diagnostics::functionNotMarkedAsDifferentiable, funcDecl, "backward");
+ getSink()->diagnose(funcDeclExpr, Diagnostics::functionNotMarkedAsDifferentiable, funcDecl, "backward");
}
if (!isEffectivelyStatic(funcDecl) && !isGlobalDecl(funcDecl))
{
- getSink()->diagnose(forwardDiff, Diagnostics::nonStaticMemberFunctionNotAllowedAsDiffOperand, funcDecl);
+ getSink()->diagnose(invoke->functionExpr, Diagnostics::nonStaticMemberFunctionNotAllowedAsDiffOperand, funcDecl);
}
}
}
diff --git a/source/slang/slang-check-impl.h b/source/slang/slang-check-impl.h
index 11aacd255..fc1b622cc 100644
--- a/source/slang/slang-check-impl.h
+++ b/source/slang/slang-check-impl.h
@@ -11,12 +11,6 @@
namespace Slang
{
- enum class FunctionDifferentiableLevel
- {
- None,
- Forward,
- Backward
- };
/// Should the given `decl` be treated as a static rather than instance declaration?
bool isEffectivelyStatic(
Decl* decl);
diff --git a/source/slang/slang-check-modifier.cpp b/source/slang/slang-check-modifier.cpp
index 520d85971..e6a524645 100644
--- a/source/slang/slang-check-modifier.cpp
+++ b/source/slang/slang-check-modifier.cpp
@@ -680,16 +680,7 @@ namespace Slang
{
SLANG_ASSERT(attr->args.getCount() == 1);
SLANG_ASSERT(as<Decl>(attrTarget));
-
- // Ensure that the argument is a reference to a function definition or declaration.
- auto primalFunc = CheckTerm(attr->args[0]);
- if (primalFunc->type == getASTBuilder()->getErrorType())
- {
- // Could not resolve the term.
- getSink()->diagnose(primalFunc, Slang::Diagnostics::invalidCustomDerivative, as<Decl>(attrTarget));
- return false;
- }
- derivativeOfAttr->funcExpr = primalFunc;
+ derivativeOfAttr->funcExpr = attr->args[0];
}
else if (auto comInterfaceAttr = as<ComInterfaceAttribute>(attr))
{
diff --git a/source/slang/slang-check-overload.cpp b/source/slang/slang-check-overload.cpp
index 09ddde2de..91af731ad 100644
--- a/source/slang/slang-check-overload.cpp
+++ b/source/slang/slang-check-overload.cpp
@@ -1622,10 +1622,10 @@ namespace Slang
else
{
// Unhandled case for the inner expr.
- funcExpr->type = this->getASTBuilder()->getErrorType();
getSink()->diagnose(funcExpr->loc,
Diagnostics::expectedFunction,
funcExpr->type);
+ funcExpr->type = this->getASTBuilder()->getErrorType();
}
}
diff --git a/tests/autodiff-dstdlib/dstdlib-elementary-bwd.slang b/tests/autodiff-dstdlib/dstdlib-elementary-bwd.slang
index 379e2c3ef..53972ac2c 100644
--- a/tests/autodiff-dstdlib/dstdlib-elementary-bwd.slang
+++ b/tests/autodiff-dstdlib/dstdlib-elementary-bwd.slang
@@ -2,7 +2,7 @@
//TEST(compute):COMPARE_COMPUTE_EX:-slang -compute -shaderobj -output-using-type
//TEST(compute):COMPARE_COMPUTE_EX:-cpu -compute -output-using-type -shaderobj
-//TEST_INPUT:ubuffer(data=[0 0 0 0 0 0], stride=4):out,name=outputBuffer
+//TEST_INPUT:ubuffer(data=[0 0 0 0 0 0 0 0], stride=4):out,name=outputBuffer
RWStructuredBuffer<float> outputBuffer;
typedef DifferentialPair<float> dpfloat;
@@ -43,4 +43,14 @@ void computeMain(uint3 dispatchThreadID: SV_DispatchThreadID)
__bwd_diff(diffSin)(dpx, 1.0);
outputBuffer[4] = dpx.d; // Expect: -1.000000
}
+
+ {
+ dpfloat dpx = dpfloat(float.getPi() / 3.0, 1.0);
+ __bwd_diff(sincos)(dpx, 1.0, 0.0);
+ outputBuffer[5] = dpx.d; // Expect: 0.5
+ __bwd_diff(sincos)(dpx, 0.0, 1.0);
+ outputBuffer[6] = dpx.d; // Expect: -0.8660254
+ __bwd_diff(sincos)(dpx, 1.0, 1.0);
+ outputBuffer[7] = dpx.d; // Expect: -0.3660254
+ }
}
diff --git a/tests/autodiff-dstdlib/dstdlib-elementary-bwd.slang.expected.txt b/tests/autodiff-dstdlib/dstdlib-elementary-bwd.slang.expected.txt
index a4b804cb8..17627df68 100644
--- a/tests/autodiff-dstdlib/dstdlib-elementary-bwd.slang.expected.txt
+++ b/tests/autodiff-dstdlib/dstdlib-elementary-bwd.slang.expected.txt
@@ -3,4 +3,7 @@ type: float
7.389056
0.000000
1.000000
--1.00000 \ No newline at end of file
+-1.000000
+0.500000
+-0.866025
+-0.366025
diff --git a/tests/autodiff-dstdlib/dstdlib-sqrt.slang b/tests/autodiff-dstdlib/dstdlib-sqrt.slang
index 15573c4ef..d68a2697c 100644
--- a/tests/autodiff-dstdlib/dstdlib-sqrt.slang
+++ b/tests/autodiff-dstdlib/dstdlib-sqrt.slang
@@ -1,7 +1,7 @@
//TEST(compute, vulkan):COMPARE_COMPUTE_EX:-vk -compute -shaderobj -output-using-type
//TEST(compute):COMPARE_COMPUTE_EX:-slang -compute -shaderobj -output-using-type
-//TEST_INPUT:ubuffer(data=[0 0 0 0 0 0 0 0], stride=4):out,name=outputBuffer
+//TEST_INPUT:ubuffer(data=[0 0 0 0 0 0 0 0 0 0 0 0], stride=4):out,name=outputBuffer
RWStructuredBuffer<float> outputBuffer;
typedef DifferentialPair<float> dpfloat;
@@ -50,4 +50,13 @@ void computeMain(uint3 dispatchThreadID: SV_DispatchThreadID)
outputBuffer[6] = dpx.d[0]; // Expect: 0.158114
outputBuffer[7] = dpx.d[1]; // Expect: 0.577350
}
+
+ {
+ var dpx = diffPair(float2x2(4.0, 9.0, 16.0, 25.0), float2x2(0.0, 0.0, 0.0, 0.0));
+ __bwd_diff(sqrt)(dpx, float2x2(1.0, 2.0, 3.0, 4.0));
+ outputBuffer[8] = dpx.d[0][0]; // Expect: 0.25
+ outputBuffer[9] = dpx.d[0][1]; // Expect: 0.3333
+ outputBuffer[10] = dpx.d[1][0]; // Expect: 0.375
+ outputBuffer[11] = dpx.d[1][1]; // Expect: 0.4
+ }
}
diff --git a/tests/autodiff-dstdlib/dstdlib-sqrt.slang.expected.txt b/tests/autodiff-dstdlib/dstdlib-sqrt.slang.expected.txt
index fe6487fef..7e0fdf02f 100644
--- a/tests/autodiff-dstdlib/dstdlib-sqrt.slang.expected.txt
+++ b/tests/autodiff-dstdlib/dstdlib-sqrt.slang.expected.txt
@@ -6,4 +6,8 @@ type: float
0.000000
0.000000
0.158114
-0.577350 \ No newline at end of file
+0.577350
+0.250000
+0.333333
+0.375000
+0.400000 \ No newline at end of file