diff options
| author | Yong He <yonghe@outlook.com> | 2021-08-17 09:39:02 -0700 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2021-08-17 09:39:02 -0700 |
| commit | 858c7c57b125afed9b5b2329d6b02477284e4803 (patch) | |
| tree | 49f67b342448dcfb19913d8ccc089d956de14462 | |
| parent | 6406523511037987d8b8ab881aea41389afd57eb (diff) | |
Add GLSL450 intrinsics to SPIRV direct emit. (#1921)
* Add GLSL450 intrinsics to SPIRV direct emit.
* Fix.
* Fix compiler error.
* Fix.
* Fix compiler error.
* Make direct-spirv tests actually run.
25 files changed, 1092 insertions, 232 deletions
@@ -613,7 +613,10 @@ extern "C" SLANG_TARGET_FLAG_GENERATE_WHOLE_PROGRAM = 1 << 8, /* When set, will dump out the IR between intermediate compilation steps.*/ - SLANG_TARGET_FLAG_DUMP_IR = 1 << 9 + SLANG_TARGET_FLAG_DUMP_IR = 1 << 9, + + /* When set, will generate SPIRV directly instead of going through glslang. */ + SLANG_TARGET_FLAG_GENERATE_SPIRV_DIRECTLY = 1 << 10, }; /*! diff --git a/source/slang/hlsl.meta.slang b/source/slang/hlsl.meta.slang index dd4f95cf5..aa833a5e7 100644 --- a/source/slang/hlsl.meta.slang +++ b/source/slang/hlsl.meta.slang @@ -200,14 +200,14 @@ struct StructuredBuffer out uint stride); __target_intrinsic(glsl, "$0._data[$1]") - __target_intrinsic(spirv_direct, "%addr = 65 resultType*StorageBuffer resultId _0 _1; 61 resultType resultId %addr;") + __target_intrinsic(spirv_direct, "%addr = 65 resultType*StorageBuffer resultId _0 const(int, 0) _1; 61 resultType resultId %addr;") T Load(int location); T Load(int location, out uint status); __subscript(uint index) -> T { __target_intrinsic(glsl, "$0._data[$1]") - __target_intrinsic(spirv_direct, "%addr = 65 resultType*StorageBuffer resultId _0 _1; 61 resultType resultId %addr;") + __target_intrinsic(spirv_direct, "%addr = 65 resultType*StorageBuffer resultId _0 const(int, 0) _1; 61 resultType resultId %addr;") get; }; }; @@ -631,14 +631,14 @@ struct $(item.name) uint IncrementCounter(); __target_intrinsic(glsl, "$0._data[$1]") - __target_intrinsic(spirv_direct, "%addr = 65 resultType*StorageBuffer resultId _0 _1; 61 resultType resultId %addr;") + __target_intrinsic(spirv_direct, "%addr = 65 resultType*StorageBuffer resultId _0 const(int, 0) _1; 61 resultType resultId %addr;") T Load(int location); T Load(int location, out uint status); __subscript(uint index) -> T { __target_intrinsic(glsl, "$0._data[$1]") - __target_intrinsic(spirv_direct, "*StorageBuffer 65 resultType resultId _0 _1") + __target_intrinsic(spirv_direct, "*StorageBuffer 65 resultType resultId _0 const(int, 0) _1") ref; } }; @@ -711,6 +711,7 @@ __target_intrinsic(hlsl) __target_intrinsic(glsl) __target_intrinsic(cuda, "$P_abs($0)") __target_intrinsic(cpp, "$P_abs($0)") +__target_intrinsic(spirv_direct, "12 resultType resultId glsl450 fi(4,5) _0") T abs(T x); /*{ // Note: this simple definition may not be appropriate for floating-point inputs @@ -720,6 +721,7 @@ T abs(T x); __generic<T : __BuiltinSignedArithmeticType, let N : int> __target_intrinsic(hlsl) __target_intrinsic(glsl) +__target_intrinsic(spirv_direct, "12 resultType resultId glsl450 fi(4,5) _0") vector<T, N> abs(vector<T, N> x) { VECTOR_MAP_UNARY(T, N, abs, x); @@ -739,11 +741,13 @@ __target_intrinsic(hlsl) __target_intrinsic(glsl) __target_intrinsic(cuda, "$P_acos($0)") __target_intrinsic(cpp, "$P_acos($0)") +__target_intrinsic(spirv_direct, "12 resultType resultId glsl450 17 _0") T acos(T x); __generic<T : __BuiltinFloatingPointType, let N : int> __target_intrinsic(hlsl) __target_intrinsic(glsl) +__target_intrinsic(spirv_direct, "12 resultType resultId glsl450 17 _0") vector<T, N> acos(vector<T, N> x) { VECTOR_MAP_UNARY(T, N, acos, x); @@ -838,6 +842,7 @@ bool any(matrix<T, N, M> x); __target_intrinsic(hlsl) __target_intrinsic(glsl, "packDouble2x32(uvec2($0, $1))") +__target_intrinsic(spirv_direct, "%v = 80 _type(uint2) resultId _0 _1; 12 resultType resultId glsl450 59 %v") __glsl_extension(GL_ARB_gpu_shader5) double asdouble(uint lowbits, uint highbits); @@ -845,15 +850,18 @@ double asdouble(uint lowbits, uint highbits); __target_intrinsic(hlsl) __target_intrinsic(glsl, "intBitsToFloat") +__target_intrinsic(spirv_direct, "124 resultType resultId _0") float asfloat(int x); __target_intrinsic(hlsl) __target_intrinsic(glsl, "uintBitsToFloat") +__target_intrinsic(spirv_direct, "124 resultType resultId _0") float asfloat(uint x); __generic<let N : int> __target_intrinsic(hlsl) __target_intrinsic(glsl, "intBitsToFloat") +__target_intrinsic(spirv_direct, "124 resultType resultId _0") vector<float, N> asfloat(vector< int, N> x) { VECTOR_MAP_UNARY(float, N, asfloat, x); @@ -862,6 +870,7 @@ vector<float, N> asfloat(vector< int, N> x) __generic<let N : int> __target_intrinsic(hlsl) __target_intrinsic(glsl, "uintBitsToFloat") +__target_intrinsic(spirv_direct, "124 resultType resultId _0") vector<float,N> asfloat(vector<uint,N> x) { VECTOR_MAP_UNARY(float, N, asfloat, x); @@ -902,11 +911,13 @@ __target_intrinsic(hlsl) __target_intrinsic(glsl) __target_intrinsic(cuda, "$P_asin($0)") __target_intrinsic(cpp, "$P_asin($0)") +__target_intrinsic(spirv_direct, "12 resultType resultId glsl450 16 _0") T asin(T x); __generic<T : __BuiltinFloatingPointType, let N : int> __target_intrinsic(hlsl) __target_intrinsic(glsl) +__target_intrinsic(spirv_direct, "12 resultType resultId glsl450 16 _0") vector<T, N> asin(vector<T, N> x) { VECTOR_MAP_UNARY(T,N,asin,x); @@ -923,15 +934,18 @@ matrix<T, N, M> asin(matrix<T, N, M> x) __target_intrinsic(hlsl) __target_intrinsic(glsl, "floatBitsToInt") +__target_intrinsic(spirv_direct, "124 resultType resultId _0") int asint(float x); __target_intrinsic(hlsl) __target_intrinsic(glsl, "int($0)") +__target_intrinsic(spirv_direct, "124 resultType resultId _0") int asint(uint x); __generic<let N : int> __target_intrinsic(hlsl) __target_intrinsic(glsl, "floatBitsToInt") +__target_intrinsic(spirv_direct, "124 resultType resultId _0") vector<int, N> asint(vector<float, N> x) { VECTOR_MAP_UNARY(int, N, asint, x); @@ -940,6 +954,7 @@ vector<int, N> asint(vector<float, N> x) __generic<let N : int> __target_intrinsic(hlsl) __target_intrinsic(glsl, "ivec$N0($0)") +__target_intrinsic(spirv_direct, "124 resultType resultId _0") vector<int, N> asint(vector<uint, N> x) { VECTOR_MAP_UNARY(int, N, asint, x); @@ -985,15 +1000,18 @@ void asuint(double value, out uint lowbits, out uint highbits); __target_intrinsic(hlsl) __target_intrinsic(glsl, "floatBitsToUint") +__target_intrinsic(spirv_direct, "124 resultType resultId _0") uint asuint(float x); __target_intrinsic(hlsl) __target_intrinsic(glsl, "uint($0)") +__target_intrinsic(spirv_direct, "124 resultType resultId _0") uint asuint(int x); __generic<let N : int> __target_intrinsic(hlsl) __target_intrinsic(glsl, "floatBitsToUint") +__target_intrinsic(spirv_direct, "124 resultType resultId _0") vector<uint,N> asuint(vector<float,N> x) { VECTOR_MAP_UNARY(uint, N, asuint, x); @@ -1002,6 +1020,7 @@ vector<uint,N> asuint(vector<float,N> x) __generic<let N : int> __target_intrinsic(hlsl) __target_intrinsic(glsl, "uvec$N0($0)") +__target_intrinsic(spirv_direct, "124 resultType resultId _0") vector<uint, N> asuint(vector<int, N> x) { VECTOR_MAP_UNARY(uint, N, asuint, x); @@ -1113,11 +1132,13 @@ __target_intrinsic(hlsl) __target_intrinsic(glsl) __target_intrinsic(cuda, "$P_atan($0)") __target_intrinsic(cpp, "$P_atan($0)") +__target_intrinsic(spirv_direct, "12 resultType resultId glsl450 18 _0") T atan(T x); __generic<T : __BuiltinFloatingPointType, let N : int> __target_intrinsic(hlsl) __target_intrinsic(glsl) +__target_intrinsic(spirv_direct, "12 resultType resultId glsl450 18 _0") vector<T, N> atan(vector<T, N> x) { VECTOR_MAP_UNARY(T, N, atan, x); @@ -1135,11 +1156,13 @@ __target_intrinsic(hlsl) __target_intrinsic(glsl,"atan($0,$1)") __target_intrinsic(cuda, "$P_atan2($0, $1)") __target_intrinsic(cpp, "$P_atan2($0, $1)") +__target_intrinsic(spirv_direct, "12 resultType resultId glsl450 25 _0 _1") T atan2(T y, T x); __generic<T : __BuiltinFloatingPointType, let N : int> __target_intrinsic(hlsl) __target_intrinsic(glsl,"atan($0,$1)") +__target_intrinsic(spirv_direct, "12 resultType resultId glsl450 25 _0 _1") vector<T, N> atan2(vector<T, N> y, vector<T, N> x) { VECTOR_MAP_BINARY(T, N, atan2, y, x); @@ -1158,11 +1181,13 @@ __target_intrinsic(hlsl) __target_intrinsic(glsl) __target_intrinsic(cuda, "$P_ceil($0)") __target_intrinsic(cpp, "$P_ceil($0)") +__target_intrinsic(spirv_direct, "12 resultType resultId glsl450 9 _0") T ceil(T x); __generic<T : __BuiltinFloatingPointType, let N : int> __target_intrinsic(hlsl) __target_intrinsic(glsl) +__target_intrinsic(spirv_direct, "12 resultType resultId glsl450 9 _0") vector<T, N> ceil(vector<T, N> x) { VECTOR_MAP_UNARY(T, N, ceil, x); @@ -1183,6 +1208,7 @@ bool CheckAccessFullyMapped(uint status); __generic<T : __BuiltinArithmeticType> __target_intrinsic(hlsl) __target_intrinsic(glsl) +__target_intrinsic(spirv_direct, "12 resultType resultId glsl450 fus(43,44,45) _0 _1 _2") T clamp(T x, T minBound, T maxBound) { return min(max(x, minBound), maxBound); @@ -1191,6 +1217,7 @@ T clamp(T x, T minBound, T maxBound) __generic<T : __BuiltinArithmeticType, let N : int> __target_intrinsic(hlsl) __target_intrinsic(glsl) +__target_intrinsic(spirv_direct, "12 resultType resultId glsl450 fus(43,44,45) _0 _1 _2") vector<T, N> clamp(vector<T, N> x, vector<T, N> minBound, vector<T, N> maxBound) { return min(max(x, minBound), maxBound); @@ -1231,11 +1258,13 @@ __target_intrinsic(hlsl) __target_intrinsic(glsl) __target_intrinsic(cuda, "$P_cos($0)") __target_intrinsic(cpp, "$P_cos($0)") +__target_intrinsic(spirv_direct, "12 resultType resultId glsl450 14 _0") T cos(T x); __generic<T : __BuiltinFloatingPointType, let N : int> __target_intrinsic(hlsl) __target_intrinsic(glsl) +__target_intrinsic(spirv_direct, "12 resultType resultId glsl450 14 _0") vector<T, N> cos(vector<T, N> x) { VECTOR_MAP_UNARY(T,N, cos, x); @@ -1254,11 +1283,13 @@ __target_intrinsic(hlsl) __target_intrinsic(glsl) __target_intrinsic(cuda, "$P_cosh($0)") __target_intrinsic(cpp, "$P_cosh($0)") +__target_intrinsic(spirv_direct, "12 resultType resultId glsl450 20 _0") T cosh(T x); __generic<T : __BuiltinFloatingPointType, let N : int> __target_intrinsic(hlsl) __target_intrinsic(glsl) +__target_intrinsic(spirv_direct, "12 resultType resultId glsl450 20 _0") vector<T,N> cosh(vector<T,N> x) { VECTOR_MAP_UNARY(T,N, cosh, x); @@ -1279,9 +1310,11 @@ __target_intrinsic(cpp, "$P_countbits($0)") uint countbits(uint value); // Cross product +// TODO: SPIRV does not support integer vectors. __generic<T : __BuiltinArithmeticType> __target_intrinsic(hlsl) __target_intrinsic(glsl) +__target_intrinsic(spirv_direct, "12 resultType resultId glsl450 68 _0 _1") vector<T,3> cross(vector<T,3> left, vector<T,3> right) { return vector<T,3>( @@ -1431,6 +1464,7 @@ matrix<T, N, M> ddy_fine(matrix<T, N, M> x) __generic<T : __BuiltinFloatingPointType> __target_intrinsic(hlsl) __target_intrinsic(glsl) +__target_intrinsic(spirv_direct, "12 resultType resultId glsl450 12 _0") T degrees(T x) { return x * (T(180) / T.getPi()); @@ -1439,6 +1473,7 @@ T degrees(T x) __generic<T : __BuiltinFloatingPointType, let N : int> __target_intrinsic(hlsl) __target_intrinsic(glsl) +__target_intrinsic(spirv_direct, "12 resultType resultId glsl450 12 _0") vector<T, N> degrees(vector<T, N> x) { VECTOR_MAP_UNARY(T, N, degrees, x); @@ -1453,7 +1488,11 @@ matrix<T, N, M> degrees(matrix<T, N, M> x) // Matrix determinant -__generic<T : __BuiltinFloatingPointType, let N : int> T determinant(matrix<T,N,N> m); +__generic<T : __BuiltinFloatingPointType, let N : int> +__target_intrinsic(hlsl) +__target_intrinsic(glsl) +__target_intrinsic(spirv_direct, "12 resultType resultId glsl450 33 _0") +T determinant(matrix<T,N,N> m); // Barrier for device memory __target_intrinsic(glsl, "memoryBarrier(), memoryBarrierImage(), memoryBarrierBuffer()") @@ -1469,6 +1508,7 @@ void DeviceMemoryBarrierWithGroupSync(); __generic<T : __BuiltinFloatingPointType, let N : int> __target_intrinsic(hlsl) __target_intrinsic(glsl) +__target_intrinsic(spirv_direct, "12 resultType resultId glsl450 67 _0 _1") T distance(vector<T, N> x, vector<T, N> y) { return length(x - y); @@ -1505,42 +1545,68 @@ RWStructuredBuffer<T> __getEquivalentStructuredBuffer<T>(RWByteAddressBuffer b); // when compiled to GLSL, since they only support scalar/vector // TODO: Should these be constrains to `__BuiltinFloatingPointType`? +// TODO: SPIRV-direct does not support non-floating-point types. __generic<T : __BuiltinArithmeticType> __target_intrinsic(glsl, interpolateAtCentroid) +__target_intrinsic(spirv_direct, "12 resultType resultId glsl450 76 _0") T EvaluateAttributeAtCentroid(T x); __generic<T : __BuiltinArithmeticType, let N : int> __target_intrinsic(glsl, interpolateAtCentroid) +__target_intrinsic(spirv_direct, "12 resultType resultId glsl450 76 _0") vector<T,N> EvaluateAttributeAtCentroid(vector<T,N> x); __generic<T : __BuiltinArithmeticType, let N : int, let M : int> __target_intrinsic(glsl, interpolateAtCentroid) -matrix<T,N,M> EvaluateAttributeAtCentroid(matrix<T,N,M> x); +matrix<T,N,M> EvaluateAttributeAtCentroid(matrix<T,N,M> x) +{ + MATRIX_MAP_UNARY(T, N, M, EvaluateAttributeAtCentroid, x); +} __generic<T : __BuiltinArithmeticType> __target_intrinsic(glsl, "interpolateAtSample($0, int($1))") +__target_intrinsic(spirv_direct, "12 resultType resultId glsl450 77 _0 _1") T EvaluateAttributeAtSample(T x, uint sampleindex); __generic<T : __BuiltinArithmeticType, let N : int> __target_intrinsic(glsl, "interpolateAtSample($0, int($1))") +__target_intrinsic(spirv_direct, "12 resultType resultId glsl450 77 _0 _1") vector<T,N> EvaluateAttributeAtSample(vector<T,N> x, uint sampleindex); __generic<T : __BuiltinArithmeticType, let N : int, let M : int> __target_intrinsic(glsl, "interpolateAtSample($0, int($1))") -matrix<T,N,M> EvaluateAttributeAtSample(matrix<T,N,M> x, uint sampleindex); +matrix<T,N,M> EvaluateAttributeAtSample(matrix<T,N,M> x, uint sampleindex) +{ + matrix<T,N,M> result; + for(int i = 0; i < N; ++i) + { + result[i] = EvaluateAttributeAtSample(x[i], sampleindex); + } + return result; +} __generic<T : __BuiltinArithmeticType> __target_intrinsic(glsl, "interpolateAtOffset($0, vec2($1) / 16.0f)") +__target_intrinsic(spirv_direct, "%foffset = 111 _type(float2) resultId _1; %offsetdiv16 = 136 _type(float2) resultId %foffset const(float2, 16.0, 16.0); 12 resultType resultId glsl450 78 _0 %offsetdiv16") T EvaluateAttributeSnapped(T x, int2 offset); __generic<T : __BuiltinArithmeticType, let N : int> __target_intrinsic(glsl, "interpolateAtOffset($0, vec2($1) / 16.0f)") +__target_intrinsic(spirv_direct, "%foffset = 111 _type(float2) resultId _1; %offsetdiv16 = 136 _type(float2) resultId %foffset const(float2, 16.0, 16.0); 12 resultType resultId glsl450 78 _0 %offsetdiv16") vector<T,N> EvaluateAttributeSnapped(vector<T,N> x, int2 offset); __generic<T : __BuiltinArithmeticType, let N : int, let M : int> __target_intrinsic(glsl, "interpolateAtOffset($0, vec2($1) / 16.0f)") -matrix<T,N,M> EvaluateAttributeSnapped(matrix<T,N,M> x, int2 offset); +matrix<T,N,M> EvaluateAttributeSnapped(matrix<T,N,M> x, int2 offset) +{ + matrix<T,N,M> result; + for(int i = 0; i < N; ++i) + { + result[i] = EvaluateAttributeSnapped(x[i], offset); + } + return result; +} // Base-e exponent @@ -1549,11 +1615,13 @@ __target_intrinsic(hlsl) __target_intrinsic(glsl) __target_intrinsic(cuda, "$P_exp($0)") __target_intrinsic(cpp, "$P_exp($0)") +__target_intrinsic(spirv_direct, "12 resultType resultId glsl450 27 _0") T exp(T x); __generic<T : __BuiltinFloatingPointType, let N : int> __target_intrinsic(hlsl) __target_intrinsic(glsl) +__target_intrinsic(spirv_direct, "12 resultType resultId glsl450 27 _0") vector<T, N> exp(vector<T, N> x) { VECTOR_MAP_UNARY(T, N, exp, x); @@ -1573,11 +1641,13 @@ __target_intrinsic(hlsl) __target_intrinsic(glsl) __target_intrinsic(cuda, "$P_exp2($0)") __target_intrinsic(cpp, "$P_exp2($0)") +__target_intrinsic(spirv_direct, "12 resultType resultId glsl450 29 _0") T exp2(T x); __generic<T : __BuiltinFloatingPointType, let N : int> __target_intrinsic(hlsl) __target_intrinsic(glsl) +__target_intrinsic(spirv_direct, "12 resultType resultId glsl450 29 _0") vector<T,N> exp2(vector<T,N> x) { VECTOR_MAP_UNARY(T, N, exp2, x); @@ -1666,10 +1736,12 @@ __target_intrinsic(hlsl) __target_intrinsic(glsl,"findMSB") __target_intrinsic(cuda, "$P_firstbithigh($0)") __target_intrinsic(cpp, "$P_firstbithigh($0)") +__target_intrinsic(spirv_direct, "12 resultType resultId glsl450 74 _0") int firstbithigh(int value); __target_intrinsic(hlsl) __target_intrinsic(glsl,"findMSB") +__target_intrinsic(spirv_direct, "12 resultType resultId glsl450 74 _0") __generic<let N : int> vector<int, N> firstbithigh(vector<int, N> value) { @@ -1680,10 +1752,12 @@ __target_intrinsic(hlsl) __target_intrinsic(glsl,"findMSB") __target_intrinsic(cuda, "$P_firstbithigh($0)") __target_intrinsic(cpp, "$P_firstbithigh($0)") +__target_intrinsic(spirv_direct, "12 resultType resultId glsl450 75 _0") uint firstbithigh(uint value); __target_intrinsic(hlsl) __target_intrinsic(glsl,"findMSB") +__target_intrinsic(spirv_direct, "12 resultType resultId glsl450 75 _0") __generic<let N : int> vector<uint,N> firstbithigh(vector<uint,N> value) { @@ -1695,10 +1769,12 @@ __target_intrinsic(hlsl) __target_intrinsic(glsl,"findLSB") __target_intrinsic(cuda, "$P_firstbitlow($0)") __target_intrinsic(cpp, "$P_firstbitlow($0)") +__target_intrinsic(spirv_direct, "12 resultType resultId glsl450 73 _0") int firstbitlow(int value); __target_intrinsic(hlsl) __target_intrinsic(glsl,"findLSB") +__target_intrinsic(spirv_direct, "12 resultType resultId glsl450 73 _0") __generic<let N : int> vector<int,N> firstbitlow(vector<int,N> value) { @@ -1709,11 +1785,13 @@ __target_intrinsic(hlsl) __target_intrinsic(glsl,"findLSB") __target_intrinsic(cuda, "$P_firstbitlow($0)") __target_intrinsic(cpp, "$P_firstbitlow($0)") +__target_intrinsic(spirv_direct, "12 resultType resultId glsl450 73 _0") uint firstbitlow(uint value); __target_intrinsic(hlsl) __target_intrinsic(glsl,"findLSB") __generic<let N : int> +__target_intrinsic(spirv_direct, "12 resultType resultId glsl450 73 _0") vector<uint,N> firstbitlow(vector<uint,N> value) { VECTOR_MAP_UNARY(uint, N, firstbitlow, value); @@ -1726,11 +1804,13 @@ __target_intrinsic(hlsl) __target_intrinsic(glsl) __target_intrinsic(cuda, "$P_floor($0)") __target_intrinsic(cpp, "$P_floor($0)") +__target_intrinsic(spirv_direct, "12 resultType resultId glsl450 8 _0") T floor(T x); __generic<T : __BuiltinFloatingPointType, let N : int> __target_intrinsic(hlsl) __target_intrinsic(glsl) +__target_intrinsic(spirv_direct, "12 resultType resultId glsl450 8 _0") vector<T, N> floor(vector<T, N> x) { VECTOR_MAP_UNARY(T, N, floor, x); @@ -1748,11 +1828,13 @@ __target_intrinsic(hlsl) __target_intrinsic(glsl) __target_intrinsic(cuda, "$P_fma($0, $1, $2)") __target_intrinsic(cpp, "$P_fma($0, $1, $2)") +__target_intrinsic(spirv_direct, "12 resultType resultId glsl450 50 _0 _1 _2") double fma(double a, double b, double c); __generic<let N : int> __target_intrinsic(hlsl) __target_intrinsic(glsl) +__target_intrinsic(spirv_direct, "12 resultType resultId glsl450 50 _0 _1 _2") vector<double, N> fma(vector<double, N> a, vector<double, N> b, vector<double, N> c) { VECTOR_MAP_TRINARY(double, N, fma, a, b, c); @@ -1795,11 +1877,13 @@ __target_intrinsic(hlsl) __target_intrinsic(glsl, fract) __target_intrinsic(cuda, "$P_frac($0)") __target_intrinsic(cpp, "$P_frac($0)") +__target_intrinsic(spirv_direct, "12 resultType resultId glsl450 10 _0") T frac(T x); __generic<T : __BuiltinFloatingPointType, let N : int> __target_intrinsic(hlsl) __target_intrinsic(glsl, fract) +__target_intrinsic(spirv_direct, "12 resultType resultId glsl450 10 _0") vector<T, N> frac(vector<T, N> x) { VECTOR_MAP_UNARY(T, N, frac, x); @@ -1815,11 +1899,13 @@ matrix<T, N, M> frac(matrix<T, N, M> x) __generic<T : __BuiltinFloatingPointType> __target_intrinsic(hlsl) __target_intrinsic(glsl) +__target_intrinsic(spirv_direct, "12 resultType resultId glsl450 51 _0 _1") T frexp(T x, out T exp); __generic<T : __BuiltinFloatingPointType, let N : int> __target_intrinsic(hlsl) __target_intrinsic(glsl) +__target_intrinsic(spirv_direct, "12 resultType resultId glsl450 51 _0 _1") vector<T, N> frexp(vector<T, N> x, out vector<T, N> exp) { VECTOR_MAP_BINARY(T, N, frexp, x, exp); @@ -2131,6 +2217,7 @@ matrix<bool, N, M> isnan(matrix<T, N, M> x) __generic<T : __BuiltinFloatingPointType> __target_intrinsic(hlsl) +__target_intrinsic(spirv_direct, "12 resultType resultId glsl450 53 _0 _1") T ldexp(T x, T exp) { return x * exp2(exp); @@ -2138,6 +2225,7 @@ T ldexp(T x, T exp) __generic<T : __BuiltinFloatingPointType, let N : int> __target_intrinsic(hlsl) +__target_intrinsic(spirv_direct, "12 resultType resultId glsl450 53 _0 _1") vector<T, N> ldexp(vector<T, N> x, vector<T, N> exp) { return x * exp2(exp); @@ -2154,6 +2242,7 @@ matrix<T, N, M> ldexp(matrix<T, N, M> x, matrix<T, N, M> exp) __generic<T : __BuiltinFloatingPointType, let N : int> __target_intrinsic(hlsl) __target_intrinsic(glsl) +__target_intrinsic(spirv_direct, "12 resultType resultId glsl450 66 _0") T length(vector<T, N> x) { return sqrt(dot(x, x)); @@ -2163,6 +2252,7 @@ T length(vector<T, N> x) __generic<T : __BuiltinFloatingPointType> __target_intrinsic(hlsl) __target_intrinsic(glsl, mix) +__target_intrinsic(spirv_direct, "12 resultType resultId glsl450 46 _0 _1 _2") T lerp(T x, T y, T s) { return x * (T(1.0f) - s) + y * s; @@ -2171,6 +2261,7 @@ T lerp(T x, T y, T s) __generic<T : __BuiltinFloatingPointType, let N : int> __target_intrinsic(hlsl) __target_intrinsic(glsl, mix) +__target_intrinsic(spirv_direct, "12 resultType resultId glsl450 46 _0 _1 _2") vector<T, N> lerp(vector<T, N> x, vector<T, N> y, vector<T, N> s) { return x * (T(1.0f) - s) + y * s; @@ -2199,11 +2290,13 @@ __target_intrinsic(hlsl) __target_intrinsic(glsl) __target_intrinsic(cuda, "$P_log($0)") __target_intrinsic(cpp, "$P_log($0)") +__target_intrinsic(spirv_direct, "12 resultType resultId glsl450 28 _0") T log(T x); __generic<T : __BuiltinFloatingPointType, let N : int> __target_intrinsic(hlsl) __target_intrinsic(glsl) +__target_intrinsic(spirv_direct, "12 resultType resultId glsl450 28 _0") vector<T, N> log(vector<T, N> x) { VECTOR_MAP_UNARY(T, N, log, x); @@ -2222,11 +2315,13 @@ __target_intrinsic(hlsl) __target_intrinsic(glsl, "(log( $0 ) * $S0( 0.43429448190325182765112891891661) )" ) __target_intrinsic(cuda, "$P_log10($0)") __target_intrinsic(cpp, "$P_log10($0)") +__target_intrinsic(spirv_direct, "%baseElog = 12 resultType resultId glsl450 28 _0; 133 resultType resultId _0 %baseElog const(_p,0.43429448190325182765112891891661)") T log10(T x); __generic<T : __BuiltinFloatingPointType, let N : int> __target_intrinsic(hlsl) __target_intrinsic(glsl, "(log( $0 ) * $S0(0.43429448190325182765112891891661) )" ) +__target_intrinsic(spirv_direct, "%baseElog = 12 resultType resultId glsl450 28 _0; 142 resultType resultId _0 %baseElog const(_p,0.43429448190325182765112891891661)") vector<T,N> log10(vector<T,N> x) { VECTOR_MAP_UNARY(T, N, log10, x); @@ -2245,11 +2340,13 @@ __target_intrinsic(hlsl) __target_intrinsic(glsl) __target_intrinsic(cuda, "$P_log2($0)") __target_intrinsic(cpp, "$P_log2($0)") +__target_intrinsic(spirv_direct, "12 resultType resultId glsl450 30 _0") T log2(T x); __generic<T : __BuiltinFloatingPointType, let N : int> __target_intrinsic(hlsl) __target_intrinsic(glsl) +__target_intrinsic(spirv_direct, "12 resultType resultId glsl450 30 _0") vector<T,N> log2(vector<T,N> x) { VECTOR_MAP_UNARY(T, N, log2, x); @@ -2269,11 +2366,13 @@ __target_intrinsic(hlsl) __target_intrinsic(glsl, fma) __target_intrinsic(cuda, "$P_fma($0, $1, $2)") __target_intrinsic(cpp, "$P_fma($0, $1, $2)") +__target_intrinsic(spirv_direct, "12 resultType resultId glsl450 50 _0 _1 _2") T mad(T mvalue, T avalue, T bvalue); __generic<T : __BuiltinArithmeticType, let N : int> __target_intrinsic(hlsl) __target_intrinsic(glsl, fma) +__target_intrinsic(spirv_direct, "12 resultType resultId glsl450 50 _0 _1 _2") vector<T, N> mad(vector<T, N> mvalue, vector<T, N> avalue, vector<T, N> bvalue) { VECTOR_MAP_TRINARY(T, N, mad, mvalue, avalue, bvalue); @@ -2292,6 +2391,7 @@ __target_intrinsic(hlsl) __target_intrinsic(glsl) __target_intrinsic(cuda, "$P_max($0, $1)") __target_intrinsic(cpp, "$P_max($0, $1)") +__target_intrinsic(spirv_direct, "12 resultType resultId glsl450 fus(40,41,42) _0") T max(T x, T y); // Note: a stdlib implementation of `max` (or `min`) will require splitting // floating-point and integer cases apart, because the floating-point @@ -2301,6 +2401,7 @@ T max(T x, T y); __generic<T : __BuiltinArithmeticType, let N : int> __target_intrinsic(hlsl) __target_intrinsic(glsl) +__target_intrinsic(spirv_direct, "12 resultType resultId glsl450 fus(40,41,42) _0") vector<T, N> max(vector<T, N> x, vector<T, N> y) { VECTOR_MAP_BINARY(T, N, max, x, y); @@ -2319,11 +2420,13 @@ __target_intrinsic(hlsl) __target_intrinsic(glsl) __target_intrinsic(cuda, "$P_min($0, $1)") __target_intrinsic(cpp, "$P_min($0, $1)") +__target_intrinsic(spirv_direct, "12 resultType resultId glsl450 fus(37,38,39) _0") T min(T x, T y); __generic<T : __BuiltinArithmeticType, let N : int> __target_intrinsic(hlsl) __target_intrinsic(glsl) +__target_intrinsic(spirv_direct, "12 resultType resultId glsl450 fus(37,38,39) _0") vector<T,N> min(vector<T,N> x, vector<T,N> y) { VECTOR_MAP_BINARY(T, N, min, x, y); @@ -2518,6 +2621,7 @@ int NonUniformResourceIndex(int index) __generic<T : __BuiltinFloatingPointType, let N : int> __target_intrinsic(hlsl) __target_intrinsic(glsl) +__target_intrinsic(spirv_direct, "12 resultType resultId glsl450 69 _0") vector<T,N> normalize(vector<T,N> x) { return x / length(x); @@ -2529,11 +2633,13 @@ __target_intrinsic(hlsl) __target_intrinsic(glsl) __target_intrinsic(cuda, "$P_pow($0, $1)") __target_intrinsic(cpp, "$P_pow($0, $1)") +__target_intrinsic(spirv_direct, "12 resultType resultId glsl450 26 _0 _1") T pow(T x, T y); __generic<T : __BuiltinFloatingPointType, let N : int> __target_intrinsic(hlsl) __target_intrinsic(glsl) +__target_intrinsic(spirv_direct, "12 resultType resultId glsl450 26 _0 _1") vector<T, N> pow(vector<T, N> x, vector<T, N> y) { VECTOR_MAP_BINARY(T, N, pow, x, y); @@ -2625,6 +2731,7 @@ void ProcessTriTessFactorsMin( __generic<T : __BuiltinFloatingPointType> __target_intrinsic(hlsl) __target_intrinsic(glsl) +__target_intrinsic(spirv_direct, "12 resultType resultId glsl450 11 _0") T radians(T x) { return x * (T.getPi() / T(180.0f)); @@ -2633,6 +2740,7 @@ T radians(T x) __generic<T : __BuiltinFloatingPointType, let N : int> __target_intrinsic(hlsl) __target_intrinsic(glsl) +__target_intrinsic(spirv_direct, "12 resultType resultId glsl450 11 _0") vector<T, N> radians(vector<T, N> x) { return x * (T.getPi() / T(180.0f)); @@ -2671,6 +2779,7 @@ matrix<T, N, M> rcp(matrix<T, N, M> x) __generic<T : __BuiltinFloatingPointType, let N : int> __target_intrinsic(hlsl) __target_intrinsic(glsl) +__target_intrinsic(spirv_direct, "12 resultType resultId glsl450 71 _0 _1") vector<T,N> reflect(vector<T,N> i, vector<T,N> n) { return i - T(2) * dot(n,i) * n; @@ -2680,6 +2789,7 @@ vector<T,N> reflect(vector<T,N> i, vector<T,N> n) __generic<T : __BuiltinFloatingPointType, let N : int> __target_intrinsic(hlsl) __target_intrinsic(glsl) +__target_intrinsic(spirv_direct, "12 resultType resultId glsl450 72 _0 _1 _2") vector<T,N> refract(vector<T,N> i, vector<T,N> n, T eta) { let dotNI = dot(n,i); @@ -2708,11 +2818,13 @@ __target_intrinsic(hlsl) __target_intrinsic(glsl) __target_intrinsic(cuda, "$P_round($0)") __target_intrinsic(cpp, "$P_round($0)") +__target_intrinsic(spirv_direct, "12 resultType resultId glsl450 1 _0") T round(T x); __generic<T : __BuiltinFloatingPointType, let N : int> __target_intrinsic(hlsl) __target_intrinsic(glsl) +__target_intrinsic(spirv_direct, "12 resultType resultId glsl450 1 _0") vector<T, N> round(vector<T, N> x) { VECTOR_MAP_UNARY(T, N, round, x); @@ -2731,6 +2843,7 @@ __target_intrinsic(hlsl) __target_intrinsic(glsl, "inversesqrt($0)") __target_intrinsic(cuda, "$P_rsqrt($0)") __target_intrinsic(cpp, "$P_rsqrt($0)") +__target_intrinsic(spirv_direct, "12 resultType resultId glsl450 32 _0") T rsqrt(T x) { return T(1.0) / sqrt(x); @@ -2739,6 +2852,7 @@ T rsqrt(T x) __generic<T : __BuiltinFloatingPointType, let N : int> __target_intrinsic(hlsl) __target_intrinsic(glsl, "inversesqrt($0)") +__target_intrinsic(spirv_direct, "12 resultType resultId glsl450 32 _0") vector<T, N> rsqrt(vector<T, N> x) { VECTOR_MAP_UNARY(T, N, rsqrt, x); @@ -2782,11 +2896,13 @@ __target_intrinsic(hlsl) __target_intrinsic(glsl, "int(sign($0))") __target_intrinsic(cuda, "$P_sign($0)") __target_intrinsic(cpp, "$P_sign($0)") +__target_intrinsic(spirv_direct, "12 resultType resultId glsl450 fi(6,7) _0") int sign(T x); __generic<T : __BuiltinSignedArithmeticType, let N : int> __target_intrinsic(hlsl) __target_intrinsic(glsl, "ivec$N0(sign($0))") +__target_intrinsic(spirv_direct, "12 resultType resultId glsl450 fi(6,7) _0") vector<int, N> sign(vector<T, N> x) { VECTOR_MAP_UNARY(int, N, sign, x); @@ -2807,11 +2923,13 @@ __target_intrinsic(hlsl) __target_intrinsic(glsl) __target_intrinsic(cuda, "$P_sin($0)") __target_intrinsic(cpp, "$P_sin($0)") +__target_intrinsic(spirv_direct, "12 resultType resultId glsl450 13 _0") T sin(T x); __generic<T : __BuiltinFloatingPointType, let N : int> __target_intrinsic(hlsl) __target_intrinsic(glsl) +__target_intrinsic(spirv_direct, "12 resultType resultId glsl450 13 _0") vector<T, N> sin(vector<T, N> x) { VECTOR_MAP_UNARY(T, N, sin, x); @@ -2856,11 +2974,13 @@ __target_intrinsic(hlsl) __target_intrinsic(glsl) __target_intrinsic(cuda, "$P_sinh($0)") __target_intrinsic(cpp, "$P_sinh($0)") +__target_intrinsic(spirv_direct, "12 resultType resultId glsl450 19 _0") T sinh(T x); __generic<T : __BuiltinFloatingPointType, let N : int> __target_intrinsic(hlsl) __target_intrinsic(glsl) +__target_intrinsic(spirv_direct, "12 resultType resultId glsl450 19 _0") vector<T, N> sinh(vector<T, N> x) { VECTOR_MAP_UNARY(T, N, sinh, x); @@ -2877,6 +2997,7 @@ matrix<T, N, M> sinh(matrix<T, N, M> x) __generic<T : __BuiltinFloatingPointType> __target_intrinsic(hlsl) __target_intrinsic(glsl) +__target_intrinsic(spirv_direct, "12 resultType resultId glsl450 49 _0 _1 _2") T smoothstep(T min, T max, T x) { let t = saturate((x - min) / (max - min)); @@ -2886,6 +3007,7 @@ T smoothstep(T min, T max, T x) __generic<T : __BuiltinFloatingPointType, let N : int> __target_intrinsic(hlsl) __target_intrinsic(glsl) +__target_intrinsic(spirv_direct, "12 resultType resultId glsl450 49 _0 _1 _2") vector<T, N> smoothstep(vector<T, N> min, vector<T, N> max, vector<T, N> x) { VECTOR_MAP_TRINARY(T, N, smoothstep, min, max, x); @@ -2904,11 +3026,13 @@ __target_intrinsic(hlsl) __target_intrinsic(glsl) __target_intrinsic(cuda, "$P_sqrt($0)") __target_intrinsic(cpp, "$P_sqrt($0)") +__target_intrinsic(spirv_direct, "12 resultType resultId glsl450 31 _0") T sqrt(T x); __generic<T : __BuiltinFloatingPointType, let N : int> __target_intrinsic(hlsl) __target_intrinsic(glsl) +__target_intrinsic(spirv_direct, "12 resultType resultId glsl450 31 _0") vector<T, N> sqrt(vector<T, N> x) { VECTOR_MAP_UNARY(T, N, sqrt, x); @@ -2925,6 +3049,7 @@ matrix<T, N, M> sqrt(matrix<T, N, M> x) __generic<T : __BuiltinFloatingPointType> __target_intrinsic(hlsl) __target_intrinsic(glsl) +__target_intrinsic(spirv_direct, "12 resultType resultId glsl450 48 _0 _1") T step(T y, T x) { return x < y ? T(0.0f) : T(1.0f); @@ -2933,6 +3058,7 @@ T step(T y, T x) __generic<T : __BuiltinFloatingPointType, let N : int> __target_intrinsic(hlsl) __target_intrinsic(glsl) +__target_intrinsic(spirv_direct, "12 resultType resultId glsl450 48 _0 _1") vector<T,N> step(vector<T,N> y, vector<T,N> x) { VECTOR_MAP_BINARY(T, N, step, y, x); @@ -2951,11 +3077,13 @@ __target_intrinsic(hlsl) __target_intrinsic(glsl) __target_intrinsic(cuda, "$P_tan($0)") __target_intrinsic(cpp, "$P_tan($0)") +__target_intrinsic(spirv_direct, "12 resultType resultId glsl450 15 _0") T tan(T x); __generic<T : __BuiltinFloatingPointType, let N : int> __target_intrinsic(hlsl) __target_intrinsic(glsl) +__target_intrinsic(spirv_direct, "12 resultType resultId glsl450 15 _0") vector<T, N> tan(vector<T, N> x) { VECTOR_MAP_UNARY(T, N, tan, x); @@ -2974,11 +3102,13 @@ __target_intrinsic(hlsl) __target_intrinsic(glsl) __target_intrinsic(cuda, "$P_tanh($0)") __target_intrinsic(cpp, "$P_tanh($0)") +__target_intrinsic(spirv_direct, "12 resultType resultId glsl450 21 _0") T tanh(T x); __generic<T : __BuiltinFloatingPointType, let N : int> __target_intrinsic(hlsl) __target_intrinsic(glsl) +__target_intrinsic(spirv_direct, "12 resultType resultId glsl450 21 _0") vector<T,N> tanh(vector<T,N> x) { VECTOR_MAP_UNARY(T, N, tanh, x); @@ -3010,11 +3140,13 @@ __target_intrinsic(hlsl) __target_intrinsic(glsl) __target_intrinsic(cuda, "$P_trunc($0)") __target_intrinsic(cpp, "$P_trunc($0)") +__target_intrinsic(spirv_direct, "12 resultType resultId glsl450 3 _0") T trunc(T x); __generic<T : __BuiltinFloatingPointType, let N : int> __target_intrinsic(hlsl) __target_intrinsic(glsl) +__target_intrinsic(spirv_direct, "12 resultType resultId glsl450 3 _0") vector<T, N> trunc(vector<T, N> x) { VECTOR_MAP_UNARY(T, N, trunc, x); diff --git a/source/slang/slang-compiler.cpp b/source/slang/slang-compiler.cpp index 028886c7e..997f3fd51 100644 --- a/source/slang/slang-compiler.cpp +++ b/source/slang/slang-compiler.cpp @@ -1442,10 +1442,9 @@ namespace Slang { RefPtr<DownstreamCompileResult> downstreamResult; - if (target == CodeGenTarget::SPIRV && compileRequest->shouldEmitSPIRVDirectly) + if (target == CodeGenTarget::SPIRV && targetReq->shouldEmitSPIRVDirectly()) { List<uint8_t> spirv; - targetReq->setDirectSPIRVEmitMode(); SLANG_RETURN_ON_FAIL(emitSPIRVForEntryPointsDirectly(compileRequest, entryPointIndices, targetReq, spirv)); auto spirvBlob = ListBlob::moveCreate(spirv); downstreamResult = new BlobDownstreamCompileResult(DownstreamDiagnostics(), spirvBlob); diff --git a/source/slang/slang-compiler.h b/source/slang/slang-compiler.h index b829fd0ee..1724794cf 100755 --- a/source/slang/slang-compiler.h +++ b/source/slang/slang-compiler.h @@ -1184,14 +1184,16 @@ namespace Slang } void addCapability(CapabilityAtom capability); + bool shouldEmitSPIRVDirectly() + { + return (targetFlags & SLANG_TARGET_FLAG_GENERATE_SPIRV_DIRECTLY) != 0; + } bool isWholeProgramRequest() { return (targetFlags & SLANG_TARGET_FLAG_GENERATE_WHOLE_PROGRAM) != 0; } - void setDirectSPIRVEmitMode(); - Linkage* getLinkage() { return linkage; } CodeGenTarget getTarget() { return format; } Profile getTargetProfile() { return targetProfile; } @@ -1219,7 +1221,6 @@ namespace Slang List<CapabilityAtom> rawCapabilities; CapabilitySet cookedCapabilities; LineDirectiveMode lineDirectiveMode = LineDirectiveMode::Default; - bool m_emitSPIRVDirectly = false; }; /// Are we generating code for a D3D API? @@ -1898,10 +1899,6 @@ namespace Slang // bool useUnknownImageFormatAsDefault = false; - /// Should SPIR-V be generated directly from Slang IR rather than via translation to GLSL? - bool shouldEmitSPIRVDirectly = false; - - // If true will disable generics/existential value specialization pass. bool disableSpecialization = false; diff --git a/source/slang/slang-emit-base.cpp b/source/slang/slang-emit-base.cpp index d00b723ab..0f1c54e70 100644 --- a/source/slang/slang-emit-base.cpp +++ b/source/slang/slang-emit-base.cpp @@ -52,4 +52,25 @@ IRVarLayout* SourceEmitterBase::getVarLayout(IRInst* var) return as<IRVarLayout>(decoration->getLayout()); } +BaseType SourceEmitterBase::extractBaseType(IRType* inType) +{ + auto type = inType; + for (;;) + { + if (auto irBaseType = as<IRBasicType>(type)) + { + return irBaseType->getBaseType(); + } + else if (auto vecType = as<IRVectorType>(type)) + { + type = vecType->getElementType(); + continue; + } + else + { + return BaseType::Void; + } + } +} + } diff --git a/source/slang/slang-emit-base.h b/source/slang/slang-emit-base.h index ffbf56618..dc8b065e4 100644 --- a/source/slang/slang-emit-base.h +++ b/source/slang/slang-emit-base.h @@ -23,6 +23,8 @@ public: virtual void handleRequiredCapabilitiesImpl(IRInst* inst) { SLANG_UNUSED(inst); } static IRVarLayout* getVarLayout(IRInst* var); + + static BaseType extractBaseType(IRType* inType); }; } diff --git a/source/slang/slang-emit-c-like.cpp b/source/slang/slang-emit-c-like.cpp index f9d71beb9..74c225feb 100644 --- a/source/slang/slang-emit-c-like.cpp +++ b/source/slang/slang-emit-c-like.cpp @@ -1943,27 +1943,6 @@ void CLikeSourceEmitter::defaultEmitInstExpr(IRInst* inst, const EmitOpInfo& inO maybeCloseParens(needClose); } -BaseType CLikeSourceEmitter::extractBaseType(IRType* inType) -{ - auto type = inType; - for(;;) - { - if(auto irBaseType = as<IRBasicType>(type)) - { - return irBaseType->getBaseType(); - } - else if(auto vecType = as<IRVectorType>(type)) - { - type = vecType->getElementType(); - continue; - } - else - { - return BaseType::Void; - } - } -} - void CLikeSourceEmitter::emitInst(IRInst* inst) { try diff --git a/source/slang/slang-emit-c-like.h b/source/slang/slang-emit-c-like.h index f699ed255..08d24ef04 100644 --- a/source/slang/slang-emit-c-like.h +++ b/source/slang/slang-emit-c-like.h @@ -285,9 +285,6 @@ public: void emitInstExpr(IRInst* inst, EmitOpInfo const& inOuterPrec); void defaultEmitInstExpr(IRInst* inst, EmitOpInfo const& inOuterPrec); void diagnoseUnhandledInst(IRInst* inst); - - BaseType extractBaseType(IRType* inType); - void emitInst(IRInst* inst); void emitSemantics(IRInst* inst); diff --git a/source/slang/slang-emit-spirv.cpp b/source/slang/slang-emit-spirv.cpp index 37fd673ed..21f5c1bc8 100644 --- a/source/slang/slang-emit-spirv.cpp +++ b/source/slang/slang-emit-spirv.cpp @@ -266,6 +266,11 @@ void SpvInstParent::dumpTo(List<SpvWord>& ioWords) struct SpvSnippetEmitContext { SpvInst* resultType; + IRType* irResultType; + // True if resultType is float or vector of float. + bool isResultTypeFloat; + // True if resultType is signed. + bool isResultTypeSigned; Dictionary<SpvStorageClass, IRInst*> qualifiedResultTypes; List<SpvWord> argumentIds; }; @@ -401,6 +406,16 @@ struct SPIRVEmitContext void registerInst(IRInst* irInst, SpvInst* spvInst) { m_mapIRInstToSpvInst.Add(irInst, spvInst); + + // If we have reserved an SpvID for `irInst`, make sure to use it. + SpvWord reservedID = 0; + m_mapIRInstToSpvID.TryGetValue(irInst, reservedID); + + if (reservedID) + { + SLANG_ASSERT(spvInst->id == 0); + spvInst->id = reservedID; + } } /// Get or reserve a SpvID for an IR value. @@ -439,18 +454,6 @@ struct SPIRVEmitContext return id; } - struct VectorTypeKey - { - BaseType baseType; - IRIntegerValue elementCount; - HashCode getHashCode() { return combineHash((int)baseType, (HashCode)elementCount); } - bool operator==(const VectorTypeKey& other) - { - return baseType == other.baseType && elementCount == other.elementCount; - } - }; - Dictionary<VectorTypeKey, SpvInst*> m_vectorTypes; - // We will build up `SpvInst`s in a stateful fashion, // mostly for convenience. We could in theory compute // the number of words each instruction needs, then allocate @@ -509,8 +512,6 @@ struct SPIRVEmitContext if(irInst) { registerInst(irInst, spvInst); - // If we have reserved an SpvID for `irInst`, make sure to use it. - m_mapIRInstToSpvID.TryGetValue(irInst, spvInst->id); } // Set up the scope @@ -675,19 +676,93 @@ struct SPIRVEmitContext void emitOperand(SpvBuiltIn builtin) { emitOperand((SpvWord)builtin); } void emitOperand(SpvStorageClass val) { emitOperand((SpvWord)val); } - Dictionary<IRIntegerValue, SpvInst*> m_spvIntConstants; - SpvInst* emitConstant(IRIntegerValue val, IRType* type) + template<typename TConstant> + struct ConstantValueKey + { + IRType* type; + TConstant value; + HashCode getHashCode() const + { + return combineHash(Slang::getHashCode(type), Slang::getHashCode(value)); + } + bool operator==(const ConstantValueKey& other) const + { + return type == other.type && value == other.value; + } + }; + Dictionary<ConstantValueKey<IRIntegerValue>, SpvInst*> m_spvIntConstants; + Dictionary<ConstantValueKey<IRFloatingPointValue>, SpvInst*> m_spvFloatConstants; + SpvInst* emitIntConstant(IRIntegerValue val, IRType* type) { + ConstantValueKey<IRIntegerValue> key; + key.value = val; + key.type = type; SpvInst* result = nullptr; - if (m_spvIntConstants.TryGetValue(val, result)) + if (m_spvIntConstants.TryGetValue(key, result)) return result; - return emitInst( - getSection(SpvLogicalSectionID::Constants), - nullptr, - SpvOpConstant, - type, - kResultID, - (SpvWord)val); + SpvWord valWord; + memcpy(&valWord, &val, sizeof(SpvWord)); + if (type->getOp() == kIROp_Int64Type || type->getOp() == kIROp_UInt64Type) + { + SpvWord valHighWord; + memcpy(&valHighWord, (char*)(&val) + 4, sizeof(SpvWord)); + result = emitInst( + getSection(SpvLogicalSectionID::Constants), + nullptr, + SpvOpConstant, + type, + kResultID, + valWord, + valHighWord); + } + else + { + result = emitInst( + getSection(SpvLogicalSectionID::Constants), + nullptr, + SpvOpConstant, + type, + kResultID, + valWord); + } + m_spvIntConstants[key] = result; + return result; + } + SpvInst* emitFloatConstant(IRFloatingPointValue val, IRType* type) + { + ConstantValueKey<IRFloatingPointValue> key; + key.value = val; + key.type = type; + SpvInst* result = nullptr; + if (m_spvFloatConstants.TryGetValue(key, result)) + return result; + SpvWord valWord; + memcpy(&valWord, &val, sizeof(SpvWord)); + if (type->getOp() == kIROp_DoubleType) + { + SpvWord valHighWord; + memcpy(&valHighWord, (char*)(&val) + 4, sizeof(SpvWord)); + result = emitInst( + getSection(SpvLogicalSectionID::Constants), + nullptr, + SpvOpConstant, + type, + kResultID, + valWord, + valHighWord); + } + else + { + result = emitInst( + getSection(SpvLogicalSectionID::Constants), + nullptr, + SpvOpConstant, + type, + kResultID, + valWord); + } + m_spvFloatConstants[key] = result; + return result; } // As another convenience, there are often cases where // we will want to emit all of the operands of some @@ -812,6 +887,22 @@ struct SPIRVEmitContext return spvInst; } + /// The SPIRV OpExtInstImport inst that represents the GLSL450 + /// extended instruction set. + SpvInst* m_glsl450ExtInst = nullptr; + + SpvInst* getGLSL450ExtInst() + { + if (m_glsl450ExtInst) + return m_glsl450ExtInst; + m_glsl450ExtInst = emitInst( + getSection(SpvLogicalSectionID::ExtIntInstImports), + nullptr, + SpvOpExtInstImport, + UnownedStringSlice("GLSL.std.450")); + return m_glsl450ExtInst; + } + // Now that we've gotten the core infrastructure out of the way, // let's start looking at emitting some instructions that make // up a SPIR-V module. @@ -849,6 +940,66 @@ struct SPIRVEmitContext emitInst(getSection(SpvLogicalSectionID::MemoryModel), nullptr, SpvOpMemoryModel, SpvAddressingModelLogical, SpvMemoryModelGLSL450); } + Dictionary<UnownedStringSlice, SpvInst*> m_extensionInsts; + SpvInst* ensureExtensionDeclaration(UnownedStringSlice name) + { + SpvInst* result = nullptr; + if (m_extensionInsts.TryGetValue(name, result)) + return result; + result = + emitInst(getSection(SpvLogicalSectionID::Extensions), nullptr, SpvOpExtension, name); + m_extensionInsts[name] = result; + return result; + } + + struct SpvTypeInstKey + { + List<SpvWord> words; + bool operator==(const SpvTypeInstKey& other) + { + if (words.getCount() != other.words.getCount()) + return false; + for (Index i = 0; i < words.getCount(); i++) + if (words[i] != other.words[i]) + return false; + return true; + } + HashCode getHashCode() + { + HashCode result = 0; + for (auto word : words) + result = combineHash(result, word); + return result; + } + }; + + Dictionary<SpvTypeInstKey, SpvInst*> m_spvTypeInsts; + + // Emits a SPV Inst that represents a type, with deduplications since + // our IR doesn't currently guarantee types are unique in generated SPV. + SpvInst* emitTypeInst(IRInst* typeInst, SpvOp opcode, ArrayView<SpvWord> operands) + { + SpvTypeInstKey key; + key.words.add((SpvWord)opcode); + for (auto op : operands) + key.words.add(op); + SpvInst* result = nullptr; + if (m_spvTypeInsts.TryGetValue(key, result)) + { + return result; + } + result = emitInstCustomOperandFunc( + getSection(SpvLogicalSectionID::Types), typeInst, opcode, [&]() { + emitOperand(kResultID); + for (auto op : operands) + { + emitOperand(op); + } + }); + m_spvTypeInsts[key] = result; + return result; + } + // Next, let's look at emitting some of the instructions // that can occur at global scope. @@ -864,7 +1015,7 @@ struct SPIRVEmitContext // #define CASE(IROP, SPVOP) \ - case IROP: return emitInst(getSection(SpvLogicalSectionID::Types), inst, SPVOP, kResultID) + case IROP: return emitTypeInst(inst, SPVOP, ArrayView<SpvWord>()); // > OpTypeVoid CASE(kIROp_VoidType, SpvOpTypeVoid); @@ -877,7 +1028,8 @@ struct SPIRVEmitContext // > OpTypeInt #define CASE(IROP, BITS, SIGNED) \ - case IROP: return emitInst(getSection(SpvLogicalSectionID::Types), inst, SpvOpTypeInt, kResultID, BITS, SIGNED) + case IROP: \ + return emitTypeInst(inst, SpvOpTypeInt, makeArray<SpvWord>((SpvWord)BITS, (SpvWord)SIGNED).getView()); CASE(kIROp_IntType, 32, 1); CASE(kIROp_UIntType, 32, 0); @@ -889,7 +1041,9 @@ struct SPIRVEmitContext // > OpTypeFloat #define CASE(IROP, BITS) \ - case IROP: return emitInst(getSection(SpvLogicalSectionID::Types), inst, SpvOpTypeFloat, kResultID, BITS) + case IROP: \ + return emitTypeInst( \ + inst, SpvOpTypeFloat, makeArray<SpvWord>(BITS).getView()); \ CASE(kIROp_HalfType, 16); CASE(kIROp_FloatType, 32); @@ -905,17 +1059,16 @@ struct SPIRVEmitContext auto ptrType = as<IRPtrTypeBase>(inst); if (ptrType->hasAddressSpace()) storageClass = (SpvStorageClass)ptrType->getAddressSpace(); - return emitInst( - getSection(SpvLogicalSectionID::Types), - inst, - SpvOpTypePointer, - kResultID, - storageClass, - inst->getOperand(0)); + if (storageClass == SpvStorageClassStorageBuffer) + ensureExtensionDeclaration(UnownedStringSlice("SPV_KHR_storage_buffer_storage_class")); + auto operands = makeArray<SpvWord>( + (SpvWord)storageClass, getID(ensureInst(inst->getOperand(0)))); + return emitTypeInst( + inst, SpvOpTypePointer, operands.getView()); } case kIROp_StructType: { - return emitInstCustomOperandFunc( + auto spvStructType = emitInstCustomOperandFunc( getSection(SpvLogicalSectionID::Types), inst, SpvOpTypeStruct, [&]() { emitOperand(kResultID); for (auto field : static_cast<IRStructType*>(inst)->getFields()) @@ -924,6 +1077,8 @@ struct SPIRVEmitContext // TODO: decorate offset } }); + emitDecorations(inst, getID(spvStructType)); + return spvStructType; } case kIROp_VectorType: { @@ -1012,6 +1167,12 @@ struct SPIRVEmitContext // return emitInst(getSection(SpvLogicalSectionID::Types), inst, SpvOpTypeFunction, kResultID, OperandsOf(inst)); + case kIROp_RateQualifiedType: + { + auto result = emitGlobalInst(as<IRRateQualifiedType>(inst)->getValueType()); + registerInst(inst, result); + return result; + } // > OpTypeForwardPointer case kIROp_Func: @@ -1046,10 +1207,6 @@ struct SPIRVEmitContext // it is nullptr, this function will create one. SpvInst* ensureVectorType(BaseType baseType, IRIntegerValue elementCount, IRVectorType* inst) { - VectorTypeKey key = {baseType, elementCount}; - SpvInst* result = nullptr; - if (m_vectorTypes.TryGetValue(key, result)) - return result; if (!inst) { IRBuilder builder; @@ -1059,14 +1216,9 @@ struct SPIRVEmitContext builder.getBasicType(baseType), builder.getIntValue(builder.getIntType(), elementCount)); } - result = emitInst( - getSection(SpvLogicalSectionID::Types), - inst, - SpvOpTypeVector, - kResultID, - inst->getElementType(), - (SpvWord)elementCount); - m_vectorTypes[key] = result; + auto operands = + makeArray<SpvWord>(getID(ensureInst(inst->getElementType())), (SpvWord)elementCount); + auto result = emitTypeInst(inst, SpvOpTypeVector, operands.getView()); return result; } @@ -1139,16 +1291,13 @@ struct SPIRVEmitContext varInst, SpvDecorationBinding, (SpvWord)index); - if (space) - { - emitInst( - getSection(SpvLogicalSectionID::Annotations), - nullptr, - SpvOpDecorate, - varInst, - SpvDecorationDescriptorSet, - (SpvWord)space); - } + emitInst( + getSection(SpvLogicalSectionID::Annotations), + nullptr, + SpvOpDecorate, + varInst, + SpvDecorationDescriptorSet, + (SpvWord)space); break; default: break; @@ -1165,6 +1314,11 @@ struct SPIRVEmitContext if (ptrType->hasAddressSpace()) storageClass = (SpvStorageClass)ptrType->getAddressSpace(); } + if (auto systemValInst = maybeEmitSystemVal(param)) + { + registerInst(param, systemValInst); + return systemValInst; + } auto varInst = emitInst( getSection(SpvLogicalSectionID::GlobalVariables), param, @@ -1304,11 +1458,25 @@ struct SPIRVEmitContext for( auto irBlock : irFunc->getBlocks() ) { emitInst(spvFunc, irBlock, SpvOpLabel, kResultID); + + // In addition to normal basic blocks, + // all loops gets a header block. + for (auto irInst : irBlock->getChildren()) + { + if (irInst->getOp() == kIROp_loop) + { + emitInst(spvFunc, irInst, SpvOpLabel, kResultID); + } + } } // Once all the basic blocks have had instructions allocated // for them, we go through and fill them in with their bodies. // + // Each loop inst results in a loop header block. + // We will defer the emit of the contents in loop header block + // until all Phi insts are emitted. + List<IRLoop*> pendingLoopInsts; for( auto irBlock : irFunc->getBlocks() ) { // Note: because we already created the block above, @@ -1334,9 +1502,20 @@ struct SPIRVEmitContext // of the block. // emitLocalInst(spvBlock, irInst); + if (irInst->getOp() == kIROp_loop) + pendingLoopInsts.add(as<IRLoop>(irInst)); } } + // Finally, we generate the body of loop header blocks. + for (auto loopInst : pendingLoopInsts) + { + SpvInst* headerBlock = nullptr; + m_mapIRInstToSpvInst.TryGetValue(loopInst, headerBlock); + SLANG_ASSERT(headerBlock); + emitLoopHeaderBlock(loopInst, headerBlock); + } + // [3.32.9. Function Instructions] // // > OpFunctionEnd @@ -1356,6 +1535,21 @@ struct SPIRVEmitContext return spvFunc; } + /// Check if a block is a loop's target block. + bool isLoopTargetBlock(IRInst* block, IRInst*& loopInst) + { + for (auto use = block->firstUse; use; use = use->nextUse) + { + if (use->getUser()->getOp() == kIROp_loop && + as<IRLoop>(use->getUser())->getTargetBlock() == block) + { + loopInst = use->getUser(); + return true; + } + } + return false; + } + // The instructions that appear inside the basic blocks of // functions are what we will call "local" instructions. // @@ -1367,13 +1561,6 @@ struct SPIRVEmitContext /// Emit an instruction that is local to the body of the given `parent`. SpvInst* emitLocalInst(SpvInstParent* parent, IRInst* inst) { - auto getBlockID = [=](IRBlock* block) - { - SpvInst* spvInst = nullptr; - m_mapIRInstToSpvInst.TryGetValue(block, spvInst); - SLANG_ASSERT(spvInst); - return getID(spvInst); - }; switch( inst->getOp() ) { default: @@ -1401,6 +1588,9 @@ struct SPIRVEmitContext return emitSwizzle(parent, as<IRSwizzle>(inst)); case kIROp_Construct: return emitConstruct(parent, inst); + case kIROp_BitCast: + return emitInst( + parent, inst, SpvOpBitcast, inst->getDataType(), kResultID, inst->getOperand(0)); case kIROp_Add: case kIROp_Sub: case kIROp_Mul: @@ -1432,50 +1622,49 @@ struct SPIRVEmitContext case kIROp_discard: return emitInst(parent, inst, SpvOpKill); case kIROp_unconditionalBranch: - return emitInst( - parent, - inst, - SpvOpBranch, - getBlockID(as<IRUnconditionalBranch>(inst)->getTargetBlock())); - case kIROp_loop: { - auto loopInst = as<IRLoop>(inst); - - SpvWord loopControl = 0; - if (auto loopControlDecoration = - loopInst->findDecoration<IRLoopControlDecoration>()) + // If we are jumping to the main block of a loop, + // emit a branch to the loop header instead. + // The SPV id of the resulting loop header block is associated with the loop inst. + auto targetBlock = as<IRUnconditionalBranch>(inst)->getTargetBlock(); + IRInst* loopInst = nullptr; + if (isLoopTargetBlock(targetBlock, loopInst)) { - switch (loopControlDecoration->getMode()) - { - case IRLoopControl::kIRLoopControl_Unroll: - loopControl = 0x1; - break; - case IRLoopControl::kIRLoopControl_Loop: - loopControl = 0x2; - break; - default: - break; - } + return emitInst(parent, inst, SpvOpBranch, getIRInstSpvID(loopInst)); } - emitInst( + // Otherwise, emit a normal branch inst into the target block. + return emitInst( parent, - nullptr, - SpvOpLoopMerge, - getBlockID(loopInst->getBreakBlock()), - getBlockID(loopInst->getContinueBlock()), - loopControl); - - return emitInst(parent, inst, SpvOpBranch, loopInst->getTargetBlock()); + inst, + SpvOpBranch, + getIRInstSpvID(targetBlock)); + } + case kIROp_loop: + { + // Return loop header block in its own block. + auto blockId = getIRInstSpvID(inst); + SpvInst* block = nullptr; + m_mapIRInstToSpvInst.TryGetValue(inst, block); + SLANG_ASSERT(block); + + // Emit a jump to the loop header block. + // Note: the body of the loop header block is emitted + // after everything else to ensure Phi instructions (which come + // from the actual loop target block) are emitted first. + emitInst(parent, nullptr, SpvOpBranch, blockId); + + return block; } case kIROp_ifElse: { auto ifelseInst = as<IRIfElse>(inst); - auto afterBlockID = getBlockID(ifelseInst->getAfterBlock()); + auto afterBlockID = getIRInstSpvID(ifelseInst->getAfterBlock()); emitInst( parent, nullptr, SpvOpSelectionMerge, - afterBlockID); + afterBlockID, + 0); auto falseLabel = ifelseInst->getFalseBlock(); return emitInst( parent, @@ -1488,11 +1677,8 @@ struct SPIRVEmitContext case kIROp_Switch: { auto switchInst = as<IRSwitch>(inst); - auto mergeBlockID = getBlockID(switchInst->getBreakLabel()); - emitInst( - parent, - nullptr, - SpvOpSelectionMerge, mergeBlockID); + auto mergeBlockID = getIRInstSpvID(switchInst->getBreakLabel()); + emitInst(parent, nullptr, SpvOpSelectionMerge, mergeBlockID, 0); return emitInstCustomOperandFunc(parent, inst, SpvOpSwitch, [&]() { emitOperand(switchInst->getCondition()); auto defaultLabel = switchInst->getDefaultLabel(); @@ -1685,7 +1871,23 @@ struct SPIRVEmitContext auto entryPointDecor = cast<IREntryPointDecoration>(decoration); auto spvStage = mapStageToExecutionModel(entryPointDecor->getProfile().getStage()); auto name = entryPointDecor->getName()->getStringSlice(); - emitInst(section, decoration, SpvOpEntryPoint, spvStage, dstID, name); + emitInstCustomOperandFunc(section, decoration, SpvOpEntryPoint, [&]() { + emitOperand(spvStage); + emitOperand(dstID); + emitOperand(name); + // `interface` part: reference all global variables that are used by this entrypoint. + // TODO: we may want to perform more accurate tracking. + for (auto globalInst : m_irModule->getModuleInst()->getChildren()) + { + switch (globalInst->getOp()) + { + case kIROp_GlobalVar: + case kIROp_GlobalParam: + emitOperand(getIRInstSpvID(globalInst)); + break; + } + } + }); } break; @@ -1713,6 +1915,24 @@ struct SPIRVEmitContext } break; + case kIROp_SPIRVBufferBlockDecoration: + { + emitInst( + getSection(SpvLogicalSectionID::Annotations), + decoration, + SpvOpDecorate, + dstID, + SpvDecorationBlock); + emitInst( + getSection(SpvLogicalSectionID::Annotations), + nullptr, + SpvOpMemberDecorate, + dstID, + 0, + SpvDecorationOffset, + 0); + } + break; // ... } } @@ -1742,14 +1962,26 @@ struct SPIRVEmitContext } } - SpvInst* emitBuiltinSystemVal(SpvInstParent* parent, IRInst* inst, SpvBuiltIn builtinVal) + Dictionary<SpvBuiltIn, SpvInst*> m_builtinGlobalVars; + SpvInst* getBuiltinGlobalVar(IRType* type, SpvBuiltIn builtinVal) { + SpvInst* result = nullptr; + if (m_builtinGlobalVars.TryGetValue(builtinVal, result)) + { + return result; + } IRBuilder builder; builder.sharedBuilder = &m_sharedIRBuilder; - builder.setInsertBefore(inst); - - auto ptrIRType = builder.getPtrType(inst->getDataType()); - auto varInst = emitInst(parent, inst, SpvOpVariable, ptrIRType, kResultID); + builder.setInsertBefore(type); + auto ptrType = as<IRPtrTypeBase>(type); + SLANG_ASSERT(ptrType && "`getBuiltinGlobalVar`: `type` must be ptr type."); + auto varInst = emitInst( + getSection(SpvLogicalSectionID::GlobalVariables), + nullptr, + SpvOpVariable, + type, + kResultID, + (SpvStorageClass)ptrType->getAddressSpace()); emitInst( getSection(SpvLogicalSectionID::Annotations), nullptr, @@ -1757,11 +1989,15 @@ struct SPIRVEmitContext varInst, SpvDecorationBuiltIn, builtinVal); + m_builtinGlobalVars[builtinVal] = varInst; return varInst; } - SpvInst* emitParam(SpvInstParent* parent, IRInst* inst) + SpvInst* maybeEmitSystemVal(IRInst* inst) { + IRBuilder builder; + builder.sharedBuilder = &m_sharedIRBuilder; + builder.setInsertBefore(inst); if (auto layout = getVarLayout(inst)) { if (auto systemValueAttr = layout->findAttr<IRSystemValueSemanticAttr>()) @@ -1770,27 +2006,26 @@ struct SPIRVEmitContext semanticName = semanticName.toLower(); if (semanticName == "sv_dispatchthreadid") { - return emitBuiltinSystemVal(parent, inst, SpvBuiltInGlobalInvocationId); + return getBuiltinGlobalVar(inst->getFullType(), SpvBuiltInGlobalInvocationId); } } } + return nullptr; + } + + SpvInst* emitParam(SpvInstParent* parent, IRInst* inst) + { return emitInst(parent, inst, SpvOpFunctionParameter, inst->getFullType(), kResultID); } SpvInst* emitVar(SpvInstParent* parent, IRInst* inst) { - SpvWord storageClass = SpvStorageClassFunction; - auto rate = inst->getFullType()->getRate(); - if (rate) + auto ptrType = as<IRPtrTypeBase>(inst->getDataType()); + SLANG_ASSERT(ptrType); + SpvStorageClass storageClass = SpvStorageClassFunction; + if (ptrType->hasAddressSpace()) { - switch (rate->getOp()) - { - case kIROp_GroupSharedRate: - storageClass = SpvStorageClassWorkgroup; - break; - default: - break; - } + storageClass = (SpvStorageClass)ptrType->getAddressSpace(); } return emitInst(parent, inst, SpvOpVariable, inst->getFullType(), kResultID, storageClass); } @@ -1828,6 +2063,48 @@ struct SPIRVEmitContext return result; } + bool isGlobalValueInst(IRInst* inst) + { + if (as<IRConstant>(inst)) + return true; + switch (inst->getOp()) + { + case kIROp_Func: + case kIROp_GlobalParam: + case kIROp_GlobalVar: + return true; + default: + return false; + } + } + + void emitLoopHeaderBlock(IRLoop* loopInst, SpvInst* loopHeaderBlock) + { + SpvWord loopControl = 0; + if (auto loopControlDecoration = loopInst->findDecoration<IRLoopControlDecoration>()) + { + switch (loopControlDecoration->getMode()) + { + case IRLoopControl::kIRLoopControl_Unroll: + loopControl = 0x1; + break; + case IRLoopControl::kIRLoopControl_Loop: + loopControl = 0x2; + break; + default: + break; + } + } + emitInst( + loopHeaderBlock, + nullptr, + SpvOpLoopMerge, + getIRInstSpvID(loopInst->getBreakBlock()), + getIRInstSpvID(loopInst->getContinueBlock()), + loopControl); + emitInst(loopHeaderBlock, nullptr, SpvOpBranch, loopInst->getTargetBlock()); + } + SpvInst* emitPhi(SpvInstParent* parent, IRParam* inst) { // An `IRParam` in an ordinary `IRBlock` represents a phi value. @@ -1838,6 +2115,16 @@ struct SPIRVEmitContext // First, we find the index of this param. IRBlock* block = as<IRBlock>(inst->getParent()); + // Special case: if block is a loop's target block, emit phis into the header block instead. + IRInst* loopInst = nullptr; + if (isLoopTargetBlock(block, loopInst)) + { + SpvInst* loopSpvBlockInst = nullptr; + m_mapIRInstToSpvInst.TryGetValue(loopInst, loopSpvBlockInst); + SLANG_ASSERT(loopSpvBlockInst); + parent = loopSpvBlockInst; + } + SLANG_ASSERT(block); int paramIndex = getParamIndexInBlock(block, inst); @@ -1865,7 +2152,9 @@ struct SPIRVEmitContext } SLANG_ASSERT(argStartIndex + paramIndex < branchInst->getOperandCount()); auto valueInst = branchInst->getOperand(argStartIndex + paramIndex); - emitOperand(valueInst); + if (isGlobalValueInst(valueInst)) + ensureInst(valueInst); + emitOperand(getIRInstSpvID(valueInst)); auto sourceBlock = as<IRBlock>(branchInst->getParent()); SLANG_ASSERT(sourceBlock); emitOperand(getIRInstSpvID(sourceBlock)); @@ -1901,7 +2190,10 @@ struct SPIRVEmitContext { SpvSnippet* snippet = getParsedSpvSnippet(intrinsic); SpvSnippetEmitContext context; + context.irResultType = inst->getDataType(); context.resultType = ensureInst(inst->getFullType()); + context.isResultTypeFloat = isFloatType(inst->getDataType()); + context.isResultTypeSigned = isSignedType((IRType*)inst->getDataType()); for (SlangUInt i = 0; i < inst->getArgCount(); i++) { auto argInst = ensureInst(inst->getArg(i)); @@ -1933,6 +2225,89 @@ struct SPIRVEmitContext return emitSpvSnippet(parent, inst, context, snippet); } + Dictionary<SpvSnippet::ASMConstant, SpvInst*> m_spvSnippetConstantInsts; + + // Emit SPV Inst that represents a constant defined in a SpvSnippet. + SpvInst* maybeEmitSpvConstant(SpvSnippet::ASMConstant constant) + { + SpvInst* result = nullptr; + if (m_spvSnippetConstantInsts.TryGetValue(constant, result)) + return result; + + IRBuilder builder; + builder.sharedBuilder = &m_sharedIRBuilder; + builder.setInsertInto(m_irModule->getModuleInst()); + switch (constant.type) + { + case SpvSnippet::ASMType::Float: + result = emitFloatConstant(constant.floatValues[0], builder.getType(kIROp_FloatType)); + break; + case SpvSnippet::ASMType::Float2: + { + auto floatType = builder.getType(kIROp_FloatType); + auto element1 = emitFloatConstant(constant.floatValues[0], floatType); + auto element2 = emitFloatConstant(constant.floatValues[1], floatType); + result = emitInst( + getSection(SpvLogicalSectionID::Constants), + nullptr, + SpvOpConstantComposite, + builder.getVectorType(floatType, builder.getIntValue(builder.getIntType(), 2)), + kResultID, + element1, + element2); + } + case SpvSnippet::ASMType::Int: + result = emitIntConstant((IRIntegerValue)constant.intValues[0], builder.getIntType()); + break; + case SpvSnippet::ASMType::UInt2: + { + auto uintType = builder.getType(kIROp_UIntType); + auto element1 = emitIntConstant((IRIntegerValue)constant.intValues[0], uintType); + auto element2 = emitIntConstant((IRIntegerValue)constant.intValues[1], uintType); + result = emitInst( + getSection(SpvLogicalSectionID::Constants), + nullptr, + SpvOpConstantComposite, + builder.getVectorType(uintType, builder.getIntValue(builder.getIntType(), 2)), + kResultID, + element1, + element2); + } + break; + } + m_spvSnippetConstantInsts[constant] = result; + return result; + } + + // Emit SPV Inst that represents a type defined in a SpvSnippet. + void emitSpvSnippetASMTypeOperand(SpvSnippet::ASMType type) + { + IRBuilder builder; + builder.sharedBuilder = &m_sharedIRBuilder; + builder.setInsertInto(m_irModule->getModuleInst()); + IRType* irType = nullptr; + switch (type) + { + case SpvSnippet::ASMType::Float: + irType = builder.getType(kIROp_FloatType); + break; + case SpvSnippet::ASMType::Int: + irType = builder.getIntType(); + break; + case SpvSnippet::ASMType::Float2: + irType = builder.getVectorType( + builder.getType(kIROp_FloatType), builder.getIntValue(builder.getIntType(), 2)); + break; + case SpvSnippet::ASMType::UInt2: + irType = builder.getVectorType( + builder.getType(kIROp_UIntType), builder.getIntValue(builder.getIntType(), 2)); + break; + default: + break; + } + emitOperand(irType); + } + SpvInst* emitSpvSnippet( SpvInstParent* parent, IRCall* inst, @@ -1950,11 +2325,10 @@ struct SPIRVEmitContext switch (operand.type) { case SpvSnippet::ASMOperandType::SpvWord: - emitOperand((SpvWord)operand.content); + emitOperand(operand.content); break; case SpvSnippet::ASMOperandType::ObjectReference: - SLANG_ASSERT( - operand.content >= 0 && operand.content < context.argumentIds.getCount()); + SLANG_ASSERT(operand.content < (SpvWord)context.argumentIds.getCount()); emitOperand(context.argumentIds[operand.content]); break; case SpvSnippet::ASMOperandType::ResultId: @@ -1972,8 +2346,64 @@ struct SPIRVEmitContext } break; case SpvSnippet::ASMOperandType::InstReference: - SLANG_ASSERT(operand.content >= 0 && operand.content < emittedInsts.getCount()); - emitOperand(getID(emittedInsts[operand.content])); + SLANG_ASSERT(operand.content < (SpvWord)emittedInsts.getCount()); + emitOperand(emittedInsts[operand.content]); + break; + case SpvSnippet::ASMOperandType::GLSL450ExtInstSet: + emitOperand(getGLSL450ExtInst()); + break; + case SpvSnippet::ASMOperandType::FloatIntegerSelection: + if (context.isResultTypeFloat) + { + emitOperand(operand.content); + } + else + { + emitOperand(operand.content2); + } + break; + case SpvSnippet::ASMOperandType::FloatUnsignedSignedSelection: + if (context.isResultTypeFloat) + { + emitOperand(operand.content); + } + else + { + if (context.isResultTypeSigned) + { + emitOperand(operand.content3); + } + else + { + emitOperand(operand.content2); + } + } + break; + case SpvSnippet::ASMOperandType::TypeReference: + { + emitSpvSnippetASMTypeOperand((SpvSnippet::ASMType)operand.content); + } + break; + case SpvSnippet::ASMOperandType::ConstantReference: + { + auto constant = snippet->constants[operand.content]; + if (constant.type == SpvSnippet::ASMType::FloatOrDouble) + { + switch (extractBaseType(context.irResultType)) + { + case BaseType::Float: + constant.type = SpvSnippet::ASMType::Float; + break; + case BaseType::Double: + constant.type = SpvSnippet::ASMType::Double; + break; + default: + break; + } + } + SpvInst* spvConstant = maybeEmitSpvConstant(constant); + emitOperand(spvConstant); + } break; } } @@ -2048,7 +2478,7 @@ struct SPIRVEmitContext baseId = getID(varInst); } SLANG_ASSERT(baseStructType && "field_address require base to be a struct."); - auto fieldId = emitConstant( + auto fieldId = emitIntConstant( getStructFieldId(baseStructType, as<IRStructKey>(fieldAddress->getField())), builder.getIntType()); return emitInst( @@ -2069,7 +2499,7 @@ struct SPIRVEmitContext IRStructType* baseStructType = as<IRStructType>(inst->getBase()->getDataType()); SLANG_ASSERT(baseStructType && "field_extract require base to be a struct."); - auto fieldId = emitConstant( + auto fieldId = emitIntConstant( getStructFieldId(baseStructType, as<IRStructKey>(inst->getField())), builder.getIntType()); @@ -2163,17 +2593,31 @@ struct SPIRVEmitContext SpvInst* emitSwizzle(SpvInstParent* parent, IRSwizzle* inst) { - return emitInstCustomOperandFunc(parent, inst, SpvOpVectorShuffle, [&]() { - emitOperand(inst->getDataType()); - emitOperand(kResultID); - emitOperand(inst->getBase()); - emitOperand(inst->getBase()); - for (UInt i = 0; i < inst->getElementCount(); i++) - { - auto index = as<IRIntLit>(inst->getElementIndex(i)); - emitOperand((SpvWord)index->getValue()); - } - }); + if (inst->getElementCount() == 1) + { + return emitInst( + parent, + inst, + SpvOpCompositeExtract, + inst->getDataType(), + kResultID, + inst->getBase(), + (SpvWord)as<IRIntLit>(inst->getElementIndex(0))->getValue()); + } + else + { + return emitInstCustomOperandFunc(parent, inst, SpvOpVectorShuffle, [&]() { + emitOperand(inst->getDataType()); + emitOperand(kResultID); + emitOperand(inst->getBase()); + emitOperand(inst->getBase()); + for (UInt i = 0; i < inst->getElementCount(); i++) + { + auto index = as<IRIntLit>(inst->getElementIndex(i)); + emitOperand((SpvWord)index->getValue()); + } + }); + } } SpvInst* emitConstruct(SpvInstParent* parent, IRInst* inst) @@ -2183,9 +2627,21 @@ struct SPIRVEmitContext if (inst->getOperandCount() == 1) { if (inst->getDataType() == inst->getOperand(0)->getDataType()) - return emitInst(parent, inst, SpvOpCopyObject, kResultID, inst->getOperand(0)); + return emitInst( + parent, + inst, + SpvOpCopyObject, + inst->getFullType(), + kResultID, + inst->getOperand(0)); else - return emitInst(parent, inst, SpvOpBitcast, inst->getDataType(), kResultID, inst->getOperand(0)); + return emitInst( + parent, + inst, + SpvOpBitcast, + inst->getFullType(), + kResultID, + inst->getOperand(0)); } else { @@ -2205,18 +2661,39 @@ struct SPIRVEmitContext } } - bool isSignedType(IRBasicType* basicType) + bool isSignedType(IRType* type) { - switch (basicType->getBaseType()) + switch (type->getOp()) { - case BaseType::Float: - case BaseType::Double: + case kIROp_FloatType: + case kIROp_DoubleType: return true; - case BaseType::Int: - case BaseType::Int16: - case BaseType::Int64: - case BaseType::Int8: + case kIROp_IntType: + case kIROp_Int16Type: + case kIROp_Int64Type: + case kIROp_Int8Type: return true; + case kIROp_VectorType: + return isSignedType(as<IRVectorType>(type)->getElementType()); + case kIROp_MatrixType: + return isSignedType(as<IRMatrixType>(type)->getElementType()); + default: + return false; + } + } + + bool isFloatType(IRInst* type) + { + switch (type->getOp()) + { + case kIROp_FloatType: + case kIROp_DoubleType: + case kIROp_HalfType: + return true; + case kIROp_VectorType: + return isFloatType(as<IRVectorType>(type)->getElementType()); + case kIROp_MatrixType: + return isFloatType(as<IRMatrixType>(type)->getElementType()); default: return false; } @@ -2224,7 +2701,7 @@ struct SPIRVEmitContext SpvInst* emitArithmetic(SpvInstParent* parent, IRInst* inst) { - IRType* elementType = inst->getDataType(); + IRType* elementType = inst->getOperand(0)->getDataType(); if (auto vectorType = as<IRVectorType>(inst->getDataType())) { elementType = vectorType->getElementType(); @@ -2245,6 +2722,7 @@ struct SPIRVEmitContext break; case BaseType::Bool: isBool = true; + break; default: break; } @@ -2371,6 +2849,12 @@ struct SPIRVEmitContext } } + void diagnoseUnhandledInst(IRInst* inst) + { + m_sink->diagnose( + inst, Diagnostics::unimplemented, "unexpected IR opcode during code emit"); + } + SPIRVEmitContext(IRModule* module, TargetRequest* target, DiagnosticSink* sink) : SPIRVEmitSharedContext(module, target) , m_irModule(module) @@ -2390,7 +2874,7 @@ SlangResult emitSPIRVFromIR( spirvOut.clear(); SPIRVEmitContext context(irModule, targetRequest, compileRequest->getSink()); - legalizeIRForSPIRV(&context, irModule, compileRequest->getSink()); + legalizeIRForSPIRV(&context, irModule, irEntryPoints, compileRequest->getSink()); context.emitFrontMatter(); for (auto irEntryPoint : irEntryPoints) diff --git a/source/slang/slang-ir-inst-defs.h b/source/slang/slang-ir-inst-defs.h index 72589912d..d81e6868c 100644 --- a/source/slang/slang-ir-inst-defs.h +++ b/source/slang/slang-ir-inst-defs.h @@ -627,10 +627,13 @@ INST(HighLevelDeclDecoration, highLevelDecl, 1, 0) INST(SemanticDecoration, semantic, 2, 0) - INST_RANGE(Decoration, HighLevelDeclDecoration, SemanticDecoration) + /// Marks a struct type as being used as a structured buffer block. + /// Recognized by SPIRV-emit pass so we can emit a SPIRV `BufferBlock` decoration. + INST(SPIRVBufferBlockDecoration, spvBufferBlock, 0, 0) + INST_RANGE(Decoration, HighLevelDeclDecoration, SPIRVBufferBlockDecoration) -// + // // A `makeExistential(v : C, w) : I` instruction takes a value `v` of type `C` // and produces a value of interface type `I` by using the witness `w` which diff --git a/source/slang/slang-ir-spirv-legalize.cpp b/source/slang/slang-ir-spirv-legalize.cpp index f7fc53bdb..a75c26e86 100644 --- a/source/slang/slang-ir-spirv-legalize.cpp +++ b/source/slang/slang-ir-spirv-legalize.cpp @@ -1,6 +1,8 @@ // slang-ir-spirv-legalize.cpp #include "slang-ir-spirv-legalize.h" +#include "slang-ir-glsl-legalize.h" + #include "slang-ir.h" #include "slang-ir-insts.h" #include "slang-emit-base.h" @@ -13,7 +15,7 @@ namespace Slang // Legalization of IR for direct SPIRV emit. // -struct StorageClassPropagationContext : public SourceEmitterBase +struct SPIRVLegalizationContext : public SourceEmitterBase { SPIRVEmitSharedContext* m_sharedContext; @@ -42,12 +44,52 @@ struct StorageClassPropagationContext : public SourceEmitterBase } } - StorageClassPropagationContext(SPIRVEmitSharedContext* sharedContext, IRModule* module) + SPIRVLegalizationContext(SPIRVEmitSharedContext* sharedContext, IRModule* module) : m_sharedContext(sharedContext), m_module(module) { } - void processGlobalParam(IRGlobalParam* inst) { processGlobalVar(inst); } + void processGlobalParam(IRGlobalParam* inst) + { + // If the global param is not a pointer type, make it so and insert explicit load insts. + auto ptrType = as<IRPtrTypeBase>(inst->getDataType()); + if (!ptrType) + { + SpvStorageClass storageClass = SpvStorageClassPrivate; + // Figure out storage class based on var layout. + if (auto layout = getVarLayout(inst)) + { + if (auto systemValueAttr = layout->findAttr<IRSystemValueSemanticAttr>()) + { + String semanticName = systemValueAttr->getName(); + semanticName = semanticName.toLower(); + if (semanticName == "sv_dispatchthreadid") + { + storageClass = SpvStorageClassInput; + } + } + } + // Make a pointer type of storageClass. + IRBuilder builder; + builder.sharedBuilder = &m_sharedContext->m_sharedIRBuilder; + builder.setInsertBefore(inst); + ptrType = builder.getPtrType(kIROp_PtrType, inst->getFullType(), storageClass); + inst->setFullType(ptrType); + // Insert an explicit load at each use site. + List<IRUse*> uses; + for (auto use = inst->firstUse; use; use = use->nextUse) + { + uses.add(use); + } + for (auto use : uses) + { + builder.setInsertBefore(use->getUser()); + auto loadedValue = builder.emitLoad(inst); + use->set(loadedValue); + } + } + processGlobalVar(inst); + } void processGlobalVar(IRInst* inst) { @@ -195,13 +237,34 @@ struct StorageClassPropagationContext : public SourceEmitterBase builder.sharedBuilder = &m_sharedContext->m_sharedIRBuilder; builder.setInsertBefore(inst); auto arrayType = builder.getUnsizedArrayType(inst->getElementType()); - auto ptrType = builder.getPtrType(kIROp_PtrType, arrayType, SpvStorageClassStorageBuffer); + auto structType = builder.createStructType(); + auto arrayKey = builder.createStructKey(); + builder.createStructField(structType, arrayKey, arrayType); + auto ptrType = builder.getPtrType(kIROp_PtrType, structType, SpvStorageClassStorageBuffer); + StringBuilder nameSb; + switch (inst->getOp()) + { + case kIROp_HLSLRWStructuredBufferType: + nameSb << "RWStructuredBuffer"; + break; + case kIROp_HLSLAppendStructuredBufferType: + nameSb << "AppendStructuredBuffer"; + break; + case kIROp_HLSLConsumeStructuredBufferType: + nameSb << "ConsumeStructuredBuffer"; + break; + default: + nameSb << "StructuredBuffer"; + break; + } + builder.addNameHintDecoration(structType, nameSb.getUnownedSlice()); + builder.addDecoration(structType, kIROp_SPIRVBufferBlockDecoration); inst->replaceUsesWith(ptrType); inst->removeAndDeallocate(); addUsersToWorkList(ptrType); } - void propagate() + void processModule() { addToWorkList(m_module->getModuleInst()); while (workList.Count() != 0) @@ -240,19 +303,22 @@ struct StorageClassPropagationContext : public SourceEmitterBase } }; -void propagateStorageClass(SPIRVEmitSharedContext* sharedContext, IRModule* module) +void legalizeSPIRV(SPIRVEmitSharedContext* sharedContext, IRModule* module) { - StorageClassPropagationContext context(sharedContext, module); - context.propagate(); + SPIRVLegalizationContext context(sharedContext, module); + context.processModule(); } void legalizeIRForSPIRV( SPIRVEmitSharedContext* context, IRModule* module, + const List<IRFunc*>& entryPoints, DiagnosticSink* sink) { SLANG_UNUSED(sink); - propagateStorageClass(context, module); + GLSLExtensionTracker extensionTracker; + legalizeEntryPointsForGLSL(module->getSession(), module, entryPoints, sink, &extensionTracker); + legalizeSPIRV(context, module); } } // namespace Slang diff --git a/source/slang/slang-ir-spirv-legalize.h b/source/slang/slang-ir-spirv-legalize.h index bf43430d8..5bf8326fa 100644 --- a/source/slang/slang-ir-spirv-legalize.h +++ b/source/slang/slang-ir-spirv-legalize.h @@ -40,6 +40,7 @@ struct SPIRVEmitSharedContext void legalizeIRForSPIRV( SPIRVEmitSharedContext* context, IRModule* module, + const List<IRFunc*>& entryPoints, DiagnosticSink* sink); } diff --git a/source/slang/slang-ir-spirv-snippet.cpp b/source/slang/slang-ir-spirv-snippet.cpp index 4083f100d..5109d88c1 100644 --- a/source/slang/slang-ir-spirv-snippet.cpp +++ b/source/slang/slang-ir-spirv-snippet.cpp @@ -18,12 +18,30 @@ static SpvStorageClass translateStorageClass(String name) return (SpvStorageClass)-1; } +SpvSnippet::ASMType parseASMType(Slang::Misc::TokenReader& tokenReader) +{ + auto word = tokenReader.ReadWord(); + if (word == "float") + return SpvSnippet::ASMType::Float; + else if (word == "double") + return SpvSnippet::ASMType::Double; + else if (word == "uint2") + return SpvSnippet::ASMType::UInt2; + else if (word == "float2") + return SpvSnippet::ASMType::Float2; + else if (word == "int") + return SpvSnippet::ASMType::Int; + else if (word == "_p") + return SpvSnippet::ASMType::FloatOrDouble; + return SpvSnippet::ASMType::None; +} + RefPtr<SpvSnippet> SpvSnippet::parse(UnownedStringSlice definition) { RefPtr<SpvSnippet> snippet = new SpvSnippet(); try { - Dictionary<String, int> mapInstNameToIndex; + Dictionary<String, SpvWord> mapInstNameToIndex; Slang::Misc::TokenReader tokenReader(definition); // A leading "*" at the beginning of the snip modifies $resultType with // a storage class. @@ -46,7 +64,7 @@ RefPtr<SpvSnippet> SpvSnippet::parse(UnownedStringSlice definition) bool insideOperandList = true; while (insideOperandList) { - ASMOperand operand = {ASMOperandType::SpvWord, 0}; + ASMOperand operand = {ASMOperandType::SpvWord, 0, 0, 0}; switch (tokenReader.NextToken().Type) { case Slang::Misc::TokenType::Semicolon: @@ -60,6 +78,7 @@ RefPtr<SpvSnippet> SpvSnippet::parse(UnownedStringSlice definition) break; case Slang::Misc::TokenType::OpMod: { + tokenReader.ReadToken(); operand.type = SpvSnippet::ASMOperandType::InstReference; auto refName = tokenReader.ReadToken().Content; if (!mapInstNameToIndex.TryGetValue(refName, operand.content)) @@ -72,17 +91,10 @@ RefPtr<SpvSnippet> SpvSnippet::parse(UnownedStringSlice definition) case Slang::Misc::TokenType::Identifier: { auto identifier = tokenReader.ReadToken().Content; - if (identifier.startsWith("_")) - { - operand.type = SpvSnippet::ASMOperandType::ObjectReference; - operand.content = - StringToInt(identifier.subString(1, identifier.getLength() - 1)); - inst.operands.add(operand); - } - else if (identifier == "resultType") + if (identifier == "resultType") { operand.type = SpvSnippet::ASMOperandType::ResultTypeId; - operand.content = -1; + operand.content = (SpvWord)0xFFFFFFFF; if (tokenReader.AdvanceIf("*")) { // A "*" at operand qualifies the use of `resultType` with @@ -99,6 +111,78 @@ RefPtr<SpvSnippet> SpvSnippet::parse(UnownedStringSlice definition) operand.type = SpvSnippet::ASMOperandType::ResultId; inst.operands.add(operand); } + else if (identifier == "glsl450") + { + operand.type = SpvSnippet::ASMOperandType::GLSL450ExtInstSet; + inst.operands.add(operand); + } + else if (identifier == "fi") + { + operand.type = SpvSnippet::ASMOperandType::FloatIntegerSelection; + tokenReader.Read("("); + operand.content = (SpvWord)tokenReader.ReadInt(); + tokenReader.Read(","); + operand.content2 = (SpvWord)tokenReader.ReadInt(); + tokenReader.Read(")"); + inst.operands.add(operand); + } + else if (identifier == "fus") + { + operand.type = SpvSnippet::ASMOperandType::FloatUnsignedSignedSelection; + tokenReader.Read("("); + operand.content = (SpvWord)tokenReader.ReadInt(); + tokenReader.Read(","); + operand.content2 = (SpvWord)tokenReader.ReadInt(); + tokenReader.Read(","); + operand.content3 = (SpvWord)tokenReader.ReadInt(); + tokenReader.Read(")"); + inst.operands.add(operand); + } + else if (identifier == "_type") + { + operand.type = SpvSnippet::ASMOperandType::TypeReference; + tokenReader.Read("("); + operand.content = (SpvWord)parseASMType(tokenReader); + tokenReader.Read(")"); + inst.operands.add(operand); + } + else if (identifier.startsWith("_")) + { + operand.type = SpvSnippet::ASMOperandType::ObjectReference; + operand.content = (SpvWord)StringToInt( + identifier.subString(1, identifier.getLength() - 1)); + inst.operands.add(operand); + } + else if (identifier == "const") + { + operand.type = SpvSnippet::ASMOperandType::ConstantReference; + ASMConstant constant; + memset(&constant, 0, sizeof(ASMConstant)); + tokenReader.Read("("); + constant.type = parseASMType(tokenReader); + int i = 0; + while (tokenReader.AdvanceIf(",")) + { + switch (constant.type) + { + case ASMType::Float: + case ASMType::Float2: + case ASMType::FloatOrDouble: + constant.floatValues[i] = tokenReader.ReadFloat(); + ++i; + break; + + default: + constant.intValues[i] = tokenReader.ReadInt(); + ++i; + break; + } + } + tokenReader.Read(")"); + snippet->constants.add(constant); + operand.content = (SpvWord)(snippet->constants.getCount() - 1); + inst.operands.add(operand); + } else { SLANG_ASSERT(!"Invalid SPV ASM operand."); diff --git a/source/slang/slang-ir-spirv-snippet.h b/source/slang/slang-ir-spirv-snippet.h index 74a9b8cd7..e524abe3a 100644 --- a/source/slang/slang-ir-spirv-snippet.h +++ b/source/slang/slang-ir-spirv-snippet.h @@ -31,6 +31,18 @@ struct SpvSnippet : public RefObject ObjectReference, // Represents a reference to an ASM inst (e.g. `%t`). InstReference, + // Refer to the GLSL450 Instruction Set. + GLSL450ExtInstSet, + // A select expression based on whether result type is float, e.g. + // `fi(x,y)` selects `x` if resultType is `float`. + FloatIntegerSelection, + // A select expression based on whether result type is float, unsigned + // or signed integer. e.g. `fus(f_opcode, u_opcode, s_opcode)`. + FloatUnsignedSignedSelection, + // Reference to a type defined in `ASMType`. + TypeReference, + // Reference to a Constant defined in `SpvSnippet::constants`. + ConstantReference, }; struct ASMOperand @@ -40,18 +52,81 @@ struct SpvSnippet : public RefObject // The value of the spv word when type is `SpvWord`, or // the reference name when type is `ObjectReference` // (e.g. an argument reference (_1) has `content` == 1). - int content; + SpvWord content; + + // Additional value contents. + SpvWord content2; + SpvWord content3; + }; + + enum class ASMType : SpvWord + { + None, + Int, + Float, + Double, + FloatOrDouble, // Float or double type, depending on the result type of the intrinsic. + Float2, + UInt2, + }; + + struct ASMConstant + { + ASMType type; + SpvWord intValues[4]; + float floatValues[4]; + HashCode getHashCode() + { + HashCode result = (HashCode)type; + for (int i = 0; i < 4; i++) + { + switch (type) + { + case ASMType::Float: + case ASMType::Double: + case ASMType::Float2: + case ASMType::FloatOrDouble: + result = combineHash(result, Slang::getHashCode(floatValues[i])); + break; + default: + result = combineHash(result, Slang::getHashCode(intValues[i])); + break; + } + } + return result; + } + bool operator==(const ASMConstant& other) + { + if (type != other.type) + return false; + switch (type) + { + case ASMType::Float: + case ASMType::Double: + case ASMType::FloatOrDouble: + return floatValues[0] == other.floatValues[0]; + case ASMType::Float2: + return floatValues[0] == other.floatValues[0] && + floatValues[1] == other.floatValues[1]; + case ASMType::Int: + return intValues[0] == other.intValues[0]; + case ASMType::UInt2: + return intValues[0] == other.intValues[0] && intValues[1] == other.intValues[1]; + default: + return false; + } + } }; struct ASMInst { - SpvWord opCode; + SpvWord opCode = 0; List<ASMOperand> operands; }; List<ASMInst> instructions; List<SpvStorageClass> usedResultTypeStorageClasses; - + List<ASMConstant> constants; SpvStorageClass resultStorageClass = SpvStorageClassMax; static RefPtr<SpvSnippet> parse(UnownedStringSlice definition); diff --git a/source/slang/slang-options.cpp b/source/slang/slang-options.cpp index 73ccaab55..90c5a3896 100644 --- a/source/slang/slang-options.cpp +++ b/source/slang/slang-options.cpp @@ -1163,7 +1163,7 @@ struct OptionsParser } else if( argValue == "-emit-spirv-directly" ) { - requestImpl->getBackEndReq()->shouldEmitSPIRVDirectly = true; + getCurrentTarget()->targetFlags |= SLANG_TARGET_FLAG_GENERATE_SPIRV_DIRECTLY; } else if (argValue == "-default-downstream-compiler") { diff --git a/source/slang/slang.cpp b/source/slang/slang.cpp index 913be346e..ef6872e76 100644 --- a/source/slang/slang.cpp +++ b/source/slang/slang.cpp @@ -1100,12 +1100,6 @@ void TargetRequest::addCapability(CapabilityAtom capability) cookedCapabilities = CapabilitySet::makeEmpty(); } -void TargetRequest::setDirectSPIRVEmitMode() -{ - m_emitSPIRVDirectly = true; - cookedCapabilities.makeEmpty(); -} - CapabilitySet TargetRequest::getTargetCaps() { if(!cookedCapabilities.isEmpty()) @@ -1140,7 +1134,7 @@ CapabilitySet TargetRequest::getTargetCaps() break; case CodeGenTarget::SPIRV: case CodeGenTarget::SPIRVAssembly: - if (m_emitSPIRVDirectly) + if (targetFlags & SLANG_TARGET_FLAG_GENERATE_SPIRV_DIRECTLY) { atoms.add(CapabilityAtom::SPIRV_DIRECT); } diff --git a/tests/hlsl-intrinsic/scalar-float.slang b/tests/hlsl-intrinsic/scalar-float.slang index 062f6c94b..a5756b01b 100644 --- a/tests/hlsl-intrinsic/scalar-float.slang +++ b/tests/hlsl-intrinsic/scalar-float.slang @@ -2,6 +2,7 @@ //TEST(compute):COMPARE_COMPUTE_EX:-slang -compute -shaderobj //TEST(compute):COMPARE_COMPUTE_EX:-slang -compute -dx12 -shaderobj //TEST(compute, vulkan):COMPARE_COMPUTE_EX:-vk -compute -shaderobj +//DISABLED_TEST(compute, vulkan):COMPARE_COMPUTE_EX:-vk -compute -shaderobj -emit-spirv-directly //TEST(compute, vulkan):COMPARE_COMPUTE_EX:-cuda -compute -shaderobj //TEST_INPUT:ubuffer(data=[0 0 0 0], stride=4):out,name outputBuffer diff --git a/tests/spirv/direct-spirv-compute-simple.slang b/tests/spirv/direct-spirv-compute-simple.slang index 39b9074ed..2fa2798e5 100644 --- a/tests/spirv/direct-spirv-compute-simple.slang +++ b/tests/spirv/direct-spirv-compute-simple.slang @@ -1,6 +1,6 @@ // direct-spirv-compute-simple.slang - -//TEST(compute, vulkan):COMPARE_COMPUTE_EX:-vk -compute -xslang -emit-spirv-directly +//TESTD:SIMPLE:-target spirv -entry computeMain -stage compute -emit-spirv-directly +//TEST(compute, vulkan):COMPARE_COMPUTE_EX:-vk -compute -emit-spirv-directly // Test runinng a shader generated from direct SPIR-V emit. diff --git a/tests/spirv/direct-spirv-control-flow-2.slang b/tests/spirv/direct-spirv-control-flow-2.slang index cc908100e..7dd829bda 100644 --- a/tests/spirv/direct-spirv-control-flow-2.slang +++ b/tests/spirv/direct-spirv-control-flow-2.slang @@ -1,6 +1,7 @@ // direct-spirv-control-flow-2.slang -//TEST(compute, vulkan):COMPARE_COMPUTE_EX:-vk -compute -output-using-type -xslang -emit-spirv-directly +//TESTD:SIMPLE:-target spirv -entry computeMain -stage compute -emit-spirv-directly +//TEST(compute, vulkan):COMPARE_COMPUTE_EX:-vk -compute -output-using-type -emit-spirv-directly // Test direct SPIR-V emit on control flows. diff --git a/tests/spirv/direct-spirv-control-flow.slang b/tests/spirv/direct-spirv-control-flow.slang index 9efddeb12..10bee2522 100644 --- a/tests/spirv/direct-spirv-control-flow.slang +++ b/tests/spirv/direct-spirv-control-flow.slang @@ -1,6 +1,6 @@ // direct-spirv-control-flow.slang -//TEST(compute, vulkan):COMPARE_COMPUTE_EX:-vk -compute -xslang -emit-spirv-directly +//TEST(compute, vulkan):COMPARE_COMPUTE_EX:-vk -compute -emit-spirv-directly // Test direct SPIRV emit on control fl. diff --git a/tools/gfx/vulkan/render-vk.cpp b/tools/gfx/vulkan/render-vk.cpp index 88770dbb9..f56a21d20 100644 --- a/tools/gfx/vulkan/render-vk.cpp +++ b/tools/gfx/vulkan/render-vk.cpp @@ -4944,6 +4944,7 @@ public: /// Note that the outShaderModule value should be cleaned up when no longer needed by caller /// via vkShaderModuleDestroy() VkPipelineShaderStageCreateInfo compileEntryPoint( + const char* entryPointName, ISlangBlob* code, VkShaderStageFlagBits stage, VkShaderModule& outShaderModule); @@ -5188,6 +5189,7 @@ VkBool32 VKDevice::handleDebugMessage(VkDebugReportFlagsEXT flags, VkDebugReport } VkPipelineShaderStageCreateInfo VKDevice::compileEntryPoint( + const char* entryPointName, ISlangBlob* code, VkShaderStageFlagBits stage, VkShaderModule& outShaderModule) @@ -5210,7 +5212,7 @@ VkPipelineShaderStageCreateInfo VKDevice::compileEntryPoint( shaderStageCreateInfo.stage = stage; shaderStageCreateInfo.module = module; - shaderStageCreateInfo.pName = "main"; + shaderStageCreateInfo.pName = entryPointName; return shaderStageCreateInfo; } @@ -6861,7 +6863,16 @@ Result VKDevice::createProgram(const IShaderProgram::Desc& desc, IShaderProgram* SLANG_RETURN_ON_FAIL(compileResult); shaderProgram->m_codeBlobs.add(kernelCode); VkShaderModule shaderModule; + // HACK: our direct-spirv-emit path generates SPIRV that respects + // the original entry point name, while the glslang path always + // uses "main" as the name. We should introduce a compiler parameter + // to control the entry point naming behavior in SPIRV-direct path + // so we can remove the ad-hoc logic here. + const char* entryPointName = "main"; + if (m_desc.slang.targetFlags & SLANG_TARGET_FLAG_GENERATE_SPIRV_DIRECTLY) + entryPointName = entryPointInfo->getName(); shaderProgram->m_stageCreateInfos.add(compileEntryPoint( + entryPointName, kernelCode, (VkShaderStageFlagBits)VulkanUtil::getShaderStage(stage), shaderModule)); @@ -7075,7 +7086,7 @@ Result VKDevice::createComputePipelineState(const ComputePipelineStateDesc& inDe returnComPtr(outState, pipelineStateImpl); return SLANG_OK; } - + VkPipelineCache pipelineCache = VK_NULL_HANDLE; VkPipeline pipeline = VK_NULL_HANDLE; diff --git a/tools/render-test/options.cpp b/tools/render-test/options.cpp index 5d3560351..a4a643e22 100644 --- a/tools/render-test/options.cpp +++ b/tools/render-test/options.cpp @@ -141,6 +141,10 @@ static gfx::DeviceType _toRenderType(Slang::RenderApiType apiType) { outOptions.useDXIL = true; } + else if (argValue == "-emit-spirv-directly") + { + outOptions.generateSPIRVDirectly = true; + } else if (argValue == "-only-startup") { outOptions.onlyStartup = true; diff --git a/tools/render-test/options.h b/tools/render-test/options.h index 7940d0d42..f41614360 100644 --- a/tools/render-test/options.h +++ b/tools/render-test/options.h @@ -73,6 +73,8 @@ struct Options Slang::DownstreamArgs downstreamArgs; ///< Args to downstream tools. Here it's just slang + bool generateSPIRVDirectly = false; + Options() { downstreamArgs.addName("slang"); } static SlangResult parse(int argc, const char*const* argv, Slang::WriterHelper stdError, Options& outOptions); diff --git a/tools/render-test/render-test-main.cpp b/tools/render-test/render-test-main.cpp index 96cc94aa3..271356cc3 100644 --- a/tools/render-test/render-test-main.cpp +++ b/tools/render-test/render-test-main.cpp @@ -1326,6 +1326,8 @@ static SlangResult _innerMain(Slang::StdWriters* stdWriters, SlangSession* sessi desc.adapter = options.adapter.getBuffer(); desc.slang.lineDirectiveMode = SLANG_LINE_DIRECTIVE_MODE_NONE; + if (options.generateSPIRVDirectly) + desc.slang.targetFlags = SLANG_TARGET_FLAG_GENERATE_SPIRV_DIRECTLY; List<const char*> requiredFeatureList; for (auto& name : options.renderFeatures) diff --git a/tools/render-test/slang-support.cpp b/tools/render-test/slang-support.cpp index a533a6378..65f8f244b 100644 --- a/tools/render-test/slang-support.cpp +++ b/tools/render-test/slang-support.cpp @@ -96,6 +96,8 @@ void ShaderCompilerUtil::Output::reset() spSetCodeGenTarget(slangRequest, input.target); spSetTargetProfile(slangRequest, 0, spFindProfile(out.session, input.profile.getBuffer())); + if (options.generateSPIRVDirectly) + spSetTargetFlags(slangRequest, 0, SLANG_TARGET_FLAG_GENERATE_SPIRV_DIRECTLY); // Define a macro so that shader code in a test can detect what language we // are nominally working with. |
