summaryrefslogtreecommitdiffstats
path: root/source/slang/diff.meta.slang
diff options
context:
space:
mode:
authorYong He <yonghe@outlook.com>2022-11-08 10:07:57 -0800
committerGitHub <noreply@github.com>2022-11-08 10:07:57 -0800
commitbf67309454032b4f92d0bc9735b608e56b16882f (patch)
treea321fe7db0b49fa67608b935c1389354a020f59c /source/slang/diff.meta.slang
parentca882a1ef46a5a8bbff50e3a1a6f973e16358634 (diff)
Make `__BuiltinFloatingPointType` conform to `IDifferentiable`. (#2499)
Diffstat (limited to 'source/slang/diff.meta.slang')
-rw-r--r--source/slang/diff.meta.slang76
1 files changed, 14 insertions, 62 deletions
diff --git a/source/slang/diff.meta.slang b/source/slang/diff.meta.slang
index 2625d79b0..c95f8e1ac 100644
--- a/source/slang/diff.meta.slang
+++ b/source/slang/diff.meta.slang
@@ -9,56 +9,10 @@ attribute_syntax [ForwardDifferentiable] : ForwardDifferentiableAttribute;
__attributeTarget(FunctionDeclBase)
attribute_syntax [ForwardDerivative(function)] : ForwardDerivativeAttribute;
-// Add extensions for the standard types
-extension float : IDifferentiable
-{
- typedef float Differential;
-
- [__unsafeForceInlineEarly]
- static Differential dzero()
- {
- return float(0.f);
- }
-
- [__unsafeForceInlineEarly]
- static Differential dadd(Differential a, Differential b)
- {
- return a + b;
- }
-
- [__unsafeForceInlineEarly]
- static Differential dmul(This a, Differential b)
- {
- return a * b;
- }
-}
-
-__generic<let N:int>
-extension vector<float, N> : IDifferentiable
-{
- typedef vector<float, N> Differential;
-
- [__unsafeForceInlineEarly]
- static Differential dzero()
- {
- return vector<float, N>(0.f);
- }
-
- [__unsafeForceInlineEarly]
- static Differential dadd(Differential a, Differential b)
- {
- return a + b;
- }
- [__unsafeForceInlineEarly]
- static Differential dmul(This a, Differential b)
- {
- return a * b;
- }
-}
+/// Pair type that serves to wrap the primal and
+/// differential types of an arbitrary type T.
- /// Pair type that serves to wrap the primal and
- /// differential types of an arbitrary type T.
__generic<T : IDifferentiable>
__magic_type(DifferentialPairType)
__intrinsic_type($(kIROp_DifferentialPairType))
@@ -126,15 +80,13 @@ struct DifferentialPair : IDifferentiable
}
};
-typealias IDFloat = IFloat & IDifferentiable;
-
#define VECTOR_MAP_UNARY(TYPE, COUNT, FUNC, VALUE) \
vector<TYPE,COUNT> result; for(int i = 0; i < COUNT; ++i) { result[i] = FUNC(VALUE[i]); } return result
namespace dstd
{
// Natural Exponent
- __generic<T : IDFloat>
+ __generic<T : __BuiltinFloatingPointType>
__target_intrinsic(hlsl)
__target_intrinsic(glsl)
__target_intrinsic(cuda, "$P_exp($0)")
@@ -143,16 +95,16 @@ namespace dstd
[ForwardDerivative(d_exp<T>)]
T exp(T x);
- __generic<T : IDFloat>
+ __generic<T : __BuiltinFloatingPointType>
DifferentialPair<T> d_exp(DifferentialPair<T> dpx)
{
return DifferentialPair<T>(
- exp(dpx.p),
- T.dmul(exp(dpx.p), dpx.d));
+ dstd.exp(dpx.p),
+ T.dmul(dstd.exp(dpx.p), dpx.d));
}
// Sine
- __generic<T : IDFloat>
+ __generic<T : __BuiltinFloatingPointType>
__target_intrinsic(hlsl)
__target_intrinsic(glsl)
__target_intrinsic(cuda, "$P_sin($0)")
@@ -161,16 +113,16 @@ namespace dstd
[ForwardDerivative(d_sin<T>)]
T sin(T x);
- __generic<T : IDFloat>
+ __generic<T : __BuiltinFloatingPointType>
DifferentialPair<T> d_sin(DifferentialPair<T> dpx)
{
return DifferentialPair<T>(
- sin(dpx.p),
- T.dmul(cos(dpx.p), dpx.d));
+ dstd.sin(dpx.p),
+ T.dmul(dstd.cos(dpx.p), dpx.d));
}
// Cosine
- __generic<T : IDFloat>
+ __generic<T : __BuiltinFloatingPointType>
__target_intrinsic(hlsl)
__target_intrinsic(glsl)
__target_intrinsic(cuda, "$P_cos($0)")
@@ -179,12 +131,12 @@ namespace dstd
[ForwardDerivative(d_cos<T>)]
T cos(T x);
- __generic<T : IDFloat>
+ __generic<T : __BuiltinFloatingPointType>
DifferentialPair<T> d_cos(DifferentialPair<T> dpx)
{
return DifferentialPair<T>(
- cos(dpx.p),
- T.dmul(-sin(dpx.p), dpx.d));
+ dstd.cos(dpx.p),
+ T.dmul(-dstd.sin(dpx.p), dpx.d));
}
__generic<let N : int>