From 8899c149b05def1cce626ea649012c4c974861de Mon Sep 17 00:00:00 2001 From: jsmall-nvidia Date: Mon, 2 Mar 2020 16:18:20 -0500 Subject: Additional Wave Intrinsic Support (#1252) * Test for some wave intrinsics. More wave intrinsic support on CUDA. * Use shfl_xor_sync. * Improvements around wave intrinsics. Fix built in integer types belong to __BuiltinIntegerType. * Improvements and fixes around Wave intrinsics. * Added WaveIsFirstLane test. No longer use __wavemask_lt, as appears not available as an intrinsic. * Small fixes to CUDA prelude. * Add wave-active-product test. Handle the special case for arbitray sums. * Used macro to implement CUDA wave intrinsics. --- source/core/slang-nvrtc-compiler.cpp | 3 ++ source/slang/core.meta.slang | 8 +++-- source/slang/core.meta.slang.h | 16 ++++++---- source/slang/hlsl.meta.slang | 58 ++++++++++++++++++++++++--------- source/slang/hlsl.meta.slang.h | 62 ++++++++++++++++++++++++++---------- 5 files changed, 108 insertions(+), 39 deletions(-) (limited to 'source') diff --git a/source/core/slang-nvrtc-compiler.cpp b/source/core/slang-nvrtc-compiler.cpp index f68c4dc01..27d269125 100644 --- a/source/core/slang-nvrtc-compiler.cpp +++ b/source/core/slang-nvrtc-compiler.cpp @@ -307,6 +307,9 @@ SlangResult NVRTCDownstreamCompiler::compile(const CompileOptions& options, RefP // This is arguably too much - but nvrtc does not appear to have a mechanism to switch off individual warnings. // I tried the -Xcudafe mechanism but that does not appear to work for nvrtc cmdLine.addArg("-w"); + + // + cmdLine.addArg("-arch=compute_70"); } nvrtcProgram program = nullptr; diff --git a/source/slang/core.meta.slang b/source/slang/core.meta.slang index 70bc90392..6822d304b 100644 --- a/source/slang/core.meta.slang +++ b/source/slang/core.meta.slang @@ -132,9 +132,12 @@ for (int tt = 0; tt < kBaseTypeCount; ++tt) case BaseType::Half: case BaseType::Float: case BaseType::Double: - sb << "\n , __BuiltinFloatingPointType\n"; + sb << "\n , __BuiltinFloatingPointType\n"; sb << "\n , __BuiltinRealType\n"; - ; // fall through to: + sb << "\n , __BuiltinSignedArithmeticType\n"; + sb << "\n , __BuiltinArithmeticType\n"; + sb << "\n , __BuiltinType\n"; + break; case BaseType::Int8: case BaseType::Int16: case BaseType::Int: @@ -146,6 +149,7 @@ for (int tt = 0; tt < kBaseTypeCount; ++tt) case BaseType::UInt: case BaseType::UInt64: sb << "\n , __BuiltinArithmeticType\n"; + sb << "\n , __BuiltinIntegerType\n"; ; // fall through to: case BaseType::Bool: sb << "\n , __BuiltinType\n"; diff --git a/source/slang/core.meta.slang.h b/source/slang/core.meta.slang.h index 4c8da2a9a..3ff1fd243 100644 --- a/source/slang/core.meta.slang.h +++ b/source/slang/core.meta.slang.h @@ -135,9 +135,12 @@ for (int tt = 0; tt < kBaseTypeCount; ++tt) case BaseType::Half: case BaseType::Float: case BaseType::Double: - sb << "\n , __BuiltinFloatingPointType\n"; + sb << "\n , __BuiltinFloatingPointType\n"; sb << "\n , __BuiltinRealType\n"; - ; // fall through to: + sb << "\n , __BuiltinSignedArithmeticType\n"; + sb << "\n , __BuiltinArithmeticType\n"; + sb << "\n , __BuiltinType\n"; + break; case BaseType::Int8: case BaseType::Int16: case BaseType::Int: @@ -149,6 +152,7 @@ for (int tt = 0; tt < kBaseTypeCount; ++tt) case BaseType::UInt: case BaseType::UInt64: sb << "\n , __BuiltinArithmeticType\n"; + sb << "\n , __BuiltinIntegerType\n"; ; // fall through to: case BaseType::Bool: sb << "\n , __BuiltinType\n"; @@ -195,7 +199,7 @@ for (int tt = 0; tt < kBaseTypeCount; ++tt) // TODO: should this cover the full gamut of integer types? case BaseType::Int: case BaseType::UInt: -SLANG_RAW("#line 195 \"core.meta.slang\"") +SLANG_RAW("#line 199 \"core.meta.slang\"") SLANG_RAW("\n") SLANG_RAW(" __generic\n") SLANG_RAW(" __init(T value);\n") @@ -211,7 +215,7 @@ SLANG_RAW(" __init(T value);\n") // Declare built-in pointer type // (eventually we can have the traditional syntax sugar for this) -SLANG_RAW("#line 210 \"core.meta.slang\"") +SLANG_RAW("#line 214 \"core.meta.slang\"") SLANG_RAW("\n") SLANG_RAW("\n") SLANG_RAW("__generic\n") @@ -273,7 +277,7 @@ sb << " __init(T value);\n"; sb << " __init(vector value);\n"; sb << "};\n"; -SLANG_RAW("#line 256 \"core.meta.slang\"") +SLANG_RAW("#line 260 \"core.meta.slang\"") SLANG_RAW("\n") SLANG_RAW("\n") SLANG_RAW("__generic\n") @@ -1509,7 +1513,7 @@ for (auto op : binaryOps) sb << "__intrinsic_op(" << int(op.opCode) << ") matrix<" << resultType << ",N,M> operator" << op.opName << "(" << leftQual << "matrix<" << leftType << ",N,M> left, " << rightType << " right);\n"; } } -SLANG_RAW("#line 1491 \"core.meta.slang\"") +SLANG_RAW("#line 1495 \"core.meta.slang\"") SLANG_RAW("\n") SLANG_RAW("\n") SLANG_RAW("// Specialized function\n") diff --git a/source/slang/hlsl.meta.slang b/source/slang/hlsl.meta.slang index c8aae9158..edb678ad6 100644 --- a/source/slang/hlsl.meta.slang +++ b/source/slang/hlsl.meta.slang @@ -1395,35 +1395,51 @@ __generic T QuadReadAcrossDiagonal(T localValue); __generic vector QuadReadAcrossDiagonal(vector localValue); __generic matrix QuadReadAcrossDiagonal(matrix localValue); -__generic T WaveActiveBitAnd(T expr); +__generic +__target_intrinsic(cuda, "_waveAnd(__activemask(), $0)") +T WaveActiveBitAnd(T expr); __generic vector WaveActiveBitAnd(vector expr); __generic matrix WaveActiveBitAnd(matrix expr); -__generic T WaveActiveBitOr(T expr); +__generic +__target_intrinsic(cuda, "_waveOr(__activemask(), $0)") +T WaveActiveBitOr(T expr); __generic vector WaveActiveBitOr(vector expr); __generic matrix WaveActiveBitOr(matrix expr); -__generic T WaveActiveBitXor(T expr); +__generic +__target_intrinsic(cuda, "_waveXor(__activemask(), $0)") +T WaveActiveBitXor(T expr); __generic vector WaveActiveBitXor(vector expr); __generic matrix WaveActiveBitXor(matrix expr); -__generic T WaveActiveMax(T expr); +__generic +__target_intrinsic(cuda, "_waveMax(__activemask(), $0)") +T WaveActiveMax(T expr); __generic vector WaveActiveMax(vector expr); __generic matrix WaveActiveMax(matrix expr); -__generic T WaveActiveMin(T expr); +__generic +__target_intrinsic(cuda, "_waveMin(__activemask(), $0)") +T WaveActiveMin(T expr); __generic vector WaveActiveMin(vector expr); __generic matrix WaveActiveMin(matrix expr); -__generic T WaveActiveProduct(T expr); +__generic +__target_intrinsic(cuda, "_waveProduct(__activemask(), $0)") +T WaveActiveProduct(T expr); __generic vector WaveActiveProduct(vector expr); __generic matrix WaveActiveProduct(matrix expr); -__generic T WaveActiveSum(T expr); +__generic +__target_intrinsic(cuda, "_waveSum(__activemask(), $0)") +T WaveActiveSum(T expr); __generic vector WaveActiveSum(vector expr); __generic matrix WaveActiveSum(matrix expr); -__generic bool WaveActiveAllEqual(T value); +__generic +__target_intrinsic(cuda, "_waveAllEqual(__activemask(), $0)") +bool WaveActiveAllEqual(T value); __generic vector WaveActiveAllEqual(vector value); __generic matrix WaveActiveAllEqual(matrix value); @@ -1438,7 +1454,7 @@ __generic uint4 WaveMatch(matrix T WavePrefixProduct(T expr); +// TODO(JS): We cannot calculate prefix sums using a mask of __activemask() & __lanemask_lt(), because (amongst other reasons) +// that would mean different lanes having a different mask, and they all have to have the same mask. + +__generic +T WavePrefixProduct(T expr); __generic vector WavePrefixProduct(vector expr); __generic matrix WavePrefixProduct(matrix expr); -__generic T WavePrefixSum(T expr); +__generic +T WavePrefixSum(T expr); __generic vector WavePrefixSum(vector expr); __generic matrix WavePrefixSum(matrix expr); @@ -1473,11 +1494,14 @@ __generic T WaveMultiPrefixBitOr(T expr); __generic vector WaveMultiPrefixBitOr(vector expr); __generic matrix WaveMultiPrefixBitOr(matrix expr); -__generic T WaveMultiPrefixBitXor(T expr); +__generic +T WaveMultiPrefixBitXor(T expr); __generic vector WaveMultiPrefixBitXor(vector expr); __generic matrix WaveMultiPrefixBitXor(matrix expr); +__target_intrinsic(cuda, "__popc(__ballot_sync(__activemask(), $0) & __lanemask_lt())") uint WavePrefixCountBits(bool value); + uint WaveMultiPrefixCountBits(bool value, uint4 mask); __generic T WaveMultiPrefixProduct(T value, uint4 mask); @@ -1488,11 +1512,15 @@ __generic T WaveMultiPrefixSum(T value, uint4 mask) __generic vector WaveMultiPrefixSum(vector value, uint4 mask); __generic matrix WaveMultiPrefixSum(matrix value, uint4 mask); -__generic T WaveReadLaneFirst(T expr); +__generic +__target_intrinsic(cuda, "_waveReadFirst($0)") +T WaveReadLaneFirst(T expr); __generic vector WaveReadLaneFirst(vector expr); __generic matrix WaveReadLaneFirst(matrix expr); -__generic T WaveReadLaneAt(T value, int lane); +__generic +__target_intrinsic(cuda, "__shfl_sync(SLANG_CUDA_WARP_MASK, $0, $1)") +T WaveReadLaneAt(T value, int lane); __generic vector WaveReadLaneAt(vector value, int lane); __generic matrix WaveReadLaneAt(matrix value, int lane); diff --git a/source/slang/hlsl.meta.slang.h b/source/slang/hlsl.meta.slang.h index 69349d9dc..16a3244ab 100644 --- a/source/slang/hlsl.meta.slang.h +++ b/source/slang/hlsl.meta.slang.h @@ -1471,35 +1471,51 @@ SLANG_RAW("__generic T QuadReadAcrossDiagonal(T localValue);\ SLANG_RAW("__generic vector QuadReadAcrossDiagonal(vector localValue);\n") SLANG_RAW("__generic matrix QuadReadAcrossDiagonal(matrix localValue);\n") SLANG_RAW("\n") -SLANG_RAW("__generic T WaveActiveBitAnd(T expr);\n") +SLANG_RAW("__generic\n") +SLANG_RAW("__target_intrinsic(cuda, \"_waveAnd(__activemask(), $0)\")\n") +SLANG_RAW("T WaveActiveBitAnd(T expr);\n") SLANG_RAW("__generic vector WaveActiveBitAnd(vector expr);\n") SLANG_RAW("__generic matrix WaveActiveBitAnd(matrix expr);\n") SLANG_RAW("\n") -SLANG_RAW("__generic T WaveActiveBitOr(T expr);\n") +SLANG_RAW("__generic\n") +SLANG_RAW("__target_intrinsic(cuda, \"_waveOr(__activemask(), $0)\")\n") +SLANG_RAW("T WaveActiveBitOr(T expr);\n") SLANG_RAW("__generic vector WaveActiveBitOr(vector expr);\n") SLANG_RAW("__generic matrix WaveActiveBitOr(matrix expr);\n") SLANG_RAW("\n") -SLANG_RAW("__generic T WaveActiveBitXor(T expr);\n") +SLANG_RAW("__generic\n") +SLANG_RAW("__target_intrinsic(cuda, \"_waveXor(__activemask(), $0)\")\n") +SLANG_RAW("T WaveActiveBitXor(T expr);\n") SLANG_RAW("__generic vector WaveActiveBitXor(vector expr);\n") SLANG_RAW("__generic matrix WaveActiveBitXor(matrix expr);\n") SLANG_RAW("\n") -SLANG_RAW("__generic T WaveActiveMax(T expr);\n") +SLANG_RAW("__generic\n") +SLANG_RAW("__target_intrinsic(cuda, \"_waveMax(__activemask(), $0)\")\n") +SLANG_RAW("T WaveActiveMax(T expr);\n") SLANG_RAW("__generic vector WaveActiveMax(vector expr);\n") SLANG_RAW("__generic matrix WaveActiveMax(matrix expr);\n") SLANG_RAW("\n") -SLANG_RAW("__generic T WaveActiveMin(T expr);\n") +SLANG_RAW("__generic\n") +SLANG_RAW("__target_intrinsic(cuda, \"_waveMin(__activemask(), $0)\")\n") +SLANG_RAW("T WaveActiveMin(T expr);\n") SLANG_RAW("__generic vector WaveActiveMin(vector expr);\n") SLANG_RAW("__generic matrix WaveActiveMin(matrix expr);\n") SLANG_RAW("\n") -SLANG_RAW("__generic T WaveActiveProduct(T expr);\n") +SLANG_RAW("__generic\n") +SLANG_RAW("__target_intrinsic(cuda, \"_waveProduct(__activemask(), $0)\")\n") +SLANG_RAW("T WaveActiveProduct(T expr);\n") SLANG_RAW("__generic vector WaveActiveProduct(vector expr);\n") SLANG_RAW("__generic matrix WaveActiveProduct(matrix expr);\n") SLANG_RAW("\n") -SLANG_RAW("__generic T WaveActiveSum(T expr);\n") +SLANG_RAW("__generic\n") +SLANG_RAW("__target_intrinsic(cuda, \"_waveSum(__activemask(), $0)\")\n") +SLANG_RAW("T WaveActiveSum(T expr);\n") SLANG_RAW("__generic vector WaveActiveSum(vector expr);\n") SLANG_RAW("__generic matrix WaveActiveSum(matrix expr);\n") SLANG_RAW("\n") -SLANG_RAW("__generic bool WaveActiveAllEqual(T value);\n") +SLANG_RAW("__generic\n") +SLANG_RAW("__target_intrinsic(cuda, \"_waveAllEqual(__activemask(), $0)\")\n") +SLANG_RAW("bool WaveActiveAllEqual(T value);\n") SLANG_RAW("__generic vector WaveActiveAllEqual(vector value);\n") SLANG_RAW("__generic matrix WaveActiveAllEqual(matrix value);\n") SLANG_RAW("\n") @@ -1514,7 +1530,7 @@ SLANG_RAW("// seems to be appropriate.\n") SLANG_RAW("\n") SLANG_RAW("__target_intrinsic(cuda, \"(__all_sync(__activemask(), $0) != 0)\") \n") SLANG_RAW("bool WaveActiveAllTrue(bool condition);\n") -SLANG_RAW("__target_intrinsic(cuda, \"(_any_sync(__activemask(), $0) != 0)\")\n") +SLANG_RAW("__target_intrinsic(cuda, \"(__any_sync(__activemask(), $0) != 0)\")\n") SLANG_RAW("bool WaveActiveAnyTrue(bool condition);\n") SLANG_RAW("\n") SLANG_RAW("__target_intrinsic(cuda, \"make_uint4(__ballot_sync(__activemask(), $0), 0, 0, 0)\")\n") @@ -1530,14 +1546,19 @@ SLANG_RAW("__target_intrinsic(cuda, \"_getLaneId()\")\n") SLANG_RAW("uint WaveGetLaneIndex();\n") SLANG_RAW("\n") SLANG_RAW("// If there are no *active* lanes less than this one, we must be the lowest lane\n") -SLANG_RAW("__target_intrinsic(cuda, \"((__activemask() & __lanemask_lt()) == 0)\")\n") +SLANG_RAW("__target_intrinsic(cuda, \"_waveIsFirstLane()\")\n") SLANG_RAW("bool WaveIsFirstLane();\n") SLANG_RAW("\n") -SLANG_RAW("__generic T WavePrefixProduct(T expr);\n") +SLANG_RAW("// TODO(JS): We cannot calculate prefix sums using a mask of __activemask() & __lanemask_lt(), because (amongst other reasons)\n") +SLANG_RAW("// that would mean different lanes having a different mask, and they all have to have the same mask.\n") +SLANG_RAW("\n") +SLANG_RAW("__generic\n") +SLANG_RAW("T WavePrefixProduct(T expr);\n") SLANG_RAW("__generic vector WavePrefixProduct(vector expr);\n") SLANG_RAW("__generic matrix WavePrefixProduct(matrix expr);\n") SLANG_RAW("\n") -SLANG_RAW("__generic T WavePrefixSum(T expr);\n") +SLANG_RAW("__generic\n") +SLANG_RAW("T WavePrefixSum(T expr);\n") SLANG_RAW("__generic vector WavePrefixSum(vector expr);\n") SLANG_RAW("__generic matrix WavePrefixSum(matrix expr);\n") SLANG_RAW("\n") @@ -1549,11 +1570,14 @@ SLANG_RAW("__generic T WaveMultiPrefixBitOr(T expr) SLANG_RAW("__generic vector WaveMultiPrefixBitOr(vector expr);\n") SLANG_RAW("__generic matrix WaveMultiPrefixBitOr(matrix expr);\n") SLANG_RAW("\n") -SLANG_RAW("__generic T WaveMultiPrefixBitXor(T expr);\n") +SLANG_RAW("__generic\n") +SLANG_RAW("T WaveMultiPrefixBitXor(T expr);\n") SLANG_RAW("__generic vector WaveMultiPrefixBitXor(vector expr);\n") SLANG_RAW("__generic matrix WaveMultiPrefixBitXor(matrix expr);\n") SLANG_RAW("\n") +SLANG_RAW("__target_intrinsic(cuda, \"__popc(__ballot_sync(__activemask(), $0) & __lanemask_lt())\")\n") SLANG_RAW("uint WavePrefixCountBits(bool value);\n") +SLANG_RAW("\n") SLANG_RAW("uint WaveMultiPrefixCountBits(bool value, uint4 mask);\n") SLANG_RAW("\n") SLANG_RAW("__generic T WaveMultiPrefixProduct(T value, uint4 mask);\n") @@ -1564,11 +1588,15 @@ SLANG_RAW("__generic T WaveMultiPrefixSum(T value, SLANG_RAW("__generic vector WaveMultiPrefixSum(vector value, uint4 mask);\n") SLANG_RAW("__generic matrix WaveMultiPrefixSum(matrix value, uint4 mask);\n") SLANG_RAW("\n") -SLANG_RAW("__generic T WaveReadLaneFirst(T expr);\n") +SLANG_RAW("__generic\n") +SLANG_RAW("__target_intrinsic(cuda, \"_waveReadFirst($0)\")\n") +SLANG_RAW("T WaveReadLaneFirst(T expr);\n") SLANG_RAW("__generic vector WaveReadLaneFirst(vector expr);\n") SLANG_RAW("__generic matrix WaveReadLaneFirst(matrix expr);\n") SLANG_RAW("\n") -SLANG_RAW("__generic T WaveReadLaneAt(T value, int lane);\n") +SLANG_RAW("__generic\n") +SLANG_RAW("__target_intrinsic(cuda, \"__shfl_sync(SLANG_CUDA_WARP_MASK, $0, $1)\")\n") +SLANG_RAW("T WaveReadLaneAt(T value, int lane);\n") SLANG_RAW("__generic vector WaveReadLaneAt(vector value, int lane);\n") SLANG_RAW("__generic matrix WaveReadLaneAt(matrix value, int lane);\n") SLANG_RAW("\n") @@ -1658,7 +1686,7 @@ for (int aa = 0; aa < kBaseBufferAccessLevelCount; ++aa) sb << "};\n"; } -SLANG_RAW("#line 1585 \"hlsl.meta.slang\"") +SLANG_RAW("#line 1613 \"hlsl.meta.slang\"") SLANG_RAW("\n") SLANG_RAW("\n") SLANG_RAW("\n") @@ -1682,6 +1710,8 @@ SLANG_RAW("static const RAY_FLAG RAY_FLAG_CULL_BACK_FACING_TRIANGLES = 0x1 SLANG_RAW("static const RAY_FLAG RAY_FLAG_CULL_FRONT_FACING_TRIANGLES = 0x20;\n") SLANG_RAW("static const RAY_FLAG RAY_FLAG_CULL_OPAQUE = 0x40;\n") SLANG_RAW("static const RAY_FLAG RAY_FLAG_CULL_NON_OPAQUE = 0x80;\n") +SLANG_RAW("static const RAY_FLAG RAY_FLAG_SKIP_TRIANGLES = 0x100;\n") +SLANG_RAW("static const RAY_FLAG RAY_FLAG_SKIP_PROCEDURAL_PRIMITIVES = 0x200;\n") SLANG_RAW("\n") SLANG_RAW("// 10.1.2 - Ray Description Structure\n") SLANG_RAW("\n") -- cgit v1.2.3