diff options
Diffstat (limited to 'source')
22 files changed, 507 insertions, 153 deletions
diff --git a/source/slang/core.meta.slang b/source/slang/core.meta.slang index 2e878b065..d8b5c38d3 100644 --- a/source/slang/core.meta.slang +++ b/source/slang/core.meta.slang @@ -2365,6 +2365,9 @@ __generic<T> __intrinsic_op($(kIROp_GetLegalizedSPIRVGlobalParamAddr)) Ptr<T> __getLegalizedSPIRVGlobalParamAddr(T val); +__intrinsic_op($(kIROp_RequireComputeDerivative)) +void __requireComputeDerivative(); + // Binding Attributes __attributeTarget(DeclBase) @@ -2627,4 +2630,10 @@ __attributeTarget(FunctionDeclBase) attribute_syntax [NonUniformReturn] : NonDynamicUniformAttribute; __attributeTarget(FunctionDeclBase) -attribute_syntax [__GLSLRequireShaderInputParameter(parameterNumber:int)] : GLSLRequireShaderInputParameterAttribute;
\ No newline at end of file +attribute_syntax [__GLSLRequireShaderInputParameter(parameterNumber:int)] : GLSLRequireShaderInputParameterAttribute; + +__attributeTarget(FuncDecl) +attribute_syntax [DerivativeGroupQuad] : DerivativeGroupQuadAttribute; + +__attributeTarget(FuncDecl) +attribute_syntax [DerivativeGroupLinear] : DerivativeGroupLinearAttribute;
\ No newline at end of file diff --git a/source/slang/glsl.meta.slang b/source/slang/glsl.meta.slang index 79f5bfdb8..b4e5cf7ab 100644 --- a/source/slang/glsl.meta.slang +++ b/source/slang/glsl.meta.slang @@ -7780,6 +7780,7 @@ public vector<float, N> dFdyCoarse(vector<float, N> p) [require(glsl_hlsl_spirv, fragmentprocessing_derivativecontrol)] public float fwidthFine(float p) { + __requireComputeDerivative(); __target_switch { case hlsl: @@ -7803,6 +7804,7 @@ __generic<let N : int> [require(glsl_hlsl_spirv, fragmentprocessing_derivativecontrol)] public vector<float, N> fwidthFine(vector<float, N> p) { + __requireComputeDerivative(); __target_switch { case hlsl: @@ -7826,6 +7828,7 @@ public vector<float, N> fwidthFine(vector<float, N> p) [require(glsl_hlsl_spirv, fragmentprocessing_derivativecontrol)] public float fwidthCoarse(float p) { + __requireComputeDerivative(); __target_switch { case hlsl: @@ -7849,6 +7852,7 @@ __generic<let N : int> [require(glsl_hlsl_spirv, fragmentprocessing_derivativecontrol)] public vector<float, N> fwidthCoarse(vector<float, N> p) { + __requireComputeDerivative(); __target_switch { case hlsl: diff --git a/source/slang/hlsl.meta.slang b/source/slang/hlsl.meta.slang index 1002fb163..77a224b61 100644 --- a/source/slang/hlsl.meta.slang +++ b/source/slang/hlsl.meta.slang @@ -331,6 +331,7 @@ extension __TextureImpl<T,Shape,isArray,isMS,sampleCount,0,isShadow,1,format> [require(glsl_hlsl_spirv, texture_querylod)] float CalculateLevelOfDetail(TextureCoord location) { + __requireComputeDerivative(); __target_switch { case hlsl: @@ -352,6 +353,7 @@ extension __TextureImpl<T,Shape,isArray,isMS,sampleCount,0,isShadow,1,format> [require(glsl_hlsl_spirv, texture_querylod)] float CalculateLevelOfDetailUnclamped(TextureCoord location) { + __requireComputeDerivative(); __target_switch { case hlsl: @@ -373,6 +375,7 @@ extension __TextureImpl<T,Shape,isArray,isMS,sampleCount,0,isShadow,1,format> [require(cpp_cuda_glsl_hlsl_spirv, texture_sm_4_1_fragment)] T Sample(vector<float, Shape.dimensions+isArray> location) { + __requireComputeDerivative(); __target_switch { case cpp: @@ -424,6 +427,7 @@ extension __TextureImpl<T,Shape,isArray,isMS,sampleCount,0,isShadow,1,format> [require(cpp_glsl_hlsl_spirv, texture_sm_4_1_fragment)] T Sample(vector<float, Shape.dimensions+isArray> location, vector<int, Shape.planeDimensions> offset, float clamp) { + __requireComputeDerivative(); __target_switch { case cpp: @@ -456,6 +460,7 @@ extension __TextureImpl<T,Shape,isArray,isMS,sampleCount,0,isShadow,1,format> [require(cpp_glsl_hlsl_spirv, texture_sm_4_1_fragment)] T SampleBias(vector<float, Shape.dimensions+isArray> location, float bias) { + __requireComputeDerivative(); __target_switch { case cpp: @@ -478,19 +483,20 @@ extension __TextureImpl<T,Shape,isArray,isMS,sampleCount,0,isShadow,1,format> [require(cpp_glsl_hlsl_spirv, texture_sm_4_1_fragment)] T SampleBias(vector<float, Shape.dimensions+isArray> location, float bias, constexpr vector<int, Shape.planeDimensions> offset) { + __requireComputeDerivative(); __target_switch { case cpp: case hlsl: return __getTexture().SampleBias(__getSampler(), location, bias, offset); case glsl: - __intrinsic_asm "$ctextureOffset($0, $1, $3, $2)$z"; + __intrinsic_asm "$ctextureOffset($0, $1, $3, $2)$z"; case spirv: - return spirv_asm - { - %sampled : __sampledType(T) = OpImageSampleImplicitLod $this $location None|Bias|ConstOffset $bias $offset; - __truncate $$T result __sampledType(T) %sampled; - }; + return spirv_asm + { + %sampled : __sampledType(T) = OpImageSampleImplicitLod $this $location None|Bias|ConstOffset $bias $offset; + __truncate $$T result __sampledType(T) %sampled; + }; } } @@ -756,6 +762,7 @@ extension __TextureImpl<T,Shape,isArray,isMS,sampleCount,access,isShadow,0,forma [require(glsl_hlsl_spirv, texture_querylod)] float CalculateLevelOfDetail(SamplerState s, TextureCoord location) { + __requireComputeDerivative(); __target_switch { case hlsl: @@ -775,6 +782,7 @@ extension __TextureImpl<T,Shape,isArray,isMS,sampleCount,access,isShadow,0,forma [require(glsl_hlsl_spirv, texture_querylod)] float CalculateLevelOfDetailUnclamped(SamplerState s, TextureCoord location) { + __requireComputeDerivative(); __target_switch { case hlsl: @@ -798,6 +806,7 @@ extension __TextureImpl<T,Shape,isArray,isMS,sampleCount,0,isShadow,0,format> [require(cpp_cuda_glsl_hlsl_spirv, texture_sm_4_1_fragment)] T Sample(SamplerState s, vector<float, Shape.dimensions+isArray> location) { + __requireComputeDerivative(); __target_switch { case cpp: @@ -851,6 +860,7 @@ extension __TextureImpl<T,Shape,isArray,isMS,sampleCount,0,isShadow,0,format> [require(cpp_glsl_hlsl_spirv, texture_sm_4_1_fragment)] T Sample(SamplerState s, vector<float, Shape.dimensions+isArray> location, constexpr vector<int, Shape.planeDimensions> offset) { + __requireComputeDerivative(); __target_switch { case cpp: @@ -874,6 +884,7 @@ extension __TextureImpl<T,Shape,isArray,isMS,sampleCount,0,isShadow,0,format> [require(cpp_glsl_hlsl_spirv, texture_sm_4_1_fragment)] T Sample(SamplerState s, vector<float, Shape.dimensions+isArray> location, constexpr vector<int, Shape.planeDimensions> offset, float clamp) { + __requireComputeDerivative(); __target_switch { case cpp: @@ -906,6 +917,7 @@ extension __TextureImpl<T,Shape,isArray,isMS,sampleCount,0,isShadow,0,format> [require(cpp_glsl_hlsl_spirv, texture_sm_4_1_fragment)] T SampleBias(SamplerState s, vector<float, Shape.dimensions+isArray> location, float bias) { + __requireComputeDerivative(); __target_switch { case cpp: @@ -928,6 +940,7 @@ extension __TextureImpl<T,Shape,isArray,isMS,sampleCount,0,isShadow,0,format> [require(cpp_glsl_hlsl_spirv, texture_sm_4_1_fragment)] T SampleBias(SamplerState s, vector<float, Shape.dimensions+isArray> location, float bias, constexpr vector<int, Shape.planeDimensions> offset) { + __requireComputeDerivative(); __target_switch { case cpp: @@ -4843,6 +4856,7 @@ __generic<T : __BuiltinFloatingPointType> [require(cpp_cuda_glsl_hlsl_spirv, fragmentprocessing)] T dd$(xOrY)(T x) { + __requireComputeDerivative(); __target_switch { case hlsl: @@ -4861,6 +4875,7 @@ __generic<T : __BuiltinFloatingPointType, let N : int> [require(cpp_cuda_glsl_hlsl_spirv, fragmentprocessing)] vector<T, N> dd$(xOrY)(vector<T, N> x) { + __requireComputeDerivative(); __target_switch { case hlsl: @@ -4880,7 +4895,14 @@ __target_intrinsic(hlsl) [require(cpp_cuda_glsl_hlsl_spirv, fragmentprocessing)] matrix<T, N, M> dd$(xOrY)(matrix<T, N, M> x) { - MATRIX_MAP_UNARY(T, N, M, dd$(xOrY), x); + __requireComputeDerivative(); + __target_switch + { + case hlsl: + __intrinsic_asm "dd$(xOrY)"; + default: + MATRIX_MAP_UNARY(T, N, M, dd$(xOrY), x); + } } __generic<T : __BuiltinFloatingPointType> @@ -4889,6 +4911,7 @@ __glsl_extension(GL_ARB_derivative_control) [require(glsl_hlsl_spirv, fragmentprocessing_derivativecontrol)] T dd$(xOrY)_coarse(T x) { + __requireComputeDerivative(); __target_switch { case hlsl: __intrinsic_asm "dd$(xOrY)_coarse"; @@ -4903,6 +4926,7 @@ __glsl_extension(GL_ARB_derivative_control) [require(glsl_hlsl_spirv, fragmentprocessing_derivativecontrol)] vector<T, N> dd$(xOrY)_coarse(vector<T, N> x) { + __requireComputeDerivative(); __target_switch { case hlsl: __intrinsic_asm "dd$(xOrY)_coarse"; @@ -4912,11 +4936,18 @@ vector<T, N> dd$(xOrY)_coarse(vector<T, N> x) } __generic<T : __BuiltinFloatingPointType, let N : int, let M : int> -__target_intrinsic(hlsl) [__readNone] +[require(glsl_hlsl_spirv, fragmentprocessing_derivativecontrol)] matrix<T, N, M> dd$(xOrY)_coarse(matrix<T, N, M> x) { - MATRIX_MAP_UNARY(T, N, M, dd$(xOrY)_coarse, x); + __requireComputeDerivative(); + __target_switch + { + case hlsl: + __intrinsic_asm "dd$(xOrY)_coarse"; + default: + MATRIX_MAP_UNARY(T, N, M, dd$(xOrY)_coarse, x); + } } __generic<T : __BuiltinFloatingPointType> @@ -4925,6 +4956,7 @@ __glsl_extension(GL_ARB_derivative_control) [require(glsl_hlsl_spirv, fragmentprocessing_derivativecontrol)] T dd$(xOrY)_fine(T x) { + __requireComputeDerivative(); __target_switch { case hlsl: __intrinsic_asm "dd$(xOrY)_fine"; @@ -4939,6 +4971,7 @@ __glsl_extension(GL_ARB_derivative_control) [require(glsl_hlsl_spirv, fragmentprocessing_derivativecontrol)] vector<T, N> dd$(xOrY)_fine(vector<T, N> x) { + __requireComputeDerivative(); __target_switch { case hlsl: __intrinsic_asm "dd$(xOrY)_fine"; @@ -4948,11 +4981,18 @@ vector<T, N> dd$(xOrY)_fine(vector<T, N> x) } __generic<T : __BuiltinFloatingPointType, let N : int, let M : int> -__target_intrinsic(hlsl) [__readNone] +[require(glsl_hlsl_spirv, fragmentprocessing_derivativecontrol)] matrix<T, N, M> dd$(xOrY)_fine(matrix<T, N, M> x) { - MATRIX_MAP_UNARY(T, N, M, dd$(xOrY)_fine, x); + __requireComputeDerivative(); + __target_switch + { + case hlsl: + __intrinsic_asm "dd$(xOrY)_fine"; + default: + MATRIX_MAP_UNARY(T, N, M, dd$(xOrY)_fine, x); + } } ${{{{ @@ -5651,28 +5691,57 @@ matrix<T, N, M> frexp(matrix<T, N, M> x, out matrix<int, N, M, L> exp) // Texture filter width __generic<T : __BuiltinFloatingPointType> [__readNone] -__target_intrinsic(hlsl) -__target_intrinsic(glsl) -__target_intrinsic(spirv, "OpFwidth resultType resultId _0") [require(glsl_hlsl_spirv, fragmentprocessing)] -T fwidth(T x); +T fwidth(T x) +{ + __requireComputeDerivative(); + __target_switch + { + case hlsl: + __intrinsic_asm "fwidth($0)"; + case glsl: + __intrinsic_asm "fwidth($0)"; + case spirv: + return spirv_asm + { + OpFwidth $$T result $x; + }; + } +} __generic<T : __BuiltinFloatingPointType, let N : int> -__target_intrinsic(hlsl) -__target_intrinsic(glsl) -__target_intrinsic(spirv, "OpFwidth resultType resultId _0") [__readNone] +[require(glsl_hlsl_spirv, fragmentprocessing)] vector<T, N> fwidth(vector<T, N> x) { - VECTOR_MAP_UNARY(T, N, fwidth, x); + __requireComputeDerivative(); + __target_switch + { + case hlsl: + __intrinsic_asm "fwidth($0)"; + case glsl: + __intrinsic_asm "fwidth($0)"; + case spirv: + return spirv_asm + { + OpFwidth $$vector<T, N> result $x; + }; + } } __generic<T : __BuiltinFloatingPointType, let N : int, let M : int> __target_intrinsic(hlsl) [__readNone] +[require(glsl_hlsl_spirv, fragmentprocessing)] matrix<T, N, M> fwidth(matrix<T, N, M> x) { - MATRIX_MAP_UNARY(T, N, M, fwidth, x); + __target_switch + { + case hlsl: + __intrinsic_asm "fwidth($0)"; + default: + MATRIX_MAP_UNARY(T, N, M, fwidth, x); + } } /// Get the value of a vertex attribute at a specific vertex. diff --git a/source/slang/slang-ast-modifier.h b/source/slang/slang-ast-modifier.h index 21661939c..1126e84ef 100644 --- a/source/slang/slang-ast-modifier.h +++ b/source/slang/slang-ast-modifier.h @@ -820,6 +820,16 @@ class GLSLLayoutLocalSizeAttribute : public Attribute IntVal* z; }; +class GLSLLayoutDerivativeGroupQuadAttribute : public Attribute +{ + SLANG_AST_CLASS(GLSLLayoutDerivativeGroupQuadAttribute) +}; + +class GLSLLayoutDerivativeGroupLinearAttribute : public Attribute +{ + SLANG_AST_CLASS(GLSLLayoutDerivativeGroupLinearAttribute) +}; + // TODO: for attributes that take arguments, the syntax node // classes should provide accessors for the values of those arguments. @@ -1437,6 +1447,16 @@ class NoInlineAttribute : public Attribute SLANG_AST_CLASS(NoInlineAttribute) }; +class DerivativeGroupQuadAttribute : public Attribute +{ + SLANG_AST_CLASS(DerivativeGroupQuadAttribute) +}; + +class DerivativeGroupLinearAttribute : public Attribute +{ + SLANG_AST_CLASS(DerivativeGroupLinearAttribute) +}; + /// A `[payload]` attribute indicates that a `struct` type will be used as /// a ray payload for `TraceRay()` calls, and thus also as input/output /// for shaders in the ray tracing pipeline that might be invoked for diff --git a/source/slang/slang-capabilities.capdef b/source/slang/slang-capabilities.capdef index fe11be4b2..cdac0d4c1 100644 --- a/source/slang/slang-capabilities.capdef +++ b/source/slang/slang-capabilities.capdef @@ -120,6 +120,7 @@ alias pixel = fragment; alias raygeneration = raygen; alias tesscontrol = hull; alias tesseval = domain; +alias amplification_mesh = amplification | mesh; alias raytracing_stages = raygen | intersection | anyhit | closesthit | miss | callable; alias raytracing_stages_intersection = intersection; alias raytracing_stages_raygen = raygen; @@ -134,6 +135,7 @@ alias shader_stages_compute_fragment_geometry_vertex = compute | fragment | geom alias shader_stages_domain_hull = domain | hull; alias raytracing_stages_fragment = raytracing_stages | fragment; alias raytracing_stages_compute = raytracing_stages | compute; +alias raytracing_stages_compute_amplification_mesh = raytracing_stages_compute | amplification_mesh; alias raytracing_stages_compute_fragment = raytracing_stages | shader_stages_compute_fragment; alias raytracing_stages_compute_fragment_geometry_vertex = raytracing_stages | shader_stages_compute_fragment_geometry_vertex; @@ -197,6 +199,7 @@ def SPV_NV_shader_invocation_reorder : spirv_1_5 + SPV_KHR_ray_tracing; def SPV_KHR_shader_clock : spirv_1_0; def SPV_NV_shader_image_footprint : spirv_1_0; def SPV_GOOGLE_user_type : spirv_1_0; +def SPV_NV_compute_shader_derivatives : spirv_1_0; // SPIRV Capabilities. @@ -326,7 +329,7 @@ alias GL_KHR_shader_subgroup_shuffle = _GL_KHR_shader_subgroup_shuffle | spvGrou alias GL_KHR_shader_subgroup_shuffle_relative = _GL_KHR_shader_subgroup_shuffle_relative | spvGroupNonUniformShuffle; alias GL_KHR_shader_subgroup_vote = _GL_KHR_shader_subgroup_vote | spvGroupNonUniformVote; alias GL_KHR_shader_subgroup_quad = _GL_KHR_shader_subgroup_quad | spvGroupNonUniformQuad; -alias GL_NV_compute_shader_derivatives = _GL_NV_compute_shader_derivatives | SOURCE_EXT_GL_NV_compute_shader_derivatives; +alias GL_NV_compute_shader_derivatives = _GL_NV_compute_shader_derivatives | SOURCE_EXT_GL_NV_compute_shader_derivatives | SPV_NV_compute_shader_derivatives | _sm_6_6; alias GL_ARB_shader_image_size = _GL_ARB_shader_image_size | spvImageQuery; alias GL_ARB_shader_texture_image_samples = _GL_ARB_shader_texture_image_samples | spvImageQuery; alias GL_NV_shader_atomic_fp16_vector = _GL_NV_shader_atomic_fp16_vector + _GL_NV_gpu_shader5 | spirv_1_0; @@ -589,15 +592,13 @@ alias atomic_hlsl_sm_6_6 = _sm_6_6; alias byteaddressbuffer = sm_4_0; alias byteaddressbuffer_rw = sm_4_0 + raytracing_stages_compute_fragment; alias consumestructuredbuffer = sm_5_0 + raytracing_stages_compute_fragment; -alias fragmentprocessing = raytracing_stages_compute_fragment + _sm_5_0 +alias fragmentprocessing = fragment + _sm_5_0 | fragment + glsl_spirv - | raytracing_stages_compute + GL_NV_compute_shader_derivatives - | raytracing_stages_compute_fragment + GLSL_460 + | raytracing_stages_compute_amplification_mesh + GL_NV_compute_shader_derivatives ; -alias fragmentprocessing_derivativecontrol = raytracing_stages_compute_fragment + _sm_5_0 +alias fragmentprocessing_derivativecontrol = fragment + _sm_5_0 | fragment + GL_ARB_derivative_control - | compute + GL_NV_compute_shader_derivatives - | raytracing_stages_compute_fragment + GLSL_460 + | raytracing_stages_compute_amplification_mesh + GL_NV_compute_shader_derivatives ; alias getattributeatvertex = fragment + _sm_6_1 | fragment + GL_EXT_fragment_shader_barycentric; alias memorybarrier_compute = raytracing_stages_compute + sm_5_0; @@ -633,7 +634,9 @@ alias texture_querylod = texture_sm_4_1 + GL_EXT_texture_query_lod; alias texture_querylevels = texture_sm_4_1 + GL_ARB_texture_query_levels; alias texture_shadowlod = texture_sm_4_1 + GL_EXT_texture_shadow_lod + _GLSL_400 | texture_sm_4_1 + GL_EXT_texture_shadow_lod; -alias texture_shadowlod_cube = texture_shadowlod + GL_ARB_texture_cube_map; +alias texture_shadowlod_cube = texture_shadowlod | texture_shadowlod + GL_ARB_texture_cube_map; +alias texture_cube = texture_sm_4_1 + GL_ARB_texture_cube_map | texture_sm_4_1; +alias texture_querylevels_cube = texture_querylevels + GL_ARB_texture_cube_map | texture_querylevels; alias atomic_glsl_float1 = GL_EXT_shader_atomic_float; alias atomic_glsl_float2 = GL_EXT_shader_atomic_float2; @@ -646,7 +649,7 @@ alias nonuniformqualifier = sm_5_1; alias printf = GL_EXT_debug_printf | _sm_4_0 | _cuda_sm_2_0 | cpp; alias texturefootprint = GL_NV_shader_texture_footprint + GLSL_450 | hlsl_nvapi + _sm_4_0; alias texturefootprintclamp = texturefootprint + GL_ARB_sparse_texture_clamp; -alias texture_cube = GL_ARB_texture_cube_map; + alias shader5_sm_4_0 = GL_ARB_gpu_shader5 | sm_4_0; alias shader5_sm_5_0 = GL_ARB_gpu_shader5 | sm_5_0; diff --git a/source/slang/slang-diagnostic-defs.h b/source/slang/slang-diagnostic-defs.h index 0c5d0f10c..b294383bb 100644 --- a/source/slang/slang-diagnostic-defs.h +++ b/source/slang/slang-diagnostic-defs.h @@ -451,6 +451,9 @@ DIAGNOSTIC(31205, Error, incompleteTypeCannotBeUsedInUniformParameter, "incomple DIAGNOSTIC(31206, Error, memoryQualifierNotAllowedOnANonImageTypeParameter, "modifier $0 is not allowed on a non image type parameter.") DIAGNOSTIC(31207, Error, InputAttachmentIndexOnlyAllowedOnSubpass, "input_attachment_index is only allowed on subpass images.") DIAGNOSTIC(31208, Error, requireInputDecoratedVarForParameter, "$0 expects for argument $1 a type which is a shader input (`in`) variable.") +DIAGNOSTIC(31210, Error, derivativeGroupQuadMustBeMultiple2ForXYThreads, "compute derivative group quad requires thread dispatch count of X and Y to each be at a multiple of 2") +DIAGNOSTIC(31211, Error, derivativeGroupLinearMustBeMultiple4ForTotalThreadCount, "compute derivative group linear requires total thread dispatch count to be at a multiple of 4") +DIAGNOSTIC(31212, Error, onlyOneOfDerivativeGroupLinearOrQuadCanBeSet, "cannot set compute derivative group linear and compute derivative group quad at the same time") // Enums diff --git a/source/slang/slang-emit-c-like.cpp b/source/slang/slang-emit-c-like.cpp index bd3337769..fa380e061 100644 --- a/source/slang/slang-emit-c-like.cpp +++ b/source/slang/slang-emit-c-like.cpp @@ -129,6 +129,11 @@ void CLikeSourceEmitter::emitPreModuleImpl() m_writer->emit(prelude->getStringSlice()); m_writer->emit("\n"); } + for (auto prelude : m_requiredPreludesRaw) + { + m_writer->emit(prelude); + m_writer->emit("\n"); + } } // @@ -2747,6 +2752,10 @@ void CLikeSourceEmitter::defaultEmitInstExpr(IRInst* inst, const EmitOpInfo& inO { break; //should already have set requirement; case covered for empty intrinsic block } + case kIROp_RequireComputeDerivative: + { + break; //should already have been parsed and used. + } default: diagnoseUnhandledInst(inst); break; @@ -3358,7 +3367,8 @@ void CLikeSourceEmitter::emitSimpleFuncImpl(IRFunc* func) // Deal with decorations that need // to be emitted as attributes - if ( IREntryPointDecoration* entryPointDecor = func->findDecoration<IREntryPointDecoration>()) + IREntryPointDecoration* entryPointDecor = func->findDecoration<IREntryPointDecoration>(); + if (entryPointDecor) { emitEntryPointAttributes(func, entryPointDecor); } @@ -4446,6 +4456,7 @@ void CLikeSourceEmitter::emitModuleImpl(IRModule* module, DiagnosticSink* sink) List<EmitAction> actions; + beforeComputeEmitActions(module); computeEmitActions(module, actions); executeEmitActions(actions); } diff --git a/source/slang/slang-emit-c-like.h b/source/slang/slang-emit-c-like.h index 64aa5f945..ba17caace 100644 --- a/source/slang/slang-emit-c-like.h +++ b/source/slang/slang-emit-c-like.h @@ -476,6 +476,8 @@ public: /// the appropriate generated declarations occur. virtual void emitPreModuleImpl(); + virtual void beforeComputeEmitActions(IRModule* module) { SLANG_UNUSED(module); }; + virtual void emitRateQualifiersAndAddressSpaceImpl(IRRate* rate, IRIntegerValue addressSpace) { SLANG_UNUSED(rate); SLANG_UNUSED(addressSpace); } virtual void emitSemanticsImpl(IRInst* inst, bool allowOffsetLayout) { SLANG_UNUSED(inst); SLANG_UNUSED(allowOffsetLayout); } virtual void emitSimpleFuncParamImpl(IRParam* param); @@ -587,6 +589,7 @@ public: // to use for it when emitting code. Dictionary<IRInst*, String> m_mapInstToName; + OrderedHashSet<String> m_requiredPreludesRaw; OrderedHashSet<IRStringLit*> m_requiredPreludes; }; diff --git a/source/slang/slang-emit-glsl.cpp b/source/slang/slang-emit-glsl.cpp index 32301c418..d768cff97 100644 --- a/source/slang/slang-emit-glsl.cpp +++ b/source/slang/slang-emit-glsl.cpp @@ -10,6 +10,7 @@ #include "slang-legalize-types.h" #include "slang-ir-layout.h" #include "slang/slang-ir.h" +#include "slang-ir-call-graph.h" #include <assert.h> @@ -26,6 +27,11 @@ GLSLSourceEmitter::GLSLSourceEmitter(const Desc& desc) : SLANG_ASSERT(m_glslExtensionTracker); } +void GLSLSourceEmitter::beforeComputeEmitActions(IRModule* module) +{ + buildEntryPointReferenceGraph(this->m_referencingEntryPoints, module); +} + SlangResult GLSLSourceEmitter::init() { SLANG_RETURN_ON_FAIL(Super::init()); @@ -2171,7 +2177,7 @@ void GLSLSourceEmitter::handleRequiredCapabilitiesImpl(IRInst* inst) } } - // The function may have IRRequireGLSLExtensionInst in its body. We also need to look for them. + // The function may have various requirment declaring functions its body. We also need to look for them. auto func = as<IRFunc>(inst); if (!func) return; @@ -2184,6 +2190,36 @@ void GLSLSourceEmitter::handleRequiredCapabilitiesImpl(IRInst* inst) { _requireGLSLExtension(requireGLSLExt->getExtensionName()); } + else if (auto requireComputeDerivative = as<IRRequireComputeDerivative>(childInst)) + { + // only allowed 1 of derivative_group_quadsNV or derivative_group_linearNV + if (m_entryPointStage != Stage::Compute + || m_requiredPreludesRaw.contains("layout(derivative_group_quadsNV) in;") + || m_requiredPreludesRaw.contains("layout(derivative_group_linearNV) in;") + ) + return; + + _requireGLSLExtension(UnownedStringSlice("GL_NV_compute_shader_derivatives")); + + // This will only run once per program. + HashSet<IRFunc*>* entryPointsUsingInst = getReferencingEntryPoints(m_referencingEntryPoints, func); + + for (auto entryPoint : *entryPointsUsingInst) + { + bool isQuad = !entryPoint->findDecoration<IRDerivativeGroupLinearDecoration>(); + auto numThreadsDecor = entryPoint->findDecoration<IRNumThreadsDecoration>(); + if (isQuad) + { + verifyComputeDerivativeGroupModifiers(getSink(), inst->sourceLoc, true, false, numThreadsDecor); + m_requiredPreludesRaw.add("layout(derivative_group_quadsNV) in;"); + } + else + { + verifyComputeDerivativeGroupModifiers(getSink(), inst->sourceLoc, false, true, numThreadsDecor); + m_requiredPreludesRaw.add("layout(derivative_group_linearNV) in;"); + } + } + } } } diff --git a/source/slang/slang-emit-glsl.h b/source/slang/slang-emit-glsl.h index d8a3e4e81..a30195d75 100644 --- a/source/slang/slang-emit-glsl.h +++ b/source/slang/slang-emit-glsl.h @@ -22,6 +22,7 @@ public: protected: + virtual void beforeComputeEmitActions(IRModule* module) SLANG_OVERRIDE; virtual void emitParameterGroupImpl(IRGlobalParam* varDecl, IRUniformParameterGroupType* type) SLANG_OVERRIDE; virtual void emitEntryPointAttributesImpl(IRFunc* irFunc, IREntryPointDecoration* entryPointDecor) SLANG_OVERRIDE; virtual void emitImageFormatModifierImpl(IRInst* varDecl, IRType* varType) SLANG_OVERRIDE; @@ -129,6 +130,8 @@ protected: void _emitSpecialFloatImpl(IRType* type, const char* valueExpr); + Dictionary<IRInst*, HashSet<IRFunc*>> m_referencingEntryPoints; + RefPtr<GLSLExtensionTracker> m_glslExtensionTracker; }; diff --git a/source/slang/slang-emit-spirv.cpp b/source/slang/slang-emit-spirv.cpp index 06e5f0766..106248ef8 100644 --- a/source/slang/slang-emit-spirv.cpp +++ b/source/slang/slang-emit-spirv.cpp @@ -4,6 +4,7 @@ #include "slang-emit-base.h" #include "slang-ir-util.h" +#include "slang-ir-call-graph.h" #include "slang-ir.h" #include "slang-ir-insts.h" #include "slang-ir-layout.h" @@ -437,6 +438,7 @@ constexpr bool isPlural<IRUse*> = true; template<typename T> constexpr bool isSingular = !isPlural<T>; + // Now that we've defined the intermediate data structures we will // use to represent SPIR-V code during emission, we will move on // to defining the main context type that will drive SPIR-V @@ -1278,6 +1280,11 @@ struct SPIRVEmitContext return result; } + bool hasExtensionDeclaration(const UnownedStringSlice& name) + { + return m_extensionInsts.containsKey(name); + } + struct SpvTypeInstKey { List<SpvWord> words; @@ -2732,6 +2739,43 @@ struct SPIRVEmitContext result = inner; break; } + case kIROp_RequireComputeDerivative: + { + auto parentFunc = getParentFunc(inst); + + HashSet<IRFunc*>* entryPointsUsingInst = getReferencingEntryPoints(m_referencingEntryPoints, parentFunc); + for (IRFunc* entryPoint : *entryPointsUsingInst) + { + bool isQuad = true; + IREntryPointDecoration* entryPointDecor = nullptr; + for(auto dec : entryPoint->getDecorations()) + { + if(auto maybeEntryPointDecor = as<IREntryPointDecoration>(dec)) + entryPointDecor = maybeEntryPointDecor; + if(as<IRDerivativeGroupLinearDecoration>(dec)) + isQuad = false; + } + if (!entryPointDecor || entryPointDecor->getProfile().getStage() != Stage::Compute) + continue; + + ensureExtensionDeclaration(UnownedStringSlice("SPV_NV_compute_shader_derivatives")); + auto numThreadsDecor = entryPointDecor->findDecoration<IRNumThreadsDecoration>(); + if (isQuad) + { + verifyComputeDerivativeGroupModifiers(this->m_sink, inst->sourceLoc, true, false, numThreadsDecor); + emitOpExecutionMode(getSection(SpvLogicalSectionID::ExecutionModes), nullptr, entryPoint, SpvExecutionModeDerivativeGroupQuadsNV); + emitOpCapability(getSection(SpvLogicalSectionID::Capabilities), nullptr, SpvCapabilityComputeDerivativeGroupQuadsNV); + } + else + { + verifyComputeDerivativeGroupModifiers(this->m_sink, inst->sourceLoc, false, true, numThreadsDecor); + emitOpExecutionMode(getSection(SpvLogicalSectionID::ExecutionModes), nullptr, entryPoint, SpvExecutionModeDerivativeGroupLinearNV); + emitOpCapability(getSection(SpvLogicalSectionID::Capabilities), nullptr, SpvCapabilityComputeDerivativeGroupLinearNV); + } + } + + break; + } case kIROp_Return: if (as<IRReturn>(inst)->getVal()->getOp() == kIROp_VoidLit) result = emitOpReturn(parent, inst); diff --git a/source/slang/slang-ir-call-graph.cpp b/source/slang/slang-ir-call-graph.cpp new file mode 100644 index 000000000..b3de60228 --- /dev/null +++ b/source/slang/slang-ir-call-graph.cpp @@ -0,0 +1,112 @@ +#include "slang-ir-call-graph.h" +#include "slang-ir-insts.h" +#include "slang-ir-clone.h" + +namespace Slang +{ + +void buildEntryPointReferenceGraph(Dictionary<IRInst*, HashSet<IRFunc*>>& referencingEntryPoints, IRModule* module) +{ + struct WorkItem + { + IRFunc* entryPoint; IRInst* inst; + + HashCode getHashCode() const + { + return combineHash(Slang::getHashCode(entryPoint), Slang::getHashCode(inst)); + } + bool operator == (const WorkItem& other) const + { + return entryPoint == other.entryPoint && inst == other.inst; + } + }; + HashSet<WorkItem> workListSet; + List<WorkItem> workList; + auto addToWorkList = [&](WorkItem item) + { + if (workListSet.add(item)) + workList.add(item); + }; + + auto registerEntryPointReference = [&](IRFunc* entryPoint, IRInst* inst) + { + if (auto set = referencingEntryPoints.tryGetValue(inst)) + set->add(entryPoint); + else + { + HashSet<IRFunc*> newSet; + newSet.add(entryPoint); + referencingEntryPoints.add(inst, _Move(newSet)); + } + }; + auto visit = [&](IRFunc* entryPoint, IRInst* inst) + { + if (auto code = as<IRGlobalValueWithCode>(inst)) + { + registerEntryPointReference(entryPoint, inst); + for (auto child : code->getChildren()) + { + addToWorkList({ entryPoint, child }); + } + return; + } + switch (inst->getOp()) + { + case kIROp_GlobalParam: + case kIROp_SPIRVAsmOperandBuiltinVar: + registerEntryPointReference(entryPoint, inst); + break; + case kIROp_Block: + case kIROp_SPIRVAsm: + for (auto child : inst->getChildren()) + { + addToWorkList({ entryPoint, child }); + } + break; + case kIROp_Call: + { + auto call = as<IRCall>(inst); + addToWorkList({ entryPoint, call->getCallee() }); + } + break; + case kIROp_SPIRVAsmOperandInst: + { + auto operand = as<IRSPIRVAsmOperandInst>(inst); + addToWorkList({ entryPoint, operand->getValue() }); + } + break; + } + for (UInt i = 0; i < inst->getOperandCount(); i++) + { + auto operand = inst->getOperand(i); + switch (operand->getOp()) + { + case kIROp_GlobalParam: + case kIROp_GlobalVar: + case kIROp_SPIRVAsmOperandBuiltinVar: + addToWorkList({ entryPoint, operand }); + break; + } + } + }; + + for (auto globalInst : module->getGlobalInsts()) + { + if (globalInst->getOp() == kIROp_Func && globalInst->findDecoration<IREntryPointDecoration>()) + { + visit(as<IRFunc>(globalInst), globalInst); + } + } + for (Index i = 0; i < workList.getCount(); i++) + visit(workList[i].entryPoint, workList[i].inst); +} + +HashSet<IRFunc*>* getReferencingEntryPoints(Dictionary<IRInst*, HashSet<IRFunc*>>& m_referencingEntryPoints, IRInst* inst) +{ + auto* referencingEntryPoints = m_referencingEntryPoints.tryGetValue(inst); + if (!referencingEntryPoints) + return nullptr; + return referencingEntryPoints; +} + +} diff --git a/source/slang/slang-ir-call-graph.h b/source/slang/slang-ir-call-graph.h new file mode 100644 index 000000000..85532e9d1 --- /dev/null +++ b/source/slang/slang-ir-call-graph.h @@ -0,0 +1,11 @@ +#include "slang-ir-insts.h" +#include "slang-ir-clone.h" + +namespace Slang +{ + + void buildEntryPointReferenceGraph(Dictionary<IRInst*, HashSet<IRFunc*>>& referencingEntryPoints, IRModule* module); + + HashSet<IRFunc*>* getReferencingEntryPoints(Dictionary<IRInst*, HashSet<IRFunc*>>& m_referencingEntryPoints, IRInst* inst); + +} diff --git a/source/slang/slang-ir-inst-defs.h b/source/slang/slang-ir-inst-defs.h index 47a2d2f7e..2aad3bf8e 100644 --- a/source/slang/slang-ir-inst-defs.h +++ b/source/slang/slang-ir-inst-defs.h @@ -594,6 +594,7 @@ INST_RANGE(TerminatorInst, Return, Unreachable) INST(RequirePrelude, RequirePrelude, 1, 0) INST(RequireGLSLExtension, RequireGLSLExtension, 1, 0) +INST(RequireComputeDerivative, RequireComputeDerivative, 0, 0) // TODO: We should consider splitting the basic arithmetic/comparison // ops into cases for signed integers, unsigned integers, and floating-point @@ -862,6 +863,9 @@ INST_RANGE(BindingQuery, GetRegisterIndex, GetRegisterSpace) /// Applie to an IR function and signals that inlining should not be performed unless unavoidable. INST(NoInlineDecoration, noInline, 0, 0) + INST(DerivativeGroupQuadDecoration, DerivativeGroupQuad, 0, 0) + INST(DerivativeGroupLinearDecoration, DerivativeGroupLinear, 0, 0) + // Marks a type to be non copyable, causing SSA pass to skip turning variables of the the type into SSA values. INST(NonCopyableTypeDecoration, nonCopyable, 0, 0) diff --git a/source/slang/slang-ir-insts.h b/source/slang/slang-ir-insts.h index 067a0cc2f..0d66efb14 100644 --- a/source/slang/slang-ir-insts.h +++ b/source/slang/slang-ir-insts.h @@ -393,6 +393,8 @@ IR_SIMPLE_DECORATION(HLSLExportDecoration) IR_SIMPLE_DECORATION(KeepAliveDecoration) IR_SIMPLE_DECORATION(RequiresNVAPIDecoration) IR_SIMPLE_DECORATION(NoInlineDecoration) +IR_SIMPLE_DECORATION(DerivativeGroupQuadDecoration) +IR_SIMPLE_DECORATION(DerivativeGroupLinearDecoration) IR_SIMPLE_DECORATION(AlwaysFoldIntoUseSiteDecoration) IR_SIMPLE_DECORATION(StaticRequirementDecoration) IR_SIMPLE_DECORATION(NonCopyableTypeDecoration) @@ -3208,6 +3210,11 @@ struct IRRequireGLSLExtension : IRInst UnownedStringSlice getExtensionName() { return as<IRStringLit>(getOperand(0))->getStringSlice(); } }; +struct IRRequireComputeDerivative : IRInst +{ + IR_LEAF_ISA(RequireComputeDerivative) +}; + struct IRBuilderSourceLocRAII; struct IRBuilder @@ -4285,9 +4292,9 @@ public: } template<typename T> - void addSimpleDecoration(IRInst* value) + IRDecoration* addSimpleDecoration(IRInst* value) { - addDecoration(value, IROp(T::kOp), (IRInst* const*) nullptr, 0); + return addDecoration(value, IROp(T::kOp), (IRInst* const*) nullptr, 0); } void addHighLevelDeclDecoration(IRInst* value, Decl* decl); diff --git a/source/slang/slang-ir-spirv-legalize.cpp b/source/slang/slang-ir-spirv-legalize.cpp index 23980e583..2f392a178 100644 --- a/source/slang/slang-ir-spirv-legalize.cpp +++ b/source/slang/slang-ir-spirv-legalize.cpp @@ -7,6 +7,7 @@ #include "slang-ir-legalize-mesh-outputs.h" #include "slang-ir.h" #include "slang-ir-insts.h" +#include "slang-ir-call-graph.h" #include "slang-emit-base.h" #include "slang-glsl-extension-tracker.h" #include "slang-ir-lower-buffer-element-type.h" @@ -2243,102 +2244,6 @@ void legalizeSPIRV(SPIRVEmitSharedContext* sharedContext, IRModule* module) context.processModule(); } -void buildEntryPointReferenceGraph(Dictionary<IRInst*, HashSet<IRFunc*>>& referencingEntryPoints, IRModule* module) -{ - struct WorkItem - { - IRFunc* entryPoint; IRInst* inst; - - HashCode getHashCode() const - { - return combineHash(Slang::getHashCode(entryPoint), Slang::getHashCode(inst)); - } - bool operator == (const WorkItem& other) const - { - return entryPoint == other.entryPoint && inst == other.inst; - } - }; - HashSet<WorkItem> workListSet; - List<WorkItem> workList; - auto addToWorkList = [&](WorkItem item) - { - if (workListSet.add(item)) - workList.add(item); - }; - - auto registerEntryPointReference = [&](IRFunc* entryPoint, IRInst* inst) - { - if (auto set = referencingEntryPoints.tryGetValue(inst)) - set->add(entryPoint); - else - { - HashSet<IRFunc*> newSet; - newSet.add(entryPoint); - referencingEntryPoints.add(inst, _Move(newSet)); - } - }; - auto visit = [&](IRFunc* entryPoint, IRInst* inst) - { - if (auto code = as<IRGlobalValueWithCode>(inst)) - { - registerEntryPointReference(entryPoint, inst); - for (auto child : code->getChildren()) - { - addToWorkList({ entryPoint, child }); - } - return; - } - switch (inst->getOp()) - { - case kIROp_GlobalParam: - case kIROp_SPIRVAsmOperandBuiltinVar: - registerEntryPointReference(entryPoint, inst); - break; - case kIROp_Block: - case kIROp_SPIRVAsm: - for (auto child : inst->getChildren()) - { - addToWorkList({ entryPoint, child }); - } - break; - case kIROp_Call: - { - auto call = as<IRCall>(inst); - addToWorkList({ entryPoint, call->getCallee() }); - } - break; - case kIROp_SPIRVAsmOperandInst: - { - auto operand = as<IRSPIRVAsmOperandInst>(inst); - addToWorkList({ entryPoint, operand->getValue() }); - } - break; - } - for (UInt i = 0; i < inst->getOperandCount(); i++) - { - auto operand = inst->getOperand(i); - switch (operand->getOp()) - { - case kIROp_GlobalParam: - case kIROp_GlobalVar: - case kIROp_SPIRVAsmOperandBuiltinVar: - addToWorkList({ entryPoint, operand }); - break; - } - } - }; - - for (auto globalInst : module->getGlobalInsts()) - { - if (globalInst->getOp() == kIROp_Func && globalInst->findDecoration<IREntryPointDecoration>()) - { - visit(as<IRFunc>(globalInst), globalInst); - } - } - for (Index i = 0; i < workList.getCount(); i++) - visit(workList[i].entryPoint, workList[i].inst); -} - void simplifyIRForSpirvLegalization(TargetProgram* target, DiagnosticSink* sink, IRModule* module) { bool changed = true; diff --git a/source/slang/slang-ir-translate-glsl-global-var.cpp b/source/slang/slang-ir-translate-glsl-global-var.cpp index 575a66457..d070cee68 100644 --- a/source/slang/slang-ir-translate-glsl-global-var.cpp +++ b/source/slang/slang-ir-translate-glsl-global-var.cpp @@ -3,11 +3,10 @@ #include "slang-ir.h" #include "slang-ir-insts.h" #include "slang-ir-util.h" +#include "slang-ir-call-graph.h" namespace Slang { - void buildEntryPointReferenceGraph(Dictionary<IRInst*, HashSet<IRFunc*>>& referencingEntryPoints, IRModule* module); - struct GlobalVarTranslationContext { CodeGenContext* context; diff --git a/source/slang/slang-ir-util.cpp b/source/slang/slang-ir-util.cpp index f7a93dca6..f6b0acaed 100644 --- a/source/slang/slang-ir-util.cpp +++ b/source/slang/slang-ir-util.cpp @@ -1743,5 +1743,41 @@ IRType* dropNormAttributes(IRType* const t) return t; } +void verifyComputeDerivativeGroupModifiers( + DiagnosticSink* sink, + SourceLoc errorLoc, + bool quadAttr, + bool linearAttr, + IRNumThreadsDecoration* numThreadsDecor) +{ + if (!numThreadsDecor) + return; + + if (quadAttr && linearAttr) + { + sink->diagnose(errorLoc, Diagnostics::onlyOneOfDerivativeGroupLinearOrQuadCanBeSet); + } + + IRIntegerValue x = 1; + IRIntegerValue y = 1; + IRIntegerValue z = 1; + if (numThreadsDecor->getX()) + x = numThreadsDecor->getX()->getValue(); + if (numThreadsDecor->getY()) + y = numThreadsDecor->getY()->getValue(); + if (numThreadsDecor->getZ()) + z = numThreadsDecor->getZ()->getValue(); + + if (quadAttr) + { + if (x % 2 != 0 || y % 2 != 0) + sink->diagnose(errorLoc, Diagnostics::derivativeGroupQuadMustBeMultiple2ForXYThreads); + } + else if (linearAttr) + { + if ((x * y * z) % 4 != 0) + sink->diagnose(errorLoc, Diagnostics::derivativeGroupLinearMustBeMultiple4ForTotalThreadCount); + } +} } diff --git a/source/slang/slang-ir-util.h b/source/slang/slang-ir-util.h index 9046f3974..3b3df27ef 100644 --- a/source/slang/slang-ir-util.h +++ b/source/slang/slang-ir-util.h @@ -6,7 +6,6 @@ // #include "slang-ir.h" #include "slang-ir-insts.h" - namespace Slang { struct GenericChildrenMigrationContextImpl; @@ -313,10 +312,6 @@ static void overAllBlocks(IRModule* module, F f) void hoistInstOutOfASMBlocks(IRBlock* block); -IRType* getSPIRVSampledElementType(IRInst* sampledType); - -IRType* replaceVectorElementType(IRType* originalVectorType, IRType* t); - inline bool isCompositeType(IRType* type) { switch (type->getOp()) @@ -330,8 +325,19 @@ inline bool isCompositeType(IRType* type) } } +IRType* getSPIRVSampledElementType(IRInst* sampledType); + +IRType* replaceVectorElementType(IRType* originalVectorType, IRType* t); + IRParam* getParamAt(IRBlock* block, UIndex ii); +void verifyComputeDerivativeGroupModifiers( + DiagnosticSink* sink, + SourceLoc errorLoc, + bool quadAttr, + bool linearAttr, + IRNumThreadsDecoration* numThreadsDecor); + } #endif diff --git a/source/slang/slang-ir.cpp b/source/slang/slang-ir.cpp index 4ea3755f8..5a917e88c 100644 --- a/source/slang/slang-ir.cpp +++ b/source/slang/slang-ir.cpp @@ -60,6 +60,8 @@ namespace Slang case kIROp_LineAdjInputPrimitiveTypeDecoration: case kIROp_LineInputPrimitiveTypeDecoration: case kIROp_NoInlineDecoration: + case kIROp_DerivativeGroupQuadDecoration: + case kIROp_DerivativeGroupLinearDecoration: case kIROp_PointInputPrimitiveTypeDecoration: case kIROp_PreciseDecoration: case kIROp_PublicDecoration: diff --git a/source/slang/slang-lower-to-ir.cpp b/source/slang/slang-lower-to-ir.cpp index a600963a7..e4fc33e33 100644 --- a/source/slang/slang-lower-to-ir.cpp +++ b/source/slang/slang-lower-to-ir.cpp @@ -12,6 +12,7 @@ #include "slang-ir-bit-field-accessors.h" #include "slang-ir-loop-inversion.h" #include "slang-ir.h" +#include "slang-ir-util.h" #include "slang-ir-constexpr.h" #include "slang-ir-dce.h" #include "slang-ir-diff-call.h" @@ -7264,26 +7265,61 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo> #undef IGNORED_CASE + void getAllEntryPointsNoOverride(List<IRInst*>& entryPoints) + { + if(entryPoints.getCount() != 0 ) + return; + for(const auto d : context->irBuilder->getModule()->getModuleInst()->getGlobalInsts()) + if(d->findDecoration<IREntryPointDecoration>()) + entryPoints.add(d); + } + LoweredValInfo visitEmptyDecl(EmptyDecl* decl) { + bool verifyComputeDerivativeGroupModifier = false; + List<IRInst*> entryPoints {}; for(const auto modifier : decl->modifiers) { if(const auto layoutLocalSizeAttr = as<GLSLLayoutLocalSizeAttribute>(modifier)) { - for(const auto d : context->irBuilder->getModule()->getModuleInst()->getGlobalInsts()) - { - if(d->findDecoration<IREntryPointDecoration>()) - { - getBuilder()->addNumThreadsDecoration( + verifyComputeDerivativeGroupModifier = true; + getAllEntryPointsNoOverride(entryPoints); + for(auto d : entryPoints) + as<IRNumThreadsDecoration>(getBuilder()->addNumThreadsDecoration( d, getSimpleVal(context, lowerVal(context, layoutLocalSizeAttr->x)), getSimpleVal(context, lowerVal(context, layoutLocalSizeAttr->y)), getSimpleVal(context, lowerVal(context, layoutLocalSizeAttr->z)) - ); - } - } + )); + } + else if(as<GLSLLayoutDerivativeGroupQuadAttribute>(modifier)) + { + verifyComputeDerivativeGroupModifier = true; + getAllEntryPointsNoOverride(entryPoints); + for(auto d : entryPoints) + getBuilder()->addSimpleDecoration<IRDerivativeGroupQuadDecoration>(d); + } + else if(as<GLSLLayoutDerivativeGroupLinearAttribute>(modifier)) + { + verifyComputeDerivativeGroupModifier = true; + getAllEntryPointsNoOverride(entryPoints); + for(auto d : entryPoints) + getBuilder()->addSimpleDecoration<IRDerivativeGroupLinearDecoration>(d); } } + + if(!verifyComputeDerivativeGroupModifier) + return LoweredValInfo(); + for(auto d : entryPoints) + { + verifyComputeDerivativeGroupModifiers( + getSink(), + decl->loc, + d->findDecoration<IRDerivativeGroupQuadDecoration>(), + d->findDecoration<IRDerivativeGroupLinearDecoration>(), + d->findDecoration<IRNumThreadsDecoration>()); + } + return LoweredValInfo(); } @@ -9677,6 +9713,9 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo> addBitFieldAccessorDecorations(irFunc, decl); + IRNumThreadsDecoration* numThreadsDecor = nullptr; + IRDecoration* derivativeGroupQuadDecor = nullptr; + IRDecoration* derivativeGroupLinearDecor = nullptr; for (auto modifier : decl->modifiers) { if (as<RequiresNVAPIAttribute>(modifier)) @@ -9691,6 +9730,14 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo> { getBuilder()->addSimpleDecoration<IRNoInlineDecoration>(irFunc); } + else if (auto derivativeGroupQuadMod = as<DerivativeGroupQuadAttribute>(modifier)) + { + derivativeGroupQuadDecor = getBuilder()->addSimpleDecoration<IRDerivativeGroupQuadDecoration>(irFunc); + } + else if (auto derivativeGroupLinearMod = as<DerivativeGroupLinearAttribute>(modifier)) + { + derivativeGroupLinearDecor = getBuilder()->addSimpleDecoration<IRDerivativeGroupLinearDecoration>(irFunc); + } else if (auto instanceAttr = as<InstanceAttribute>(modifier)) { IRIntLit* intLit = _getIntLitFromAttribute(getBuilder(), instanceAttr); @@ -9703,12 +9750,13 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo> } else if (auto numThreadsAttr = as<NumThreadsAttribute>(modifier)) { - getBuilder()->addNumThreadsDecoration( - irFunc, - getSimpleVal(context, lowerVal(context, numThreadsAttr->x)), - getSimpleVal(context, lowerVal(context, numThreadsAttr->y)), - getSimpleVal(context, lowerVal(context, numThreadsAttr->z)) - ); + numThreadsDecor = as<IRNumThreadsDecoration>( + getBuilder()->addNumThreadsDecoration( + irFunc, + getSimpleVal(context, lowerVal(context, numThreadsAttr->x)), + getSimpleVal(context, lowerVal(context, numThreadsAttr->y)), + getSimpleVal(context, lowerVal(context, numThreadsAttr->z)) + )); } else if (auto waveSizeAttr = as<WaveSizeAttribute>(modifier)) { @@ -9862,6 +9910,13 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo> getBuilder()->addDecoration(irFunc, kIROp_NonDynamicUniformReturnDecoration); } + verifyComputeDerivativeGroupModifiers( + getSink(), + decl->loc, + derivativeGroupQuadDecor, + derivativeGroupLinearDecor, + numThreadsDecor); + if (!isInline) { // If there are any constant expr rate parameters, we should inline this function. diff --git a/source/slang/slang-parser.cpp b/source/slang/slang-parser.cpp index e0b964a45..23ee5e29d 100644 --- a/source/slang/slang-parser.cpp +++ b/source/slang/slang-parser.cpp @@ -8036,6 +8036,8 @@ namespace Slang ModifierListBuilder listBuilder; GLSLLayoutLocalSizeAttribute* numThreadsAttrib = nullptr; + GLSLLayoutDerivativeGroupQuadAttribute* derivativeGroupQuadAttrib = nullptr; + GLSLLayoutDerivativeGroupLinearAttribute* derivativeGroupLinearAttrib = nullptr; ImageFormat format; @@ -8082,6 +8084,14 @@ namespace Slang numThreadsAttrib->args[localSizeIndex] = expr; } } + else if (nameText == "derivative_group_quadsNV") + { + derivativeGroupQuadAttrib = parser->astBuilder->create<GLSLLayoutDerivativeGroupQuadAttribute>(); + } + else if (nameText == "derivative_group_linearNV") + { + derivativeGroupLinearAttrib = parser->astBuilder->create<GLSLLayoutDerivativeGroupLinearAttribute>(); + } else if (nameText == "binding" || nameText == "set") { @@ -8189,9 +8199,11 @@ namespace Slang #undef CASE if (numThreadsAttrib) - { listBuilder.add(numThreadsAttrib); - } + if(derivativeGroupQuadAttrib) + listBuilder.add(derivativeGroupQuadAttrib); + if(derivativeGroupLinearAttrib) + listBuilder.add(derivativeGroupLinearAttrib); listBuilder.add(parser->astBuilder->create<GLSLLayoutModifierGroupEnd>()); |
