From 4b350645042b8e8fbdad19784ee745d11c7bc616 Mon Sep 17 00:00:00 2001 From: kaizhangNV <149626564+kaizhangNV@users.noreply.github.com> Date: Wed, 5 Feb 2025 12:12:49 -0600 Subject: Fix autodiff issue for vector (#6275) * Fix autodiff issue for vector Close #6154 We didn't implement correctly for vector regarding the differentiablity. As we check differentiable before specialization, however according to the definition of vector, it has to be specialized to IFloat to know it's conformed to IDifferential type. Therefore for parameter type vector will become no_diff. Therefore, we change the implementation a to make it explicit conform to IDifferential type. * fix typo --- source/slang/core.meta.slang | 26 ++++++++++++++++++++------ 1 file changed, 20 insertions(+), 6 deletions(-) (limited to 'source') diff --git a/source/slang/core.meta.slang b/source/slang/core.meta.slang index 36e9d6885..19421735c 100644 --- a/source/slang/core.meta.slang +++ b/source/slang/core.meta.slang @@ -1988,8 +1988,18 @@ extension vector : IFloat [OverloadRank(-1)] [__unsafeForceInlineEarly] __init(float v) { this = vector(T(v)); } - // IDifferentiable +} + +__intrinsic_op($(kIROp_Add)) +T __internal_add(T a, T b); +__intrinsic_op($(kIROp_Mul)) +T __internal_mul(U a, T b); + +__generic +extension vector : IDifferentiable +{ + // IDifferentiable typedef vector Differential; [__unsafeForceInlineEarly] @@ -2003,7 +2013,7 @@ extension vector : IFloat [BackwardDifferentiable] static Differential dadd(Differential a, Differential b) { - return a + b; + return __internal_add(a, b); } __generic @@ -2011,7 +2021,7 @@ extension vector : IFloat [BackwardDifferentiable] static Differential dmul(U a, Differential b) { - return __realCast(a) * b; + return __internal_mul(__realCast(a), b); } } @@ -2042,7 +2052,11 @@ extension matrix : IFloat [__unsafeForceInlineEarly] __implicit_conversion($(kConversionCost_ScalarToMatrix)) __init(float v) { this = matrix(T(v)); } +} +__generic +extension matrix : IDifferentiable +{ // IDifferentiable. typedef matrix Differential; @@ -2057,15 +2071,15 @@ extension matrix : IFloat [BackwardDifferentiable] static Differential dadd(Differential a, Differential b) { - return a + b; + return __internal_add(a, b); } - + __generic [__unsafeForceInlineEarly] [BackwardDifferentiable] static Differential dmul(U a, Differential b) { - return __realCast(a) * b; + return __internal_mul(__realCast(a), b); } } -- cgit v1.2.3