summaryrefslogtreecommitdiffstats
path: root/source/slang/diff.meta.slang
diff options
context:
space:
mode:
authorYong He <yonghe@outlook.com>2023-03-08 21:52:34 -0800
committerGitHub <noreply@github.com>2023-03-08 21:52:34 -0800
commit86fc50c5092fbccf6072dcf7bbdfafb8915f02c8 (patch)
treeb4f9eb6cb1eea88145fde0bd1f670a8803120257 /source/slang/diff.meta.slang
parent257733f328f38a763c8b0c8830ff4c0d34ec9491 (diff)
Add support for `[PrimalSubstitute]` and `[PrimalSubstituteOf]`. (#2691)
* Add support for `[PrimalSubstitute]` and `[PrimalSubstituteOf]`. * Fix * Fix. * Cleanup. --------- Co-authored-by: Yong He <yhe@nvidia.com>
Diffstat (limited to 'source/slang/diff.meta.slang')
-rw-r--r--source/slang/diff.meta.slang85
1 files changed, 11 insertions, 74 deletions
diff --git a/source/slang/diff.meta.slang b/source/slang/diff.meta.slang
index 54f927816..4301eda94 100644
--- a/source/slang/diff.meta.slang
+++ b/source/slang/diff.meta.slang
@@ -12,6 +12,9 @@ __attributeTarget(FunctionDeclBase)
attribute_syntax [BackwardDerivative(function)] : BackwardDerivativeAttribute;
__attributeTarget(FunctionDeclBase)
+attribute_syntax [PrimalSubstitute(function)] : PrimalSubstituteAttribute;
+
+__attributeTarget(FunctionDeclBase)
attribute_syntax [BackwardDifferentiable] : BackwardDifferentiableAttribute;
__attributeTarget(FunctionDeclBase)
@@ -20,6 +23,9 @@ attribute_syntax [ForwardDerivativeOf(function)] : ForwardDerivativeOfAttribute;
__attributeTarget(FunctionDeclBase)
attribute_syntax [BackwardDerivativeOf(function)] : BackwardDerivativeOfAttribute;
+__attributeTarget(FunctionDeclBase)
+attribute_syntax [PrimalSubstituteOf(function)] : PrimalSubstituteOfAttribute;
+
__attributeTarget(DeclBase)
attribute_syntax [DerivativeMember(memberName)] : DerivativeMemberAttribute;
@@ -1037,6 +1043,7 @@ void __d_refract(inout DifferentialPair<vector<T, N>> i, inout DifferentialPair<
// Sine and cosine
__generic<T : __BuiltinFloatingPointType>
[BackwardDifferentiable]
+[PrimalSubstituteOf(sincos)]
void __sincos_impl(T x, out T s, out T c)
{
s = sin(x);
@@ -1045,6 +1052,7 @@ void __sincos_impl(T x, out T s, out T c)
__generic<T : __BuiltinFloatingPointType, let N : int>
[BackwardDifferentiable]
+[PrimalSubstituteOf(sincos)]
void __sincos_impl(vector<T, N> x, out vector<T, N> s, out vector<T, N> c)
{
s = sin(x);
@@ -1053,62 +1061,18 @@ void __sincos_impl(vector<T, N> x, out vector<T, N> s, out vector<T, N> c)
__generic<T : __BuiltinFloatingPointType, let N : int, let M : int>
[BackwardDifferentiable]
+[PrimalSubstituteOf(sincos)]
void __sincos_impl(matrix<T, N, M> x, out matrix<T, N, M> s, out matrix<T, N, M> c)
{
s = sin(x);
c = cos(x);
}
-__generic<T: __BuiltinFloatingPointType>
-[ForwardDerivativeOf(sincos)]
-[ForceInline]
-void __d_sincos(DifferentialPair<T> x, out DifferentialPair<T> s, out DifferentialPair<T> c)
-{
- __fwd_diff(__sincos_impl)(x, s, c);
-}
-
-__generic<T : __BuiltinFloatingPointType, let N : int>
-[ForwardDerivativeOf(sincos)]
-[ForceInline]
-void __d_sincos(DifferentialPair<vector<T, N>> x, out DifferentialPair<vector<T, N>> s, out DifferentialPair<vector<T, N>> c)
-{
- __fwd_diff(__sincos_impl)(x, s, c);
-}
-
-__generic<T : __BuiltinFloatingPointType, let N : int, let M : int>
-[ForwardDerivativeOf(sincos)]
-[ForceInline]
-void __d_sincos(DifferentialPair<matrix<T, N, M>> x, out DifferentialPair<matrix<T, N, M>> s, out DifferentialPair<matrix<T, N, M>> c)
-{
- __fwd_diff(__sincos_impl)(x, s, c);
-}
-
-__generic<T: __BuiltinFloatingPointType>
-[BackwardDerivativeOf(sincos)]
-[ForceInline]
-void __d_sincos(inout DifferentialPair<T> x, T.Differential dS, T.Differential dC)
-{
- __bwd_diff(__sincos_impl)(x, dS, dC);
-}
-__generic<T: __BuiltinFloatingPointType, let N : int>
-[BackwardDerivativeOf(sincos)]
-[ForceInline]
-void __d_sincos(inout DifferentialPair<vector<T, N>> x, vector<T, N>.Differential dS, vector<T, N>.Differential dC)
-{
- __bwd_diff(__sincos_impl)(x, dS, dC);
-}
-
-__generic<T : __BuiltinFloatingPointType, let N : int, let M : int>
-[BackwardDerivativeOf(sincos)]
-[ForceInline]
-void __d_sincos(inout DifferentialPair<matrix<T, N, M>> x, matrix<T, N, M>.Differential dS, matrix<T, N, M>.Differential dC)
-{
- __bwd_diff(__sincos_impl)(x, dS, dC);
-}
// dst (obsolete)
__generic<T : __BuiltinFloatingPointType>
[BackwardDifferentiable]
+[PrimalSubstituteOf(dst)]
vector<T, 4> __dst_impl(vector<T, 4> src0, vector<T, 4> src1)
{
vector<T, 4> dest;
@@ -1118,25 +1082,11 @@ vector<T, 4> __dst_impl(vector<T, 4> src0, vector<T, 4> src1)
dest.w = src1.w; ;
return dest;
}
-__generic<T : __BuiltinFloatingPointType>
-[ForwardDerivativeOf(dst)]
-[ForceInline]
-DifferentialPair<vector<T, 4>> __d_dst(DifferentialPair<vector<T, 4>> src0, DifferentialPair<vector<T, 4>> src1)
-{
- return __fwd_diff(__dst_impl)(src0, src1);
-}
-__generic<T : __BuiltinFloatingPointType>
-[BackwardDerivativeOf(dst)]
-[ForceInline]
-void __d_dst(inout DifferentialPair<vector<T, 4>> src0, inout DifferentialPair<vector<T, 4>> src1, vector<T, 4>.Differential dOut)
-{
- __bwd_diff(__dst_impl)(src0, src1, dOut);
-}
// Legacy lighting function (obsolete)
-__target_intrinsic(hlsl)
[__readNone]
[BackwardDifferentiable]
+[PrimalSubstituteOf(lit)]
float4 __lit_impl(float n_dot_l, float n_dot_h, float m)
{
let ambient = 1.0f;
@@ -1144,19 +1094,6 @@ float4 __lit_impl(float n_dot_l, float n_dot_h, float m)
let specular = ((n_dot_l < 0.0f || n_dot_h < 0.0) ? 0.0 : pow(n_dot_h, m));
return float4(ambient, diffuse, specular, 1.0f);
}
-[ForwardDerivativeOf(lit)]
-[ForceInline]
-DifferentialPair<float4> __d_lit(DifferentialPair<float> n_dot_l, DifferentialPair<float> n_dot_h, DifferentialPair<float> m)
-{
- return __fwd_diff(__lit_impl)(n_dot_l, n_dot_h, m);
-}
-[BackwardDerivativeOf(lit)]
-[ForceInline]
-void __d_lit(inout DifferentialPair<float> n_dot_l, inout DifferentialPair<float> n_dot_h, inout DifferentialPair<float> m, float4 dOut)
-{
- __bwd_diff(__lit_impl)(n_dot_l, n_dot_h, m, dOut);
-}
-
// Matrix determinant
__generic<T : __BuiltinFloatingPointType, let N : int>
[BackwardDifferentiable]