summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorJay Kwak <82421531+jkwak-work@users.noreply.github.com>2024-05-02 11:56:13 -0700
committerGitHub <noreply@github.com>2024-05-02 11:56:13 -0700
commitf7d54af67e026feb2546af1deaf2513a36f8516e (patch)
treec751573fbfceb114eeac6217bb3ccc4c13f19ef2
parent679a457940027420817a85070b3fdb9bfc0cca2e (diff)
Fix fmod behavior targetting GLSL and SPIR-V (#4080)
* Fix fmod behavior targetting GLSL and SPIR-V The default implementation of fmod was doing "Modulo" operation when "fmod" in HLSL should do "remainder" operation. * Fix a mistake in `fmod` GLSL target When using __intrinsic_asm, the "if" logic wasn't emitted. "__intrinsic_asm" had to be called from a new function and `fmod` had to call it. Alternatively, I am using `operator?()` to workaround. A similar modification is made to `roundEven()` hoping for a better performance.
-rw-r--r--source/slang/glsl.meta.slang58
-rw-r--r--source/slang/hlsl.meta.slang55
-rw-r--r--tests/cross-compile/fmod.slang3
-rw-r--r--tests/glsl-intrinsic/intrinsic-basic.slang25
4 files changed, 113 insertions, 28 deletions
diff --git a/source/slang/glsl.meta.slang b/source/slang/glsl.meta.slang
index 0ba6c17aa..9715a44ce 100644
--- a/source/slang/glsl.meta.slang
+++ b/source/slang/glsl.meta.slang
@@ -461,16 +461,28 @@ __generic<T : __BuiltinFloatingPointType>
[require(cpp_cuda_glsl_hlsl_spirv, GLSL_130)]
public T roundEven(T x)
{
- T i;
- if (T(0.5) <= fmod(x, i))
+ __target_switch
{
- bool evenInteger = (fmod(i, T(2)) == T(0));
- if (!evenInteger)
+ case glsl: __intrinsic_asm "roundEven";
+ case spirv: return spirv_asm {
+ OpExtInst $$T result glsl450 RoundEven $x
+ };
+ default:
+ T nearest = round(x);
+
+ // Check if the value is exactly halfway between two integers
+ if (abs(x - nearest) == T(0.5))
{
- x += T(0.1);
+ // If halfway, choose the even number
+ if (mod(nearest, T(2)) != T(0))
+ {
+ // If the nearest number is odd,
+ // move to the closest even number
+ nearest -= ((x < nearest) ? T(1) : T(-1));
+ }
}
+ return nearest;
}
- return round(x);
}
__generic<T : __BuiltinFloatingPointType, let N:int>
@@ -479,7 +491,15 @@ __generic<T : __BuiltinFloatingPointType, let N:int>
[require(cpp_cuda_glsl_hlsl_spirv, GLSL_130)]
public vector<T,N> roundEven(vector<T,N> x)
{
- VECTOR_MAP_UNARY(T, N, roundEven, x);
+ __target_switch
+ {
+ case glsl: __intrinsic_asm "roundEven";
+ case spirv: return spirv_asm {
+ OpExtInst $$vector<T,N> result glsl450 RoundEven $x
+ };
+ default:
+ VECTOR_MAP_UNARY(T, N, roundEven, x);
+ }
}
__generic<T : __BuiltinFloatingPointType>
@@ -506,7 +526,15 @@ __generic<T : __BuiltinFloatingPointType>
[require(cpp_cuda_glsl_hlsl_spirv, sm_2_0_GLSL_140)]
public T mod(T x, T y)
{
- return fmod(x, y);
+ // SPIR-V doesn't have "modulus".
+ // All of Op?Mod and OpFRem are "remainder".
+
+ __target_switch
+ {
+ case glsl: __intrinsic_asm "mod";
+ default:
+ return x - y * floor(x / y);
+ }
}
__generic<T : __BuiltinFloatingPointType, let N:int>
@@ -515,7 +543,12 @@ __generic<T : __BuiltinFloatingPointType, let N:int>
[require(cpp_cuda_glsl_hlsl_spirv, sm_2_0_GLSL_140)]
public vector<T, N> mod(vector<T, N> x, T y)
{
- return fmod(x, vector<T, N>(y));
+ __target_switch
+ {
+ case glsl: __intrinsic_asm "mod";
+ default:
+ return x - y * floor(x / y);
+ }
}
__generic<T : __BuiltinFloatingPointType, let N:int>
@@ -524,7 +557,12 @@ __generic<T : __BuiltinFloatingPointType, let N:int>
[require(cpp_cuda_glsl_hlsl_spirv, sm_2_0_GLSL_140)]
public vector<T, N> mod(vector<T, N> x, vector<T, N> y)
{
- return fmod(x, y);
+ __target_switch
+ {
+ case glsl: __intrinsic_asm "mod";
+ default:
+ return x - y * floor(x / y);
+ }
}
__generic<T : __BuiltinFloatingPointType, let N : int>
diff --git a/source/slang/hlsl.meta.slang b/source/slang/hlsl.meta.slang
index 7cafe764f..ae81289d1 100644
--- a/source/slang/hlsl.meta.slang
+++ b/source/slang/hlsl.meta.slang
@@ -6504,13 +6504,61 @@ __generic<T : __BuiltinFloatingPointType>
[require(cpp_cuda_glsl_hlsl_spirv, sm_2_0_GLSL_140)]
T fmod(T x, T y)
{
+ // In HLSL, fmod returns a remainder.
+ // Definition of `fmod` in HLSL is,
+ // "The floating-point remainder is calculated such that x = i * y + f,
+ // where i is an integer, f has the same sign as x, and the absolute value
+ // of f is less than the absolute value of y."
+ //
+ // In GLSL, mod is a Modulus function.
+ // OpenGL document defines "Modulus" as "Returns x - y * floor(x / y)".
+ // The use of "Floor()" makes the difference.
+ //
+ // The tricky ones are when x or y is a negative value.
+ //
+ // | Remainder | Modulus
+ // x y | x= i*y +f | x-y*floor(x/y)
+ // ------+-----------+------------------------------
+ // 4 3 | 4= 1*3 +1 | 4-3*floor( 4/3) = 4-3* 1 = 1
+ // 3 3 | 3= 1*3 +0 | 3-3*floor( 3/3) = 3-3* 1 = 0
+ // 2 3 | 2= 0*3 +2 | 2-3*floor( 2/3) = 2-3* 0 = 2
+ // 1 3 | 1= 0*3 +1 | 1-3*floor( 1/3) = 1-3* 0 = 1
+ // 0 3 | 0= 0*3 +0 | 0-3*floor( 0/3) = 0-3* 0 = 0
+ // -1 3 |-1= 0*3 -1 |-1-3*floor(-1/3) =-1-3*-1 = 2
+ // -2 3 |-2= 0*3 -2 |-2-3*floor(-2/3) =-2-3*-1 = 1
+ // -3 3 |-3=-1*3 0 |-3-3*floor(-3/3) =-3-3*-1 = 0
+ // -4 3 |-4=-1*3 -1 |-4-3*floor(-4/3) =-4-3*-2 = 2
+ //
+ // When y is a negative value,
+ //
+ // | Remainder | Modulus
+ // x y | x= i*y +f | x-y*floor(x/y)
+ // ------+-----------+------------------------------
+ // 4 -3 | 4=-1*-3+1 | 4+3*floor( 4/-3) = 4+3*-2 =-2
+ // 3 -3 | 3=-1*-3+0 | 3+3*floor( 3/-3) = 3+3*-1 = 0
+ // 2 -3 | 2= 0*-3+2 | 2+3*floor( 2/-3) = 2+3*-1 =-1
+ // 1 -3 | 1= 0*-3+1 | 1+3*floor( 1/-3) = 1+3*-1 =-2
+ // 0 -3 | 0= 0*-3+0 | 0+3*floor( 0/-3) = 0+3* 0 = 0
+ // -1 -3 |-1= 0*-3-1 |-1+3*floor(-1/-3) =-1+3* 0 =-1
+ // -2 -3 |-2= 0*-3-2 |-2+3*floor(-2/-3) =-2+3* 0 =-2
+ // -3 -3 |-3= 1*-3 0 |-3+3*floor(-3/-3) =-3+3* 1 = 0
+ // -4 -3 |-4= 1*-3-1 |-4+3*floor(-4/-3) =-4+3* 1 =-1
+
__target_switch
{
case cpp: __intrinsic_asm "$P_fmod($0, $1)";
case cuda: __intrinsic_asm "$P_fmod($0, $1)";
case hlsl: __intrinsic_asm "fmod";
- default:
- return x - y * trunc(x/y);
+ case glsl:
+ // GLSL doesn't have a function for remainder.
+ __intrinsic_asm "(($0 < 0) ? -mod(-$0,abs($1)) : mod($0,abs($1)))";
+ case spirv:
+ // OpFRem return "The floating-point remainder whose sign
+ // matches the sign of Operand 1", where Operand 1 is "x".
+ return spirv_asm
+ {
+ result:$$T = OpFRem $x $y
+ };
}
}
@@ -6522,6 +6570,9 @@ vector<T, N> fmod(vector<T, N> x, vector<T, N> y)
__target_switch
{
case hlsl: __intrinsic_asm "fmod";
+ case spirv: return spirv_asm {
+ result:$$vector<T,N> = OpFRem $x $y
+ };
default:
VECTOR_MAP_BINARY(T, N, fmod, x, y);
}
diff --git a/tests/cross-compile/fmod.slang b/tests/cross-compile/fmod.slang
index 6efff35b3..16ecf072a 100644
--- a/tests/cross-compile/fmod.slang
+++ b/tests/cross-compile/fmod.slang
@@ -4,7 +4,8 @@
// expected output on Vulkan/GLSL.
//TEST(compute):COMPARE_COMPUTE:-dx11 -compute -shaderobj
-//TEST(compute):COMPARE_COMPUTE:-vk -compute -shaderobj
+//TEST(compute):COMPARE_COMPUTE:-vk -compute -shaderobj -emit-spirv-via-glsl
+//TEST(compute):COMPARE_COMPUTE:-vk -compute -shaderobj -emit-spirv-directly
//TEST_INPUT:cbuffer(data=[4 0 0 0]):name=C
cbuffer C
diff --git a/tests/glsl-intrinsic/intrinsic-basic.slang b/tests/glsl-intrinsic/intrinsic-basic.slang
index 42d416c1e..fbe29a000 100644
--- a/tests/glsl-intrinsic/intrinsic-basic.slang
+++ b/tests/glsl-intrinsic/intrinsic-basic.slang
@@ -207,12 +207,10 @@ bool Test_ScalarType()
&& genFType(0) == round(genFType(zero))
&& genDType(0) == round(genDType(zero))
-#if 0
- // C-HECK_GLSL-COUNT-2: roundEven(
- // C-HECK_SPIR-COUNT-2: RoundEven{{ }}
+ // CHECK_GLSL-COUNT-2: roundEven(
+ // CHECK_SPIR-COUNT-2: RoundEven{{ }}
&& genFType(0) == roundEven(genFType(zero))
&& genDType(0) == roundEven(genDType(zero))
-#endif
// CHECK_GLSL-COUNT-2: ceil(
// CHECK_SPIR-COUNT-2: Ceil{{ }}
@@ -224,11 +222,10 @@ bool Test_ScalarType()
&& genFType(0) == fract(genFType(zero))
&& genDType(0) == fract(genDType(zero))
-#if 0
- // C-HECK_GLSL-COUNT-2: mod(
+ // CHECK_GLSL-COUNT-2: mod(
+ // CHECK_SPIR-COUNT-2: Floor{{ }}
&& genFType(0) == mod(genFType(zero), genFType(one))
&& genDType(0) == mod(genDType(zero), genDType(one))
-#endif
// CHECK_GLSL-COUNT-2: modf(
// CHECK_SPIR-COUNT-2: Modf{{ }}
@@ -733,13 +730,11 @@ bool Test_VectorType()
&& genFType(0) == round(genFType(zero))
&& genDType(0) == round(genDType(zero))
-#if 0
- // C-HECK_GLSL-COUNT-2: roundEven(
- // C-HECK_SPIR-COUNT-2: RoundEven{{ }}
- // C-HECK_SPIR-NOT: RoundEven{{ }}
+ // CHECK_GLSL-COUNT-2: roundEven(
+ // CHECK_SPIR-COUNT-2: RoundEven{{ }}
+ // CHECK_SPIR-NOT: RoundEven{{ }}
&& genFType(0) == roundEven(genFType(zero))
&& genDType(0) == roundEven(genDType(zero))
-#endif
// CHECK_GLSL-COUNT-2: ceil(
// CHECK_SPIR-COUNT-2: Ceil{{ }}
@@ -753,13 +748,13 @@ bool Test_VectorType()
&& genFType(0) == fract(genFType(zero))
&& genDType(0) == fract(genDType(zero))
-#if 0
- // C-HECK_GLSL-COUNT-4: mod(
+ // CHECK_GLSL-COUNT-4: mod(
+ // CHECK_SPIR-COUNT-4: Floor{{ }}
+ // CHECK_SPIR-NOT: Floor{{ }}
&& genFType(0) == mod(genFType(zero), float(one))
&& genFType(0) == mod(genFType(zero), genFType(one))
&& genDType(0) == mod(genDType(zero), double(one))
&& genDType(0) == mod(genDType(zero), genDType(one))
-#endif
// CHECK_GLSL-COUNT-2: modf(
// CHECK_SPIR-COUNT-2: Modf{{ }}