diff options
| author | Yong He <yonghe@outlook.com> | 2024-02-02 22:28:02 -0800 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2024-02-02 22:28:02 -0800 |
| commit | 14764896c34b230a5563f48d8b8e565de2f3aa10 (patch) | |
| tree | 2f105d3f6222103f458054f1cd38e280b6fb52b4 | |
| parent | c15e7ade4e27e1649d5b98f5854e9e52bb9e60ae (diff) | |
Capability type checking. (#3530)
* Capability type checking.
* Fix.
---------
Co-authored-by: Yong He <yhe@nvidia.com>
51 files changed, 1869 insertions, 489 deletions
diff --git a/source/slang/core.meta.slang b/source/slang/core.meta.slang index 3ebb77f03..3961403e7 100644 --- a/source/slang/core.meta.slang +++ b/source/slang/core.meta.slang @@ -2475,6 +2475,9 @@ attribute_syntax [vk_image_format(format : String)] : FormatAttribute; __attributeTarget(Decl) attribute_syntax [allow(diagnostic: String)] : AllowAttribute; +__attributeTarget(Decl) +attribute_syntax[require(capability)] : RequireCapabilityAttribute; + // Linking __attributeTarget(Decl) attribute_syntax [__extern] : ExternAttribute; diff --git a/source/slang/hlsl.meta.slang b/source/slang/hlsl.meta.slang index 0b60bda0d..e5ebc8409 100644 --- a/source/slang/hlsl.meta.slang +++ b/source/slang/hlsl.meta.slang @@ -4716,7 +4716,6 @@ T GetAttributeAtVertex(T attribute, uint vertexIndex) { case hlsl: __intrinsic_asm "GetAttributeAtVertex"; - case _GL_NV_fragment_shader_barycentric: case _GL_EXT_fragment_shader_barycentric: __intrinsic_asm "$0[$1]"; case spirv: @@ -4749,7 +4748,6 @@ vector<T,N> GetAttributeAtVertex(vector<T,N> attribute, uint vertexIndex) { case hlsl: __intrinsic_asm "GetAttributeAtVertex"; - case _GL_NV_fragment_shader_barycentric: case _GL_EXT_fragment_shader_barycentric: __intrinsic_asm "$0[$1]"; case spirv: @@ -4782,7 +4780,6 @@ matrix<T,N,M> GetAttributeAtVertex(matrix<T,N,M> attribute, uint vertexIndex) { case hlsl: __intrinsic_asm "GetAttributeAtVertex"; - case _GL_NV_fragment_shader_barycentric: case _GL_EXT_fragment_shader_barycentric: __intrinsic_asm "$0[$1]"; case spirv: @@ -9288,8 +9285,7 @@ struct BuiltInTriangleIntersectionAttributes // `executeCallableNV` is the GLSL intrinsic that will be used to implement // `CallShader()` for GLSL-based targets. // -__target_intrinsic(GL_NV_ray_tracing, "executeCallableNV") -__target_intrinsic(GL_EXT_ray_tracing, "executeCallableEXT") +__target_intrinsic(_GL_EXT_ray_tracing, "executeCallableEXT") void __executeCallable(uint shaderIndex, int payloadLocation); // Next is the custom intrinsic that will compute the payload location @@ -9335,8 +9331,7 @@ void CallShader(uint shaderIndex, inout Payload payload) // 10.3.2 -__target_intrinsic(GL_NV_ray_tracing, "traceNV") -__target_intrinsic(GL_EXT_ray_tracing, "traceRayEXT") +__target_intrinsic(_GL_EXT_ray_tracing, "traceRayEXT") void __traceRay( RaytracingAccelerationStructure AccelerationStructure, uint RayFlags, @@ -9528,7 +9523,6 @@ bool __reportIntersection(float tHit, uint hitKind) __target_switch { case _GL_EXT_ray_tracing: __intrinsic_asm "reportIntersectionEXT"; - case _GL_NV_ray_tracing: __intrinsic_asm "reportIntersectionNV"; case spirv: return spirv_asm { result:$$bool = OpReportIntersectionKHR $tHit $hitKind; @@ -9555,7 +9549,6 @@ void IgnoreHit() { case hlsl: __intrinsic_asm "IgnoreHit"; case _GL_EXT_ray_tracing: __intrinsic_asm "ignoreIntersectionEXT;"; - case _GL_NV_ray_tracing: __intrinsic_asm "ignoreIntersectionNV"; case cuda: __intrinsic_asm "optixIgnoreIntersection"; case spirv: spirv_asm { OpIgnoreIntersectionKHR; %_ = OpLabel }; } @@ -9568,7 +9561,6 @@ void AcceptHitAndEndSearch() { case hlsl: __intrinsic_asm "AcceptHitAndEndSearch"; case _GL_EXT_ray_tracing: __intrinsic_asm "terminateRayEXT;"; - case _GL_NV_ray_tracing: __intrinsic_asm "terminateRayNV"; case cuda: __intrinsic_asm "optixTerminateRay"; case spirv: spirv_asm { OpTerminateRayKHR; %_ = OpLabel }; } @@ -9587,7 +9579,6 @@ uint3 DispatchRaysIndex() { case hlsl: __intrinsic_asm "DispatchRaysIndex"; case _GL_EXT_ray_tracing: __intrinsic_asm "(gl_LaunchIDEXT)"; - case _GL_NV_ray_tracing: __intrinsic_asm "(gl_LaunchIDNV)"; case cuda: __intrinsic_asm "optixGetLaunchIndex"; case spirv: return spirv_asm { @@ -9602,7 +9593,6 @@ uint3 DispatchRaysDimensions() { case hlsl: __intrinsic_asm "DispatchRaysDimensions"; case _GL_EXT_ray_tracing: __intrinsic_asm "(gl_LaunchSizeEXT)"; - case _GL_NV_ray_tracing: __intrinsic_asm "(gl_LaunchSizeNV)"; case cuda: __intrinsic_asm "optixGetLaunchDimensions"; case spirv: return spirv_asm { @@ -9619,7 +9609,6 @@ float3 WorldRayOrigin() { case hlsl: __intrinsic_asm "WorldRayOrigin"; case _GL_EXT_ray_tracing: __intrinsic_asm "(gl_WorldRayOriginEXT)"; - case _GL_NV_ray_tracing: __intrinsic_asm "(gl_WorldRayOriginNV)"; case cuda: __intrinsic_asm "optixGetWorldRayOrigin"; case spirv: return spirv_asm { @@ -9634,7 +9623,6 @@ float3 WorldRayDirection() { case hlsl: __intrinsic_asm "WorldRayDirection"; case _GL_EXT_ray_tracing: __intrinsic_asm "(gl_WorldRayDirectionEXT)"; - case _GL_NV_ray_tracing: __intrinsic_asm "(gl_WorldRayDirectionNV)"; case cuda: __intrinsic_asm "optixGetWorldRayDirection"; case spirv: return spirv_asm { @@ -9649,7 +9637,6 @@ float RayTMin() { case hlsl: __intrinsic_asm "RayTMin"; case _GL_EXT_ray_tracing: __intrinsic_asm "(gl_RayTminEXT)"; - case _GL_NV_ray_tracing: __intrinsic_asm "(gl_RayTminNV)"; case cuda: __intrinsic_asm "optixGetRayTmin"; case spirv: return spirv_asm { @@ -9674,7 +9661,6 @@ float RayTCurrent() { case hlsl: __intrinsic_asm "RayTCurrent"; case _GL_EXT_ray_tracing: __intrinsic_asm "(gl_RayTmaxEXT)"; - case _GL_NV_ray_tracing: __intrinsic_asm "(gl_RayTmaxNV)"; case cuda: __intrinsic_asm "optixGetRayTmax"; case spirv: return spirv_asm { @@ -9689,7 +9675,6 @@ uint RayFlags() { case hlsl: __intrinsic_asm "RayFlags"; case _GL_EXT_ray_tracing: __intrinsic_asm "(gl_IncomingRayFlagsEXT)"; - case _GL_NV_ray_tracing: __intrinsic_asm "(gl_IncomingRayFlagsNV)"; case cuda: __intrinsic_asm "optixGetRayFlags"; case spirv: return spirv_asm { @@ -9720,7 +9705,6 @@ uint InstanceID() { case hlsl: __intrinsic_asm "InstanceID"; case _GL_EXT_ray_tracing: __intrinsic_asm "(gl_InstanceCustomIndexEXT)"; - case _GL_NV_ray_tracing: __intrinsic_asm "(gl_InstanceCustomIndexNV)"; case cuda: __intrinsic_asm "optixGetInstanceId"; case spirv: return spirv_asm { @@ -9749,7 +9733,6 @@ float3 ObjectRayOrigin() { case hlsl: __intrinsic_asm "ObjectRayOrigin"; case _GL_EXT_ray_tracing: __intrinsic_asm "(gl_ObjectRayOriginEXT)"; - case _GL_NV_ray_tracing: __intrinsic_asm "(gl_ObjectRayOriginNV)"; case cuda: __intrinsic_asm "optixGetObjectRayOrigin"; case spirv: return spirv_asm { @@ -9764,7 +9747,6 @@ float3 ObjectRayDirection() { case hlsl: __intrinsic_asm "ObjectRayDirection"; case _GL_EXT_ray_tracing: __intrinsic_asm "(gl_ObjectRayDirectionEXT)"; - case _GL_NV_ray_tracing: __intrinsic_asm "(gl_ObjectRayDirectionNV)"; case cuda: __intrinsic_asm "optixGetObjectRayDirection"; case spirv: return spirv_asm { @@ -9781,7 +9763,6 @@ float3x4 ObjectToWorld3x4() { case hlsl: __intrinsic_asm "ObjectToWorld3x4"; case _GL_EXT_ray_tracing: __intrinsic_asm "transpose(gl_ObjectToWorldEXT)"; - case _GL_NV_ray_tracing: __intrinsic_asm "transpose(gl_ObjectToWorldNV)"; case spirv: return spirv_asm { %mat:$$float4x3 = OpLoad builtin(ObjectToWorldKHR:float4x3); @@ -9796,7 +9777,6 @@ float3x4 WorldToObject3x4() { case hlsl: __intrinsic_asm "WorldToObject3x4"; case _GL_EXT_ray_tracing: __intrinsic_asm "transpose(gl_WorldToObjectEXT)"; - case _GL_NV_ray_tracing: __intrinsic_asm "transpose(gl_WorldToObjectNV)"; case spirv: return spirv_asm { %mat:$$float4x3 = OpLoad builtin(WorldToObjectKHR:float4x3); @@ -9811,7 +9791,6 @@ float4x3 ObjectToWorld4x3() { case hlsl: __intrinsic_asm "ObjectToWorld4x3"; case _GL_EXT_ray_tracing: __intrinsic_asm "(gl_ObjectToWorldEXT)"; - case _GL_NV_ray_tracing: __intrinsic_asm "(gl_ObjectToWorldNV)"; case spirv: return spirv_asm { result:$$float4x3 = OpLoad builtin(ObjectToWorldKHR:float4x3); @@ -9825,7 +9804,6 @@ float4x3 WorldToObject4x3() { case hlsl: __intrinsic_asm "WorldToObject4x3"; case _GL_EXT_ray_tracing: __intrinsic_asm "(gl_WorldToObjectEXT)"; - case _GL_NV_ray_tracing: __intrinsic_asm "(gl_WorldToObjectNV)"; case spirv: return spirv_asm { result:$$float4x3 = OpLoad builtin(WorldToObjectKHR:float4x3); @@ -9872,7 +9850,6 @@ uint HitKind() { case hlsl: __intrinsic_asm "HitKind"; case _GL_EXT_ray_tracing: __intrinsic_asm "(gl_HitKindEXT)"; - case _GL_NV_ray_tracing: __intrinsic_asm "(gl_HitKindNV)"; case cuda: __intrinsic_asm "optixGetHitKind"; case spirv: return spirv_asm { @@ -11874,6 +11851,7 @@ void debugBreak(); [__requiresNVAPI] __glsl_extension(GL_EXT_shader_realtime_clock) +[require(shaderclock)] uint getRealtimeClockLow() { __target_switch @@ -11886,14 +11864,18 @@ uint getRealtimeClockLow() __intrinsic_asm "clock"; case spirv: return getRealtimeClock().x; + case cpp: + __intrinsic_asm "(uint32_t)std::chrono::high_resolution_clock::now().time_since_epoch().count()"; } } +__target_intrinsic(cpp, "std::chrono::high_resolution_clock::now().time_since_epoch().count()") __target_intrinsic(cuda, "clock64") -int64_t __cudaGetRealtimeClock(); +int64_t __cudaCppGetRealtimeClock(); [__requiresNVAPI] __glsl_extension(GL_EXT_shader_realtime_clock) +[require(shaderclock)] uint2 getRealtimeClock() { __target_switch @@ -11903,7 +11885,8 @@ uint2 getRealtimeClock() case glsl: __intrinsic_asm "clockRealtime2x32EXT()"; case cuda: - int64_t ticks = __cudaGetRealtimeClock(); + case cpp: + int64_t ticks = __cudaCppGetRealtimeClock(); return uint2(uint(ticks), uint(uint64_t(ticks) >> 32)); case spirv: return spirv_asm diff --git a/source/slang/slang-ast-base.h b/source/slang/slang-ast-base.h index 579bda73a..e11dbe259 100644 --- a/source/slang/slang-ast-base.h +++ b/source/slang/slang-ast-base.h @@ -6,7 +6,7 @@ #include "slang-generated-ast.h" #include "slang-ast-reflect.h" - +#include "slang-capability.h" #include "slang-serialize-reflection.h" // This file defines the primary base classes for the hierarchy of @@ -695,6 +695,11 @@ class ModifiableSyntaxNode : public SyntaxNode bool hasModifier() { return findModifier<T>() != nullptr; } }; +struct DeclReferenceWithLoc +{ + Decl* referencedDecl; + SourceLoc referenceLoc; +}; // An intermediate type to represent either a single declaration, or a group of declarations class DeclBase : public ModifiableSyntaxNode @@ -716,6 +721,7 @@ public: DeclRefBase* getDefaultDeclRef(); NameLoc nameAndLoc; + CapabilitySet inferredCapabilityRequirements; RefPtr<MarkupEntry> markup; @@ -736,6 +742,8 @@ public: } bool isChildOf(Decl* other) const; + // Track the decl reference that caused the requirement of a capability atom. + SLANG_UNREFLECTED Dictionary<CapabilityAtom, DeclReferenceWithLoc> capabilityRequirementProvenance; private: SLANG_UNREFLECTED DeclRefBase* m_defaultDeclRef = nullptr; SLANG_UNREFLECTED Index m_defaultDeclRefEpoch = -1; diff --git a/source/slang/slang-ast-dump.cpp b/source/slang/slang-ast-dump.cpp index cde7c6201..9c40fb12b 100644 --- a/source/slang/slang-ast-dump.cpp +++ b/source/slang/slang-ast-dump.cpp @@ -692,6 +692,31 @@ struct ASTDumpContext m_writer->emit("}"); } + void dump(const CapabilitySet& capSet) + { + m_writer->emit("capability_set("); + bool isFirstSet = true; + for (auto& set : capSet.getExpandedAtoms()) + { + if (!isFirstSet) + { + m_writer->emit(" | "); + } + bool isFirst = true; + for (auto atom : set.getExpandedAtoms()) + { + if (!isFirst) + { + m_writer->emit("+"); + } + dump(capabilityNameToString((CapabilityName)atom)); + isFirst = false; + } + isFirstSet = false; + } + m_writer->emit(")"); + } + void dumpObjectFull(NodeBase* node); ASTDumpContext(SourceWriter* writer, ASTDumpUtil::Flags flags, ASTDumpUtil::Style dumpStyle): diff --git a/source/slang/slang-ast-iterator.h b/source/slang/slang-ast-iterator.h index 2e8f02697..e2d0638e0 100644 --- a/source/slang/slang-ast-iterator.h +++ b/source/slang/slang-ast-iterator.h @@ -3,16 +3,14 @@ namespace Slang { -template <typename Callback> +template <typename Callback, typename Filter> struct ASTIterator { const Callback& callback; - UnownedStringSlice fileName; - SourceManager* sourceManager; - ASTIterator(const Callback& func, SourceManager* manager, UnownedStringSlice sourceFileName) + const Filter& filter; + ASTIterator(const Callback& func, const Filter& filterFunc) : callback(func) - , fileName(sourceFileName) - , sourceManager(manager) + , filter(filterFunc) {} void visitDecl(DeclBase* decl); @@ -429,13 +427,11 @@ struct ASTIterator }; }; -template <typename CallbackFunc> -void ASTIterator<CallbackFunc>::visitDecl(DeclBase* decl) +template <typename CallbackFunc, typename FilterFunc> +void ASTIterator<CallbackFunc, FilterFunc>::visitDecl(DeclBase* decl) { // Don't look at the decl if it is defined in a different file. - if (!as<NamespaceDeclBase>(decl) && !sourceManager->getHumaneLoc(decl->loc, SourceLocType::Actual) - .pathInfo.foundPath.getUnownedSlice() - .endsWithCaseInsensitive(fileName)) + if (!filter(decl)) return; maybeDispatchCallback(decl); @@ -490,24 +486,23 @@ void ASTIterator<CallbackFunc>::visitDecl(DeclBase* decl) } } } -template <typename CallbackFunc> -void ASTIterator<CallbackFunc>::visitExpr(Expr* expr) +template <typename CallbackFunc, typename FilterFunc> +void ASTIterator<CallbackFunc, FilterFunc>::visitExpr(Expr* expr) { ASTIteratorExprVisitor visitor(this); visitor.dispatchIfNotNull(expr); } -template <typename CallbackFunc> -void ASTIterator<CallbackFunc>::visitStmt(Stmt* stmt) +template <typename CallbackFunc, typename FilterFunc> +void ASTIterator<CallbackFunc, FilterFunc>::visitStmt(Stmt* stmt) { ASTIteratorStmtVisitor visitor(this); visitor.dispatchIfNotNull(stmt); } -template <typename Func> -void iterateAST( - UnownedStringSlice fileName, SourceManager* manager, SyntaxNode* node, const Func& f) +template <typename Func, typename FilterFunc> +void iterateAST(SyntaxNode* node, const FilterFunc& filterFunc, const Func& f) { - ASTIterator<Func> iter(f, manager, fileName); + ASTIterator<Func, FilterFunc> iter(f, filterFunc); if (auto decl = as<Decl>(node)) { iter.visitDecl(decl); @@ -521,4 +516,18 @@ void iterateAST( iter.visitStmt(stmt); } } + +template <typename Func> +void iterateASTWithLanguageServerFilter( + UnownedStringSlice fileName, SourceManager* sourceManager, SyntaxNode* node, const Func& f) +{ + auto filter = [&](DeclBase* decl) + { + return as<NamespaceDeclBase>(decl) || + sourceManager->getHumaneLoc(decl->loc, SourceLocType::Actual) + .pathInfo.foundPath.getUnownedSlice() + .endsWithCaseInsensitive(fileName); + }; + iterateAST(node, filter, f); +} } // namespace Slang diff --git a/source/slang/slang-ast-modifier.h b/source/slang/slang-ast-modifier.h index 53d13d1b5..dae8966de 100644 --- a/source/slang/slang-ast-modifier.h +++ b/source/slang/slang-ast-modifier.h @@ -605,7 +605,6 @@ class Attribute : public AttributeBase class UserDefinedAttribute : public Attribute { SLANG_AST_CLASS(UserDefinedAttribute) - }; class AttributeUsageAttribute : public Attribute @@ -615,6 +614,14 @@ class AttributeUsageAttribute : public Attribute SyntaxClass<NodeBase> targetSyntaxClass; }; + +class RequireCapabilityAttribute : public Attribute +{ + SLANG_AST_CLASS(RequireCapabilityAttribute) + CapabilitySet capabilitySet; +}; + + // An `[unroll]` or `[unroll(count)]` attribute class UnrollAttribute : public Attribute { diff --git a/source/slang/slang-ast-stmt.h b/source/slang/slang-ast-stmt.h index af1fe9ec1..055785333 100644 --- a/source/slang/slang-ast-stmt.h +++ b/source/slang/slang-ast-stmt.h @@ -98,6 +98,7 @@ class TargetCaseStmt : public Stmt { SLANG_AST_CLASS(TargetCaseStmt) int32_t capability; + Token capabilityToken; Stmt* body = nullptr; }; diff --git a/source/slang/slang-ast-support-types.h b/source/slang/slang-ast-support-types.h index 955aff06c..882e26078 100644 --- a/source/slang/slang-ast-support-types.h +++ b/source/slang/slang-ast-support-types.h @@ -51,7 +51,7 @@ namespace Slang class NodeBase; class LookupDeclRef; class GenericAppDeclRef; - + struct CapabilitySet; template <typename T> T* as(NodeBase* node); @@ -66,6 +66,8 @@ namespace Slang void printDiagnosticArg(StringBuilder& sb, Val* val); void printDiagnosticArg(StringBuilder& sb, DeclRefBase* declRefBase); void printDiagnosticArg(StringBuilder& sb, ASTNodeType nodeType); + void printDiagnosticArg(StringBuilder& sb, const CapabilitySet& set); + struct QualifiedDeclPath { @@ -442,7 +444,7 @@ namespace Slang /// checking the function body. AttributesChecked, - /// The declaration is fully checked. + /// The body/definition is checked. /// /// This step includes any validation of the declaration that is /// immaterial to clients code using the declaration, but that is @@ -453,7 +455,11 @@ namespace Slang /// but we still need to (eventually) check the bodies of all /// functions, so it belongs in the last phase of checking. /// - Checked, + DefinitionChecked, + + /// The capabilities required by the decl is infered and validated. + /// + CapabilityChecked, // For convenience at sites that call `ensureDecl()`, we define // some aliases for the above states that are expressed in terms diff --git a/source/slang/slang-capabilities.capdef b/source/slang/slang-capabilities.capdef index 7e5931edf..5f15b8e57 100644 --- a/source/slang/slang-capabilities.capdef +++ b/source/slang/slang-capabilities.capdef @@ -13,9 +13,9 @@ // // A capability name is defined by a unique disjunction of conjunction of capability atoms. // For example, `raytracing` is a name that expands to -// glsl + _GL_EXT_ray_tracing | spirv_1_4 + _GL_EXT_ray_tracing | hlsl + _sm_6_4 -// which means it requires the `_GL_EXT_ray_tracing` extension when generating code for glsl -// or spirv, and requires sm_6_4 when generating hlsl. +// glsl + _GL_EXT_ray_tracing | spirv_1_4 + SPV_KHR_ray_tracing | hlsl + _sm_6_4 +// which means it requires the `_GL_EXT_ray_tracing` extension when generating code for glsl, +// requires SPV_KHR_ray_tracing for spirv, and requires sm_6_4 when generating hlsl. // // There are three types of capability definitions: // - `def`: this will introduce an new capability atom. If there is an inheritance clause, @@ -63,7 +63,7 @@ alias spirv_latest = spirv_1_6; // Capabilities that stand for target spirv version for GLSL backend. // These are not compilation targets. -def glsl_spirv_1_0; +def glsl_spirv_1_0 : glsl; def glsl_spirv_1_1 : glsl_spirv_1_0; def glsl_spirv_1_2 : glsl_spirv_1_1; def glsl_spirv_1_3 : glsl_spirv_1_2; @@ -80,13 +80,24 @@ def hull : stage; def domain : stage; def geometry : stage; def raygen : stage; +alias raygeneration = raygen; def intersection : stage; def anyhit : stage; def closesthit: stage; def miss : stage; def mesh : stage; +def amplification : stage; +def callable : stage; -def _sm_5_1 : hlsl; +alias all_stages = vertex + fragment + compute + hull + domain + + geometry + raygen + intersection + anyhit + + closesthit + miss + mesh + amplification + + callable; + +def _sm_4_0 : hlsl; +def _sm_4_1 : _sm_4_0; +def _sm_5_0 : _sm_4_1; +def _sm_5_1 : _sm_5_0; def _sm_6_0 : _sm_5_1; def _sm_6_1 : _sm_6_0; def _sm_6_2 : _sm_6_1; @@ -96,6 +107,60 @@ def _sm_6_5 : _sm_6_4; def _sm_6_6 : _sm_6_5; def _sm_6_7 : _sm_6_6; +def hlsl_nvapi : hlsl; + +// SPIRV extensions. + +def SPV_EXT_fragment_shader_interlock : spirv_1_0; +def SPV_KHR_fragment_shader_barycentric : spirv_1_0; +def SPV_EXT_fragment_fully_covered : spirv_1_0; +def SPV_EXT_descriptor_indexing : spirv_1_0; +def SPV_EXT_shader_atomic_float_add : spirv_1_0; +def SPV_EXT_shader_atomic_float16_add : SPV_EXT_shader_atomic_float_add; +def SPV_EXT_shader_atomic_float_min_max : spirv_1_0; +def SPV_KHR_non_semantic_info : spirv_1_0; +def SPV_NV_shader_subgroup_partitioned : spirv_1_0; +def SPV_NV_ray_tracing_motion_blur : spirv_1_0; +def SPV_EXT_mesh_shader : spirv_1_4; +def SPV_KHR_ray_tracing : spirv_1_4; +def SPV_KHR_ray_query : spirv_1_0; +def SPV_KHR_ray_tracing_position_fetch : SPV_KHR_ray_tracing + SPV_KHR_ray_query; +def SPV_NV_shader_invocation_reorder : spirv_1_5 + SPV_KHR_ray_tracing; +def SPV_KHR_shader_clock : spirv_1_0; +def SPV_NV_shader_image_footprint : spirv_1_0; +def SPV_GOOGLE_user_type : spirv_1_0; + +// SPIRV Capabilities. + +def spvAtomicFloat32AddEXT : SPV_EXT_shader_atomic_float_add; +def spvAtomicFloat16AddEXT : SPV_EXT_shader_atomic_float16_add; +def spvInt64Atomics : spirv_1_0; +def spvAtomicFloat32MinMaxEXT : SPV_EXT_shader_atomic_float_min_max; +def spvAtomicFloat16MinMaxEXT : SPV_EXT_shader_atomic_float_min_max; +def spvDerivativeControl : spirv_1_0; +def spvImageQuery : spirv_1_0; +def spvImageGatherExtended : spirv_1_0; +def spvImageFootprintNV : SPV_NV_shader_image_footprint; +def spvMinLod : spirv_1_0; +def spvFragmentShaderPixelInterlockEXT : SPV_EXT_fragment_shader_interlock; +def spvFragmentBarycentricKHR : SPV_KHR_fragment_shader_barycentric; +def spvFragmentFullyCoveredEXT : SPV_EXT_fragment_fully_covered; +def spvGroupNonUniformBallot : spirv_1_3; +def spvGroupNonUniformShuffle : spirv_1_3; +def spvGroupNonUniformArithmetic : spirv_1_3; +def spvGroupNonUniformQuad : spirv_1_3; +def spvGroupNonUniformVote : spirv_1_3; +def spvGroupNonUniformPartitionedNV : spirv_1_3 + SPV_NV_shader_subgroup_partitioned; +def spvRayTracingMotionBlurNV : SPV_NV_ray_tracing_motion_blur; +def spvMeshShadingEXT : SPV_EXT_mesh_shader; +def spvRayTracingKHR : SPV_KHR_ray_tracing; +def spvRayTracingPositionFetchKHR : SPV_KHR_ray_tracing_position_fetch; +def spvRayQueryKHR : SPV_KHR_ray_query; +def spvRayQueryPositionFetchKHR : SPV_KHR_ray_tracing_position_fetch; +def spvShaderInvocationReorderNV : SPV_NV_shader_invocation_reorder; +def spvShaderClockKHR : SPV_KHR_shader_clock; +def spvShaderNonUniform : spirv_1_5; + // The following capabilities all pertain to how ray tracing shaders are translated // to GLSL, where there are two different extensions that can provide the core // functionality of `TraceRay` and the related operations. @@ -104,32 +169,213 @@ def _sm_6_7 : _sm_6_6; // as conflicting on the `RayTracingExtension` axis, so that a compilation target // cannot have both enabled at once. // -// The `GL_EXT_ray_tracing` extension should be favored, so it has a rank of `1` +// The `_GL_EXT_ray_tracing` extension should be favored, so it has a rank of `1` // instead of `0`, which means that when comparing overloads that require these // extensions, the `EXT` extension will be favored over the `NV` extension, if // all other factors are equal. // -def _GL_EXT_ray_tracing : glsl + glsl_spirv_1_4 = 1; -def _GL_NV_ray_tracing : _GL_EXT_ray_tracing; -def _SPV_KHR_ray_tracing : spirv_1_4; -alias GL_NV_ray_tracing = _GL_NV_ray_tracing | _SPV_KHR_ray_tracing | _sm_6_4 | cuda; -alias GL_EXT_ray_tracing = _GL_EXT_ray_tracing | _SPV_KHR_ray_tracing | _sm_6_4 | cuda; -alias raytracing = GL_EXT_ray_tracing; - -def _GL_EXT_fragment_shader_barycentric : glsl + fragment; -def _GL_NV_fragment_shader_barycentric : _GL_EXT_fragment_shader_barycentric; -def _SPV_KHR_fragment_shader_barycentric : spirv_1_0 + fragment; -alias GL_NV_fragment_shader_barycentric = _GL_NV_fragment_shader_barycentric | _SPV_KHR_fragment_shader_barycentric | hlsl + fragment; -alias GL_EXT_fragment_shader_barycentric = _GL_EXT_fragment_shader_barycentric | _SPV_KHR_fragment_shader_barycentric | hlsl + fragment; - -// TODO: define what SM means for all supported targets. - -alias sm_5_1 = _sm_5_1; -alias sm_6_0 = _sm_6_0; -alias sm_6_1 = _sm_6_1; -alias sm_6_2 = _sm_6_2; -alias sm_6_3 = _sm_6_3; -alias sm_6_4 = _sm_6_4 | raytracing; -alias sm_6_5 = _sm_6_5; -alias sm_6_6 = _sm_6_6; -alias sm_6_7 = _sm_6_7; + +def _GL_ARB_derivative_control : glsl; +def _GL_ARB_fragment_shader_interlock : glsl; +def _GL_ARB_gpu_shader5 : glsl; +def _GL_ARB_sparse_texture_clamp : glsl; +def _GL_EXT_buffer_reference : glsl; +def _GL_EXT_debug_printf : glsl; +def _GL_EXT_fragment_shader_barycentric : glsl; +def _GL_EXT_mesh_shader : glsl; +def _GL_EXT_nonuniform_qualifier : glsl; +def _GL_EXT_ray_tracing : glsl_spirv_1_4; +def _GL_EXT_ray_tracing_position_fetch : glsl_spirv_1_4; +def _GL_EXT_samplerless_texture_functions : glsl; +def _GL_EXT_shader_atomic_float : glsl; +def _GL_EXT_shader_atomic_float2 : glsl; +def _GL_EXT_shader_atomic_int64 : glsl; +def _GL_EXT_shader_atomic_float_min_max : glsl; +def _GL_EXT_shader_explicit_arithmetic_types_int64 : glsl; +def _GL_EXT_shader_realtime_clock : glsl; +def _GL_EXT_texture_shadow_lod : glsl; +def _GL_KHR_memory_scope_semantics : glsl; +def _GL_KHR_shader_subgroup_arithmetic : glsl; +def _GL_KHR_shader_subgroup_basic : glsl; +def _GL_KHR_shader_subgroup_ballot : glsl; +def _GL_KHR_shader_subgroup_quad : glsl; +def _GL_KHR_shader_subgroup_shuffle : glsl; +def _GL_KHR_shader_subgroup_vote : glsl; +def _GL_NV_shader_subgroup_partitioned : glsl; +def _GL_NV_ray_tracing_motion_blur : glsl_spirv_1_4; +def _GL_NV_shader_invocation_reorder : glsl_spirv_1_4; +def _GL_NV_shader_texture_footprint : glsl; +alias _GL_NV_fragment_shader_barycentric = _GL_EXT_fragment_shader_barycentric; +alias _GL_NV_ray_tracing = _GL_EXT_ray_tracing; + +// GLSL extension and SPV extension associations. +alias GL_ARB_derivative_control = _GL_ARB_derivative_control | spvDerivativeControl; +alias GL_ARB_fragment_shader_interlock = _GL_ARB_fragment_shader_interlock | spvFragmentShaderPixelInterlockEXT; +alias GL_ARB_gpu_shader5 = _GL_ARB_fragment_shader_interlock | spirv_1_0; +alias GL_ARB_sparse_texture_clamp = _GL_ARB_fragment_shader_interlock | spirv_1_0; +alias GL_EXT_buffer_reference = _GL_ARB_fragment_shader_interlock | spirv_1_5; +alias GL_EXT_debug_printf = _GL_EXT_debug_printf | SPV_KHR_non_semantic_info; +alias GL_EXT_fragment_shader_barycentric = _GL_EXT_fragment_shader_barycentric | spvFragmentBarycentricKHR; +alias GL_EXT_mesh_shader = _GL_EXT_mesh_shader | spvMeshShadingEXT; +alias GL_EXT_nonuniform_qualifier = _GL_EXT_nonuniform_qualifier | spvShaderNonUniform; +alias GL_EXT_ray_tracing = _GL_EXT_ray_tracing | spvRayTracingKHR + spvRayQueryKHR; +alias GL_EXT_ray_tracing_position_fetch = _GL_EXT_ray_tracing_position_fetch | spvRayTracingPositionFetchKHR + spvRayQueryPositionFetchKHR; +alias GL_EXT_samplerless_texture_functions = _GL_EXT_samplerless_texture_functions | spirv_1_0; +alias GL_EXT_shader_atomic_float = _GL_EXT_shader_atomic_float | spvAtomicFloat32AddEXT + spvAtomicFloat32MinMaxEXT; +alias GL_EXT_shader_atomic_float2 = _GL_EXT_shader_atomic_float2 | spvAtomicFloat32AddEXT + spvAtomicFloat32MinMaxEXT + spvAtomicFloat16AddEXT + spvAtomicFloat16MinMaxEXT; +alias GL_EXT_shader_atomic_int64 = _GL_EXT_shader_atomic_int64 | spvInt64Atomics; +alias GL_EXT_shader_atomic_float_min_max = _GL_EXT_shader_atomic_float_min_max | spvAtomicFloat32MinMaxEXT + spvAtomicFloat16MinMaxEXT; +alias GL_EXT_shader_explicit_arithmetic_types_int64 = _GL_EXT_shader_explicit_arithmetic_types_int64 | spirv_1_0; +alias GL_EXT_shader_realtime_clock = _GL_EXT_shader_realtime_clock | spvShaderClockKHR; +alias GL_EXT_texture_shadow_lod = _GL_EXT_texture_shadow_lod | spirv_1_0; +alias GL_KHR_memory_scope_semantics = _GL_KHR_memory_scope_semantics | spirv_1_0; +alias GL_KHR_shader_subgroup_arithmetic = _GL_KHR_shader_subgroup_arithmetic | spvGroupNonUniformArithmetic; +alias GL_KHR_shader_subgroup_basic = _GL_KHR_shader_subgroup_basic | spvGroupNonUniformBallot; +alias GL_KHR_shader_subgroup_ballot = _GL_KHR_shader_subgroup_ballot | spvGroupNonUniformBallot; +alias GL_KHR_shader_subgroup_quad = _GL_KHR_shader_subgroup_quad | spvGroupNonUniformQuad; +alias GL_KHR_shader_subgroup_shuffle = _GL_KHR_shader_subgroup_shuffle | spvGroupNonUniformShuffle; +alias GL_KHR_shader_subgroup_vote = _GL_KHR_shader_subgroup_vote | spvGroupNonUniformVote; +alias GL_NV_shader_subgroup_partitioned = _GL_NV_shader_subgroup_partitioned | spvGroupNonUniformPartitionedNV; +alias GL_NV_ray_tracing_motion_blur = _GL_NV_ray_tracing_motion_blur | spvRayTracingMotionBlurNV; +alias GL_NV_shader_invocation_reorder = _GL_NV_shader_invocation_reorder | spvShaderInvocationReorderNV; +alias GL_NV_shader_texture_footprint = _GL_NV_shader_texture_footprint | spvImageFootprintNV; + +alias GL_NV_fragment_shader_barycentric = GL_EXT_fragment_shader_barycentric; +alias GL_NV_ray_tracing = GL_EXT_ray_tracing; + +// Define feature names + +alias nvapi = hlsl_nvapi; +alias raytracing = spvRayTracingKHR + spvRayQueryKHR + spvRayQueryPositionFetchKHR | _GL_EXT_ray_tracing+_GL_EXT_ray_tracing_position_fetch | _sm_6_5 | cuda; +alias ser = spvShaderInvocationReorderNV | _GL_NV_shader_invocation_reorder | _sm_6_6 + hlsl_nvapi; +alias shaderclock = spvShaderClockKHR | hlsl_nvapi | _GL_EXT_shader_realtime_clock | cpp | cuda; +alias meshshading = spvMeshShadingEXT | _sm_6_5 | _GL_EXT_mesh_shader; +alias motionblur = spvRayTracingMotionBlurNV | hlsl_nvapi | _GL_NV_ray_tracing_motion_blur; +alias texturefootprint = GL_NV_shader_texture_footprint | hlsl_nvapi; +alias fragmentshaderinterlock = _GL_ARB_fragment_shader_interlock | hlsl_nvapi | spvFragmentShaderPixelInterlockEXT; +alias atomic64 = GL_EXT_shader_atomic_int64 | _sm_6_6 | cpp | cuda; +alias atomicfloat = GL_EXT_shader_atomic_float | _sm_6_0 + hlsl_nvapi | cpp | cuda; +alias atomicfloat2 = GL_EXT_shader_atomic_float2 | _sm_6_6 + hlsl_nvapi | cpp | cuda; +alias groupnonuniform = GL_KHR_shader_subgroup_ballot + GL_KHR_shader_subgroup_shuffle + + GL_KHR_shader_subgroup_arithmetic + GL_KHR_shader_subgroup_quad + GL_KHR_shader_subgroup_vote + | _sm_6_0 | cuda; +alias fragmentshaderbarycentric = GL_EXT_fragment_shader_barycentric | _sm_6_1; + + +// Define what each HLSL shader model means on different targets. + + +alias sm_4_0 = _sm_4_0 + | glsl_spirv_1_0 + | spirv_1_0 + spvImageQuery + spvImageGatherExtended + spvMinLod + SPV_GOOGLE_user_type + | cuda + | cpp; + +alias sm_4_1 = _sm_4_1 + | glsl_spirv_1_0 + sm_4_0 + | spirv_1_0 + sm_4_0 + | cuda + | cpp; + +alias sm_5_0 = _sm_5_0 + | glsl_spirv_1_0 + sm_4_1 + _GL_KHR_memory_scope_semantics + | spirv_1_0 + sm_4_1 + spvDerivativeControl + spvFragmentFullyCoveredEXT + | cuda + | cpp; + +alias sm_5_1 = _sm_5_1 + | glsl_spirv_1_0 + sm_5_0 + _GL_ARB_gpu_shader5 + _GL_ARB_sparse_texture_clamp + _GL_EXT_nonuniform_qualifier + | spirv_1_0 + sm_5_0 + spvShaderNonUniform + | cuda + | cpp; + +alias sm_6_0 = _sm_6_0 + | glsl_spirv_1_3 + sm_5_1 + + groupnonuniform + atomicfloat + | spirv_1_3 + sm_5_1 + + groupnonuniform + atomicfloat + | cuda + | cpp; + +alias sm_6_1 = _sm_6_1 + | glsl_spirv_1_3 + sm_6_0 + fragmentshaderbarycentric + | spirv_1_3 + sm_6_0 + fragmentshaderbarycentric + | cuda + | cpp; + +alias sm_6_2 = _sm_6_2 + | glsl_spirv_1_3 + sm_6_1 + | spirv_1_3 + sm_6_1 + | cuda + | cpp; + +alias sm_6_3 = _sm_6_3 + | glsl_spirv_1_4 + sm_6_2 + _GL_EXT_ray_tracing + | spirv_1_4 + sm_6_2 + SPV_KHR_ray_tracing + | cuda + | cpp; + +alias sm_6_4 = _sm_6_4 + | glsl_spirv_1_4 + sm_6_3 + | spirv_1_4 + sm_6_3 + | cuda + | cpp; + +alias sm_6_5 = _sm_6_5 + | glsl_spirv_1_4 + sm_6_4 + raytracing + meshshading + | spirv_1_4 + sm_6_4 + raytracing + meshshading + | cuda + | cpp; + +alias sm_6_6 = _sm_6_6 + | glsl_spirv_1_5 + sm_6_5 + + GL_EXT_shader_atomic_int64 + atomicfloat2 + | spirv_1_5 + sm_6_5 + + GL_EXT_shader_atomic_int64 + atomicfloat2 + + SPV_EXT_descriptor_indexing + | cuda + | cpp; + +alias sm_6_7 = _sm_6_7 + | glsl_spirv_1_5 + sm_6_6 + | spirv_1_5 + sm_6_6 + | cuda + | cpp; + +alias all = _sm_6_7 + hlsl_nvapi + | glsl_spirv_1_5 + sm_6_7 + + ser + shaderclock + texturefootprint + fragmentshaderinterlock + _GL_NV_shader_subgroup_partitioned + + _GL_NV_ray_tracing_motion_blur + _GL_NV_shader_texture_footprint + | spirv_1_5 + sm_6_7 + + ser + shaderclock + texturefootprint + fragmentshaderinterlock + spvGroupNonUniformPartitionedNV + + spvRayTracingMotionBlurNV + spvRayTracingMotionBlurNV; + +// Profiles + +alias GLSL_150 = glsl + sm_5_1 | spirv_1_0; +alias GLSL_330 = GLSL_150 | spirv_1_0 + sm_5_1; +alias GLSL_400 = GLSL_150 | spirv_1_0 + sm_5_1; +alias GLSL_410 = glsl + sm_5_1 | spirv_1_5 + sm_5_1; +alias GLSL_420 = glsl + sm_5_1 | spirv_1_5 + sm_5_1; +alias GLSL_430 = glsl + sm_5_1 | spirv_1_5 + sm_5_1; +alias GLSL_440 = glsl + sm_6_0 | spirv_1_5 + sm_6_0; +alias GLSL_450 = glsl + sm_6_3 | spirv_1_5 + sm_6_3; +alias GLSL_460 = glsl_spirv_1_5 + all | spirv_1_5 + all; + +alias tess_control = hull; +alias tess_eval = domain; + +alias DX_4_0 = sm_4_0; +alias DX_4_1 = sm_4_1; +alias DX_5_0 = sm_5_0; +alias DX_5_1 = sm_5_1; +alias DX_6_0 = sm_6_0; +alias DX_6_1 = sm_6_1; +alias DX_6_2 = sm_6_2; +alias DX_6_3 = sm_6_3; +alias DX_6_4 = sm_6_4; +alias DX_6_5 = sm_6_5; +alias DX_6_6 = sm_6_6; +alias DX_6_7 = sm_6_7; + + diff --git a/source/slang/slang-capability.cpp b/source/slang/slang-capability.cpp index a3f8157e7..fbe37892a 100644 --- a/source/slang/slang-capability.cpp +++ b/source/slang/slang-capability.cpp @@ -96,6 +96,16 @@ void getCapabilityNames(List<UnownedStringSlice>& ioNames) } } +UnownedStringSlice capabilityNameToString(CapabilityName name) +{ + return UnownedStringSlice(_getInfo(name).name); +} + +bool isDirectChildOfAbstractAtom(CapabilityAtom name) +{ + return _getInfo(name).abstractBase != CapabilityName::Invalid; +} + bool lookupCapabilityName(const UnownedStringSlice& str, CapabilityName& value); CapabilityName findCapabilityName(UnownedStringSlice const& name) @@ -482,7 +492,10 @@ bool CapabilityConjunctionSet::implies(CapabilityConjunctionSet const& that) con return false; } } - return true; + // We reached the end of either this or that atom. + // If we reached the end of 'that', we know everything in 'that' + // is also contained in this, so this implies that. + return thatIndex == thatCount; } /// Helper functor for binary search on lists of `CapabilityAtom` @@ -935,6 +948,46 @@ void CapabilitySet::calcCompactedAtoms(List<List<CapabilityAtom>>& outAtoms) con } } +void CapabilitySet::unionWith(const CapabilityConjunctionSet& conjunctionToAdd) +{ + // We add conjunctionToAdd to resultSet only if it does not imply any existing conjunctions. + // For example, if `resultSet` is (a), and conjunctionToAdd is (ab), then we don't want to add the conjunction + // to form (a | ab) because that would reduce to (a). + bool skipAdd = false; + for (auto& c : m_conjunctions) + { + if (conjunctionToAdd.implies(c)) + { + skipAdd = true; + break; + } + } + if (!skipAdd) + { + // Once we added the new conjunction, any existing conjunctions that implies the new one can be + // removed. + // For example, if resultSet was (ab), and we are adding (a), the result should be just (a). + for (Index i = 0; i < m_conjunctions.getCount();) + { + if (m_conjunctions[i].implies(conjunctionToAdd)) + { + m_conjunctions.fastRemoveAt(i); + } + else + { + i++; + } + } + m_conjunctions.add(conjunctionToAdd); + } +} + +void CapabilitySet::canonicalize() +{ + // Make sure conjunctions are sorted so equality tests are trivial. + m_conjunctions.sort(); +} + void CapabilitySet::join(const CapabilitySet& other) { if (isEmpty() || other.isInvalid()) @@ -947,7 +1000,7 @@ void CapabilitySet::join(const CapabilitySet& other) if (other.isEmpty()) return; - List<CapabilityConjunctionSet> resultSet; + CapabilitySet resultSet; for (auto& thatConjunction : other.m_conjunctions) { for (auto& thisConjunction : m_conjunctions) @@ -980,42 +1033,20 @@ void CapabilitySet::join(const CapabilitySet& other) // Otherwise, thisConjunction implies thatConjunction, so we just add thisConjunction to resultSet. conjunctionToAdd = &thisConjunction; } - // We add conjunctionToAdd to resultSet only if it does not imply any existing conjunctions. - // For example, if `resultSet` is (a), and conjunctionToAdd is (ab), then we don't want to add the conjunction - // to form (a | ab) because that would reduce to (a). - bool skipAdd = false; - for (auto& c : resultSet) - { - if (conjunctionToAdd->implies(c)) - { - skipAdd = true; - break; - } - } - if (!skipAdd) - { - // Once we added the new conjunction, any existing conjunctions that implies the new one can be - // removed. - // For example, if resultSet was (ab), and we are adding (a), the result should be just (a). - for (Index i = 0; i < resultSet.getCount();) - { - if (resultSet[i].implies(*conjunctionToAdd)) - { - resultSet.fastRemoveAt(i); - } - else - { - i++; - } - } - resultSet.add(*conjunctionToAdd); - } + resultSet.unionWith(*conjunctionToAdd); } } - m_conjunctions = _Move(resultSet); + m_conjunctions = _Move(resultSet.m_conjunctions); - // Make sure conjunctions are sorted so equality tests are trivial. - m_conjunctions.sort(); + if (m_conjunctions.getCount() == 0) + { + // If the result is empty, then we should return as impossible. + *this = CapabilitySet::makeInvalid(); + } + else + { + canonicalize(); + } } bool CapabilitySet::isBetterForTarget(CapabilitySet const& that, CapabilitySet const& targetCaps) const @@ -1102,4 +1133,86 @@ bool CapabilitySet::isBetterForTarget(CapabilitySet const& that, CapabilitySet c return false; } +bool CapabilitySet::checkCapabilityRequirement(CapabilitySet const& available, CapabilitySet const& required, const CapabilityConjunctionSet*& outFailedAvailableSet) +{ + // Requirements x are met by available disjoint capabilities (a | b) iff + // both 'a' satisfies x and 'b' satisfies x. + // If we have a caller function F() decorated with: + // [require(hlsl, _sm_6_3)] [require(spirv, _spv_ray_tracing)] void F() { g(); } + // We'd better make sure that `g()` can be compiled with both (hlsl+_sm_6_3) and (spirv+_spv_ray_tracing) capability sets. + // In this method, F()'s capability declaration is represented by `available`, + // and g()'s capability is represented by `required`. + // We will check that for every capability conjunction X of F(), there is one capability conjunction Y in g() such that X implies Y. + // + + outFailedAvailableSet = nullptr; + + if (required.isInvalid()) + return false; + + // If F's capability is empty, we can satisfy any non-empty requirements. + // + if (available.isEmpty() && !required.isEmpty()) + return false; + + for (auto& availTargetSet : available.getExpandedAtoms()) + { + bool implied = false; + for (auto& requiredTargetSet : required.getExpandedAtoms()) + { + if (availTargetSet.implies(requiredTargetSet)) + { + implied = true; + break; + } + } + if (!implied) + { + outFailedAvailableSet = &availTargetSet; + return false; + } + } + + return true; +} + +void printDiagnosticArg(StringBuilder& sb, const CapabilitySet& capSet) +{ + bool isFirstSet = true; + for (auto& set : capSet.getExpandedAtoms()) + { + List<CapabilityAtom> compactAtomList; + set.calcCompactedAtoms(compactAtomList); + + if (!isFirstSet) + { + sb<< " | "; + } + bool isFirst = true; + for (auto atom : compactAtomList) + { + if (!isFirst) + { + sb << " + "; + } + auto name = capabilityNameToString((CapabilityName)atom); + if (name.startsWith("_")) + name = name.tail(1); + sb << name; + isFirst = false; + } + isFirstSet = false; + } +} + +void printDiagnosticArg(StringBuilder& sb, CapabilityAtom atom) +{ + printDiagnosticArg(sb, (CapabilityName)atom); +} + +void printDiagnosticArg(StringBuilder& sb, CapabilityName name) +{ + sb << _getInfo(name).name; +} + } diff --git a/source/slang/slang-capability.h b/source/slang/slang-capability.h index e686ff5a9..b0ca9231a 100644 --- a/source/slang/slang-capability.h +++ b/source/slang/slang-capability.h @@ -108,6 +108,7 @@ public: /// Does this capability set imply all the capabilities in `other`? bool implies(CapabilityConjunctionSet const& other) const; + /// Does this capability set imply the atomic capability `other`? bool implies(CapabilityAtom other) const; @@ -206,6 +207,10 @@ public: /// Join two capability sets to form (this & other). void join(const CapabilitySet& other); + void unionWith(const CapabilityConjunctionSet& other); + + void canonicalize(); + /// Are these two capability sets equal? bool operator==(CapabilitySet const& that) const; @@ -219,6 +224,8 @@ public: bool isBetterForTarget(CapabilitySet const& that, CapabilitySet const& targetCaps) const; + static bool checkCapabilityRequirement(CapabilitySet const& available, CapabilitySet const& required, const CapabilityConjunctionSet*& outFailedAvailableSet); + private: // The underlying representation we use is a list of conjunctions. // @@ -242,4 +249,11 @@ CapabilityName findCapabilityName(UnownedStringSlice const& name); /// Gets the capability names. void getCapabilityNames(List<UnownedStringSlice>& ioNames); +UnownedStringSlice capabilityNameToString(CapabilityName name); + +bool isDirectChildOfAbstractAtom(CapabilityAtom name); + +void printDiagnosticArg(StringBuilder& sb, CapabilityAtom atom); +void printDiagnosticArg(StringBuilder& sb, CapabilityName name); + } diff --git a/source/slang/slang-check-decl.cpp b/source/slang/slang-check-decl.cpp index fe4a7d64c..a7e197d81 100644 --- a/source/slang/slang-check-decl.cpp +++ b/source/slang/slang-check-decl.cpp @@ -14,7 +14,7 @@ #include "slang-syntax.h" #include "slang-ast-synthesis.h" #include "slang-ast-reflect.h" - +#include "slang-ast-iterator.h" #include <limits> namespace Slang @@ -304,6 +304,416 @@ namespace Slang void visitParamDecl(ParamDecl* paramDecl); }; + template<typename VisitorType> + struct SemanticsDeclReferenceVisitor + : public SemanticsDeclVisitorBase + , public StmtVisitor<VisitorType> + , public ExprVisitor<VisitorType> + , public ValVisitor<VisitorType> + , public DeclVisitor<VisitorType> + { + SemanticsDeclReferenceVisitor(SemanticsContext const& outer) + : SemanticsDeclVisitorBase(outer) + {} + + List<SourceLoc> sourceLocStack; + + struct PushSourceLocRAII + { + List<SourceLoc>& stack; + bool shouldPop = false; + PushSourceLocRAII(List<SourceLoc>& sourceLocStack, SourceLoc loc) + : stack(sourceLocStack) + { + if (loc.isValid()) + { + stack.add(loc); + shouldPop = true; + } + } + ~PushSourceLocRAII() + { + if (shouldPop) + { + stack.removeLast(); + } + } + }; + + virtual void processReferencedDecl(Decl* decl) = 0; + + void dispatchIfNotNull(Stmt* stmt) + { + if (!stmt) + return; + PushSourceLocRAII sourceLocRAII(sourceLocStack, stmt->loc); + return StmtVisitor<VisitorType>::dispatch(stmt); + } + void dispatchIfNotNull(Expr* expr) + { + if (!expr) + return; + PushSourceLocRAII sourceLocRAII(sourceLocStack, expr->loc); + return ExprVisitor<VisitorType>::dispatch(expr); + } + void dispatchIfNotNull(Val* val) + { + if (!val) + return; + return ValVisitor<VisitorType>::dispatch(val); + } + void dispatchIfNotNull(DeclBase* val) + { + if (!val) + return; + return DeclVisitor<VisitorType>::dispatch(val); + } + // Expr Visitor + void visitExpr(Expr*) { } + void visitIndexExpr(IndexExpr* subscriptExpr) + { + for (auto arg : subscriptExpr->indexExprs) + dispatchIfNotNull(arg); + dispatchIfNotNull(subscriptExpr->baseExpression); + } + + void visitParenExpr(ParenExpr* expr) + { + dispatchIfNotNull(expr->base); + } + + void visitAssignExpr(AssignExpr* expr) + { + dispatchIfNotNull(expr->left); + dispatchIfNotNull(expr->right); + } + + void visitGenericAppExpr(GenericAppExpr* genericAppExpr) + { + dispatchIfNotNull(genericAppExpr->functionExpr); + for (auto arg : genericAppExpr->arguments) + dispatchIfNotNull(arg); + } + + void visitSharedTypeExpr(SharedTypeExpr* expr) { dispatchIfNotNull(expr->base.exp); } + + void visitInvokeExpr(InvokeExpr* expr) + { + dispatchIfNotNull(expr->functionExpr); + for (auto arg : expr->arguments) + dispatchIfNotNull(arg); + } + + void visitTypeCastExpr(TypeCastExpr* expr) + { + dispatchIfNotNull(expr->functionExpr); + for (auto arg : expr->arguments) + dispatchIfNotNull(arg); + } + + void visitDerefExpr(DerefExpr* expr) { dispatchIfNotNull(expr->base); } + void visitMatrixSwizzleExpr(MatrixSwizzleExpr* expr) + { + dispatchIfNotNull(expr->base); + } + void visitSwizzleExpr(SwizzleExpr* expr) + { + dispatchIfNotNull(expr->base); + } + void visitOverloadedExpr(OverloadedExpr*) + { + return; + } + void visitOverloadedExpr2(OverloadedExpr2*) + { + return; + } + void visitAggTypeCtorExpr(AggTypeCtorExpr*) + { + return; + } + void visitCastToSuperTypeExpr(CastToSuperTypeExpr* expr) + { + dispatchIfNotNull(expr->valueArg); + } + void visitModifierCastExpr(ModifierCastExpr* expr) { dispatchIfNotNull(expr->valueArg); } + void visitLetExpr(LetExpr* expr) + { + dispatchIfNotNull(expr->body); + } + void visitExtractExistentialValueExpr(ExtractExistentialValueExpr* expr) + { + dispatchIfNotNull(expr->declRef.declRefBase); + } + + void visitDeclRefExpr(DeclRefExpr* expr) + { + dispatchIfNotNull(expr->declRef.declRefBase); + } + void visitStaticMemberExpr(StaticMemberExpr* expr) + { + dispatchIfNotNull(expr->declRef.declRefBase); + } + void visitInitializerListExpr(InitializerListExpr* expr) + { + for (auto arg : expr->args) + { + dispatchIfNotNull(arg); + } + } + + void visitThisExpr(ThisExpr*) + { + return; + } + + void visitThisTypeExpr(ThisTypeExpr*) { return; } + void visitAndTypeExpr(AndTypeExpr* expr) + { + dispatchIfNotNull(expr->left.type); + dispatchIfNotNull(expr->right.type); + } + void visitPointerTypeExpr(PointerTypeExpr* expr) + { + dispatchIfNotNull(expr->base.type); + } + void visitAsTypeExpr(AsTypeExpr* expr) + { + dispatchIfNotNull(expr->value); + dispatchIfNotNull(expr->witnessArg); + } + void visitIsTypeExpr(IsTypeExpr* expr) + { + dispatchIfNotNull(expr->value); + dispatchIfNotNull(expr->witnessArg); + } + void visitMakeOptionalExpr(MakeOptionalExpr* expr) + { + dispatchIfNotNull(expr->value); + dispatchIfNotNull(expr->typeExpr); + } + void visitPartiallyAppliedGenericExpr(PartiallyAppliedGenericExpr*) + { + return; + } + void visitSPIRVAsmExpr(SPIRVAsmExpr*) + { + return; + } + void visitModifiedTypeExpr(ModifiedTypeExpr* expr) { dispatchIfNotNull(expr->base.type); } + void visitFuncTypeExpr(FuncTypeExpr* expr) + { + for (const auto& t : expr->parameters) + { + dispatchIfNotNull(t.type); + } + dispatchIfNotNull(expr->result.type); + } + void visitTupleTypeExpr(TupleTypeExpr* expr) + { + for (auto t : expr->members) + { + dispatchIfNotNull(t.type); + } + } + void visitTryExpr(TryExpr* expr) { dispatchIfNotNull(expr->base); } + void visitHigherOrderInvokeExpr(HigherOrderInvokeExpr* expr) + { + dispatchIfNotNull(expr->baseFunction); + } + void visitTreatAsDifferentiableExpr(TreatAsDifferentiableExpr* expr) + { + dispatchIfNotNull(expr->innerExpr); + } + + // Stmt Visitor + + void visitDeclStmt(DeclStmt* stmt) { dispatchIfNotNull(stmt->decl); } + + void visitBlockStmt(BlockStmt* stmt) + { + dispatchIfNotNull(stmt->body); + } + + void visitSeqStmt(SeqStmt* seqStmt) + { + for (auto stmt : seqStmt->stmts) + dispatchIfNotNull(stmt); + } + + void visitLabelStmt(LabelStmt* stmt) + { + dispatchIfNotNull(stmt->innerStmt); + } + + void visitBreakStmt(BreakStmt*) { return; } + + void visitContinueStmt(ContinueStmt*) { return; } + + void visitDoWhileStmt(DoWhileStmt* stmt) + { + dispatchIfNotNull(stmt->predicate); + dispatchIfNotNull(stmt->statement); + } + + void visitForStmt(ForStmt* stmt) + { + dispatchIfNotNull(stmt->initialStatement); + dispatchIfNotNull(stmt->predicateExpression); + dispatchIfNotNull(stmt->sideEffectExpression); + dispatchIfNotNull(stmt->statement); + } + + void visitCompileTimeForStmt(CompileTimeForStmt* stmt) + { + dispatchIfNotNull(stmt->rangeBeginExpr); + dispatchIfNotNull(stmt->rangeEndExpr); + dispatchIfNotNull(stmt->body); + } + + void visitSwitchStmt(SwitchStmt* stmt) + { + dispatchIfNotNull(stmt->condition); + dispatchIfNotNull(stmt->body); + } + + void visitCaseStmt(CaseStmt* stmt) { dispatchIfNotNull(stmt->expr); } + + void visitTargetSwitchStmt(TargetSwitchStmt* stmt) + { + for (auto targetCase : stmt->targetCases) + dispatchIfNotNull(targetCase); + } + + void visitTargetCaseStmt(TargetCaseStmt* stmt) + { + dispatchIfNotNull(stmt->body); + } + + void visitIntrinsicAsmStmt(IntrinsicAsmStmt*) { return; } + + void visitDefaultStmt(DefaultStmt*) { return; } + + void visitIfStmt(IfStmt* stmt) + { + dispatchIfNotNull(stmt->predicate); + dispatchIfNotNull(stmt->positiveStatement); + dispatchIfNotNull(stmt->negativeStatement); + } + + void visitUnparsedStmt(UnparsedStmt*) { return; } + + void visitEmptyStmt(EmptyStmt*) { return; } + + void visitDiscardStmt(DiscardStmt*) { return; } + + void visitReturnStmt(ReturnStmt* stmt) { dispatchIfNotNull(stmt->expression); } + + void visitWhileStmt(WhileStmt* stmt) + { + dispatchIfNotNull(stmt->predicate); + dispatchIfNotNull(stmt->statement); + } + + void visitGpuForeachStmt(GpuForeachStmt*) { return; } + + void visitExpressionStmt(ExpressionStmt* stmt) + { + dispatchIfNotNull(stmt->expression); + } + + // Val Visitor + + void visitDirectDeclRef(DirectDeclRef* declRef) + { + // If we have already visited, return. + // Otherwise add it to visited set. + if (!visitedVals.add(declRef)) + return; + + processReferencedDecl(declRef->getDecl()); + } + + void visitVal(Val* val) + { + // If we have already visited, return. + // Otherwise add it to visited set. + if (!visitedVals.add(val)) + return; + + for (Index i = 0; i < val->getOperandCount(); i++) + { + auto& operand = val->m_operands[i]; + switch (operand.kind) + { + case ValNodeOperandKind::ValNode: + dispatchIfNotNull(val->getOperand(i)); + break; + default: + break; + } + } + return; + } + + HashSet<Val*> visitedVals; + + // Decl visitor + void visitDeclBase(DeclBase*) + {} + + void visitContainerDecl(ContainerDecl* decl) + { + for (auto m : decl->members) + { + dispatchIfNotNull(m); + } + } + + void visitFunctionDeclBase(FunctionDeclBase* decl) + { + visitContainerDecl(decl); + dispatchIfNotNull(decl->body); + } + + void visitVarDeclBase(VarDeclBase* varDecl) + { + dispatchIfNotNull(varDecl->type.type); + dispatchIfNotNull(varDecl->initExpr); + } + }; + + struct SemanticsDeclCapabilityVisitor + : public SemanticsDeclVisitorBase + , public DeclVisitor<SemanticsDeclCapabilityVisitor> + { + SemanticsDeclCapabilityVisitor(SemanticsContext const& outer) + : SemanticsDeclVisitorBase(outer) + {} + + void visitDecl(Decl*) {} + void visitDeclGroup(DeclGroup*) {} + + void checkVarDeclCommon(VarDeclBase* varDecl); + + void visitVarDecl(VarDecl* varDecl) + { + checkVarDeclCommon(varDecl); + } + + void visitParamDecl(ParamDecl* paramDecl) + { + checkVarDeclCommon(paramDecl); + } + + void visitFunctionDeclBase(FunctionDeclBase* funcDecl); + + void visitInheritanceDecl(InheritanceDecl* inheritanceDecl); + + void diagnoseUndeclaredCapability(Decl* decl, const DiagnosticInfo& diagnosticInfo, const CapabilityConjunctionSet* failedAvailableSet); + }; + + /// Should the given `decl` nested in `parentDecl` be treated as a static rather than instance declaration? bool isEffectivelyStatic( Decl* decl, @@ -528,7 +938,7 @@ namespace Slang } else if( auto enumCaseDeclRef = declRef.as<EnumCaseDecl>() ) { - sema->ensureDecl(declRef.declRefBase, DeclCheckState::Checked); + sema->ensureDecl(declRef.declRefBase, DeclCheckState::DefinitionChecked); QualType qualType; qualType.type = getType(astBuilder, enumCaseDeclRef); qualType.isLeftValue = false; @@ -873,7 +1283,7 @@ namespace Slang bool SemanticsVisitor::shouldSkipChecking(Decl* decl, DeclCheckState state) { - if (state != DeclCheckState::Checked) + if (state < DeclCheckState::DefinitionChecked) return false; // If we are in language server, we should skip checking all the function bodies // except for the module or function that the user cared about. @@ -1058,7 +1468,7 @@ namespace Slang // If we've gone down this path, then the variable // declaration is actually pretty far along in checking - varDecl->setCheckState(DeclCheckState::Checked); + varDecl->setCheckState(DeclCheckState::DefinitionChecked); } else { @@ -1087,7 +1497,7 @@ namespace Slang maybeInferArraySizeForVariable(varDecl); - varDecl->setCheckState(DeclCheckState::Checked); + varDecl->setCheckState(DeclCheckState::DefinitionChecked); } } // @@ -1306,7 +1716,7 @@ namespace Slang // We need to ensure that any variable doesn't introduce // a constant with a circular definition. // - varDecl->setCheckState(DeclCheckState::Checked); + varDecl->setCheckState(DeclCheckState::DefinitionChecked); _validateCircularVarDefinition(varDecl); } else @@ -1553,7 +1963,7 @@ namespace Slang assocTypeDef->nameAndLoc.name = getName("Differential"); assocTypeDef->type.type = satisfyingType; assocTypeDef->parentDecl = aggTypeDecl; - assocTypeDef->setCheckState(DeclCheckState::Checked); + assocTypeDef->setCheckState(DeclCheckState::DefinitionChecked); aggTypeDecl->members.add(assocTypeDef); } @@ -1861,7 +2271,7 @@ namespace Slang // for(auto importDecl : moduleDecl->getMembersOfType<ImportDecl>()) { - ensureDecl(importDecl, DeclCheckState::Checked); + ensureDecl(importDecl, DeclCheckState::DefinitionChecked); } // Next, make sure all `__include` decls are processed and the referenced @@ -1873,15 +2283,15 @@ namespace Slang auto decl = fileDecl->members[i]; if (auto includeDecl = as<IncludeDecl>(decl)) { - ensureDecl(includeDecl, DeclCheckState::Checked); + ensureDecl(includeDecl, DeclCheckState::DefinitionChecked); } else if (auto implementingDecl = as<ImplementingDecl>(decl)) { - ensureDecl(implementingDecl, DeclCheckState::Checked); + ensureDecl(implementingDecl, DeclCheckState::DefinitionChecked); } else if (auto importDecl = as<ImportDecl>(decl)) { - ensureDecl(importDecl, DeclCheckState::Checked); + ensureDecl(importDecl, DeclCheckState::DefinitionChecked); } } }; @@ -1893,7 +2303,7 @@ namespace Slang } // The entire goal of semantic checking is to get all of the - // declarations in the module up to `DeclCheckState::Checked`. + // declarations in the module up to `DeclCheckState::DefinitionChecked`. // // The main catch is that checking one declaration A up to state M // may required that declaration B is checked up to state N. @@ -1950,7 +2360,8 @@ namespace Slang DeclCheckState::ReadyForReference, DeclCheckState::ReadyForLookup, DeclCheckState::ReadyForConformances, - DeclCheckState::Checked + DeclCheckState::DefinitionChecked, + DeclCheckState::CapabilityChecked, }; for(auto s : states) { @@ -2855,6 +3266,9 @@ namespace Slang ThisExpr*& synThis) { auto synFuncDecl = m_astBuilder->create<FuncDecl>(); + synFuncDecl->ownedScope = m_astBuilder->create<Scope>(); + synFuncDecl->ownedScope->containerDecl = synFuncDecl; + synFuncDecl->ownedScope->parent = getScope(context->parentDecl); // For now our synthesized method will use the name and source // location of the requirement we are trying to satisfy. @@ -2954,6 +3368,7 @@ namespace Slang // For a non-`static` requirement, we need a `this` parameter. // synThis = m_astBuilder->create<ThisExpr>(); + synThis->scope = synFuncDecl->ownedScope; // The type of `this` in our method will be the type for // which we are synthesizing a conformance. @@ -3314,6 +3729,9 @@ namespace Slang // the required accessor. // auto synAccessorDecl = (AccessorDecl*) m_astBuilder->createByNodeType(requiredAccessorDeclRef.getDecl()->astNodeType); + synAccessorDecl->ownedScope = m_astBuilder->create<Scope>(); + synAccessorDecl->ownedScope->containerDecl = synAccessorDecl; + synAccessorDecl->ownedScope->parent = getScope(context->parentDecl); // Whatever the required accessor returns, that is what our synthesized accessor will return. // @@ -3359,6 +3777,7 @@ namespace Slang // a `this` expression. // ThisExpr* synThis = m_astBuilder->create<ThisExpr>(); + synThis->scope = synAccessorDecl->ownedScope; // The type of `this` in our accessor will be the type for // which we are synthesizing a conformance. @@ -5029,7 +5448,7 @@ namespace Slang // the min/max tag values, or the total number of tags, so that people don't // have to declare these as additional cases. - enumConformanceDecl->setCheckState(DeclCheckState::Checked); + enumConformanceDecl->setCheckState(DeclCheckState::DefinitionChecked); } } @@ -5055,7 +5474,7 @@ namespace Slang // doing its own header checking, rather than rely on this... caseDecl->type.type = enumType; - ensureDecl(caseDecl, DeclCheckState::Checked); + ensureDecl(caseDecl, DeclCheckState::DefinitionChecked); } // For any enum case that didn't provide an explicit @@ -7569,9 +7988,13 @@ namespace Slang SemanticsDeclAttributesVisitor(shared).dispatch(decl); break; - case DeclCheckState::Checked: + case DeclCheckState::DefinitionChecked: SemanticsDeclBodyVisitor(shared).dispatch(decl); break; + + case DeclCheckState::CapabilityChecked: + SemanticsDeclCapabilityVisitor(shared).dispatch(decl); + break; } } @@ -8144,4 +8567,493 @@ namespace Slang checkDerivativeAttribute(this, decl, primalAttr); } } + + static void _propagateRequirement(SemanticsVisitor* visitor, CapabilitySet& resultCaps, SyntaxNode* userNode, SyntaxNode* referencedNode, const CapabilitySet& nodeCaps, SourceLoc referenceLoc) + { + auto referencedDecl = as<Decl>(referencedNode); + + // Ignore cyclic references. + if (referencedDecl) + { + if (referencedDecl->checkState.isBeingChecked()) + return; + + ensureDecl(visitor, referencedDecl, DeclCheckState::CapabilityChecked); + } + + if (resultCaps.implies(nodeCaps)) + return; + + auto oldCaps = resultCaps; + bool isAnyInvalid = resultCaps.isInvalid() || nodeCaps.isInvalid(); + resultCaps.join(nodeCaps); + + auto decl = as<Decl>(userNode); + + if (!isAnyInvalid && resultCaps.isInvalid()) + { + // If joining the referenced decl's requirements results an invalid capability set, + // then the decl is using things that require conflicting set of capabilities, and we should diagnose an error. + if (referencedDecl && decl) + { + visitor->getSink()->diagnose( + referenceLoc, + Diagnostics::conflictingCapabilityDueToUseOfDecl, + referencedDecl, + nodeCaps, + decl, + oldCaps); + } + else if (decl) + { + visitor->getSink()->diagnose( + referenceLoc, + Diagnostics::conflictingCapabilityDueToStatement, + nodeCaps, + decl, + oldCaps); + } + else + { + visitor->getSink()->diagnose( + referenceLoc, + Diagnostics::conflictingCapabilityDueToStatementEnclosingFunc, + nodeCaps, + oldCaps); + } + } + if (referencedDecl && decl) + { + for (auto& capSet : nodeCaps.getExpandedAtoms()) + { + for (auto atom : capSet.getExpandedAtoms()) + { + decl->capabilityRequirementProvenance.addIfNotExists(atom, DeclReferenceWithLoc{ referencedDecl, referenceLoc }); + } + } + } + }; + + CapabilitySet getStatementCapabilityUsage(SemanticsVisitor* visitor, Stmt* stmt); + + template<typename ProcessFunc> + struct CapabilityDeclReferenceVisitor + : public SemanticsDeclReferenceVisitor<CapabilityDeclReferenceVisitor<ProcessFunc>> + { + typedef SemanticsDeclReferenceVisitor<CapabilityDeclReferenceVisitor<ProcessFunc>> Base; + + const ProcessFunc& handleReferenceFunc; + CapabilityDeclReferenceVisitor(const ProcessFunc& processFunc, SemanticsContext const& outer) + : handleReferenceFunc(processFunc) + , SemanticsDeclReferenceVisitor<CapabilityDeclReferenceVisitor<ProcessFunc>>(outer) + { + } + virtual void processReferencedDecl(Decl* decl) override + { + SourceLoc loc = SourceLoc(); + if (Base::sourceLocStack.getCount()) + loc = Base::sourceLocStack.getLast(); + handleReferenceFunc(decl, decl->inferredCapabilityRequirements, loc); + } + void visitDiscardStmt(DiscardStmt* stmt) + { + handleReferenceFunc(stmt, CapabilitySet(CapabilityName::fragment), stmt->loc); + } + void visitTargetSwitchStmt(TargetSwitchStmt* stmt) + { + CapabilitySet set; + for (auto targetCase : stmt->targetCases) + { + auto targetCap = CapabilitySet(CapabilityName(targetCase->capability)); + auto oldCap = targetCap; + auto bodyCap = getStatementCapabilityUsage(this, targetCase->body); + targetCap.join(bodyCap); + if (targetCap.isInvalid()) + { + Base::getSink()->diagnose(targetCase->body->loc, Diagnostics::conflictingCapabilityDueToStatement, bodyCap, "target_switch", oldCap); + } + for (auto& conjunction : targetCap.getExpandedAtoms()) + set.unionWith(conjunction); + } + set.canonicalize(); + handleReferenceFunc(stmt, set, stmt->loc); + } + }; + + template<typename ProcessFunc> + void visitReferencedDecls(SemanticsContext& context, NodeBase* node, SourceLoc initialLoc, const ProcessFunc& func) + { + CapabilityDeclReferenceVisitor<ProcessFunc> visitor(func, context); + visitor.sourceLocStack.add(initialLoc); + + if (auto val = as<Val>(node)) + visitor.dispatchIfNotNull(val); + if (auto stmt = as<Stmt>(node)) + visitor.dispatchIfNotNull(stmt); + if (auto expr = as<Expr>(node)) + visitor.dispatchIfNotNull(expr); + if (auto decl = as<Decl>(node)) + visitor.dispatchIfNotNull(decl); + } + + CapabilitySet getStatementCapabilityUsage(SemanticsVisitor* visitor, Stmt* stmt) + { + if (stmt == nullptr) + return CapabilitySet(); + + CapabilitySet inferredRequirements; + visitReferencedDecls(*visitor, stmt, stmt->loc, [&](SyntaxNode* node, const CapabilitySet& nodeCaps, SourceLoc refLoc) + { + _propagateRequirement(visitor, inferredRequirements, stmt, node, nodeCaps, refLoc); + }); + return inferredRequirements; + } + + void SemanticsDeclCapabilityVisitor::checkVarDeclCommon(VarDeclBase* varDecl) + { + visitReferencedDecls(*this, varDecl->type.type, varDecl->loc, [this, varDecl](SyntaxNode* node, const CapabilitySet& nodeCaps, SourceLoc refLoc) + { + _propagateRequirement(this, varDecl->inferredCapabilityRequirements, varDecl, node, nodeCaps, refLoc); + }); + visitReferencedDecls(*this, varDecl->initExpr, varDecl->loc, [this, varDecl](SyntaxNode* node, const CapabilitySet& nodeCaps, SourceLoc refLoc) + { + _propagateRequirement(this, varDecl->inferredCapabilityRequirements, varDecl, node, nodeCaps, refLoc); + }); + } + + void SemanticsDeclCapabilityVisitor::visitFunctionDeclBase(FunctionDeclBase* funcDecl) + { + for (auto member : funcDecl->members) + { + ensureDecl(member, DeclCheckState::CapabilityChecked); + _propagateRequirement(this, funcDecl->inferredCapabilityRequirements, funcDecl, member, member->inferredCapabilityRequirements, member->loc); + } + visitReferencedDecls(*this, funcDecl->body, funcDecl->loc, [this, funcDecl](SyntaxNode* node, const CapabilitySet& nodeCaps, SourceLoc refLoc) + { + _propagateRequirement(this, funcDecl->inferredCapabilityRequirements, funcDecl, node, nodeCaps, refLoc); + }); + + // A decls's declared capability set is a transitive join of all parent declarations. + CapabilitySet declaredCaps; + for (Decl* parent = funcDecl; parent; parent = getParentDecl(parent)) + { + CapabilitySet localDeclaredCaps; + + for (auto decoration : parent->getModifiersOfType<RequireCapabilityAttribute>()) + { + for (auto& set : decoration->capabilitySet.getExpandedAtoms()) + localDeclaredCaps.unionWith(set); + } + declaredCaps.join(localDeclaredCaps); + } + + if (!declaredCaps.isEmpty()) + { + // If the function is an entrypoint, add the stage to declaredCaps. + if (auto entryPointAttr = funcDecl->findModifier<EntryPointAttribute>()) + { + auto stageCaps = CapabilitySet(Profile(entryPointAttr->stage).getCapabilityName()); + if (declaredCaps.isIncompatibleWith(stageCaps)) + { + getSink()->diagnose(funcDecl->loc, Diagnostics::stageIsInCompatibleWithCapabilityDefinition, funcDecl, stageCaps, declaredCaps); + } + else + { + declaredCaps.join(stageCaps); + } + } + } + + auto vis = getDeclVisibility(funcDecl); + if (declaredCaps.isEmpty()) + { + // If the user has not declared any capabilities, + // we should diagnose an error if this is a public symbol. + if (vis == DeclVisibility::Public && !funcDecl->inferredCapabilityRequirements.isEmpty()) + { + if (!getModuleDecl(funcDecl)->isInLegacyLanguage) + { + getSink()->diagnose(funcDecl->loc, Diagnostics::missingCapabilityRequirementOnPublicDecl, funcDecl); + } + } + } + else + { + if (vis == DeclVisibility::Public) + { + // For public decls, we need to enforce that the function + // only uses capabilities that it declares. + const CapabilityConjunctionSet* failedAvailableCapabilityConjunction = nullptr; + if (!CapabilitySet::checkCapabilityRequirement( + declaredCaps, + funcDecl->inferredCapabilityRequirements, + failedAvailableCapabilityConjunction)) + { + diagnoseUndeclaredCapability(funcDecl, Diagnostics::useOfUndeclaredCapability, failedAvailableCapabilityConjunction); + funcDecl->inferredCapabilityRequirements = declaredCaps; + } + } + else + { + // For internal decls, their inferred capability should be joined + // with the declared capabilities. + funcDecl->inferredCapabilityRequirements.join(declaredCaps); + } + } + } + + void SemanticsDeclCapabilityVisitor::visitInheritanceDecl(InheritanceDecl* inheritanceDecl) + { + // Check that the implementation of an interface requirement is not using more capabilities + // than what's declared on the interface method. + if (inheritanceDecl->witnessTable) + { + for (auto& kv : inheritanceDecl->witnessTable->m_requirementDictionary) + { + if (kv.value.getFlavor() != RequirementWitness::Flavor::declRef) + continue; + auto requirementDecl = kv.key; + auto implDecl = kv.value.getDeclRef(); + if (!implDecl) + continue; + + if (getModuleDecl(implDecl.getDecl())->isInLegacyLanguage) + break; + + ensureDecl(requirementDecl, DeclCheckState::CapabilityChecked); + ensureDecl(implDecl.declRefBase, DeclCheckState::CapabilityChecked); + + const CapabilityConjunctionSet* failedAvailableCapabilityConjunction = nullptr; + if (!CapabilitySet::checkCapabilityRequirement( + requirementDecl->inferredCapabilityRequirements, + implDecl.getDecl()->inferredCapabilityRequirements, + failedAvailableCapabilityConjunction)) + { + diagnoseUndeclaredCapability(implDecl.getDecl(), Diagnostics::useOfUndeclaredCapabilityOfInterfaceRequirement, failedAvailableCapabilityConjunction); + } + } + } + } + + DeclVisibility getDeclVisibility(Decl* decl) + { + if (as<GenericTypeParamDecl>(decl) || as<GenericValueParamDecl>(decl) || as<GenericTypeConstraintDecl>(decl)) + { + auto genericDecl = as<GenericDecl>(decl->parentDecl); + if (!genericDecl) + return DeclVisibility::Default; + if (genericDecl->inner) + return getDeclVisibility(genericDecl->inner); + return DeclVisibility::Default; + } + if (auto genericDecl = as<GenericDecl>(decl)) + decl = genericDecl->inner; + for (; decl; decl = getParentDecl(decl)) + { + if (as<AccessorDecl>(decl)) + continue; + if (as<EnumCaseDecl>(decl)) + continue; + break; + } + if (!decl) + return DeclVisibility::Public; + + for (auto modifier : decl->modifiers) + { + if (as<PublicModifier>(modifier)) + return DeclVisibility::Public; + else if (as<InternalModifier>(modifier)) + return DeclVisibility::Internal; + else if (as<PrivateModifier>(modifier)) + return DeclVisibility::Private; + } + + // Interface members will always have the same visibility as the interface itself. + if (auto interfaceDecl = findParentInterfaceDecl(decl)) + { + return getDeclVisibility(interfaceDecl); + } + else if (as<NamespaceDecl>(decl)) + { + return DeclVisibility::Public; + } + if (auto parentModule = getModuleDecl(decl)) + return parentModule->isInLegacyLanguage ? DeclVisibility::Public : DeclVisibility::Internal; + + return DeclVisibility::Default; + } + + void diagnoseCapabilityProvenance(DiagnosticSink* sink, Decl* decl, CapabilityAtom missingAtom) + { + HashSet<Decl*> printedDecls; + auto thisModule = getModuleDecl(decl); + Decl* declToPrint = decl; + while (declToPrint) + { + printedDecls.add(declToPrint); + if (auto provenance = declToPrint->capabilityRequirementProvenance.tryGetValue(missingAtom)) + { + sink->diagnose(provenance->referenceLoc, Diagnostics::seeUsingOf, provenance->referencedDecl); + declToPrint = provenance->referencedDecl; + if (printedDecls.contains(declToPrint)) + break; + if (declToPrint->findModifier<RequireCapabilityAttribute>()) + break; + auto moduleDecl = getModuleDecl(declToPrint); + if (thisModule != moduleDecl) + break; + if (moduleDecl && moduleDecl->isInLegacyLanguage) + continue; + if (getDeclVisibility(declToPrint) == DeclVisibility::Public) + break; + } + else + { + break; + } + } + if (declToPrint) + { + sink->diagnose(declToPrint->loc, Diagnostics::seeDefinitionOf, declToPrint); + } + } + + // Print diagnostics tracing which referenced decls are not compatible with the given atom. + void diagnoseIncompatibleAtomProvenance(SemanticsVisitor* visitor, DiagnosticSink* sink, Decl* decl, CapabilityAtom incompatibleAtom, int traceLevels = 10) + { + Decl* refDecl = nullptr; + SourceLoc loc; + while (traceLevels > 0) + { + refDecl = nullptr; + visitReferencedDecls(*visitor, decl, decl->loc, [&](SyntaxNode* node, const CapabilitySet& nodeCaps, SourceLoc refLoc) + { + if (nodeCaps.isIncompatibleWith(incompatibleAtom)) + { + if (auto referencedDecl = as<Decl>(node)) + { + refDecl = referencedDecl; + loc = refLoc; + } + else + sink->diagnose(refLoc, Diagnostics::seeDefinitionOf, "statement"); + } + }); + if (refDecl) + { + sink->diagnose(loc, Diagnostics::seeUsingOf, refDecl); + decl = refDecl; + } + else + { + break; + } + traceLevels--; + } + } + + void SemanticsDeclCapabilityVisitor::diagnoseUndeclaredCapability(Decl* decl, const DiagnosticInfo& diagnosticInfo, const CapabilityConjunctionSet* failedAvailableSet) + { + if (decl->inferredCapabilityRequirements.getExpandedAtoms().getCount() == 0) + return; + + // There are two causes for why type checking failed on failedAvailableSet. + // The first scenario is that failedAvailableSet defines a set of capabilities on a + // compilation target (e.g. hlsl) that isn't defined by some callees, for example, if we have + // a function: + // [require(hlsl)] // <-- failedAvailableSet + // [require(cpp)] + // void caller() + // { + // printf(); // assume this is defined for (cpp | cuda). + // } + // In this case we should diagnose error reporting printf isn't defined on a required target. + // + // The second scenario is when the callee is using a capability that is not provided by the requirement. + // For example: + // [require(hlsl,b,c)] + // void caller() + // { + // useD(); // require capability (hlsl,d) + // } + // In this case we should report that useD() is using a capability that is not declared by caller. + // + + // Now, we detect if we are case 1. + if (decl->inferredCapabilityRequirements.isIncompatibleWith(*failedAvailableSet)) + { + // Find the most derived atom that is leading to the incompatiblity. + for (Index i = failedAvailableSet->getExpandedAtoms().getCount() - 1; i >= 0; i--) + { + auto atom = failedAvailableSet->getExpandedAtoms()[i]; + if (!isDirectChildOfAbstractAtom(atom)) + continue; + if (decl->inferredCapabilityRequirements.isIncompatibleWith(atom)) + { + getSink()->diagnose(decl->loc, Diagnostics::declHasDependenciesNotDefinedOnTarget, decl, atom); + diagnoseIncompatibleAtomProvenance(this, getSink(), decl, atom); + return; + } + } + return; + } + + // If we reach here, we are case 2. + + CapabilityConjunctionSet* matchingRequirement = &decl->inferredCapabilityRequirements.getExpandedAtoms().getFirst(); + CapabilityAtom missingAtom = matchingRequirement->getExpandedAtoms().getFirst(); + if (missingAtom == CapabilityAtom::Invalid) + return; + + if (failedAvailableSet) + { + Int maxIntersectionCount = 0; + for (auto& usedSet : decl->inferredCapabilityRequirements.getExpandedAtoms()) + { + auto intersection = usedSet.countIntersectionWith(*failedAvailableSet); + if (intersection > maxIntersectionCount) + { + matchingRequirement = &usedSet; + maxIntersectionCount = intersection; + } + } + Index pos = 0; + for (Index i = 0; i < matchingRequirement->getExpandedAtoms().getCount(); i++) + { + auto atom = matchingRequirement->getExpandedAtoms()[i]; + while (pos < failedAvailableSet->getExpandedAtoms().getCount()) + { + if (failedAvailableSet->getExpandedAtoms()[pos] < atom) + pos++; + else + break; + } + + if (pos >= failedAvailableSet->getExpandedAtoms().getCount() || + failedAvailableSet->getExpandedAtoms()[pos] != atom) + { + missingAtom = atom; + break; + } + } + + // Select the most derived atom of `missingAtom`. + for (Index i = matchingRequirement->getExpandedAtoms().getCount() - 1; i >= 0 ; i--) + { + auto atom = matchingRequirement->getExpandedAtoms()[i]; + if (CapabilityConjunctionSet(atom).implies(missingAtom)) + { + missingAtom = atom; + break; + } + } + } + + getSink()->diagnose(decl->loc, diagnosticInfo, decl, missingAtom); + + // Print provenances. + diagnoseCapabilityProvenance(getSink(), decl, missingAtom); + } + } diff --git a/source/slang/slang-check-expr.cpp b/source/slang/slang-check-expr.cpp index e5e990fe5..f9adcc91a 100644 --- a/source/slang/slang-check-expr.cpp +++ b/source/slang/slang-check-expr.cpp @@ -62,7 +62,7 @@ namespace Slang { VarDecl* varDecl = m_astBuilder->create<VarDecl>(); varDecl->parentDecl = nullptr; // TODO: need to fill this in somehow! - varDecl->checkState = DeclCheckState::Checked; + varDecl->checkState = DeclCheckState::DefinitionChecked; varDecl->nameAndLoc.loc = expr->loc; varDecl->initExpr = expr; varDecl->type.type = expr->type.type; @@ -827,55 +827,6 @@ namespace Slang } } - DeclVisibility SemanticsVisitor::getDeclVisibility(Decl* decl) - { - if (as<GenericTypeParamDecl>(decl) || as<GenericValueParamDecl>(decl) || as<GenericTypeConstraintDecl>(decl)) - { - auto genericDecl = as<GenericDecl>(decl->parentDecl); - if (!genericDecl) - return DeclVisibility::Default; - if (genericDecl->inner) - return getDeclVisibility(genericDecl->inner); - return DeclVisibility::Default; - } - if (auto genericDecl = as<GenericDecl>(decl)) - decl = genericDecl->inner; - for (; decl; decl = getParentDecl(decl)) - { - if (as<AccessorDecl>(decl)) - continue; - if (as<EnumCaseDecl>(decl)) - continue; - break; - } - if (!decl) - return DeclVisibility::Public; - - for (auto modifier : decl->modifiers) - { - if (as<PublicModifier>(modifier)) - return DeclVisibility::Public; - else if (as<InternalModifier>(modifier)) - return DeclVisibility::Internal; - else if (as<PrivateModifier>(modifier)) - return DeclVisibility::Private; - } - - // Interface members will always have the same visibility as the interface itself. - if (auto interfaceDecl = findParentInterfaceDecl(decl)) - { - return getDeclVisibility(interfaceDecl); - } - else if (as<NamespaceDecl>(decl)) - { - return DeclVisibility::Public; - } - if (auto parentModule = getModuleDecl(decl)) - return parentModule->isInLegacyLanguage ? DeclVisibility::Public : DeclVisibility::Internal; - - return DeclVisibility::Default; - } - DeclVisibility SemanticsVisitor::getTypeVisibility(Type* type) { if (auto declRefType = as<DeclRefType>(type)) @@ -1721,7 +1672,7 @@ namespace Slang if (!getInitExpr(m_astBuilder, declRef)) return nullptr; - ensureDecl(declRef.getDecl(), DeclCheckState::Checked); + ensureDecl(declRef.getDecl(), DeclCheckState::DefinitionChecked); ConstantFoldingCircularityInfo newCircularityInfo(decl, circularityInfo); return tryConstantFoldExpr(getInitExpr(m_astBuilder, declRef), &newCircularityInfo); } diff --git a/source/slang/slang-check-impl.h b/source/slang/slang-check-impl.h index 4102b3eba..1808274f3 100644 --- a/source/slang/slang-check-impl.h +++ b/source/slang/slang-check-impl.h @@ -1099,7 +1099,6 @@ namespace Slang SourceLoc loc, Expr* originalExpr); - DeclVisibility getDeclVisibility(Decl* decl); DeclVisibility getTypeVisibility(Type* type); bool isDeclVisibleFromScope(DeclRef<Decl> declRef, Scope* scope); LookupResult filterLookupResultByVisibility(const LookupResult& lookupResult); @@ -1472,6 +1471,8 @@ namespace Slang Expr* expr, String* outVal); + bool checkCapabilityName(Expr* expr, CapabilityName& outCapabilityName); + void visitModifier(Modifier*); AttributeDecl* lookUpAttributeDecl(Name* attributeName, Scope* scope); @@ -2646,4 +2647,8 @@ namespace Slang }; bool isUnsizedArrayType(Type* type); + + DeclVisibility getDeclVisibility(Decl* decl); + + void diagnoseCapabilityProvenance(DiagnosticSink* sink, Decl* decl, CapabilityAtom missingAtom); } diff --git a/source/slang/slang-check-modifier.cpp b/source/slang/slang-check-modifier.cpp index 1a7f64944..51cb5346a 100644 --- a/source/slang/slang-check-modifier.cpp +++ b/source/slang/slang-check-modifier.cpp @@ -74,6 +74,30 @@ namespace Slang return false; } + bool SemanticsVisitor::checkCapabilityName(Expr* expr, CapabilityName& outCapabilityName) + { + if (auto varExpr = as<VarExpr>(expr)) + { + if (!varExpr->name) + return false; + if (varExpr->name == getSession()->getCompletionRequestTokenName()) + { + auto& suggestions = getLinkage()->contentAssistInfo.completionSuggestions; + suggestions.clear(); + suggestions.scopeKind = CompletionSuggestions::ScopeKind::Capabilities; + } + outCapabilityName = findCapabilityName(varExpr->name->text.getUnownedSlice()); + if (outCapabilityName == CapabilityName::Invalid) + { + getSink()->diagnose(expr, Diagnostics::unknownCapability, varExpr->name); + return false; + } + return true; + } + getSink()->diagnose(expr, Diagnostics::expectCapability); + return false; + } + void SemanticsVisitor::visitModifier(Modifier*) { // Do nothing with modifiers for now @@ -209,7 +233,7 @@ namespace Slang paramDecl->nameAndLoc = member->nameAndLoc; paramDecl->type = varMember->type; paramDecl->loc = member->loc; - paramDecl->setCheckState(DeclCheckState::Checked); + paramDecl->setCheckState(DeclCheckState::DefinitionChecked); paramDecl->parentDecl = attrDecl; attrDecl->members.add(paramDecl); @@ -233,7 +257,7 @@ namespace Slang // // TODO: what check state is relevant here? // - ensureDecl(attrDecl, DeclCheckState::Checked); + ensureDecl(attrDecl, DeclCheckState::DefinitionChecked); return attrDecl; } @@ -783,6 +807,19 @@ namespace Slang pyExportAttr->name = name; } + else if (auto requireCapAttr = as<RequireCapabilityAttribute>(attr)) + { + List<CapabilityName> capabilityNames; + for (auto& arg : attr->args) + { + CapabilityName capName; + if (checkCapabilityName(arg, capName)) + { + capabilityNames.add(capName); + } + } + requireCapAttr->capabilitySet = CapabilitySet(capabilityNames); + } else { if(attr->args.getCount() == 0) diff --git a/source/slang/slang-check-overload.cpp b/source/slang/slang-check-overload.cpp index 37c92a317..c6e6677b3 100644 --- a/source/slang/slang-check-overload.cpp +++ b/source/slang/slang-check-overload.cpp @@ -351,7 +351,7 @@ namespace Slang { // Otherwise, the generic decl had better provide a default value // or this reference is ill-formed. - ensureDecl(valParamRef, DeclCheckState::Checked); + ensureDecl(valParamRef, DeclCheckState::DefinitionChecked); ConstantFoldingCircularityInfo newCircularityInfo(valParamRef.getDecl(), nullptr); auto defaultVal = tryConstantFoldExpr(valParamRef.substitute(m_astBuilder, valParamRef.getDecl()->initExpr), &newCircularityInfo); if (!defaultVal) diff --git a/source/slang/slang-check-shader.cpp b/source/slang/slang-check-shader.cpp index b45107640..2e854554e 100644 --- a/source/slang/slang-check-shader.cpp +++ b/source/slang/slang-check-shader.cpp @@ -396,7 +396,6 @@ namespace Slang auto module = getModule(entryPointFuncDecl); auto linkage = module->getLinkage(); - // Every entry point needs to have a stage specified either via // command-line/API options, or via an explicit `[shader("...")]` attribute. // @@ -506,6 +505,38 @@ namespace Slang } } } + + for (auto target : linkage->targets) + { + auto targetCaps = target->getTargetCaps(); + auto stageCapabilitySet = CapabilitySet(entryPoint->getProfile().getCapabilityName()); + targetCaps.join(stageCapabilitySet); + if (targetCaps.isIncompatibleWith(entryPointFuncDecl->inferredCapabilityRequirements)) + { + sink->diagnose(entryPointFuncDecl, Diagnostics::entryPointUsesUnavailableCapability, entryPointFuncDecl, entryPointFuncDecl->inferredCapabilityRequirements, targetCaps); + auto& interredCapConjunctions = entryPointFuncDecl->inferredCapabilityRequirements.getExpandedAtoms(); + + // Find out what exactly is incompatible and print out a trace of provenance to + // help user diagnose their code. + auto& conjunctions = targetCaps.getExpandedAtoms(); + if (conjunctions.getCount() == 1 && interredCapConjunctions.getCount() == 1) + { + for (auto atom : conjunctions[0].getExpandedAtoms()) + { + for (auto inferredAtom : interredCapConjunctions[0].getExpandedAtoms()) + { + if (CapabilityConjunctionSet(inferredAtom).isIncompatibleWith(atom)) + { + diagnoseCapabilityProvenance(sink, entryPointFuncDecl, inferredAtom); + goto breakLabel; + } + } + } + } + } + } + breakLabel:; + } // Given an entry point specified via API or command line options, @@ -533,7 +564,7 @@ namespace Slang auto entryPointName = entryPointReq->getName(); FuncDecl* entryPointFuncDecl = findFunctionDeclByName(translationUnit->getModule(), entryPointName, sink); - + // Did we find a function declaration in our search? if(!entryPointFuncDecl) { diff --git a/source/slang/slang-check-stmt.cpp b/source/slang/slang-check-stmt.cpp index b2c58ccc7..1090655e5 100644 --- a/source/slang/slang-check-stmt.cpp +++ b/source/slang/slang-check-stmt.cpp @@ -52,7 +52,7 @@ namespace Slang // local `struct` declaration, where it would have members // that need to be recursively checked. // - ensureDeclBase(stmt->decl, DeclCheckState::Checked, this); + ensureDeclBase(stmt->decl, DeclCheckState::DefinitionChecked, this); } void SemanticsStmtVisitor::visitBlockStmt(BlockStmt* stmt) @@ -207,7 +207,7 @@ namespace Slang stmt->varDecl->type.type = m_astBuilder->getIntType(); addModifier(stmt->varDecl, m_astBuilder->create<ConstModifier>()); - stmt->varDecl->setCheckState(DeclCheckState::Checked); + stmt->varDecl->setCheckState(DeclCheckState::DefinitionChecked); IntVal* rangeBeginVal = nullptr; IntVal* rangeEndVal = nullptr; @@ -280,7 +280,20 @@ namespace Slang void SemanticsStmtVisitor::visitTargetCaseStmt(TargetCaseStmt* stmt) { auto switchStmt = FindOuterStmt<TargetSwitchStmt>(); + CapabilitySet set((CapabilityName)stmt->capability); + if (getShared()->isInLanguageServer() && getShared()->getSession()->getCompletionRequestTokenName() == stmt->capabilityToken.getName()) + { + getShared()->getLinkage()->contentAssistInfo.completionSuggestions.scopeKind = CompletionSuggestions::ScopeKind::Capabilities; + } + if (stmt->capabilityToken.getContentLength() != 0 && + (set.getExpandedAtoms().getCount() != 1 || set.isInvalid() || set.isEmpty())) + { + getSink()->diagnose( + stmt->capabilityToken.loc, + Diagnostics::invalidTargetSwitchCase, + capabilityNameToString((CapabilityName)stmt->capability)); + } if (!switchStmt) { getSink()->diagnose(stmt, Diagnostics::caseOutsideSwitch); @@ -648,7 +661,7 @@ namespace Slang { stmt->device = CheckExpr(stmt->device); stmt->gridDims = CheckExpr(stmt->gridDims); - ensureDeclBase(stmt->dispatchThreadID, DeclCheckState::Checked, this); + ensureDeclBase(stmt->dispatchThreadID, DeclCheckState::DefinitionChecked, this); WithOuterStmt subContext(this, stmt); stmt->kernelCall = subContext.CheckExpr(stmt->kernelCall); return; diff --git a/source/slang/slang-check.cpp b/source/slang/slang-check.cpp index de86a333f..3f79b7f41 100644 --- a/source/slang/slang-check.cpp +++ b/source/slang/slang-check.cpp @@ -218,4 +218,5 @@ namespace Slang { return sv->getASTBuilder(); } + } diff --git a/source/slang/slang-compiler.cpp b/source/slang/slang-compiler.cpp index fbdb91ac7..f8ad95108 100644 --- a/source/slang/slang-compiler.cpp +++ b/source/slang/slang-compiler.cpp @@ -359,6 +359,26 @@ namespace Slang return lookUp(UnownedTerminatedStringSlice(name)); } + List<CapabilityName> Profile::getCapabilityName() + { + List<CapabilityName> result; + switch (getVersion()) + { + #define PROFILE_VERSION(TAG, NAME) case ProfileVersion::TAG: result.add(CapabilityName::TAG); break; + #include "slang-profile-defs.h" + default: + break; + } + switch (getStage()) + { +#define PROFILE_STAGE(TAG, NAME, VAL) case Stage::TAG: result.add(CapabilityName::NAME); break; +#include "slang-profile-defs.h" + default: + break; + } + return result; + } + char const* Profile::getName() { switch( raw ) @@ -711,7 +731,7 @@ namespace Slang // to clobber it with something to get a default. // // TODO: This is a huge hack... - profile.setVersion(ProfileVersion::DX_5_0); + profile.setVersion(ProfileVersion::DX_5_1); break; } @@ -755,9 +775,6 @@ namespace Slang { #define CASE(TAG, SUFFIX) case ProfileVersion::TAG: versionSuffix = #SUFFIX; break CASE(DX_4_0, _4_0); - CASE(DX_4_0_Level_9_0, _4_0_level_9_0); - CASE(DX_4_0_Level_9_1, _4_0_level_9_1); - CASE(DX_4_0_Level_9_3, _4_0_level_9_3); CASE(DX_4_1, _4_1); CASE(DX_5_0, _5_0); CASE(DX_5_1, _5_1); diff --git a/source/slang/slang-content-assist-info.h b/source/slang/slang-content-assist-info.h index bd2c4b7d9..8f4105184 100644 --- a/source/slang/slang-content-assist-info.h +++ b/source/slang/slang-content-assist-info.h @@ -19,7 +19,8 @@ struct CompletionSuggestions Stmt, Expr, Attribute, - HLSLSemantics + HLSLSemantics, + Capabilities }; ScopeKind scopeKind = ScopeKind::Invalid; List<LookupResultItem> candidateItems; diff --git a/source/slang/slang-diagnostic-defs.h b/source/slang/slang-diagnostic-defs.h index d0e188292..786dbda35 100644 --- a/source/slang/slang-diagnostic-defs.h +++ b/source/slang/slang-diagnostic-defs.h @@ -297,8 +297,8 @@ DIAGNOSTIC(30041, Error, bitOperationNonIntegral, "bit operation: operand must b DIAGNOSTIC(30043, Error, getStringHashRequiresStringLiteral, "getStringHash parameter can only accept a string literal") DIAGNOSTIC(30047, Error, argumentExpectedLValue, "argument passed to parameter '$0' must be l-value.") -DIAGNOSTIC(30049, Note, thisIsImmutableByDefault, "a 'this' parameter is an immutable parameter by default in Slang; apply the `[mutating]` attribute to the function declaration to opt in to a mutable `this`") -DIAGNOSTIC(30050, Error, mutatingMethodOnImmutableValue, "mutating method '$0' cannot be called on an immutable value") +DIAGNOSTIC(30049, Note, thisIsImmutableByDefault, "a 'this' parameter is an immutable parameter by default in Slang; apply the `[mutating]` attribute to the function declaration to opt in to a mutable `this`") +DIAGNOSTIC(30050, Error, mutatingMethodOnImmutableValue, "mutating method '$0' cannot be called on an immutable value") DIAGNOSTIC(30051, Error, invalidValueForArgument, "invalid value for argument '$0'") DIAGNOSTIC(30052, Error, invalidSwizzleExpr, "invalid swizzle pattern '$0' on type '$1'") @@ -312,9 +312,9 @@ DIAGNOSTIC(30058, Warning, danglingEqualityExpr, "result of '==' not used, did y DIAGNOSTIC(30060, Error, expectedAType, "expected a type, got a '$0'") DIAGNOSTIC(30061, Error, expectedANamespace, "expected a namespace, got a '$0'") -DIAGNOSTIC(30062, Note, implicitCastUsedAsLValueRef, "argument was implicitly cast from '$0' to '$1', and Slang does not support using an implicit cast as an l-value with a reference") -DIAGNOSTIC(30063, Note, implicitCastUsedAsLValueType, "argument was implicitly cast from '$0' to '$1', and Slang does not support using an implicit cast as an l-value with this type") -DIAGNOSTIC(30064, Note, implicitCastUsedAsLValue, "argument was implicitly cast from '$0' to '$1', and Slang does not support using an implicit cast as an l-value for this usage") +DIAGNOSTIC(30062, Note, implicitCastUsedAsLValueRef, "argument was implicitly cast from '$0' to '$1', and Slang does not support using an implicit cast as an l-value with a reference") +DIAGNOSTIC(30063, Note, implicitCastUsedAsLValueType, "argument was implicitly cast from '$0' to '$1', and Slang does not support using an implicit cast as an l-value with this type") +DIAGNOSTIC(30064, Note, implicitCastUsedAsLValue, "argument was implicitly cast from '$0' to '$1', and Slang does not support using an implicit cast as an l-value for this usage") DIAGNOSTIC(30065, Error, newCanOnlyBeUsedToInitializeAClass, "`new` can only be used to initialize a class") DIAGNOSTIC(30066, Error, classCanOnlyBeInitializedWithNew, "a class can only be initialized by a `new` clause") @@ -333,7 +333,7 @@ DIAGNOSTIC(30300, Error, isOperatorValueMustBeInterfaceType, "'is'/'as' operator DIAGNOSTIC(33070, Error, expectedFunction, "expected a function, got '$0'") DIAGNOSTIC(33071, Error, expectedAStringLiteral, "expected a string literal") -DIAGNOSTIC( -1, Note, noteExplicitConversionPossible, "explicit conversion from '$0' to '$1' is possible") +DIAGNOSTIC(-1, Note, noteExplicitConversionPossible, "explicit conversion from '$0' to '$1' is possible") DIAGNOSTIC(30080, Error, ambiguousConversion, "more than one implicit conversion exists from '$0' to '$1'") DIAGNOSTIC(30081, Warning, unrecommendedImplicitConversion, "implicit conversion from '$0' to '$1' is not recommended") DIAGNOSTIC(30082, Warning, implicitConversionToDouble, " implicit float-to-double conversion may cause unexpected performance issues, use explicit cast if intended.") @@ -366,6 +366,20 @@ DIAGNOSTIC(30603, Error, invalidUseOfPrivateVisibility, "'$0' cannot have privat DIAGNOSTIC(30604, Error, useOfLessVisibleType, "'$0' references less visible type '$1'.") DIAGNOSTIC(36005, Error, invalidVisibilityModifierOnTypeOfDecl, "visibility modifier is not allowed on '$0'.") +// Capability +DIAGNOSTIC(36100, Error, conflictingCapabilityDueToUseOfDecl, "'$0' requires capability '$1' that is conflicting with the '$2''s current capability requirement '$3'.") +DIAGNOSTIC(36101, Error, conflictingCapabilityDueToStatement, "statement requires capability '$0' that is conflicting with the '$1''s current capability requirement '$2'.") +DIAGNOSTIC(36102, Error, conflictingCapabilityDueToStatementEnclosingFunc, "statement requires capability '$0' that is conflicting with the current function's capability requirement '$1'.") +DIAGNOSTIC(36103, Error, missingCapabilityRequirementOnPublicDecl, "public symbol '$0' is missing capability requirement declaration.") +DIAGNOSTIC(36104, Error, useOfUndeclaredCapability, "'$0' uses undeclared capability '$1'.") +DIAGNOSTIC(36104, Error, useOfUndeclaredCapabilityOfInterfaceRequirement, "'$0' uses capability '$1' that is missing from the interface requirement.") +DIAGNOSTIC(36105, Error, unknownCapability, "unknown capability name '$0'.") +DIAGNOSTIC(36106, Error, expectCapability, "expect a capability name.") +DIAGNOSTIC(36107, Error, entryPointUsesUnavailableCapability, "entrypoint '$0' requires capability '$1', which is incompatible with the current compilation target '$2'.") +DIAGNOSTIC(36108, Error, declHasDependenciesNotDefinedOnTarget, "'$0' has dependencies that are not defined on the required target '$1'.") +DIAGNOSTIC(36109, Error, invalidTargetSwitchCase, "'$0' cannot be used as a target_switch case.") +DIAGNOSTIC(36110, Error, stageIsInCompatibleWithCapabilityDefinition, "'$0' is defined for stage '$1', which is incompatible with the declared capability set '$2'.") + // Attributes DIAGNOSTIC(31000, Error, unknownAttributeName, "unknown attribute '$0'") DIAGNOSTIC(31001, Error, attributeArgumentCountMismatch, "attribute '$0' expects $1 arguments ($2 provided)") diff --git a/source/slang/slang-emit-glsl.cpp b/source/slang/slang-emit-glsl.cpp index a0151133d..248b803f4 100644 --- a/source/slang/slang-emit-glsl.cpp +++ b/source/slang/slang-emit-glsl.cpp @@ -64,43 +64,14 @@ SlangResult GLSLSourceEmitter::init() void GLSLSourceEmitter::_requireRayTracing() { - // There is more than one extension that provides ray-tracing capabilities, - // and we need to pick which one to enable. - // - // By default, we will use the `GL_EXT_ray_tracing` extension, but if - // the user has explicitly opted in to the `GL_NV_ray_tracing` extension - // we will use that one instead. - // - if( getTargetCaps().implies(CapabilityAtom::_GL_NV_ray_tracing) ) - { - m_glslExtensionTracker->requireExtension(UnownedStringSlice::fromLiteral("GL_NV_ray_tracing")); - } - else - { - m_glslExtensionTracker->requireExtension(UnownedStringSlice::fromLiteral("GL_EXT_ray_tracing")); - m_glslExtensionTracker->requireSPIRVVersion(SemanticVersion(1, 4)); - } - + m_glslExtensionTracker->requireExtension(UnownedStringSlice::fromLiteral("GL_EXT_ray_tracing")); + m_glslExtensionTracker->requireSPIRVVersion(SemanticVersion(1, 4)); m_glslExtensionTracker->requireVersion(ProfileVersion::GLSL_460); } void GLSLSourceEmitter::_requireFragmentShaderBarycentric() { - // There is more than one extension that provides barycentric coords in fragment shaders, - // and we need to pick which one to enable. - // - // By default, we will use the `GL_EXT_fragment_shader_barycentric` extension, but if - // the user has explicitly opted in to the `GL_NV_fragment_shader_barycentric` extension - // we will use that one instead. - - if( getTargetCaps().implies(CapabilityAtom::_GL_NV_fragment_shader_barycentric) ) - { - m_glslExtensionTracker->requireExtension(UnownedStringSlice::fromLiteral("GL_NV_fragment_shader_barycentric")); - } - else - { - m_glslExtensionTracker->requireExtension(UnownedStringSlice::fromLiteral("GL_EXT_fragment_shader_barycentric")); - } + m_glslExtensionTracker->requireExtension(UnownedStringSlice::fromLiteral("GL_EXT_fragment_shader_barycentric")); m_glslExtensionTracker->requireVersion(ProfileVersion::GLSL_450); } @@ -129,11 +100,6 @@ void GLSLSourceEmitter::_requireGLSLVersion(int version) { #define CASE(NUMBER) \ case NUMBER: _requireGLSLVersion(ProfileVersion::GLSL_##NUMBER); break - - CASE(110); - CASE(120); - CASE(130); - CASE(140); CASE(150); CASE(330); CASE(400); @@ -684,14 +650,7 @@ bool GLSLSourceEmitter::_emitGLSLLayoutQualifierWithBindingKinds(LayoutResourceK m_writer->emit("layout(push_constant)\n"); break; case LayoutResourceKind::ShaderRecord: - if (getTargetCaps().implies(CapabilityAtom::_GL_NV_ray_tracing)) - { - m_writer->emit("layout(shaderRecordNV)\n"); - } - else - { - m_writer->emit("layout(shaderRecordEXT)\n"); - } + m_writer->emit("layout(shaderRecordEXT)\n"); break; } @@ -1430,40 +1389,19 @@ void GLSLSourceEmitter::emitLayoutQualifiersImpl(IRVarLayout* layout) case LayoutResourceKind::RayPayload: { - if( getTargetCaps().implies(CapabilityAtom::_GL_NV_ray_tracing) ) - { - m_writer->emit("rayPayloadInNV "); - } - else - { - m_writer->emit("rayPayloadInEXT "); - } + m_writer->emit("rayPayloadInEXT "); } break; case LayoutResourceKind::CallablePayload: { - if( getTargetCaps().implies(CapabilityAtom::_GL_NV_ray_tracing) ) - { - m_writer->emit("callableDataInNV "); - } - else - { - m_writer->emit("callableDataInEXT "); - } + m_writer->emit("callableDataInEXT "); } break; case LayoutResourceKind::HitAttributes: { - if( getTargetCaps().implies(CapabilityAtom::_GL_NV_ray_tracing) ) - { - m_writer->emit("hitAttributeNV "); - } - else - { - m_writer->emit("hitAttributeEXT "); - } + m_writer->emit("hitAttributeEXT "); } break; @@ -2136,10 +2074,6 @@ static Index _getGLSLVersion(ProfileVersion profile) switch (profile) { #define CASE(TAG, VALUE) case ProfileVersion::TAG: return VALUE; - CASE(GLSL_110, 110); - CASE(GLSL_120, 120); - CASE(GLSL_130, 130); - CASE(GLSL_140, 140); CASE(GLSL_150, 150); CASE(GLSL_330, 330); CASE(GLSL_400, 400); @@ -2479,45 +2413,8 @@ void GLSLSourceEmitter::emitSimpleTypeImpl(IRType* type) { case kIROp_RaytracingAccelerationStructureType: { - // Note: We have the problem here that we want to do `_requireRayTracing()`, - // but just based on the use of a ray-tracing acceleration structure we - // cannot know which extension the user means to use. The current options are: - // - // * GL_NV_ray_tracing - // * GL_EXT_ray_tracing - // * GL_EXT_ray_query - // - // The first two options there are basically equivalent extensions with - // different GLSL syntax. We end up requiring the user to opt in to - // `GL_NV_ray_tracing` using target capabilities, and will always default - // to `GL_EXT_ray_tracing` otherwise. - // - if( getTargetCaps().implies(CapabilityAtom::_GL_NV_ray_tracing) ) - { - // If the user has explicitly opted in to `GL_NV_ray_tracing`, - // then we don't need to explicitly request the extentsion again. - // We know that the acceleration structure type will translate - // to the one from that extension: - // - _requireRayTracing(); - m_writer->emit("accelerationStructureNV"); - } - else - { - // If the user does *not* opt into a specific extension, then we - // have the problem that either `GL_EXT_ray_tracing` or `GL_EXT_ray-query` - // could provide the `accelerationSturctureEXT` type, but there - // can be drivers that provide only one and not the other. - // - // For now we will just kludge this by assuming that any driver - // that supports one of these extensions supports the other. - // - // TODO: Revisit that decision once the driver landscape is more stable/clear. - // - _requireRayTracing(); - - m_writer->emit("accelerationStructureEXT"); - } + _requireRayTracing(); + m_writer->emit("accelerationStructureEXT"); break; } @@ -2580,14 +2477,7 @@ bool GLSLSourceEmitter::_maybeEmitInterpolationModifierText(IRInterpolationMode if( stage == Stage::Fragment && isInput) { _requireFragmentShaderBarycentric(); - if (getTargetCaps().implies(CapabilityAtom::_GL_NV_fragment_shader_barycentric)) - { - m_writer->emit("pervertexNV "); - } - else - { - m_writer->emit("pervertexEXT "); - } + m_writer->emit("pervertexEXT "); } else { @@ -2694,6 +2584,7 @@ void GLSLSourceEmitter::emitVarDecorationsImpl(IRInst* varDecl) for (auto decoration : varDecl->getDecorations()) { UnownedStringSlice prefix; + UnownedStringSlice postfix = toSlice("EXT"); if (as<IRVulkanHitAttributesDecoration>(decoration)) { prefix = toSlice("hitAttribute"); @@ -2713,6 +2604,7 @@ void GLSLSourceEmitter::emitVarDecorationsImpl(IRInst* varDecl) break; case kIROp_VulkanHitObjectAttributesDecoration: prefix = toSlice("hitObjectAttribute"); + postfix = toSlice("NV"); locationValue = getIntVal(decoration->getOperand(0)); break; default: @@ -2725,17 +2617,7 @@ void GLSLSourceEmitter::emitVarDecorationsImpl(IRInst* varDecl) SLANG_ASSERT(prefix.getLength()); m_writer->emit(prefix); - - // Special case hitObjectAttribute as is only NV currently - if (decoration->getOp() == kIROp_VulkanHitObjectAttributesDecoration || - getTargetCaps().implies(CapabilityAtom::_GL_NV_ray_tracing)) - { - m_writer->emit(toSlice("NV")); - } - else - { - m_writer->emit(toSlice("EXT")); - } + m_writer->emit(postfix); m_writer->emit(toSlice("\n")); // If we emit a location we are done. diff --git a/source/slang/slang-emit-spirv.cpp b/source/slang/slang-emit-spirv.cpp index 192a40f54..7ff6ac7d6 100644 --- a/source/slang/slang-emit-spirv.cpp +++ b/source/slang/slang-emit-spirv.cpp @@ -3359,18 +3359,9 @@ struct SPIRVEmitContext } else if (semanticName == "sv_barycentrics") { - if (m_targetRequest->getTargetCaps().implies(CapabilityAtom::_GL_NV_fragment_shader_barycentric)) - { - requireSPIRVCapability(SpvCapabilityFragmentBarycentricNV); - ensureExtensionDeclaration(UnownedStringSlice("SPV_NV_fragment_shader_barycentric")); - return getBuiltinGlobalVar(inst->getFullType(), SpvBuiltInBaryCoordNV); - } - else - { - requireSPIRVCapability(SpvCapabilityFragmentBarycentricKHR); - ensureExtensionDeclaration(UnownedStringSlice("SPV_KHR_fragment_shader_barycentric")); - return getBuiltinGlobalVar(inst->getFullType(), SpvBuiltInBaryCoordKHR); - } + requireSPIRVCapability(SpvCapabilityFragmentBarycentricKHR); + ensureExtensionDeclaration(UnownedStringSlice("SPV_KHR_fragment_shader_barycentric")); + return getBuiltinGlobalVar(inst->getFullType(), SpvBuiltInBaryCoordKHR); // TODO: There is also the `gl_BaryCoordNoPerspNV` builtin, which // we ought to use if the `noperspective` modifier has been diff --git a/source/slang/slang-glsl-extension-tracker.h b/source/slang/slang-glsl-extension-tracker.h index cee11cad5..966e1e927 100644 --- a/source/slang/slang-glsl-extension-tracker.h +++ b/source/slang/slang-glsl-extension-tracker.h @@ -39,7 +39,7 @@ protected: uint32_t m_hasBaseTypeFlags = _getFlag(BaseType::Float) | _getFlag(BaseType::Int) | _getFlag(BaseType::UInt) | _getFlag(BaseType::Void) | _getFlag(BaseType::Bool); - ProfileVersion m_profileVersion = ProfileVersion::GLSL_110; + ProfileVersion m_profileVersion = ProfileVersion::GLSL_150; StringSlicePool m_extensionPool; diff --git a/source/slang/slang-ir-glsl-legalize.cpp b/source/slang/slang-ir-glsl-legalize.cpp index 0d279549c..dd165769c 100644 --- a/source/slang/slang-ir-glsl-legalize.cpp +++ b/source/slang/slang-ir-glsl-legalize.cpp @@ -518,6 +518,8 @@ GLSLSystemValueInfo* getGLSLSystemValueInfo( GlobalVaryingDeclarator* declarator, GLSLSystemValueInfo* inStorage) { + SLANG_UNUSED(codeGenContext); + if(auto indicesSemantic = getMeshOutputIndicesSystemValueInfo( context, kind, @@ -914,16 +916,8 @@ GLSLSystemValueInfo* getGLSLSystemValueInfo( else if (semanticName == "sv_barycentrics") { context->requireGLSLVersion(ProfileVersion::GLSL_450); - if (codeGenContext->getTargetCaps().implies(CapabilityAtom::_GL_NV_fragment_shader_barycentric)) - { - context->requireGLSLExtension(UnownedStringSlice::fromLiteral("GL_NV_fragment_shader_barycentric")); - name = "gl_BaryCoordNV"; - } - else - { - context->requireGLSLExtension(UnownedStringSlice::fromLiteral("GL_EXT_fragment_shader_barycentric")); - name = "gl_BaryCoordEXT"; - } + context->requireGLSLExtension(UnownedStringSlice::fromLiteral("GL_EXT_fragment_shader_barycentric")); + name = "gl_BaryCoordEXT"; // TODO: There is also the `gl_BaryCoordNoPerspNV` builtin, which // we ought to use if the `noperspective` modifier has been diff --git a/source/slang/slang-language-server-completion.cpp b/source/slang/slang-language-server-completion.cpp index 6038a432a..b723e14b8 100644 --- a/source/slang/slang-language-server-completion.cpp +++ b/source/slang/slang-language-server-completion.cpp @@ -443,6 +443,11 @@ List<LanguageServerProtocol::CompletionItem> CompletionContext::collectMembersAn linkage->contentAssistInfo.completionSuggestions.swizzleBaseType, linkage->contentAssistInfo.completionSuggestions.elementCount); } + else if (linkage->contentAssistInfo.completionSuggestions.scopeKind == + CompletionSuggestions::ScopeKind::Capabilities) + { + return createCapabilityCandidates(); + } List<LanguageServerProtocol::CompletionItem> result; bool useCommitChars = true; bool addKeywords = false; @@ -595,6 +600,24 @@ List<LanguageServerProtocol::CompletionItem> CompletionContext::collectMembersAn return result; } +List<LanguageServerProtocol::CompletionItem> CompletionContext::createCapabilityCandidates() +{ + List<LanguageServerProtocol::CompletionItem> result; + List<UnownedStringSlice> names; + getCapabilityNames(names); + for (auto name : names.getArrayView(1, names.getCount()-1)) + { + if (name.startsWith("_")) + continue; + LanguageServerProtocol::CompletionItem item; + item.data = 0; + item.kind = LanguageServerProtocol::kCompletionItemKindEnumMember; + item.label = name; + result.add(item); + } + return result; +} + List<LanguageServerProtocol::CompletionItem> CompletionContext::createSwizzleCandidates( Type* type, IntegerLiteralValue elementCount[2]) { diff --git a/source/slang/slang-language-server-completion.h b/source/slang/slang-language-server-completion.h index 5a09ba371..d3910bcfd 100644 --- a/source/slang/slang-language-server-completion.h +++ b/source/slang/slang-language-server-completion.h @@ -39,6 +39,7 @@ struct CompletionContext List<LanguageServerProtocol::CompletionItem> collectMembersAndSymbols(); List<LanguageServerProtocol::CompletionItem> createSwizzleCandidates( Type* baseType, IntegerLiteralValue elementCount[2]); + List<LanguageServerProtocol::CompletionItem> createCapabilityCandidates(); List<LanguageServerProtocol::CompletionItem> collectAttributes(); LanguageServerProtocol::CompletionItem generateGUIDCompletionItem(); List<LanguageServerProtocol::TextEditCompletionItem> gatherFileAndModuleCompletionItems( diff --git a/source/slang/slang-language-server-inlay-hints.cpp b/source/slang/slang-language-server-inlay-hints.cpp index 35b603b20..0eee347d1 100644 --- a/source/slang/slang-language-server-inlay-hints.cpp +++ b/source/slang/slang-language-server-inlay-hints.cpp @@ -18,7 +18,7 @@ List<LanguageServerProtocol::InlayHint> getInlayHints( List<LanguageServerProtocol::InlayHint> result; auto manager = linkage->getSourceManager(); auto docText = doc->getText().getUnownedSlice(); - iterateAST(fileName, manager, module->getModuleDecl(), [&](SyntaxNode* node) + iterateASTWithLanguageServerFilter(fileName, manager, module->getModuleDecl(), [&](SyntaxNode* node) { if (auto invokeExpr = as<InvokeExpr>(node)) { diff --git a/source/slang/slang-language-server-semantic-tokens.cpp b/source/slang/slang-language-server-semantic-tokens.cpp index ae10d62e8..3a40c8e92 100644 --- a/source/slang/slang-language-server-semantic-tokens.cpp +++ b/source/slang/slang-language-server-semantic-tokens.cpp @@ -126,7 +126,7 @@ List<SemanticToken> getSemanticTokens(Linkage* linkage, Module* module, UnownedS } maybeInsertToken(token); }; - iterateAST( + iterateASTWithLanguageServerFilter( fileName, manager, module->getModuleDecl(), @@ -240,8 +240,34 @@ List<SemanticToken> getSemanticTokens(Linkage* linkage, Module* module, UnownedS token.length = (int)attr->originalIdentifierToken.getContentLength(); token.type = SemanticTokenType::Type; maybeInsertToken(token); + + // Insert capability names as enum cases. + if (as<RequireCapabilityAttribute>(attr)) + { + for (auto arg : attr->args) + { + if (auto varExpr = as<VarExpr>(arg)) + { + if (varExpr->name) + { + SemanticToken capToken = _createSemanticToken( + manager, varExpr->loc, nullptr); + capToken.length = (int)varExpr->name->text.getLength(); + capToken.type = SemanticTokenType::EnumMember; + maybeInsertToken(capToken); + } + } + } + } } } + else if (auto targetCase = as<TargetCaseStmt>(node)) + { + SemanticToken token = _createSemanticToken( + manager, targetCase->capabilityToken.loc, targetCase->capabilityToken.getName()); + token.type = SemanticTokenType::EnumMember; + maybeInsertToken(token); + } else if (auto spirvAsmExpr = as<SPIRVAsmExpr>(node)) { // Highlight opcodes and enums. diff --git a/source/slang/slang-language-server.cpp b/source/slang/slang-language-server.cpp index 76e21c2ca..0256464c9 100644 --- a/source/slang/slang-language-server.cpp +++ b/source/slang/slang-language-server.cpp @@ -127,6 +127,8 @@ SlangResult LanguageServer::parseNextMessage() caps.completionProvider.triggerCharacters.add("."); caps.completionProvider.triggerCharacters.add(":"); caps.completionProvider.triggerCharacters.add("["); + caps.completionProvider.triggerCharacters.add(" "); + caps.completionProvider.triggerCharacters.add("("); caps.completionProvider.triggerCharacters.add("\""); caps.completionProvider.triggerCharacters.add("/"); caps.completionProvider.resolveProvider = true; @@ -985,6 +987,12 @@ SlangResult LanguageServer::completion( context.line = utf8Line; context.col = utf8Col; context.commitCharacterBehavior = m_commitCharacterBehavior; + if (args.context.triggerKind == kCompletionTriggerKindTriggerCharacter && + (args.context.triggerCharacter == " " || args.context.triggerCharacter == "[" || args.context.triggerCharacter == "(")) + { + // Never use commit character if completion request is triggerred by these characters to prevent annoyance. + context.commitCharacterBehavior = CommitCharacterBehavior::Disabled; + } if (SLANG_SUCCEEDED(context.tryCompleteInclude())) { diff --git a/source/slang/slang-options.cpp b/source/slang/slang-options.cpp index 3edc7dfb5..ca8b60d31 100644 --- a/source/slang/slang-options.cpp +++ b/source/slang/slang-options.cpp @@ -3185,10 +3185,10 @@ SlangResult OptionsParser::parse( m_sink = nullptr; - if (requestSink->getErrorCount() > 0) + if (m_parseSink.getErrorCount() > 0) { // Put the errors in the diagnostic - m_requestImpl->m_diagnosticOutput = requestSink->outputBuffer.produceString(); + m_requestImpl->m_diagnosticOutput = m_parseSink.outputBuffer.produceString(); } return res; diff --git a/source/slang/slang-parser.cpp b/source/slang/slang-parser.cpp index c5007569e..3940e59cf 100644 --- a/source/slang/slang-parser.cpp +++ b/source/slang/slang-parser.cpp @@ -4900,6 +4900,7 @@ namespace Slang parser->sink->diagnose(caseName.loc, Diagnostics::unknownTargetName, caseName.getContent()); } targetCase->capability = int32_t(cap); + targetCase->capabilityToken = caseName; targetCase->loc = caseName.loc; targetCase->body = bodyStmt; stmt->targetCases.add(targetCase); diff --git a/source/slang/slang-profile-defs.h b/source/slang/slang-profile-defs.h index 2eaf6f897..9a9c128a5 100644 --- a/source/slang/slang-profile-defs.h +++ b/source/slang/slang-profile-defs.h @@ -88,12 +88,7 @@ PROFILE_FAMILY(DX) PROFILE_FAMILY(GLSL) // Profile versions - - PROFILE_VERSION(DX_4_0, DX) -PROFILE_VERSION(DX_4_0_Level_9_0, DX) -PROFILE_VERSION(DX_4_0_Level_9_1, DX) -PROFILE_VERSION(DX_4_0_Level_9_3, DX) PROFILE_VERSION(DX_4_1, DX) PROFILE_VERSION(DX_5_0, DX) PROFILE_VERSION(DX_5_1, DX) @@ -106,10 +101,6 @@ PROFILE_VERSION(DX_6_5, DX) PROFILE_VERSION(DX_6_6, DX) PROFILE_VERSION(DX_6_7, DX) -PROFILE_VERSION(GLSL_110, GLSL) -PROFILE_VERSION(GLSL_120, GLSL) -PROFILE_VERSION(GLSL_130, GLSL) -PROFILE_VERSION(GLSL_140, GLSL) PROFILE_VERSION(GLSL_150, GLSL) PROFILE_VERSION(GLSL_330, GLSL) PROFILE_VERSION(GLSL_400, GLSL) @@ -122,7 +113,6 @@ PROFILE_VERSION(GLSL_460, GLSL) // Specific profiles - PROFILE(DX_Compute_4_0, cs_4_0, Compute, DX_4_0) PROFILE(DX_Compute_4_1, cs_4_1, Compute, DX_4_1) PROFILE(DX_Compute_5_0, cs_5_0, Compute, DX_5_0) @@ -160,7 +150,6 @@ PROFILE(DX_Geometry_6_5, gs_6_5, Geometry, DX_6_5) PROFILE(DX_Geometry_6_6, gs_6_6, Geometry, DX_6_6) PROFILE(DX_Geometry_6_7, gs_6_7, Geometry, DX_6_7) - PROFILE(DX_Hull_5_0, hs_5_0, Hull, DX_5_0) PROFILE(DX_Hull_5_1, hs_5_1, Hull, DX_5_1) PROFILE(DX_Hull_6_0, hs_6_0, Hull, DX_6_0) @@ -172,13 +161,9 @@ PROFILE(DX_Hull_6_5, hs_6_5, Hull, DX_6_5) PROFILE(DX_Hull_6_6, hs_6_6, Hull, DX_6_6) PROFILE(DX_Hull_6_7, hs_6_7, Hull, DX_6_7) - -PROFILE(DX_Fragment_4_0, ps_4_0, Fragment, DX_4_0) -PROFILE(DX_Fragment_4_0_Level_9_0, ps_4_0_level_9_0, Fragment, DX_4_0_Level_9_0) -PROFILE(DX_Fragment_4_0_Level_9_1, ps_4_0_level_9_1, Fragment, DX_4_0_Level_9_1) -PROFILE(DX_Fragment_4_0_Level_9_3, ps_4_0_level_9_3, Fragment, DX_4_0_Level_9_3) -PROFILE(DX_Fragment_4_1, ps_4_1, Fragment, DX_4_1) -PROFILE(DX_Fragment_5_0, ps_5_0, Fragment, DX_5_0) +PROFILE(DX_Fragment_4_0, ps_4_0, Fragment, DX_4_0) +PROFILE(DX_Fragment_4_1, ps_4_1, Fragment, DX_4_1) +PROFILE(DX_Fragment_5_0, ps_5_0, Fragment, DX_5_0) PROFILE(DX_Fragment_5_1, ps_5_1, Fragment, DX_5_1) PROFILE(DX_Fragment_6_0, ps_6_0, Fragment, DX_6_0) PROFILE(DX_Fragment_6_1, ps_6_1, Fragment, DX_6_1) @@ -189,11 +174,7 @@ PROFILE(DX_Fragment_6_5, ps_6_5, Fragment, DX_6_5) PROFILE(DX_Fragment_6_6, ps_6_6, Fragment, DX_6_6) PROFILE(DX_Fragment_6_7, ps_6_7, Fragment, DX_6_7) - PROFILE(DX_Vertex_4_0, vs_4_0, Vertex, DX_4_0) -PROFILE(DX_Vertex_4_0_Level_9_0, vs_4_0_level_9_0, Vertex, DX_4_0_Level_9_0) -PROFILE(DX_Vertex_4_0_Level_9_1, vs_4_0_level_9_1, Vertex, DX_4_0_Level_9_1) -PROFILE(DX_Vertex_4_0_Level_9_3, vs_4_0_level_9_3, Vertex, DX_4_0_Level_9_3) PROFILE(DX_Vertex_4_1, vs_4_1, Vertex, DX_4_1) PROFILE(DX_Vertex_5_0, vs_5_0, Vertex, DX_5_0) PROFILE(DX_Vertex_5_1, vs_5_1, Vertex, DX_5_1) @@ -216,9 +197,6 @@ PROFILE(DX_Amplification_6_7, as_6_7, Amplification, DX_6_7) // TODO: consider making `lib_*_*` alias these... PROFILE(DX_None_4_0, sm_4_0, Unknown, DX_4_0) -PROFILE(DX_None_4_0_Level_9_0, sm_4_0_level_9_0, Unknown, DX_4_0_Level_9_0) -PROFILE(DX_None_4_0_Level_9_1, sm_4_0_level_9_1, Unknown, DX_4_0_Level_9_1) -PROFILE(DX_None_4_0_Level_9_3, sm_4_0_level_9_3, Unknown, DX_4_0_Level_9_3) PROFILE(DX_None_4_1, sm_4_1, Unknown, DX_4_1) PROFILE(DX_None_5_0, sm_5_0, Unknown, DX_5_0) PROFILE(DX_None_5_1, sm_5_1, Unknown, DX_5_1) @@ -254,10 +232,6 @@ PROFILE_ALIAS(DX_None_6_7, DX_Lib_6_7, sm_6_7) // Define all the GLSL profiles -PROFILE(GLSL_None_110, glsl_110, Unknown, GLSL_110) -PROFILE(GLSL_None_120, glsl_120, Unknown, GLSL_120) -PROFILE(GLSL_None_130, glsl_130, Unknown, GLSL_130) -PROFILE(GLSL_None_140, glsl_140, Unknown, GLSL_140) PROFILE(GLSL_None_150, glsl_150, Unknown, GLSL_150) PROFILE(GLSL_None_330, glsl_330, Unknown, GLSL_330) PROFILE(GLSL_None_400, glsl_400, Unknown, GLSL_400) @@ -271,10 +245,6 @@ PROFILE(GLSL_None_460, glsl_460, Unknown, GLSL_460) #define P(UPPER, LOWER, VERSION) \ PROFILE(GLSL_##UPPER##_##VERSION, glsl_##LOWER##_##VERSION, UPPER, GLSL_##VERSION) -P(Vertex, vertex, 110) -P(Vertex, vertex, 120) -P(Vertex, vertex, 130) -P(Vertex, vertex, 140) P(Vertex, vertex, 150) P(Vertex, vertex, 330) P(Vertex, vertex, 400) @@ -284,10 +254,6 @@ P(Vertex, vertex, 430) P(Vertex, vertex, 440) P(Vertex, vertex, 450) -P(Fragment, fragment, 110) -P(Fragment, fragment, 120) -P(Fragment, fragment, 130) -P(Fragment, fragment, 140) P(Fragment, fragment, 150) P(Fragment, fragment, 330) P(Fragment, fragment, 400) diff --git a/source/slang/slang-profile.h b/source/slang/slang-profile.h index a0284215f..a1c08fe6a 100644 --- a/source/slang/slang-profile.h +++ b/source/slang/slang-profile.h @@ -3,6 +3,7 @@ #include "../core/slang-basic.h" #include "../../slang.h" +#include "slang-capability.h" namespace Slang { @@ -109,6 +110,8 @@ namespace Slang static Profile lookUp(char const* name); char const* getName(); + List<CapabilityName> getCapabilityName(); + RawVal raw = Unknown; }; diff --git a/source/slang/slang-serialize-ast-type-info.h b/source/slang/slang-serialize-ast-type-info.h index f5d636b01..f7d8cab08 100644 --- a/source/slang/slang-serialize-ast-type-info.h +++ b/source/slang/slang-serialize-ast-type-info.h @@ -78,6 +78,50 @@ struct PtrSerialTypeInfo<T, std::enable_if_t<std::is_base_of_v<Val, T>>> template <typename T> struct SerialTypeInfo<DeclRef<T>> : public SerialTypeInfo<DeclRefBase*> {}; +template<> +struct SerialTypeInfo<CapabilitySet> +{ + typedef CapabilitySet NativeType; + typedef SerialIndex SerialType; + enum { SerialAlignment = SLANG_ALIGN_OF(SerialType) }; + static void toSerial(SerialWriter* writer, const void* native, void* serial) + { + auto& src = *(const NativeType*)native; + auto& dst = *(SerialType*)serial; + + dst = writer->addArray(src.getExpandedAtoms().getBuffer(), src.getExpandedAtoms().getCount()); + } + static void toNative(SerialReader* reader, const void* serial, void* native) + { + auto& dst = *(NativeType*)native; + auto& src = *(const SerialType*)serial; + + reader->getArray(src, dst.getExpandedAtoms()); + } +}; + +template<> +struct SerialTypeInfo<CapabilityConjunctionSet> +{ + typedef CapabilityConjunctionSet NativeType; + typedef SerialIndex SerialType; + enum { SerialAlignment = SLANG_ALIGN_OF(SerialType) }; + static void toSerial(SerialWriter* writer, const void* native, void* serial) + { + auto& src = *(const NativeType*)native; + auto& dst = *(SerialType*)serial; + + dst = writer->addArray(src.getExpandedAtoms().getBuffer(), src.getExpandedAtoms().getCount()); + } + static void toNative(SerialReader* reader, const void* serial, void* native) + { + auto& dst = *(NativeType*)native; + auto& src = *(const SerialType*)serial; + + reader->getArray(src, dst.getExpandedAtoms()); + } +}; + // ValNodeOperand template <> struct SerialTypeInfo<ValNodeOperand> diff --git a/source/slang/slang.cpp b/source/slang/slang.cpp index b95b21bb5..e99f94484 100644 --- a/source/slang/slang.cpp +++ b/source/slang/slang.cpp @@ -858,7 +858,7 @@ Profile getEffectiveProfile(EntryPoint* entryPoint, TargetRequest* target) case CodeGenTarget::SPIRVAssembly: if(targetProfile.getFamily() != ProfileFamily::GLSL) { - targetProfile.setVersion(ProfileVersion::GLSL_110); + targetProfile.setVersion(ProfileVersion::GLSL_150); } break; @@ -869,7 +869,7 @@ Profile getEffectiveProfile(EntryPoint* entryPoint, TargetRequest* target) case CodeGenTarget::DXILAssembly: if(targetProfile.getFamily() != ProfileFamily::DX) { - targetProfile.setVersion(ProfileVersion::DX_4_0); + targetProfile.setVersion(ProfileVersion::DX_5_1); } break; } @@ -1608,6 +1608,8 @@ CapabilitySet TargetRequest::getTargetCaps() break; } + CapabilitySet targetCap = CapabilitySet(atoms); + CapabilitySet latestSpirvCapSet = CapabilitySet(CapabilityName::spirv_latest); CapabilityName latestSpirvAtom = (CapabilityName)latestSpirvCapSet.getExpandedAtoms()[0].getExpandedAtoms().getLast(); for (auto atom : rawCapabilities) @@ -1623,7 +1625,11 @@ CapabilitySet TargetRequest::getTargetCaps() atom = (CapabilityName)((Int)CapabilityName::glsl_spirv_1_0 + ((Int)atom - (Int)CapabilityName::spirv_1_0)); } } - atoms.add(atom); + if (!targetCap.isIncompatibleWith(atom)) + { + // Only add atoms that are compatible with the current target. + atoms.add(atom); + } } cookedCapabilities = CapabilitySet(atoms); diff --git a/tests/bugs/vk-structured-buffer-load.hlsl.glsl b/tests/bugs/vk-structured-buffer-load.hlsl.glsl index 35fad779b..93181d0a3 100644 --- a/tests/bugs/vk-structured-buffer-load.hlsl.glsl +++ b/tests/bugs/vk-structured-buffer-load.hlsl.glsl @@ -1,12 +1,10 @@ #version 460 -#extension GL_NV_ray_tracing : require +#extension GL_EXT_ray_tracing : require layout(row_major) uniform; layout(row_major) buffer; - layout(std430, binding = 1) readonly buffer StructuredBuffer_float_t_0 { float _data[]; } gParamBlock_sbuf_0; - float rcp_0(float x_0) { return 1.0 / x_0; @@ -17,24 +15,22 @@ struct RayHitInfoPacked_0 vec4 PackedHitInfoA_0; }; -rayPayloadInNV RayHitInfoPacked_0 _S1; +rayPayloadInEXT RayHitInfoPacked_0 _S1; struct BuiltInTriangleIntersectionAttributes_0 { vec2 barycentrics_0; }; -hitAttributeNV BuiltInTriangleIntersectionAttributes_0 _S2; +hitAttributeEXT BuiltInTriangleIntersectionAttributes_0 _S2; void main() { - float HitT_0 = ((gl_RayTmaxNV)); + float HitT_0 = ((gl_RayTmaxEXT)); _S1.PackedHitInfoA_0[0] = HitT_0; float offsfloat_0 = gParamBlock_sbuf_0._data[0]; - uint use_rcp_0 = 0U | uint(HitT_0 > 0.0); - if(use_rcp_0 != 0U) { _S1.PackedHitInfoA_0[1] = rcp_0(offsfloat_0); diff --git a/tests/cross-compile/barycentrics-nv.slang b/tests/cross-compile/barycentrics-nv.slang index fb0272679..60070f913 100644 --- a/tests/cross-compile/barycentrics-nv.slang +++ b/tests/cross-compile/barycentrics-nv.slang @@ -1,4 +1,9 @@ -//TEST:CROSS_COMPILE: -target spirv-assembly -capability GL_NV_fragment_shader_barycentric -entry main -stage fragment +//TEST:SIMPLE(filecheck=CHECK): -target spirv-assembly -capability GL_NV_fragment_shader_barycentric -entry main -stage fragment + +// CHECK: OpCapability FragmentBarycentricKHR +// CHECK: OpDecorate [[NAME:%[A-Za-z0-9_]+]] BuiltIn BaryCoordKHR +// CHECK: [[NAME]] = OpVariable {{.*}} Input +// CHECK: {{.*}} = OpLoad %v3float [[NAME]] float4 main(float3 bary : SV_Barycentrics) : SV_Target { diff --git a/tests/cross-compile/barycentrics-nv.slang.glsl b/tests/cross-compile/barycentrics-nv.slang.glsl deleted file mode 100644 index 583310125..000000000 --- a/tests/cross-compile/barycentrics-nv.slang.glsl +++ /dev/null @@ -1,12 +0,0 @@ -#version 450 - -#extension GL_NV_fragment_shader_barycentric : enable - -layout(location = 0) -out vec4 main_0; - -void main() -{ - main_0 = vec4(gl_BaryCoordNV, float(0)); - return; -} diff --git a/tests/diagnostics/discard-in-compute.slang b/tests/diagnostics/discard-in-compute.slang new file mode 100644 index 000000000..e530881bd --- /dev/null +++ b/tests/diagnostics/discard-in-compute.slang @@ -0,0 +1,13 @@ +//TEST:SIMPLE(filecheck=CHECK): -target hlsl -entry main -profile cs_6_1 +void test() +{ + discard; // This should lead to `test` having `fragment` capability requirement. +} + +[shader("compute")] +[numthreads(1,1,1)] +void main() +{ + // CHECK: error 36107 + test(); // compute shader cannot call `test` that require capabiltiy `fragment`. +} diff --git a/tests/language-feature/capability/capability1.slang b/tests/language-feature/capability/capability1.slang new file mode 100644 index 000000000..bccccb964 --- /dev/null +++ b/tests/language-feature/capability/capability1.slang @@ -0,0 +1,28 @@ +//TEST:SIMPLE(filecheck=CHECK): -target spirv -emit-spirv-directly -entry main2 -stage compute + +[require(spvShaderClockKHR)] +void leafFunc1() {} + +[require(spvShaderNonUniform)] +void leafFunc2() {} + +void caller() +{ + leafFunc1(); + leafFunc2(); +} + +[require(spirv, shaderclock)] +// CHECK: ([[# @LINE+1]]): error 36104: +void main1() +{ + caller(); // Error, shaderclock does not imply spvShaderNonUniform. +} + + +[require(spirv, shaderclock)] +void main2() +{ + // CHECK-NOT: error + leafFunc1(); // OK, shaderclock implies spvShaderClockKHR. +} diff --git a/tests/language-feature/capability/capability2.slang b/tests/language-feature/capability/capability2.slang new file mode 100644 index 000000000..743f998cf --- /dev/null +++ b/tests/language-feature/capability/capability2.slang @@ -0,0 +1,61 @@ +//TEST:SIMPLE(filecheck=CHECK): -target spirv -emit-spirv-directly -entry main -stage compute +module test; + +[require(spvAtomicFloat16AddEXT)] +interface IFoo +{ + [require(spvRayQueryKHR)] + void method1(); + + void method2(); +} + +[require(spvGroupNonUniformArithmetic)] +void useNonUniformArithmetic() +{} + +[require(spvRayQueryKHR)] +void useRayQueryKHR() +{} + +[require(spvAtomicFloat16AddEXT)] +void useAtomicFloat16() +{} + +// This should be OK, uses nothing past what is declared in the interface. +struct Impl1 : IFoo +{ + void method1() + { + useAtomicFloat16(); + useRayQueryKHR(); + } + + void method2() + { + useAtomicFloat16(); + } +} + +// CHECK-NOT: error 361 + +struct Impl2 : IFoo +{ + // CHECK: ([[# @LINE+1]]): error 36104: {{.*}}spvGroupNonUniformArithmetic + void method1() + { + useRayQueryKHR(); // OK. + useNonUniformArithmetic(); // error. + } + // CHECK-NOT: error 361 + + // CHECK: ([[# @LINE+1]]): error 36104: {{.*}}spvGroupNonUniformArithmetic + void method2() + { + useAtomicFloat16(); + useNonUniformArithmetic(); // error. + } +} + +void main() +{} diff --git a/tests/pipeline/rasterization/get-attribute-at-vertex-nv.slang.glsl b/tests/pipeline/rasterization/get-attribute-at-vertex-nv.slang.glsl index 1897c6467..820918d8b 100644 --- a/tests/pipeline/rasterization/get-attribute-at-vertex-nv.slang.glsl +++ b/tests/pipeline/rasterization/get-attribute-at-vertex-nv.slang.glsl @@ -2,11 +2,11 @@ //TEST_IGNORE_FILE: #version 450 -#extension GL_NV_fragment_shader_barycentric : require +#extension GL_EXT_fragment_shader_barycentric : require layout(row_major) uniform; layout(row_major) buffer; -pervertexNV layout(location = 0) +pervertexEXT layout(location = 0) in vec4 color_0[3]; layout(location = 0) @@ -14,6 +14,7 @@ out vec4 result_0; void main() { - result_0 = gl_BaryCoordNV.x * ((color_0)[(0U)]) + gl_BaryCoordNV.y * ((color_0)[(1U)]) + gl_BaryCoordNV.z * ((color_0)[(2U)]); + result_0 = gl_BaryCoordEXT.x * ((color_0)[(0U)]) + gl_BaryCoordEXT.y * ((color_0)[(1U)]) + gl_BaryCoordEXT.z * ((color_0)[(2U)]); return; } + diff --git a/tests/pipeline/rasterization/get-attribute-at-vertex-nv.slang.hlsl b/tests/pipeline/rasterization/get-attribute-at-vertex-nv.slang.hlsl index ce23492c9..a6b45eab4 100644 --- a/tests/pipeline/rasterization/get-attribute-at-vertex-nv.slang.hlsl +++ b/tests/pipeline/rasterization/get-attribute-at-vertex-nv.slang.hlsl @@ -2,6 +2,8 @@ //TEST_IGNORE_FILE: +#pragma warning(disable: 3557) + [shader("pixel")] void main( nointerpolation vector<float,4> color_0 : COLOR, diff --git a/tests/pipeline/rasterization/get-attribute-at-vertex.slang.hlsl b/tests/pipeline/rasterization/get-attribute-at-vertex.slang.hlsl index ce23492c9..9322964d5 100644 --- a/tests/pipeline/rasterization/get-attribute-at-vertex.slang.hlsl +++ b/tests/pipeline/rasterization/get-attribute-at-vertex.slang.hlsl @@ -11,4 +11,4 @@ void main( result_0 = bary_0.x * GetAttributeAtVertex(color_0, 0U) + bary_0.y * GetAttributeAtVertex(color_0, 1U) + bary_0.z * GetAttributeAtVertex(color_0, 2U); -} +}
\ No newline at end of file diff --git a/tests/vkray/callable-caller.slang b/tests/vkray/callable-caller.slang index 64311988a..6a0c85c38 100644 --- a/tests/vkray/callable-caller.slang +++ b/tests/vkray/callable-caller.slang @@ -1,6 +1,8 @@ // callable-caller.slang -//TEST:CROSS_COMPILE: -profile glsl_460 -capability GL_NV_ray_tracing -stage raygeneration -entry main -target spirv-assembly +//TEST:SIMPLE(filecheck=CHECK): -profile glsl_460 -capability GL_NV_ray_tracing -stage raygeneration -entry main -target spirv-assembly +//TEST:SIMPLE(filecheck=CHECK): -profile glsl_460 -capability GL_NV_ray_tracing -stage raygeneration -entry main -target spirv-assembly -emit-spirv-directly + import callable_shared; @@ -16,7 +18,7 @@ void main() MaterialPayload payload; payload.albedo = 0; payload.uv = float2(DispatchRaysIndex().xy) / float2(DispatchRaysDimensions().xy); - + // CHECK: OpExecuteCallable CallShader(shaderIndex, payload); gImage[DispatchRaysIndex().xy] = payload.albedo; diff --git a/tests/vkray/callable-caller.slang.glsl b/tests/vkray/callable-caller.slang.glsl deleted file mode 100644 index a42e6eaf3..000000000 --- a/tests/vkray/callable-caller.slang.glsl +++ /dev/null @@ -1,49 +0,0 @@ -#version 460 -#extension GL_NV_ray_tracing : require -layout(row_major) uniform; -layout(row_major) buffer; -struct MaterialPayload_0 -{ - vec4 albedo_0; - vec2 uv_0; -}; - -layout(location = 0) -callableDataNV -MaterialPayload_0 p_0; - -struct SLANG_ParameterGroup_C_0 -{ - uint shaderIndex_0; -}; - -layout(binding = 0) -layout(std140) uniform _S1 -{ - uint shaderIndex_0; -} C_0; -void CallShader_0(uint shaderIndex_1, inout MaterialPayload_0 payload_0) -{ - p_0 = payload_0; - executeCallableNV(shaderIndex_1, (0)); - payload_0 = p_0; - return; -} - -layout(rgba32f) -layout(binding = 1) -uniform image2D gImage_0; - -void main() -{ - MaterialPayload_0 payload_1; - payload_1.albedo_0 = vec4(0.0); - uvec3 _S2 = ((gl_LaunchIDNV)); - vec2 _S3 = vec2(_S2.xy); - uvec3 _S4 = ((gl_LaunchSizeNV)); - payload_1.uv_0 = _S3 / vec2(_S4.xy); - CallShader_0(C_0.shaderIndex_0, payload_1); - uvec3 _S5 = ((gl_LaunchIDNV)); - imageStore((gImage_0), ivec2((_S5.xy)), payload_1.albedo_0); - return; -} diff --git a/tests/vkray/miss.slang.glsl b/tests/vkray/miss.slang.glsl index df7647411..1bc6af5b3 100644 --- a/tests/vkray/miss.slang.glsl +++ b/tests/vkray/miss.slang.glsl @@ -1,17 +1,7 @@ //TEST_IGNORE_FILE: #version 460 -#if USE_NV_RT -#extension GL_NV_ray_tracing : require -#define callableDataInEXT callableDataInNV -#define gl_LaunchIDEXT gl_LaunchIDNV -#define hitAttributeEXT hitAttributeNV -#define ignoreIntersectionEXT ignoreIntersectionNV -#define rayPayloadInEXT rayPayloadInNV -#define terminateRayEXT terminateRayNV -#else #extension GL_EXT_ray_tracing : require -#endif struct ShadowRay_0 { diff --git a/tools/slang-test/slang-test-main.cpp b/tools/slang-test/slang-test-main.cpp index 5de9a8dc6..c93bebf33 100644 --- a/tools/slang-test/slang-test-main.cpp +++ b/tools/slang-test/slang-test-main.cpp @@ -2698,6 +2698,8 @@ TestResult generateActualOutput(TestContext* const context, const TestInput& inp return TestResult::Pass; } + actualOutput = getOutput(actualExeRes); + // Always fail if the compilation produced a failure, just // to catch situations where, e.g., command-line options parsing // caused the same error in both the Slang and glslang cases. @@ -2707,7 +2709,6 @@ TestResult generateActualOutput(TestContext* const context, const TestInput& inp return TestResult::Fail; } - actualOutput = getOutput(actualExeRes); return TestResult::Pass; } diff --git a/tools/slang-test/slangc-tool.cpp b/tools/slang-test/slangc-tool.cpp index a62ea6975..4c2a3244c 100644 --- a/tools/slang-test/slangc-tool.cpp +++ b/tools/slang-test/slangc-tool.cpp @@ -49,7 +49,7 @@ SlangResult SlangCTool::innerMain(StdWriters* stdWriters, slang::IGlobalSession* const SlangResult res = compileRequest->processCommandLineArguments(&argv[1], argc - 1); if (SLANG_FAILED(res)) { - // TODO: print usage message + StdWriters::getOut().print("%s", compileRequest->getDiagnosticOutput()); return res; } } |
