From 86fc50c5092fbccf6072dcf7bbdfafb8915f02c8 Mon Sep 17 00:00:00 2001 From: Yong He Date: Wed, 8 Mar 2023 21:52:34 -0800 Subject: Add support for `[PrimalSubstitute]` and `[PrimalSubstituteOf]`. (#2691) * Add support for `[PrimalSubstitute]` and `[PrimalSubstituteOf]`. * Fix * Fix. * Cleanup. --------- Co-authored-by: Yong He --- source/slang/diff.meta.slang | 85 ++++++-------------------------------------- 1 file changed, 11 insertions(+), 74 deletions(-) (limited to 'source/slang/diff.meta.slang') diff --git a/source/slang/diff.meta.slang b/source/slang/diff.meta.slang index 54f927816..4301eda94 100644 --- a/source/slang/diff.meta.slang +++ b/source/slang/diff.meta.slang @@ -11,6 +11,9 @@ attribute_syntax [ForwardDerivative(function)] : ForwardDerivativeAttribute; __attributeTarget(FunctionDeclBase) attribute_syntax [BackwardDerivative(function)] : BackwardDerivativeAttribute; +__attributeTarget(FunctionDeclBase) +attribute_syntax [PrimalSubstitute(function)] : PrimalSubstituteAttribute; + __attributeTarget(FunctionDeclBase) attribute_syntax [BackwardDifferentiable] : BackwardDifferentiableAttribute; @@ -20,6 +23,9 @@ attribute_syntax [ForwardDerivativeOf(function)] : ForwardDerivativeOfAttribute; __attributeTarget(FunctionDeclBase) attribute_syntax [BackwardDerivativeOf(function)] : BackwardDerivativeOfAttribute; +__attributeTarget(FunctionDeclBase) +attribute_syntax [PrimalSubstituteOf(function)] : PrimalSubstituteOfAttribute; + __attributeTarget(DeclBase) attribute_syntax [DerivativeMember(memberName)] : DerivativeMemberAttribute; @@ -1037,6 +1043,7 @@ void __d_refract(inout DifferentialPair> i, inout DifferentialPair< // Sine and cosine __generic [BackwardDifferentiable] +[PrimalSubstituteOf(sincos)] void __sincos_impl(T x, out T s, out T c) { s = sin(x); @@ -1045,6 +1052,7 @@ void __sincos_impl(T x, out T s, out T c) __generic [BackwardDifferentiable] +[PrimalSubstituteOf(sincos)] void __sincos_impl(vector x, out vector s, out vector c) { s = sin(x); @@ -1053,62 +1061,18 @@ void __sincos_impl(vector x, out vector s, out vector c) __generic [BackwardDifferentiable] +[PrimalSubstituteOf(sincos)] void __sincos_impl(matrix x, out matrix s, out matrix c) { s = sin(x); c = cos(x); } -__generic -[ForwardDerivativeOf(sincos)] -[ForceInline] -void __d_sincos(DifferentialPair x, out DifferentialPair s, out DifferentialPair c) -{ - __fwd_diff(__sincos_impl)(x, s, c); -} - -__generic -[ForwardDerivativeOf(sincos)] -[ForceInline] -void __d_sincos(DifferentialPair> x, out DifferentialPair> s, out DifferentialPair> c) -{ - __fwd_diff(__sincos_impl)(x, s, c); -} - -__generic -[ForwardDerivativeOf(sincos)] -[ForceInline] -void __d_sincos(DifferentialPair> x, out DifferentialPair> s, out DifferentialPair> c) -{ - __fwd_diff(__sincos_impl)(x, s, c); -} - -__generic -[BackwardDerivativeOf(sincos)] -[ForceInline] -void __d_sincos(inout DifferentialPair x, T.Differential dS, T.Differential dC) -{ - __bwd_diff(__sincos_impl)(x, dS, dC); -} -__generic -[BackwardDerivativeOf(sincos)] -[ForceInline] -void __d_sincos(inout DifferentialPair> x, vector.Differential dS, vector.Differential dC) -{ - __bwd_diff(__sincos_impl)(x, dS, dC); -} - -__generic -[BackwardDerivativeOf(sincos)] -[ForceInline] -void __d_sincos(inout DifferentialPair> x, matrix.Differential dS, matrix.Differential dC) -{ - __bwd_diff(__sincos_impl)(x, dS, dC); -} // dst (obsolete) __generic [BackwardDifferentiable] +[PrimalSubstituteOf(dst)] vector __dst_impl(vector src0, vector src1) { vector dest; @@ -1118,25 +1082,11 @@ vector __dst_impl(vector src0, vector src1) dest.w = src1.w; ; return dest; } -__generic -[ForwardDerivativeOf(dst)] -[ForceInline] -DifferentialPair> __d_dst(DifferentialPair> src0, DifferentialPair> src1) -{ - return __fwd_diff(__dst_impl)(src0, src1); -} -__generic -[BackwardDerivativeOf(dst)] -[ForceInline] -void __d_dst(inout DifferentialPair> src0, inout DifferentialPair> src1, vector.Differential dOut) -{ - __bwd_diff(__dst_impl)(src0, src1, dOut); -} // Legacy lighting function (obsolete) -__target_intrinsic(hlsl) [__readNone] [BackwardDifferentiable] +[PrimalSubstituteOf(lit)] float4 __lit_impl(float n_dot_l, float n_dot_h, float m) { let ambient = 1.0f; @@ -1144,19 +1094,6 @@ float4 __lit_impl(float n_dot_l, float n_dot_h, float m) let specular = ((n_dot_l < 0.0f || n_dot_h < 0.0) ? 0.0 : pow(n_dot_h, m)); return float4(ambient, diffuse, specular, 1.0f); } -[ForwardDerivativeOf(lit)] -[ForceInline] -DifferentialPair __d_lit(DifferentialPair n_dot_l, DifferentialPair n_dot_h, DifferentialPair m) -{ - return __fwd_diff(__lit_impl)(n_dot_l, n_dot_h, m); -} -[BackwardDerivativeOf(lit)] -[ForceInline] -void __d_lit(inout DifferentialPair n_dot_l, inout DifferentialPair n_dot_h, inout DifferentialPair m, float4 dOut) -{ - __bwd_diff(__lit_impl)(n_dot_l, n_dot_h, m, dOut); -} - // Matrix determinant __generic [BackwardDifferentiable] -- cgit v1.2.3