From 1817e9530989072ac34ff16d11fcc570f7862998 Mon Sep 17 00:00:00 2001 From: Sai Praveen Bangaru <31557731+saipraveenb25@users.noreply.github.com> Date: Tue, 1 Nov 2022 12:55:36 -0400 Subject: Added a vector intrinsic definition for exp (to serve as template for other vector intrinsics) (#2481) * Added vector exp definition * Naming --- source/slang/diff.meta.slang | 82 ++++++++++++++++++-------------------------- 1 file changed, 33 insertions(+), 49 deletions(-) (limited to 'source') diff --git a/source/slang/diff.meta.slang b/source/slang/diff.meta.slang index 674531048..1c3066e1d 100644 --- a/source/slang/diff.meta.slang +++ b/source/slang/diff.meta.slang @@ -59,60 +59,15 @@ extension float : IDifferentiable } } -extension vector : IDifferentiable +__generic +extension vector : IDifferentiable { - typedef vector Differential; + typedef vector Differential; [__unsafeForceInlineEarly] static Differential dzero() { - return vector(0.f); - } - - [__unsafeForceInlineEarly] - static Differential dadd(Differential a, Differential b) - { - return a + b; - } - - [__unsafeForceInlineEarly] - static Differential dmul(This a, Differential b) - { - return a * b; - } -} - -extension vector : IDifferentiable -{ - typedef vector Differential; - - [__unsafeForceInlineEarly] - static Differential dzero() - { - return vector(0.f); - } - - [__unsafeForceInlineEarly] - static Differential dadd(Differential a, Differential b) - { - return a + b; - } - - [__unsafeForceInlineEarly] - static Differential dmul(This a, Differential b) - { - return a * b; - } -} - -extension vector : IDifferentiable -{ - typedef vector Differential; - - [__unsafeForceInlineEarly] - static Differential dzero() - { - return vector(0.f); + return vector(0.f); } [__unsafeForceInlineEarly] @@ -207,6 +162,9 @@ struct DifferentialPair : IDifferentiable typealias IDFloat = IFloat & IDifferentiable; +#define VECTOR_MAP_UNARY(TYPE, COUNT, FUNC, VALUE) \ + vector result; for(int i = 0; i < COUNT; ++i) { result[i] = FUNC(VALUE[i]); } return result + namespace dstd { // Natural Exponent @@ -262,4 +220,30 @@ namespace dstd cos(dpx.p()), T.dmul(-sin(dpx.p()), dpx.d())); } + + __generic + __target_intrinsic(hlsl) + __target_intrinsic(glsl) + __target_intrinsic(spirv_direct, "12 resultType resultId glsl450 27 _0") + [ForwardDerivative(d_exp_vector)] + vector exp(vector x) + { + VECTOR_MAP_UNARY(float, N, dstd.exp, x); + } + + __generic + DifferentialPair> d_exp_vector(DifferentialPair> dpx) + { + vector result; + vector.Differential d_result; + for(int i = 0; i < N; ++i) + { + DifferentialPair dpexp = dstd.d_exp(DifferentialPair(dpx.p()[i], dpx.d()[i])); + result[i] = dpexp.p(); + d_result[i] = dpexp.d(); + } + + return DifferentialPair>(result, d_result); + } + }; -- cgit v1.2.3