summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorYong He <yonghe@outlook.com>2024-02-02 22:28:02 -0800
committerGitHub <noreply@github.com>2024-02-02 22:28:02 -0800
commit14764896c34b230a5563f48d8b8e565de2f3aa10 (patch)
tree2f105d3f6222103f458054f1cd38e280b6fb52b4
parentc15e7ade4e27e1649d5b98f5854e9e52bb9e60ae (diff)
Capability type checking. (#3530)
* Capability type checking. * Fix. --------- Co-authored-by: Yong He <yhe@nvidia.com>
-rw-r--r--source/slang/core.meta.slang3
-rw-r--r--source/slang/hlsl.meta.slang37
-rw-r--r--source/slang/slang-ast-base.h10
-rw-r--r--source/slang/slang-ast-dump.cpp25
-rw-r--r--source/slang/slang-ast-iterator.h47
-rw-r--r--source/slang/slang-ast-modifier.h9
-rw-r--r--source/slang/slang-ast-stmt.h1
-rw-r--r--source/slang/slang-ast-support-types.h12
-rw-r--r--source/slang/slang-capabilities.capdef306
-rw-r--r--source/slang/slang-capability.cpp183
-rw-r--r--source/slang/slang-capability.h14
-rw-r--r--source/slang/slang-check-decl.cpp944
-rw-r--r--source/slang/slang-check-expr.cpp53
-rw-r--r--source/slang/slang-check-impl.h7
-rw-r--r--source/slang/slang-check-modifier.cpp41
-rw-r--r--source/slang/slang-check-overload.cpp2
-rw-r--r--source/slang/slang-check-shader.cpp35
-rw-r--r--source/slang/slang-check-stmt.cpp19
-rw-r--r--source/slang/slang-check.cpp1
-rw-r--r--source/slang/slang-compiler.cpp25
-rw-r--r--source/slang/slang-content-assist-info.h3
-rw-r--r--source/slang/slang-diagnostic-defs.h26
-rw-r--r--source/slang/slang-emit-glsl.cpp144
-rw-r--r--source/slang/slang-emit-spirv.cpp15
-rw-r--r--source/slang/slang-glsl-extension-tracker.h2
-rw-r--r--source/slang/slang-ir-glsl-legalize.cpp14
-rw-r--r--source/slang/slang-language-server-completion.cpp23
-rw-r--r--source/slang/slang-language-server-completion.h1
-rw-r--r--source/slang/slang-language-server-inlay-hints.cpp2
-rw-r--r--source/slang/slang-language-server-semantic-tokens.cpp28
-rw-r--r--source/slang/slang-language-server.cpp8
-rw-r--r--source/slang/slang-options.cpp4
-rw-r--r--source/slang/slang-parser.cpp1
-rw-r--r--source/slang/slang-profile-defs.h40
-rw-r--r--source/slang/slang-profile.h3
-rw-r--r--source/slang/slang-serialize-ast-type-info.h44
-rw-r--r--source/slang/slang.cpp12
-rw-r--r--tests/bugs/vk-structured-buffer-load.hlsl.glsl12
-rw-r--r--tests/cross-compile/barycentrics-nv.slang7
-rw-r--r--tests/cross-compile/barycentrics-nv.slang.glsl12
-rw-r--r--tests/diagnostics/discard-in-compute.slang13
-rw-r--r--tests/language-feature/capability/capability1.slang28
-rw-r--r--tests/language-feature/capability/capability2.slang61
-rw-r--r--tests/pipeline/rasterization/get-attribute-at-vertex-nv.slang.glsl7
-rw-r--r--tests/pipeline/rasterization/get-attribute-at-vertex-nv.slang.hlsl2
-rw-r--r--tests/pipeline/rasterization/get-attribute-at-vertex.slang.hlsl2
-rw-r--r--tests/vkray/callable-caller.slang6
-rw-r--r--tests/vkray/callable-caller.slang.glsl49
-rw-r--r--tests/vkray/miss.slang.glsl10
-rw-r--r--tools/slang-test/slang-test-main.cpp3
-rw-r--r--tools/slang-test/slangc-tool.cpp2
51 files changed, 1869 insertions, 489 deletions
diff --git a/source/slang/core.meta.slang b/source/slang/core.meta.slang
index 3ebb77f03..3961403e7 100644
--- a/source/slang/core.meta.slang
+++ b/source/slang/core.meta.slang
@@ -2475,6 +2475,9 @@ attribute_syntax [vk_image_format(format : String)] : FormatAttribute;
__attributeTarget(Decl)
attribute_syntax [allow(diagnostic: String)] : AllowAttribute;
+__attributeTarget(Decl)
+attribute_syntax[require(capability)] : RequireCapabilityAttribute;
+
// Linking
__attributeTarget(Decl)
attribute_syntax [__extern] : ExternAttribute;
diff --git a/source/slang/hlsl.meta.slang b/source/slang/hlsl.meta.slang
index 0b60bda0d..e5ebc8409 100644
--- a/source/slang/hlsl.meta.slang
+++ b/source/slang/hlsl.meta.slang
@@ -4716,7 +4716,6 @@ T GetAttributeAtVertex(T attribute, uint vertexIndex)
{
case hlsl:
__intrinsic_asm "GetAttributeAtVertex";
- case _GL_NV_fragment_shader_barycentric:
case _GL_EXT_fragment_shader_barycentric:
__intrinsic_asm "$0[$1]";
case spirv:
@@ -4749,7 +4748,6 @@ vector<T,N> GetAttributeAtVertex(vector<T,N> attribute, uint vertexIndex)
{
case hlsl:
__intrinsic_asm "GetAttributeAtVertex";
- case _GL_NV_fragment_shader_barycentric:
case _GL_EXT_fragment_shader_barycentric:
__intrinsic_asm "$0[$1]";
case spirv:
@@ -4782,7 +4780,6 @@ matrix<T,N,M> GetAttributeAtVertex(matrix<T,N,M> attribute, uint vertexIndex)
{
case hlsl:
__intrinsic_asm "GetAttributeAtVertex";
- case _GL_NV_fragment_shader_barycentric:
case _GL_EXT_fragment_shader_barycentric:
__intrinsic_asm "$0[$1]";
case spirv:
@@ -9288,8 +9285,7 @@ struct BuiltInTriangleIntersectionAttributes
// `executeCallableNV` is the GLSL intrinsic that will be used to implement
// `CallShader()` for GLSL-based targets.
//
-__target_intrinsic(GL_NV_ray_tracing, "executeCallableNV")
-__target_intrinsic(GL_EXT_ray_tracing, "executeCallableEXT")
+__target_intrinsic(_GL_EXT_ray_tracing, "executeCallableEXT")
void __executeCallable(uint shaderIndex, int payloadLocation);
// Next is the custom intrinsic that will compute the payload location
@@ -9335,8 +9331,7 @@ void CallShader(uint shaderIndex, inout Payload payload)
// 10.3.2
-__target_intrinsic(GL_NV_ray_tracing, "traceNV")
-__target_intrinsic(GL_EXT_ray_tracing, "traceRayEXT")
+__target_intrinsic(_GL_EXT_ray_tracing, "traceRayEXT")
void __traceRay(
RaytracingAccelerationStructure AccelerationStructure,
uint RayFlags,
@@ -9528,7 +9523,6 @@ bool __reportIntersection(float tHit, uint hitKind)
__target_switch
{
case _GL_EXT_ray_tracing: __intrinsic_asm "reportIntersectionEXT";
- case _GL_NV_ray_tracing: __intrinsic_asm "reportIntersectionNV";
case spirv:
return spirv_asm {
result:$$bool = OpReportIntersectionKHR $tHit $hitKind;
@@ -9555,7 +9549,6 @@ void IgnoreHit()
{
case hlsl: __intrinsic_asm "IgnoreHit";
case _GL_EXT_ray_tracing: __intrinsic_asm "ignoreIntersectionEXT;";
- case _GL_NV_ray_tracing: __intrinsic_asm "ignoreIntersectionNV";
case cuda: __intrinsic_asm "optixIgnoreIntersection";
case spirv: spirv_asm { OpIgnoreIntersectionKHR; %_ = OpLabel };
}
@@ -9568,7 +9561,6 @@ void AcceptHitAndEndSearch()
{
case hlsl: __intrinsic_asm "AcceptHitAndEndSearch";
case _GL_EXT_ray_tracing: __intrinsic_asm "terminateRayEXT;";
- case _GL_NV_ray_tracing: __intrinsic_asm "terminateRayNV";
case cuda: __intrinsic_asm "optixTerminateRay";
case spirv: spirv_asm { OpTerminateRayKHR; %_ = OpLabel };
}
@@ -9587,7 +9579,6 @@ uint3 DispatchRaysIndex()
{
case hlsl: __intrinsic_asm "DispatchRaysIndex";
case _GL_EXT_ray_tracing: __intrinsic_asm "(gl_LaunchIDEXT)";
- case _GL_NV_ray_tracing: __intrinsic_asm "(gl_LaunchIDNV)";
case cuda: __intrinsic_asm "optixGetLaunchIndex";
case spirv:
return spirv_asm {
@@ -9602,7 +9593,6 @@ uint3 DispatchRaysDimensions()
{
case hlsl: __intrinsic_asm "DispatchRaysDimensions";
case _GL_EXT_ray_tracing: __intrinsic_asm "(gl_LaunchSizeEXT)";
- case _GL_NV_ray_tracing: __intrinsic_asm "(gl_LaunchSizeNV)";
case cuda: __intrinsic_asm "optixGetLaunchDimensions";
case spirv:
return spirv_asm {
@@ -9619,7 +9609,6 @@ float3 WorldRayOrigin()
{
case hlsl: __intrinsic_asm "WorldRayOrigin";
case _GL_EXT_ray_tracing: __intrinsic_asm "(gl_WorldRayOriginEXT)";
- case _GL_NV_ray_tracing: __intrinsic_asm "(gl_WorldRayOriginNV)";
case cuda: __intrinsic_asm "optixGetWorldRayOrigin";
case spirv:
return spirv_asm {
@@ -9634,7 +9623,6 @@ float3 WorldRayDirection()
{
case hlsl: __intrinsic_asm "WorldRayDirection";
case _GL_EXT_ray_tracing: __intrinsic_asm "(gl_WorldRayDirectionEXT)";
- case _GL_NV_ray_tracing: __intrinsic_asm "(gl_WorldRayDirectionNV)";
case cuda: __intrinsic_asm "optixGetWorldRayDirection";
case spirv:
return spirv_asm {
@@ -9649,7 +9637,6 @@ float RayTMin()
{
case hlsl: __intrinsic_asm "RayTMin";
case _GL_EXT_ray_tracing: __intrinsic_asm "(gl_RayTminEXT)";
- case _GL_NV_ray_tracing: __intrinsic_asm "(gl_RayTminNV)";
case cuda: __intrinsic_asm "optixGetRayTmin";
case spirv:
return spirv_asm {
@@ -9674,7 +9661,6 @@ float RayTCurrent()
{
case hlsl: __intrinsic_asm "RayTCurrent";
case _GL_EXT_ray_tracing: __intrinsic_asm "(gl_RayTmaxEXT)";
- case _GL_NV_ray_tracing: __intrinsic_asm "(gl_RayTmaxNV)";
case cuda: __intrinsic_asm "optixGetRayTmax";
case spirv:
return spirv_asm {
@@ -9689,7 +9675,6 @@ uint RayFlags()
{
case hlsl: __intrinsic_asm "RayFlags";
case _GL_EXT_ray_tracing: __intrinsic_asm "(gl_IncomingRayFlagsEXT)";
- case _GL_NV_ray_tracing: __intrinsic_asm "(gl_IncomingRayFlagsNV)";
case cuda: __intrinsic_asm "optixGetRayFlags";
case spirv:
return spirv_asm {
@@ -9720,7 +9705,6 @@ uint InstanceID()
{
case hlsl: __intrinsic_asm "InstanceID";
case _GL_EXT_ray_tracing: __intrinsic_asm "(gl_InstanceCustomIndexEXT)";
- case _GL_NV_ray_tracing: __intrinsic_asm "(gl_InstanceCustomIndexNV)";
case cuda: __intrinsic_asm "optixGetInstanceId";
case spirv:
return spirv_asm {
@@ -9749,7 +9733,6 @@ float3 ObjectRayOrigin()
{
case hlsl: __intrinsic_asm "ObjectRayOrigin";
case _GL_EXT_ray_tracing: __intrinsic_asm "(gl_ObjectRayOriginEXT)";
- case _GL_NV_ray_tracing: __intrinsic_asm "(gl_ObjectRayOriginNV)";
case cuda: __intrinsic_asm "optixGetObjectRayOrigin";
case spirv:
return spirv_asm {
@@ -9764,7 +9747,6 @@ float3 ObjectRayDirection()
{
case hlsl: __intrinsic_asm "ObjectRayDirection";
case _GL_EXT_ray_tracing: __intrinsic_asm "(gl_ObjectRayDirectionEXT)";
- case _GL_NV_ray_tracing: __intrinsic_asm "(gl_ObjectRayDirectionNV)";
case cuda: __intrinsic_asm "optixGetObjectRayDirection";
case spirv:
return spirv_asm {
@@ -9781,7 +9763,6 @@ float3x4 ObjectToWorld3x4()
{
case hlsl: __intrinsic_asm "ObjectToWorld3x4";
case _GL_EXT_ray_tracing: __intrinsic_asm "transpose(gl_ObjectToWorldEXT)";
- case _GL_NV_ray_tracing: __intrinsic_asm "transpose(gl_ObjectToWorldNV)";
case spirv:
return spirv_asm {
%mat:$$float4x3 = OpLoad builtin(ObjectToWorldKHR:float4x3);
@@ -9796,7 +9777,6 @@ float3x4 WorldToObject3x4()
{
case hlsl: __intrinsic_asm "WorldToObject3x4";
case _GL_EXT_ray_tracing: __intrinsic_asm "transpose(gl_WorldToObjectEXT)";
- case _GL_NV_ray_tracing: __intrinsic_asm "transpose(gl_WorldToObjectNV)";
case spirv:
return spirv_asm {
%mat:$$float4x3 = OpLoad builtin(WorldToObjectKHR:float4x3);
@@ -9811,7 +9791,6 @@ float4x3 ObjectToWorld4x3()
{
case hlsl: __intrinsic_asm "ObjectToWorld4x3";
case _GL_EXT_ray_tracing: __intrinsic_asm "(gl_ObjectToWorldEXT)";
- case _GL_NV_ray_tracing: __intrinsic_asm "(gl_ObjectToWorldNV)";
case spirv:
return spirv_asm {
result:$$float4x3 = OpLoad builtin(ObjectToWorldKHR:float4x3);
@@ -9825,7 +9804,6 @@ float4x3 WorldToObject4x3()
{
case hlsl: __intrinsic_asm "WorldToObject4x3";
case _GL_EXT_ray_tracing: __intrinsic_asm "(gl_WorldToObjectEXT)";
- case _GL_NV_ray_tracing: __intrinsic_asm "(gl_WorldToObjectNV)";
case spirv:
return spirv_asm {
result:$$float4x3 = OpLoad builtin(WorldToObjectKHR:float4x3);
@@ -9872,7 +9850,6 @@ uint HitKind()
{
case hlsl: __intrinsic_asm "HitKind";
case _GL_EXT_ray_tracing: __intrinsic_asm "(gl_HitKindEXT)";
- case _GL_NV_ray_tracing: __intrinsic_asm "(gl_HitKindNV)";
case cuda: __intrinsic_asm "optixGetHitKind";
case spirv:
return spirv_asm {
@@ -11874,6 +11851,7 @@ void debugBreak();
[__requiresNVAPI]
__glsl_extension(GL_EXT_shader_realtime_clock)
+[require(shaderclock)]
uint getRealtimeClockLow()
{
__target_switch
@@ -11886,14 +11864,18 @@ uint getRealtimeClockLow()
__intrinsic_asm "clock";
case spirv:
return getRealtimeClock().x;
+ case cpp:
+ __intrinsic_asm "(uint32_t)std::chrono::high_resolution_clock::now().time_since_epoch().count()";
}
}
+__target_intrinsic(cpp, "std::chrono::high_resolution_clock::now().time_since_epoch().count()")
__target_intrinsic(cuda, "clock64")
-int64_t __cudaGetRealtimeClock();
+int64_t __cudaCppGetRealtimeClock();
[__requiresNVAPI]
__glsl_extension(GL_EXT_shader_realtime_clock)
+[require(shaderclock)]
uint2 getRealtimeClock()
{
__target_switch
@@ -11903,7 +11885,8 @@ uint2 getRealtimeClock()
case glsl:
__intrinsic_asm "clockRealtime2x32EXT()";
case cuda:
- int64_t ticks = __cudaGetRealtimeClock();
+ case cpp:
+ int64_t ticks = __cudaCppGetRealtimeClock();
return uint2(uint(ticks), uint(uint64_t(ticks) >> 32));
case spirv:
return spirv_asm
diff --git a/source/slang/slang-ast-base.h b/source/slang/slang-ast-base.h
index 579bda73a..e11dbe259 100644
--- a/source/slang/slang-ast-base.h
+++ b/source/slang/slang-ast-base.h
@@ -6,7 +6,7 @@
#include "slang-generated-ast.h"
#include "slang-ast-reflect.h"
-
+#include "slang-capability.h"
#include "slang-serialize-reflection.h"
// This file defines the primary base classes for the hierarchy of
@@ -695,6 +695,11 @@ class ModifiableSyntaxNode : public SyntaxNode
bool hasModifier() { return findModifier<T>() != nullptr; }
};
+struct DeclReferenceWithLoc
+{
+ Decl* referencedDecl;
+ SourceLoc referenceLoc;
+};
// An intermediate type to represent either a single declaration, or a group of declarations
class DeclBase : public ModifiableSyntaxNode
@@ -716,6 +721,7 @@ public:
DeclRefBase* getDefaultDeclRef();
NameLoc nameAndLoc;
+ CapabilitySet inferredCapabilityRequirements;
RefPtr<MarkupEntry> markup;
@@ -736,6 +742,8 @@ public:
}
bool isChildOf(Decl* other) const;
+ // Track the decl reference that caused the requirement of a capability atom.
+ SLANG_UNREFLECTED Dictionary<CapabilityAtom, DeclReferenceWithLoc> capabilityRequirementProvenance;
private:
SLANG_UNREFLECTED DeclRefBase* m_defaultDeclRef = nullptr;
SLANG_UNREFLECTED Index m_defaultDeclRefEpoch = -1;
diff --git a/source/slang/slang-ast-dump.cpp b/source/slang/slang-ast-dump.cpp
index cde7c6201..9c40fb12b 100644
--- a/source/slang/slang-ast-dump.cpp
+++ b/source/slang/slang-ast-dump.cpp
@@ -692,6 +692,31 @@ struct ASTDumpContext
m_writer->emit("}");
}
+ void dump(const CapabilitySet& capSet)
+ {
+ m_writer->emit("capability_set(");
+ bool isFirstSet = true;
+ for (auto& set : capSet.getExpandedAtoms())
+ {
+ if (!isFirstSet)
+ {
+ m_writer->emit(" | ");
+ }
+ bool isFirst = true;
+ for (auto atom : set.getExpandedAtoms())
+ {
+ if (!isFirst)
+ {
+ m_writer->emit("+");
+ }
+ dump(capabilityNameToString((CapabilityName)atom));
+ isFirst = false;
+ }
+ isFirstSet = false;
+ }
+ m_writer->emit(")");
+ }
+
void dumpObjectFull(NodeBase* node);
ASTDumpContext(SourceWriter* writer, ASTDumpUtil::Flags flags, ASTDumpUtil::Style dumpStyle):
diff --git a/source/slang/slang-ast-iterator.h b/source/slang/slang-ast-iterator.h
index 2e8f02697..e2d0638e0 100644
--- a/source/slang/slang-ast-iterator.h
+++ b/source/slang/slang-ast-iterator.h
@@ -3,16 +3,14 @@
namespace Slang
{
-template <typename Callback>
+template <typename Callback, typename Filter>
struct ASTIterator
{
const Callback& callback;
- UnownedStringSlice fileName;
- SourceManager* sourceManager;
- ASTIterator(const Callback& func, SourceManager* manager, UnownedStringSlice sourceFileName)
+ const Filter& filter;
+ ASTIterator(const Callback& func, const Filter& filterFunc)
: callback(func)
- , fileName(sourceFileName)
- , sourceManager(manager)
+ , filter(filterFunc)
{}
void visitDecl(DeclBase* decl);
@@ -429,13 +427,11 @@ struct ASTIterator
};
};
-template <typename CallbackFunc>
-void ASTIterator<CallbackFunc>::visitDecl(DeclBase* decl)
+template <typename CallbackFunc, typename FilterFunc>
+void ASTIterator<CallbackFunc, FilterFunc>::visitDecl(DeclBase* decl)
{
// Don't look at the decl if it is defined in a different file.
- if (!as<NamespaceDeclBase>(decl) && !sourceManager->getHumaneLoc(decl->loc, SourceLocType::Actual)
- .pathInfo.foundPath.getUnownedSlice()
- .endsWithCaseInsensitive(fileName))
+ if (!filter(decl))
return;
maybeDispatchCallback(decl);
@@ -490,24 +486,23 @@ void ASTIterator<CallbackFunc>::visitDecl(DeclBase* decl)
}
}
}
-template <typename CallbackFunc>
-void ASTIterator<CallbackFunc>::visitExpr(Expr* expr)
+template <typename CallbackFunc, typename FilterFunc>
+void ASTIterator<CallbackFunc, FilterFunc>::visitExpr(Expr* expr)
{
ASTIteratorExprVisitor visitor(this);
visitor.dispatchIfNotNull(expr);
}
-template <typename CallbackFunc>
-void ASTIterator<CallbackFunc>::visitStmt(Stmt* stmt)
+template <typename CallbackFunc, typename FilterFunc>
+void ASTIterator<CallbackFunc, FilterFunc>::visitStmt(Stmt* stmt)
{
ASTIteratorStmtVisitor visitor(this);
visitor.dispatchIfNotNull(stmt);
}
-template <typename Func>
-void iterateAST(
- UnownedStringSlice fileName, SourceManager* manager, SyntaxNode* node, const Func& f)
+template <typename Func, typename FilterFunc>
+void iterateAST(SyntaxNode* node, const FilterFunc& filterFunc, const Func& f)
{
- ASTIterator<Func> iter(f, manager, fileName);
+ ASTIterator<Func, FilterFunc> iter(f, filterFunc);
if (auto decl = as<Decl>(node))
{
iter.visitDecl(decl);
@@ -521,4 +516,18 @@ void iterateAST(
iter.visitStmt(stmt);
}
}
+
+template <typename Func>
+void iterateASTWithLanguageServerFilter(
+ UnownedStringSlice fileName, SourceManager* sourceManager, SyntaxNode* node, const Func& f)
+{
+ auto filter = [&](DeclBase* decl)
+ {
+ return as<NamespaceDeclBase>(decl) ||
+ sourceManager->getHumaneLoc(decl->loc, SourceLocType::Actual)
+ .pathInfo.foundPath.getUnownedSlice()
+ .endsWithCaseInsensitive(fileName);
+ };
+ iterateAST(node, filter, f);
+}
} // namespace Slang
diff --git a/source/slang/slang-ast-modifier.h b/source/slang/slang-ast-modifier.h
index 53d13d1b5..dae8966de 100644
--- a/source/slang/slang-ast-modifier.h
+++ b/source/slang/slang-ast-modifier.h
@@ -605,7 +605,6 @@ class Attribute : public AttributeBase
class UserDefinedAttribute : public Attribute
{
SLANG_AST_CLASS(UserDefinedAttribute)
-
};
class AttributeUsageAttribute : public Attribute
@@ -615,6 +614,14 @@ class AttributeUsageAttribute : public Attribute
SyntaxClass<NodeBase> targetSyntaxClass;
};
+
+class RequireCapabilityAttribute : public Attribute
+{
+ SLANG_AST_CLASS(RequireCapabilityAttribute)
+ CapabilitySet capabilitySet;
+};
+
+
// An `[unroll]` or `[unroll(count)]` attribute
class UnrollAttribute : public Attribute
{
diff --git a/source/slang/slang-ast-stmt.h b/source/slang/slang-ast-stmt.h
index af1fe9ec1..055785333 100644
--- a/source/slang/slang-ast-stmt.h
+++ b/source/slang/slang-ast-stmt.h
@@ -98,6 +98,7 @@ class TargetCaseStmt : public Stmt
{
SLANG_AST_CLASS(TargetCaseStmt)
int32_t capability;
+ Token capabilityToken;
Stmt* body = nullptr;
};
diff --git a/source/slang/slang-ast-support-types.h b/source/slang/slang-ast-support-types.h
index 955aff06c..882e26078 100644
--- a/source/slang/slang-ast-support-types.h
+++ b/source/slang/slang-ast-support-types.h
@@ -51,7 +51,7 @@ namespace Slang
class NodeBase;
class LookupDeclRef;
class GenericAppDeclRef;
-
+ struct CapabilitySet;
template <typename T>
T* as(NodeBase* node);
@@ -66,6 +66,8 @@ namespace Slang
void printDiagnosticArg(StringBuilder& sb, Val* val);
void printDiagnosticArg(StringBuilder& sb, DeclRefBase* declRefBase);
void printDiagnosticArg(StringBuilder& sb, ASTNodeType nodeType);
+ void printDiagnosticArg(StringBuilder& sb, const CapabilitySet& set);
+
struct QualifiedDeclPath
{
@@ -442,7 +444,7 @@ namespace Slang
/// checking the function body.
AttributesChecked,
- /// The declaration is fully checked.
+ /// The body/definition is checked.
///
/// This step includes any validation of the declaration that is
/// immaterial to clients code using the declaration, but that is
@@ -453,7 +455,11 @@ namespace Slang
/// but we still need to (eventually) check the bodies of all
/// functions, so it belongs in the last phase of checking.
///
- Checked,
+ DefinitionChecked,
+
+ /// The capabilities required by the decl is infered and validated.
+ ///
+ CapabilityChecked,
// For convenience at sites that call `ensureDecl()`, we define
// some aliases for the above states that are expressed in terms
diff --git a/source/slang/slang-capabilities.capdef b/source/slang/slang-capabilities.capdef
index 7e5931edf..5f15b8e57 100644
--- a/source/slang/slang-capabilities.capdef
+++ b/source/slang/slang-capabilities.capdef
@@ -13,9 +13,9 @@
//
// A capability name is defined by a unique disjunction of conjunction of capability atoms.
// For example, `raytracing` is a name that expands to
-// glsl + _GL_EXT_ray_tracing | spirv_1_4 + _GL_EXT_ray_tracing | hlsl + _sm_6_4
-// which means it requires the `_GL_EXT_ray_tracing` extension when generating code for glsl
-// or spirv, and requires sm_6_4 when generating hlsl.
+// glsl + _GL_EXT_ray_tracing | spirv_1_4 + SPV_KHR_ray_tracing | hlsl + _sm_6_4
+// which means it requires the `_GL_EXT_ray_tracing` extension when generating code for glsl,
+// requires SPV_KHR_ray_tracing for spirv, and requires sm_6_4 when generating hlsl.
//
// There are three types of capability definitions:
// - `def`: this will introduce an new capability atom. If there is an inheritance clause,
@@ -63,7 +63,7 @@ alias spirv_latest = spirv_1_6;
// Capabilities that stand for target spirv version for GLSL backend.
// These are not compilation targets.
-def glsl_spirv_1_0;
+def glsl_spirv_1_0 : glsl;
def glsl_spirv_1_1 : glsl_spirv_1_0;
def glsl_spirv_1_2 : glsl_spirv_1_1;
def glsl_spirv_1_3 : glsl_spirv_1_2;
@@ -80,13 +80,24 @@ def hull : stage;
def domain : stage;
def geometry : stage;
def raygen : stage;
+alias raygeneration = raygen;
def intersection : stage;
def anyhit : stage;
def closesthit: stage;
def miss : stage;
def mesh : stage;
+def amplification : stage;
+def callable : stage;
-def _sm_5_1 : hlsl;
+alias all_stages = vertex + fragment + compute + hull + domain
+ + geometry + raygen + intersection + anyhit
+ + closesthit + miss + mesh + amplification
+ + callable;
+
+def _sm_4_0 : hlsl;
+def _sm_4_1 : _sm_4_0;
+def _sm_5_0 : _sm_4_1;
+def _sm_5_1 : _sm_5_0;
def _sm_6_0 : _sm_5_1;
def _sm_6_1 : _sm_6_0;
def _sm_6_2 : _sm_6_1;
@@ -96,6 +107,60 @@ def _sm_6_5 : _sm_6_4;
def _sm_6_6 : _sm_6_5;
def _sm_6_7 : _sm_6_6;
+def hlsl_nvapi : hlsl;
+
+// SPIRV extensions.
+
+def SPV_EXT_fragment_shader_interlock : spirv_1_0;
+def SPV_KHR_fragment_shader_barycentric : spirv_1_0;
+def SPV_EXT_fragment_fully_covered : spirv_1_0;
+def SPV_EXT_descriptor_indexing : spirv_1_0;
+def SPV_EXT_shader_atomic_float_add : spirv_1_0;
+def SPV_EXT_shader_atomic_float16_add : SPV_EXT_shader_atomic_float_add;
+def SPV_EXT_shader_atomic_float_min_max : spirv_1_0;
+def SPV_KHR_non_semantic_info : spirv_1_0;
+def SPV_NV_shader_subgroup_partitioned : spirv_1_0;
+def SPV_NV_ray_tracing_motion_blur : spirv_1_0;
+def SPV_EXT_mesh_shader : spirv_1_4;
+def SPV_KHR_ray_tracing : spirv_1_4;
+def SPV_KHR_ray_query : spirv_1_0;
+def SPV_KHR_ray_tracing_position_fetch : SPV_KHR_ray_tracing + SPV_KHR_ray_query;
+def SPV_NV_shader_invocation_reorder : spirv_1_5 + SPV_KHR_ray_tracing;
+def SPV_KHR_shader_clock : spirv_1_0;
+def SPV_NV_shader_image_footprint : spirv_1_0;
+def SPV_GOOGLE_user_type : spirv_1_0;
+
+// SPIRV Capabilities.
+
+def spvAtomicFloat32AddEXT : SPV_EXT_shader_atomic_float_add;
+def spvAtomicFloat16AddEXT : SPV_EXT_shader_atomic_float16_add;
+def spvInt64Atomics : spirv_1_0;
+def spvAtomicFloat32MinMaxEXT : SPV_EXT_shader_atomic_float_min_max;
+def spvAtomicFloat16MinMaxEXT : SPV_EXT_shader_atomic_float_min_max;
+def spvDerivativeControl : spirv_1_0;
+def spvImageQuery : spirv_1_0;
+def spvImageGatherExtended : spirv_1_0;
+def spvImageFootprintNV : SPV_NV_shader_image_footprint;
+def spvMinLod : spirv_1_0;
+def spvFragmentShaderPixelInterlockEXT : SPV_EXT_fragment_shader_interlock;
+def spvFragmentBarycentricKHR : SPV_KHR_fragment_shader_barycentric;
+def spvFragmentFullyCoveredEXT : SPV_EXT_fragment_fully_covered;
+def spvGroupNonUniformBallot : spirv_1_3;
+def spvGroupNonUniformShuffle : spirv_1_3;
+def spvGroupNonUniformArithmetic : spirv_1_3;
+def spvGroupNonUniformQuad : spirv_1_3;
+def spvGroupNonUniformVote : spirv_1_3;
+def spvGroupNonUniformPartitionedNV : spirv_1_3 + SPV_NV_shader_subgroup_partitioned;
+def spvRayTracingMotionBlurNV : SPV_NV_ray_tracing_motion_blur;
+def spvMeshShadingEXT : SPV_EXT_mesh_shader;
+def spvRayTracingKHR : SPV_KHR_ray_tracing;
+def spvRayTracingPositionFetchKHR : SPV_KHR_ray_tracing_position_fetch;
+def spvRayQueryKHR : SPV_KHR_ray_query;
+def spvRayQueryPositionFetchKHR : SPV_KHR_ray_tracing_position_fetch;
+def spvShaderInvocationReorderNV : SPV_NV_shader_invocation_reorder;
+def spvShaderClockKHR : SPV_KHR_shader_clock;
+def spvShaderNonUniform : spirv_1_5;
+
// The following capabilities all pertain to how ray tracing shaders are translated
// to GLSL, where there are two different extensions that can provide the core
// functionality of `TraceRay` and the related operations.
@@ -104,32 +169,213 @@ def _sm_6_7 : _sm_6_6;
// as conflicting on the `RayTracingExtension` axis, so that a compilation target
// cannot have both enabled at once.
//
-// The `GL_EXT_ray_tracing` extension should be favored, so it has a rank of `1`
+// The `_GL_EXT_ray_tracing` extension should be favored, so it has a rank of `1`
// instead of `0`, which means that when comparing overloads that require these
// extensions, the `EXT` extension will be favored over the `NV` extension, if
// all other factors are equal.
//
-def _GL_EXT_ray_tracing : glsl + glsl_spirv_1_4 = 1;
-def _GL_NV_ray_tracing : _GL_EXT_ray_tracing;
-def _SPV_KHR_ray_tracing : spirv_1_4;
-alias GL_NV_ray_tracing = _GL_NV_ray_tracing | _SPV_KHR_ray_tracing | _sm_6_4 | cuda;
-alias GL_EXT_ray_tracing = _GL_EXT_ray_tracing | _SPV_KHR_ray_tracing | _sm_6_4 | cuda;
-alias raytracing = GL_EXT_ray_tracing;
-
-def _GL_EXT_fragment_shader_barycentric : glsl + fragment;
-def _GL_NV_fragment_shader_barycentric : _GL_EXT_fragment_shader_barycentric;
-def _SPV_KHR_fragment_shader_barycentric : spirv_1_0 + fragment;
-alias GL_NV_fragment_shader_barycentric = _GL_NV_fragment_shader_barycentric | _SPV_KHR_fragment_shader_barycentric | hlsl + fragment;
-alias GL_EXT_fragment_shader_barycentric = _GL_EXT_fragment_shader_barycentric | _SPV_KHR_fragment_shader_barycentric | hlsl + fragment;
-
-// TODO: define what SM means for all supported targets.
-
-alias sm_5_1 = _sm_5_1;
-alias sm_6_0 = _sm_6_0;
-alias sm_6_1 = _sm_6_1;
-alias sm_6_2 = _sm_6_2;
-alias sm_6_3 = _sm_6_3;
-alias sm_6_4 = _sm_6_4 | raytracing;
-alias sm_6_5 = _sm_6_5;
-alias sm_6_6 = _sm_6_6;
-alias sm_6_7 = _sm_6_7;
+
+def _GL_ARB_derivative_control : glsl;
+def _GL_ARB_fragment_shader_interlock : glsl;
+def _GL_ARB_gpu_shader5 : glsl;
+def _GL_ARB_sparse_texture_clamp : glsl;
+def _GL_EXT_buffer_reference : glsl;
+def _GL_EXT_debug_printf : glsl;
+def _GL_EXT_fragment_shader_barycentric : glsl;
+def _GL_EXT_mesh_shader : glsl;
+def _GL_EXT_nonuniform_qualifier : glsl;
+def _GL_EXT_ray_tracing : glsl_spirv_1_4;
+def _GL_EXT_ray_tracing_position_fetch : glsl_spirv_1_4;
+def _GL_EXT_samplerless_texture_functions : glsl;
+def _GL_EXT_shader_atomic_float : glsl;
+def _GL_EXT_shader_atomic_float2 : glsl;
+def _GL_EXT_shader_atomic_int64 : glsl;
+def _GL_EXT_shader_atomic_float_min_max : glsl;
+def _GL_EXT_shader_explicit_arithmetic_types_int64 : glsl;
+def _GL_EXT_shader_realtime_clock : glsl;
+def _GL_EXT_texture_shadow_lod : glsl;
+def _GL_KHR_memory_scope_semantics : glsl;
+def _GL_KHR_shader_subgroup_arithmetic : glsl;
+def _GL_KHR_shader_subgroup_basic : glsl;
+def _GL_KHR_shader_subgroup_ballot : glsl;
+def _GL_KHR_shader_subgroup_quad : glsl;
+def _GL_KHR_shader_subgroup_shuffle : glsl;
+def _GL_KHR_shader_subgroup_vote : glsl;
+def _GL_NV_shader_subgroup_partitioned : glsl;
+def _GL_NV_ray_tracing_motion_blur : glsl_spirv_1_4;
+def _GL_NV_shader_invocation_reorder : glsl_spirv_1_4;
+def _GL_NV_shader_texture_footprint : glsl;
+alias _GL_NV_fragment_shader_barycentric = _GL_EXT_fragment_shader_barycentric;
+alias _GL_NV_ray_tracing = _GL_EXT_ray_tracing;
+
+// GLSL extension and SPV extension associations.
+alias GL_ARB_derivative_control = _GL_ARB_derivative_control | spvDerivativeControl;
+alias GL_ARB_fragment_shader_interlock = _GL_ARB_fragment_shader_interlock | spvFragmentShaderPixelInterlockEXT;
+alias GL_ARB_gpu_shader5 = _GL_ARB_fragment_shader_interlock | spirv_1_0;
+alias GL_ARB_sparse_texture_clamp = _GL_ARB_fragment_shader_interlock | spirv_1_0;
+alias GL_EXT_buffer_reference = _GL_ARB_fragment_shader_interlock | spirv_1_5;
+alias GL_EXT_debug_printf = _GL_EXT_debug_printf | SPV_KHR_non_semantic_info;
+alias GL_EXT_fragment_shader_barycentric = _GL_EXT_fragment_shader_barycentric | spvFragmentBarycentricKHR;
+alias GL_EXT_mesh_shader = _GL_EXT_mesh_shader | spvMeshShadingEXT;
+alias GL_EXT_nonuniform_qualifier = _GL_EXT_nonuniform_qualifier | spvShaderNonUniform;
+alias GL_EXT_ray_tracing = _GL_EXT_ray_tracing | spvRayTracingKHR + spvRayQueryKHR;
+alias GL_EXT_ray_tracing_position_fetch = _GL_EXT_ray_tracing_position_fetch | spvRayTracingPositionFetchKHR + spvRayQueryPositionFetchKHR;
+alias GL_EXT_samplerless_texture_functions = _GL_EXT_samplerless_texture_functions | spirv_1_0;
+alias GL_EXT_shader_atomic_float = _GL_EXT_shader_atomic_float | spvAtomicFloat32AddEXT + spvAtomicFloat32MinMaxEXT;
+alias GL_EXT_shader_atomic_float2 = _GL_EXT_shader_atomic_float2 | spvAtomicFloat32AddEXT + spvAtomicFloat32MinMaxEXT + spvAtomicFloat16AddEXT + spvAtomicFloat16MinMaxEXT;
+alias GL_EXT_shader_atomic_int64 = _GL_EXT_shader_atomic_int64 | spvInt64Atomics;
+alias GL_EXT_shader_atomic_float_min_max = _GL_EXT_shader_atomic_float_min_max | spvAtomicFloat32MinMaxEXT + spvAtomicFloat16MinMaxEXT;
+alias GL_EXT_shader_explicit_arithmetic_types_int64 = _GL_EXT_shader_explicit_arithmetic_types_int64 | spirv_1_0;
+alias GL_EXT_shader_realtime_clock = _GL_EXT_shader_realtime_clock | spvShaderClockKHR;
+alias GL_EXT_texture_shadow_lod = _GL_EXT_texture_shadow_lod | spirv_1_0;
+alias GL_KHR_memory_scope_semantics = _GL_KHR_memory_scope_semantics | spirv_1_0;
+alias GL_KHR_shader_subgroup_arithmetic = _GL_KHR_shader_subgroup_arithmetic | spvGroupNonUniformArithmetic;
+alias GL_KHR_shader_subgroup_basic = _GL_KHR_shader_subgroup_basic | spvGroupNonUniformBallot;
+alias GL_KHR_shader_subgroup_ballot = _GL_KHR_shader_subgroup_ballot | spvGroupNonUniformBallot;
+alias GL_KHR_shader_subgroup_quad = _GL_KHR_shader_subgroup_quad | spvGroupNonUniformQuad;
+alias GL_KHR_shader_subgroup_shuffle = _GL_KHR_shader_subgroup_shuffle | spvGroupNonUniformShuffle;
+alias GL_KHR_shader_subgroup_vote = _GL_KHR_shader_subgroup_vote | spvGroupNonUniformVote;
+alias GL_NV_shader_subgroup_partitioned = _GL_NV_shader_subgroup_partitioned | spvGroupNonUniformPartitionedNV;
+alias GL_NV_ray_tracing_motion_blur = _GL_NV_ray_tracing_motion_blur | spvRayTracingMotionBlurNV;
+alias GL_NV_shader_invocation_reorder = _GL_NV_shader_invocation_reorder | spvShaderInvocationReorderNV;
+alias GL_NV_shader_texture_footprint = _GL_NV_shader_texture_footprint | spvImageFootprintNV;
+
+alias GL_NV_fragment_shader_barycentric = GL_EXT_fragment_shader_barycentric;
+alias GL_NV_ray_tracing = GL_EXT_ray_tracing;
+
+// Define feature names
+
+alias nvapi = hlsl_nvapi;
+alias raytracing = spvRayTracingKHR + spvRayQueryKHR + spvRayQueryPositionFetchKHR | _GL_EXT_ray_tracing+_GL_EXT_ray_tracing_position_fetch | _sm_6_5 | cuda;
+alias ser = spvShaderInvocationReorderNV | _GL_NV_shader_invocation_reorder | _sm_6_6 + hlsl_nvapi;
+alias shaderclock = spvShaderClockKHR | hlsl_nvapi | _GL_EXT_shader_realtime_clock | cpp | cuda;
+alias meshshading = spvMeshShadingEXT | _sm_6_5 | _GL_EXT_mesh_shader;
+alias motionblur = spvRayTracingMotionBlurNV | hlsl_nvapi | _GL_NV_ray_tracing_motion_blur;
+alias texturefootprint = GL_NV_shader_texture_footprint | hlsl_nvapi;
+alias fragmentshaderinterlock = _GL_ARB_fragment_shader_interlock | hlsl_nvapi | spvFragmentShaderPixelInterlockEXT;
+alias atomic64 = GL_EXT_shader_atomic_int64 | _sm_6_6 | cpp | cuda;
+alias atomicfloat = GL_EXT_shader_atomic_float | _sm_6_0 + hlsl_nvapi | cpp | cuda;
+alias atomicfloat2 = GL_EXT_shader_atomic_float2 | _sm_6_6 + hlsl_nvapi | cpp | cuda;
+alias groupnonuniform = GL_KHR_shader_subgroup_ballot + GL_KHR_shader_subgroup_shuffle
+ + GL_KHR_shader_subgroup_arithmetic + GL_KHR_shader_subgroup_quad + GL_KHR_shader_subgroup_vote
+ | _sm_6_0 | cuda;
+alias fragmentshaderbarycentric = GL_EXT_fragment_shader_barycentric | _sm_6_1;
+
+
+// Define what each HLSL shader model means on different targets.
+
+
+alias sm_4_0 = _sm_4_0
+ | glsl_spirv_1_0
+ | spirv_1_0 + spvImageQuery + spvImageGatherExtended + spvMinLod + SPV_GOOGLE_user_type
+ | cuda
+ | cpp;
+
+alias sm_4_1 = _sm_4_1
+ | glsl_spirv_1_0 + sm_4_0
+ | spirv_1_0 + sm_4_0
+ | cuda
+ | cpp;
+
+alias sm_5_0 = _sm_5_0
+ | glsl_spirv_1_0 + sm_4_1 + _GL_KHR_memory_scope_semantics
+ | spirv_1_0 + sm_4_1 + spvDerivativeControl + spvFragmentFullyCoveredEXT
+ | cuda
+ | cpp;
+
+alias sm_5_1 = _sm_5_1
+ | glsl_spirv_1_0 + sm_5_0 + _GL_ARB_gpu_shader5 + _GL_ARB_sparse_texture_clamp + _GL_EXT_nonuniform_qualifier
+ | spirv_1_0 + sm_5_0 + spvShaderNonUniform
+ | cuda
+ | cpp;
+
+alias sm_6_0 = _sm_6_0
+ | glsl_spirv_1_3 + sm_5_1
+ + groupnonuniform + atomicfloat
+ | spirv_1_3 + sm_5_1
+ + groupnonuniform + atomicfloat
+ | cuda
+ | cpp;
+
+alias sm_6_1 = _sm_6_1
+ | glsl_spirv_1_3 + sm_6_0 + fragmentshaderbarycentric
+ | spirv_1_3 + sm_6_0 + fragmentshaderbarycentric
+ | cuda
+ | cpp;
+
+alias sm_6_2 = _sm_6_2
+ | glsl_spirv_1_3 + sm_6_1
+ | spirv_1_3 + sm_6_1
+ | cuda
+ | cpp;
+
+alias sm_6_3 = _sm_6_3
+ | glsl_spirv_1_4 + sm_6_2 + _GL_EXT_ray_tracing
+ | spirv_1_4 + sm_6_2 + SPV_KHR_ray_tracing
+ | cuda
+ | cpp;
+
+alias sm_6_4 = _sm_6_4
+ | glsl_spirv_1_4 + sm_6_3
+ | spirv_1_4 + sm_6_3
+ | cuda
+ | cpp;
+
+alias sm_6_5 = _sm_6_5
+ | glsl_spirv_1_4 + sm_6_4 + raytracing + meshshading
+ | spirv_1_4 + sm_6_4 + raytracing + meshshading
+ | cuda
+ | cpp;
+
+alias sm_6_6 = _sm_6_6
+ | glsl_spirv_1_5 + sm_6_5
+ + GL_EXT_shader_atomic_int64 + atomicfloat2
+ | spirv_1_5 + sm_6_5
+ + GL_EXT_shader_atomic_int64 + atomicfloat2
+ + SPV_EXT_descriptor_indexing
+ | cuda
+ | cpp;
+
+alias sm_6_7 = _sm_6_7
+ | glsl_spirv_1_5 + sm_6_6
+ | spirv_1_5 + sm_6_6
+ | cuda
+ | cpp;
+
+alias all = _sm_6_7 + hlsl_nvapi
+ | glsl_spirv_1_5 + sm_6_7
+ + ser + shaderclock + texturefootprint + fragmentshaderinterlock + _GL_NV_shader_subgroup_partitioned
+ + _GL_NV_ray_tracing_motion_blur + _GL_NV_shader_texture_footprint
+ | spirv_1_5 + sm_6_7
+ + ser + shaderclock + texturefootprint + fragmentshaderinterlock + spvGroupNonUniformPartitionedNV
+ + spvRayTracingMotionBlurNV + spvRayTracingMotionBlurNV;
+
+// Profiles
+
+alias GLSL_150 = glsl + sm_5_1 | spirv_1_0;
+alias GLSL_330 = GLSL_150 | spirv_1_0 + sm_5_1;
+alias GLSL_400 = GLSL_150 | spirv_1_0 + sm_5_1;
+alias GLSL_410 = glsl + sm_5_1 | spirv_1_5 + sm_5_1;
+alias GLSL_420 = glsl + sm_5_1 | spirv_1_5 + sm_5_1;
+alias GLSL_430 = glsl + sm_5_1 | spirv_1_5 + sm_5_1;
+alias GLSL_440 = glsl + sm_6_0 | spirv_1_5 + sm_6_0;
+alias GLSL_450 = glsl + sm_6_3 | spirv_1_5 + sm_6_3;
+alias GLSL_460 = glsl_spirv_1_5 + all | spirv_1_5 + all;
+
+alias tess_control = hull;
+alias tess_eval = domain;
+
+alias DX_4_0 = sm_4_0;
+alias DX_4_1 = sm_4_1;
+alias DX_5_0 = sm_5_0;
+alias DX_5_1 = sm_5_1;
+alias DX_6_0 = sm_6_0;
+alias DX_6_1 = sm_6_1;
+alias DX_6_2 = sm_6_2;
+alias DX_6_3 = sm_6_3;
+alias DX_6_4 = sm_6_4;
+alias DX_6_5 = sm_6_5;
+alias DX_6_6 = sm_6_6;
+alias DX_6_7 = sm_6_7;
+
+
diff --git a/source/slang/slang-capability.cpp b/source/slang/slang-capability.cpp
index a3f8157e7..fbe37892a 100644
--- a/source/slang/slang-capability.cpp
+++ b/source/slang/slang-capability.cpp
@@ -96,6 +96,16 @@ void getCapabilityNames(List<UnownedStringSlice>& ioNames)
}
}
+UnownedStringSlice capabilityNameToString(CapabilityName name)
+{
+ return UnownedStringSlice(_getInfo(name).name);
+}
+
+bool isDirectChildOfAbstractAtom(CapabilityAtom name)
+{
+ return _getInfo(name).abstractBase != CapabilityName::Invalid;
+}
+
bool lookupCapabilityName(const UnownedStringSlice& str, CapabilityName& value);
CapabilityName findCapabilityName(UnownedStringSlice const& name)
@@ -482,7 +492,10 @@ bool CapabilityConjunctionSet::implies(CapabilityConjunctionSet const& that) con
return false;
}
}
- return true;
+ // We reached the end of either this or that atom.
+ // If we reached the end of 'that', we know everything in 'that'
+ // is also contained in this, so this implies that.
+ return thatIndex == thatCount;
}
/// Helper functor for binary search on lists of `CapabilityAtom`
@@ -935,6 +948,46 @@ void CapabilitySet::calcCompactedAtoms(List<List<CapabilityAtom>>& outAtoms) con
}
}
+void CapabilitySet::unionWith(const CapabilityConjunctionSet& conjunctionToAdd)
+{
+ // We add conjunctionToAdd to resultSet only if it does not imply any existing conjunctions.
+ // For example, if `resultSet` is (a), and conjunctionToAdd is (ab), then we don't want to add the conjunction
+ // to form (a | ab) because that would reduce to (a).
+ bool skipAdd = false;
+ for (auto& c : m_conjunctions)
+ {
+ if (conjunctionToAdd.implies(c))
+ {
+ skipAdd = true;
+ break;
+ }
+ }
+ if (!skipAdd)
+ {
+ // Once we added the new conjunction, any existing conjunctions that implies the new one can be
+ // removed.
+ // For example, if resultSet was (ab), and we are adding (a), the result should be just (a).
+ for (Index i = 0; i < m_conjunctions.getCount();)
+ {
+ if (m_conjunctions[i].implies(conjunctionToAdd))
+ {
+ m_conjunctions.fastRemoveAt(i);
+ }
+ else
+ {
+ i++;
+ }
+ }
+ m_conjunctions.add(conjunctionToAdd);
+ }
+}
+
+void CapabilitySet::canonicalize()
+{
+ // Make sure conjunctions are sorted so equality tests are trivial.
+ m_conjunctions.sort();
+}
+
void CapabilitySet::join(const CapabilitySet& other)
{
if (isEmpty() || other.isInvalid())
@@ -947,7 +1000,7 @@ void CapabilitySet::join(const CapabilitySet& other)
if (other.isEmpty())
return;
- List<CapabilityConjunctionSet> resultSet;
+ CapabilitySet resultSet;
for (auto& thatConjunction : other.m_conjunctions)
{
for (auto& thisConjunction : m_conjunctions)
@@ -980,42 +1033,20 @@ void CapabilitySet::join(const CapabilitySet& other)
// Otherwise, thisConjunction implies thatConjunction, so we just add thisConjunction to resultSet.
conjunctionToAdd = &thisConjunction;
}
- // We add conjunctionToAdd to resultSet only if it does not imply any existing conjunctions.
- // For example, if `resultSet` is (a), and conjunctionToAdd is (ab), then we don't want to add the conjunction
- // to form (a | ab) because that would reduce to (a).
- bool skipAdd = false;
- for (auto& c : resultSet)
- {
- if (conjunctionToAdd->implies(c))
- {
- skipAdd = true;
- break;
- }
- }
- if (!skipAdd)
- {
- // Once we added the new conjunction, any existing conjunctions that implies the new one can be
- // removed.
- // For example, if resultSet was (ab), and we are adding (a), the result should be just (a).
- for (Index i = 0; i < resultSet.getCount();)
- {
- if (resultSet[i].implies(*conjunctionToAdd))
- {
- resultSet.fastRemoveAt(i);
- }
- else
- {
- i++;
- }
- }
- resultSet.add(*conjunctionToAdd);
- }
+ resultSet.unionWith(*conjunctionToAdd);
}
}
- m_conjunctions = _Move(resultSet);
+ m_conjunctions = _Move(resultSet.m_conjunctions);
- // Make sure conjunctions are sorted so equality tests are trivial.
- m_conjunctions.sort();
+ if (m_conjunctions.getCount() == 0)
+ {
+ // If the result is empty, then we should return as impossible.
+ *this = CapabilitySet::makeInvalid();
+ }
+ else
+ {
+ canonicalize();
+ }
}
bool CapabilitySet::isBetterForTarget(CapabilitySet const& that, CapabilitySet const& targetCaps) const
@@ -1102,4 +1133,86 @@ bool CapabilitySet::isBetterForTarget(CapabilitySet const& that, CapabilitySet c
return false;
}
+bool CapabilitySet::checkCapabilityRequirement(CapabilitySet const& available, CapabilitySet const& required, const CapabilityConjunctionSet*& outFailedAvailableSet)
+{
+ // Requirements x are met by available disjoint capabilities (a | b) iff
+ // both 'a' satisfies x and 'b' satisfies x.
+ // If we have a caller function F() decorated with:
+ // [require(hlsl, _sm_6_3)] [require(spirv, _spv_ray_tracing)] void F() { g(); }
+ // We'd better make sure that `g()` can be compiled with both (hlsl+_sm_6_3) and (spirv+_spv_ray_tracing) capability sets.
+ // In this method, F()'s capability declaration is represented by `available`,
+ // and g()'s capability is represented by `required`.
+ // We will check that for every capability conjunction X of F(), there is one capability conjunction Y in g() such that X implies Y.
+ //
+
+ outFailedAvailableSet = nullptr;
+
+ if (required.isInvalid())
+ return false;
+
+ // If F's capability is empty, we can satisfy any non-empty requirements.
+ //
+ if (available.isEmpty() && !required.isEmpty())
+ return false;
+
+ for (auto& availTargetSet : available.getExpandedAtoms())
+ {
+ bool implied = false;
+ for (auto& requiredTargetSet : required.getExpandedAtoms())
+ {
+ if (availTargetSet.implies(requiredTargetSet))
+ {
+ implied = true;
+ break;
+ }
+ }
+ if (!implied)
+ {
+ outFailedAvailableSet = &availTargetSet;
+ return false;
+ }
+ }
+
+ return true;
+}
+
+void printDiagnosticArg(StringBuilder& sb, const CapabilitySet& capSet)
+{
+ bool isFirstSet = true;
+ for (auto& set : capSet.getExpandedAtoms())
+ {
+ List<CapabilityAtom> compactAtomList;
+ set.calcCompactedAtoms(compactAtomList);
+
+ if (!isFirstSet)
+ {
+ sb<< " | ";
+ }
+ bool isFirst = true;
+ for (auto atom : compactAtomList)
+ {
+ if (!isFirst)
+ {
+ sb << " + ";
+ }
+ auto name = capabilityNameToString((CapabilityName)atom);
+ if (name.startsWith("_"))
+ name = name.tail(1);
+ sb << name;
+ isFirst = false;
+ }
+ isFirstSet = false;
+ }
+}
+
+void printDiagnosticArg(StringBuilder& sb, CapabilityAtom atom)
+{
+ printDiagnosticArg(sb, (CapabilityName)atom);
+}
+
+void printDiagnosticArg(StringBuilder& sb, CapabilityName name)
+{
+ sb << _getInfo(name).name;
+}
+
}
diff --git a/source/slang/slang-capability.h b/source/slang/slang-capability.h
index e686ff5a9..b0ca9231a 100644
--- a/source/slang/slang-capability.h
+++ b/source/slang/slang-capability.h
@@ -108,6 +108,7 @@ public:
/// Does this capability set imply all the capabilities in `other`?
bool implies(CapabilityConjunctionSet const& other) const;
+
/// Does this capability set imply the atomic capability `other`?
bool implies(CapabilityAtom other) const;
@@ -206,6 +207,10 @@ public:
/// Join two capability sets to form (this & other).
void join(const CapabilitySet& other);
+ void unionWith(const CapabilityConjunctionSet& other);
+
+ void canonicalize();
+
/// Are these two capability sets equal?
bool operator==(CapabilitySet const& that) const;
@@ -219,6 +224,8 @@ public:
bool isBetterForTarget(CapabilitySet const& that, CapabilitySet const& targetCaps) const;
+ static bool checkCapabilityRequirement(CapabilitySet const& available, CapabilitySet const& required, const CapabilityConjunctionSet*& outFailedAvailableSet);
+
private:
// The underlying representation we use is a list of conjunctions.
//
@@ -242,4 +249,11 @@ CapabilityName findCapabilityName(UnownedStringSlice const& name);
/// Gets the capability names.
void getCapabilityNames(List<UnownedStringSlice>& ioNames);
+UnownedStringSlice capabilityNameToString(CapabilityName name);
+
+bool isDirectChildOfAbstractAtom(CapabilityAtom name);
+
+void printDiagnosticArg(StringBuilder& sb, CapabilityAtom atom);
+void printDiagnosticArg(StringBuilder& sb, CapabilityName name);
+
}
diff --git a/source/slang/slang-check-decl.cpp b/source/slang/slang-check-decl.cpp
index fe4a7d64c..a7e197d81 100644
--- a/source/slang/slang-check-decl.cpp
+++ b/source/slang/slang-check-decl.cpp
@@ -14,7 +14,7 @@
#include "slang-syntax.h"
#include "slang-ast-synthesis.h"
#include "slang-ast-reflect.h"
-
+#include "slang-ast-iterator.h"
#include <limits>
namespace Slang
@@ -304,6 +304,416 @@ namespace Slang
void visitParamDecl(ParamDecl* paramDecl);
};
+ template<typename VisitorType>
+ struct SemanticsDeclReferenceVisitor
+ : public SemanticsDeclVisitorBase
+ , public StmtVisitor<VisitorType>
+ , public ExprVisitor<VisitorType>
+ , public ValVisitor<VisitorType>
+ , public DeclVisitor<VisitorType>
+ {
+ SemanticsDeclReferenceVisitor(SemanticsContext const& outer)
+ : SemanticsDeclVisitorBase(outer)
+ {}
+
+ List<SourceLoc> sourceLocStack;
+
+ struct PushSourceLocRAII
+ {
+ List<SourceLoc>& stack;
+ bool shouldPop = false;
+ PushSourceLocRAII(List<SourceLoc>& sourceLocStack, SourceLoc loc)
+ : stack(sourceLocStack)
+ {
+ if (loc.isValid())
+ {
+ stack.add(loc);
+ shouldPop = true;
+ }
+ }
+ ~PushSourceLocRAII()
+ {
+ if (shouldPop)
+ {
+ stack.removeLast();
+ }
+ }
+ };
+
+ virtual void processReferencedDecl(Decl* decl) = 0;
+
+ void dispatchIfNotNull(Stmt* stmt)
+ {
+ if (!stmt)
+ return;
+ PushSourceLocRAII sourceLocRAII(sourceLocStack, stmt->loc);
+ return StmtVisitor<VisitorType>::dispatch(stmt);
+ }
+ void dispatchIfNotNull(Expr* expr)
+ {
+ if (!expr)
+ return;
+ PushSourceLocRAII sourceLocRAII(sourceLocStack, expr->loc);
+ return ExprVisitor<VisitorType>::dispatch(expr);
+ }
+ void dispatchIfNotNull(Val* val)
+ {
+ if (!val)
+ return;
+ return ValVisitor<VisitorType>::dispatch(val);
+ }
+ void dispatchIfNotNull(DeclBase* val)
+ {
+ if (!val)
+ return;
+ return DeclVisitor<VisitorType>::dispatch(val);
+ }
+ // Expr Visitor
+ void visitExpr(Expr*) { }
+ void visitIndexExpr(IndexExpr* subscriptExpr)
+ {
+ for (auto arg : subscriptExpr->indexExprs)
+ dispatchIfNotNull(arg);
+ dispatchIfNotNull(subscriptExpr->baseExpression);
+ }
+
+ void visitParenExpr(ParenExpr* expr)
+ {
+ dispatchIfNotNull(expr->base);
+ }
+
+ void visitAssignExpr(AssignExpr* expr)
+ {
+ dispatchIfNotNull(expr->left);
+ dispatchIfNotNull(expr->right);
+ }
+
+ void visitGenericAppExpr(GenericAppExpr* genericAppExpr)
+ {
+ dispatchIfNotNull(genericAppExpr->functionExpr);
+ for (auto arg : genericAppExpr->arguments)
+ dispatchIfNotNull(arg);
+ }
+
+ void visitSharedTypeExpr(SharedTypeExpr* expr) { dispatchIfNotNull(expr->base.exp); }
+
+ void visitInvokeExpr(InvokeExpr* expr)
+ {
+ dispatchIfNotNull(expr->functionExpr);
+ for (auto arg : expr->arguments)
+ dispatchIfNotNull(arg);
+ }
+
+ void visitTypeCastExpr(TypeCastExpr* expr)
+ {
+ dispatchIfNotNull(expr->functionExpr);
+ for (auto arg : expr->arguments)
+ dispatchIfNotNull(arg);
+ }
+
+ void visitDerefExpr(DerefExpr* expr) { dispatchIfNotNull(expr->base); }
+ void visitMatrixSwizzleExpr(MatrixSwizzleExpr* expr)
+ {
+ dispatchIfNotNull(expr->base);
+ }
+ void visitSwizzleExpr(SwizzleExpr* expr)
+ {
+ dispatchIfNotNull(expr->base);
+ }
+ void visitOverloadedExpr(OverloadedExpr*)
+ {
+ return;
+ }
+ void visitOverloadedExpr2(OverloadedExpr2*)
+ {
+ return;
+ }
+ void visitAggTypeCtorExpr(AggTypeCtorExpr*)
+ {
+ return;
+ }
+ void visitCastToSuperTypeExpr(CastToSuperTypeExpr* expr)
+ {
+ dispatchIfNotNull(expr->valueArg);
+ }
+ void visitModifierCastExpr(ModifierCastExpr* expr) { dispatchIfNotNull(expr->valueArg); }
+ void visitLetExpr(LetExpr* expr)
+ {
+ dispatchIfNotNull(expr->body);
+ }
+ void visitExtractExistentialValueExpr(ExtractExistentialValueExpr* expr)
+ {
+ dispatchIfNotNull(expr->declRef.declRefBase);
+ }
+
+ void visitDeclRefExpr(DeclRefExpr* expr)
+ {
+ dispatchIfNotNull(expr->declRef.declRefBase);
+ }
+ void visitStaticMemberExpr(StaticMemberExpr* expr)
+ {
+ dispatchIfNotNull(expr->declRef.declRefBase);
+ }
+ void visitInitializerListExpr(InitializerListExpr* expr)
+ {
+ for (auto arg : expr->args)
+ {
+ dispatchIfNotNull(arg);
+ }
+ }
+
+ void visitThisExpr(ThisExpr*)
+ {
+ return;
+ }
+
+ void visitThisTypeExpr(ThisTypeExpr*) { return; }
+ void visitAndTypeExpr(AndTypeExpr* expr)
+ {
+ dispatchIfNotNull(expr->left.type);
+ dispatchIfNotNull(expr->right.type);
+ }
+ void visitPointerTypeExpr(PointerTypeExpr* expr)
+ {
+ dispatchIfNotNull(expr->base.type);
+ }
+ void visitAsTypeExpr(AsTypeExpr* expr)
+ {
+ dispatchIfNotNull(expr->value);
+ dispatchIfNotNull(expr->witnessArg);
+ }
+ void visitIsTypeExpr(IsTypeExpr* expr)
+ {
+ dispatchIfNotNull(expr->value);
+ dispatchIfNotNull(expr->witnessArg);
+ }
+ void visitMakeOptionalExpr(MakeOptionalExpr* expr)
+ {
+ dispatchIfNotNull(expr->value);
+ dispatchIfNotNull(expr->typeExpr);
+ }
+ void visitPartiallyAppliedGenericExpr(PartiallyAppliedGenericExpr*)
+ {
+ return;
+ }
+ void visitSPIRVAsmExpr(SPIRVAsmExpr*)
+ {
+ return;
+ }
+ void visitModifiedTypeExpr(ModifiedTypeExpr* expr) { dispatchIfNotNull(expr->base.type); }
+ void visitFuncTypeExpr(FuncTypeExpr* expr)
+ {
+ for (const auto& t : expr->parameters)
+ {
+ dispatchIfNotNull(t.type);
+ }
+ dispatchIfNotNull(expr->result.type);
+ }
+ void visitTupleTypeExpr(TupleTypeExpr* expr)
+ {
+ for (auto t : expr->members)
+ {
+ dispatchIfNotNull(t.type);
+ }
+ }
+ void visitTryExpr(TryExpr* expr) { dispatchIfNotNull(expr->base); }
+ void visitHigherOrderInvokeExpr(HigherOrderInvokeExpr* expr)
+ {
+ dispatchIfNotNull(expr->baseFunction);
+ }
+ void visitTreatAsDifferentiableExpr(TreatAsDifferentiableExpr* expr)
+ {
+ dispatchIfNotNull(expr->innerExpr);
+ }
+
+ // Stmt Visitor
+
+ void visitDeclStmt(DeclStmt* stmt) { dispatchIfNotNull(stmt->decl); }
+
+ void visitBlockStmt(BlockStmt* stmt)
+ {
+ dispatchIfNotNull(stmt->body);
+ }
+
+ void visitSeqStmt(SeqStmt* seqStmt)
+ {
+ for (auto stmt : seqStmt->stmts)
+ dispatchIfNotNull(stmt);
+ }
+
+ void visitLabelStmt(LabelStmt* stmt)
+ {
+ dispatchIfNotNull(stmt->innerStmt);
+ }
+
+ void visitBreakStmt(BreakStmt*) { return; }
+
+ void visitContinueStmt(ContinueStmt*) { return; }
+
+ void visitDoWhileStmt(DoWhileStmt* stmt)
+ {
+ dispatchIfNotNull(stmt->predicate);
+ dispatchIfNotNull(stmt->statement);
+ }
+
+ void visitForStmt(ForStmt* stmt)
+ {
+ dispatchIfNotNull(stmt->initialStatement);
+ dispatchIfNotNull(stmt->predicateExpression);
+ dispatchIfNotNull(stmt->sideEffectExpression);
+ dispatchIfNotNull(stmt->statement);
+ }
+
+ void visitCompileTimeForStmt(CompileTimeForStmt* stmt)
+ {
+ dispatchIfNotNull(stmt->rangeBeginExpr);
+ dispatchIfNotNull(stmt->rangeEndExpr);
+ dispatchIfNotNull(stmt->body);
+ }
+
+ void visitSwitchStmt(SwitchStmt* stmt)
+ {
+ dispatchIfNotNull(stmt->condition);
+ dispatchIfNotNull(stmt->body);
+ }
+
+ void visitCaseStmt(CaseStmt* stmt) { dispatchIfNotNull(stmt->expr); }
+
+ void visitTargetSwitchStmt(TargetSwitchStmt* stmt)
+ {
+ for (auto targetCase : stmt->targetCases)
+ dispatchIfNotNull(targetCase);
+ }
+
+ void visitTargetCaseStmt(TargetCaseStmt* stmt)
+ {
+ dispatchIfNotNull(stmt->body);
+ }
+
+ void visitIntrinsicAsmStmt(IntrinsicAsmStmt*) { return; }
+
+ void visitDefaultStmt(DefaultStmt*) { return; }
+
+ void visitIfStmt(IfStmt* stmt)
+ {
+ dispatchIfNotNull(stmt->predicate);
+ dispatchIfNotNull(stmt->positiveStatement);
+ dispatchIfNotNull(stmt->negativeStatement);
+ }
+
+ void visitUnparsedStmt(UnparsedStmt*) { return; }
+
+ void visitEmptyStmt(EmptyStmt*) { return; }
+
+ void visitDiscardStmt(DiscardStmt*) { return; }
+
+ void visitReturnStmt(ReturnStmt* stmt) { dispatchIfNotNull(stmt->expression); }
+
+ void visitWhileStmt(WhileStmt* stmt)
+ {
+ dispatchIfNotNull(stmt->predicate);
+ dispatchIfNotNull(stmt->statement);
+ }
+
+ void visitGpuForeachStmt(GpuForeachStmt*) { return; }
+
+ void visitExpressionStmt(ExpressionStmt* stmt)
+ {
+ dispatchIfNotNull(stmt->expression);
+ }
+
+ // Val Visitor
+
+ void visitDirectDeclRef(DirectDeclRef* declRef)
+ {
+ // If we have already visited, return.
+ // Otherwise add it to visited set.
+ if (!visitedVals.add(declRef))
+ return;
+
+ processReferencedDecl(declRef->getDecl());
+ }
+
+ void visitVal(Val* val)
+ {
+ // If we have already visited, return.
+ // Otherwise add it to visited set.
+ if (!visitedVals.add(val))
+ return;
+
+ for (Index i = 0; i < val->getOperandCount(); i++)
+ {
+ auto& operand = val->m_operands[i];
+ switch (operand.kind)
+ {
+ case ValNodeOperandKind::ValNode:
+ dispatchIfNotNull(val->getOperand(i));
+ break;
+ default:
+ break;
+ }
+ }
+ return;
+ }
+
+ HashSet<Val*> visitedVals;
+
+ // Decl visitor
+ void visitDeclBase(DeclBase*)
+ {}
+
+ void visitContainerDecl(ContainerDecl* decl)
+ {
+ for (auto m : decl->members)
+ {
+ dispatchIfNotNull(m);
+ }
+ }
+
+ void visitFunctionDeclBase(FunctionDeclBase* decl)
+ {
+ visitContainerDecl(decl);
+ dispatchIfNotNull(decl->body);
+ }
+
+ void visitVarDeclBase(VarDeclBase* varDecl)
+ {
+ dispatchIfNotNull(varDecl->type.type);
+ dispatchIfNotNull(varDecl->initExpr);
+ }
+ };
+
+ struct SemanticsDeclCapabilityVisitor
+ : public SemanticsDeclVisitorBase
+ , public DeclVisitor<SemanticsDeclCapabilityVisitor>
+ {
+ SemanticsDeclCapabilityVisitor(SemanticsContext const& outer)
+ : SemanticsDeclVisitorBase(outer)
+ {}
+
+ void visitDecl(Decl*) {}
+ void visitDeclGroup(DeclGroup*) {}
+
+ void checkVarDeclCommon(VarDeclBase* varDecl);
+
+ void visitVarDecl(VarDecl* varDecl)
+ {
+ checkVarDeclCommon(varDecl);
+ }
+
+ void visitParamDecl(ParamDecl* paramDecl)
+ {
+ checkVarDeclCommon(paramDecl);
+ }
+
+ void visitFunctionDeclBase(FunctionDeclBase* funcDecl);
+
+ void visitInheritanceDecl(InheritanceDecl* inheritanceDecl);
+
+ void diagnoseUndeclaredCapability(Decl* decl, const DiagnosticInfo& diagnosticInfo, const CapabilityConjunctionSet* failedAvailableSet);
+ };
+
+
/// Should the given `decl` nested in `parentDecl` be treated as a static rather than instance declaration?
bool isEffectivelyStatic(
Decl* decl,
@@ -528,7 +938,7 @@ namespace Slang
}
else if( auto enumCaseDeclRef = declRef.as<EnumCaseDecl>() )
{
- sema->ensureDecl(declRef.declRefBase, DeclCheckState::Checked);
+ sema->ensureDecl(declRef.declRefBase, DeclCheckState::DefinitionChecked);
QualType qualType;
qualType.type = getType(astBuilder, enumCaseDeclRef);
qualType.isLeftValue = false;
@@ -873,7 +1283,7 @@ namespace Slang
bool SemanticsVisitor::shouldSkipChecking(Decl* decl, DeclCheckState state)
{
- if (state != DeclCheckState::Checked)
+ if (state < DeclCheckState::DefinitionChecked)
return false;
// If we are in language server, we should skip checking all the function bodies
// except for the module or function that the user cared about.
@@ -1058,7 +1468,7 @@ namespace Slang
// If we've gone down this path, then the variable
// declaration is actually pretty far along in checking
- varDecl->setCheckState(DeclCheckState::Checked);
+ varDecl->setCheckState(DeclCheckState::DefinitionChecked);
}
else
{
@@ -1087,7 +1497,7 @@ namespace Slang
maybeInferArraySizeForVariable(varDecl);
- varDecl->setCheckState(DeclCheckState::Checked);
+ varDecl->setCheckState(DeclCheckState::DefinitionChecked);
}
}
//
@@ -1306,7 +1716,7 @@ namespace Slang
// We need to ensure that any variable doesn't introduce
// a constant with a circular definition.
//
- varDecl->setCheckState(DeclCheckState::Checked);
+ varDecl->setCheckState(DeclCheckState::DefinitionChecked);
_validateCircularVarDefinition(varDecl);
}
else
@@ -1553,7 +1963,7 @@ namespace Slang
assocTypeDef->nameAndLoc.name = getName("Differential");
assocTypeDef->type.type = satisfyingType;
assocTypeDef->parentDecl = aggTypeDecl;
- assocTypeDef->setCheckState(DeclCheckState::Checked);
+ assocTypeDef->setCheckState(DeclCheckState::DefinitionChecked);
aggTypeDecl->members.add(assocTypeDef);
}
@@ -1861,7 +2271,7 @@ namespace Slang
//
for(auto importDecl : moduleDecl->getMembersOfType<ImportDecl>())
{
- ensureDecl(importDecl, DeclCheckState::Checked);
+ ensureDecl(importDecl, DeclCheckState::DefinitionChecked);
}
// Next, make sure all `__include` decls are processed and the referenced
@@ -1873,15 +2283,15 @@ namespace Slang
auto decl = fileDecl->members[i];
if (auto includeDecl = as<IncludeDecl>(decl))
{
- ensureDecl(includeDecl, DeclCheckState::Checked);
+ ensureDecl(includeDecl, DeclCheckState::DefinitionChecked);
}
else if (auto implementingDecl = as<ImplementingDecl>(decl))
{
- ensureDecl(implementingDecl, DeclCheckState::Checked);
+ ensureDecl(implementingDecl, DeclCheckState::DefinitionChecked);
}
else if (auto importDecl = as<ImportDecl>(decl))
{
- ensureDecl(importDecl, DeclCheckState::Checked);
+ ensureDecl(importDecl, DeclCheckState::DefinitionChecked);
}
}
};
@@ -1893,7 +2303,7 @@ namespace Slang
}
// The entire goal of semantic checking is to get all of the
- // declarations in the module up to `DeclCheckState::Checked`.
+ // declarations in the module up to `DeclCheckState::DefinitionChecked`.
//
// The main catch is that checking one declaration A up to state M
// may required that declaration B is checked up to state N.
@@ -1950,7 +2360,8 @@ namespace Slang
DeclCheckState::ReadyForReference,
DeclCheckState::ReadyForLookup,
DeclCheckState::ReadyForConformances,
- DeclCheckState::Checked
+ DeclCheckState::DefinitionChecked,
+ DeclCheckState::CapabilityChecked,
};
for(auto s : states)
{
@@ -2855,6 +3266,9 @@ namespace Slang
ThisExpr*& synThis)
{
auto synFuncDecl = m_astBuilder->create<FuncDecl>();
+ synFuncDecl->ownedScope = m_astBuilder->create<Scope>();
+ synFuncDecl->ownedScope->containerDecl = synFuncDecl;
+ synFuncDecl->ownedScope->parent = getScope(context->parentDecl);
// For now our synthesized method will use the name and source
// location of the requirement we are trying to satisfy.
@@ -2954,6 +3368,7 @@ namespace Slang
// For a non-`static` requirement, we need a `this` parameter.
//
synThis = m_astBuilder->create<ThisExpr>();
+ synThis->scope = synFuncDecl->ownedScope;
// The type of `this` in our method will be the type for
// which we are synthesizing a conformance.
@@ -3314,6 +3729,9 @@ namespace Slang
// the required accessor.
//
auto synAccessorDecl = (AccessorDecl*) m_astBuilder->createByNodeType(requiredAccessorDeclRef.getDecl()->astNodeType);
+ synAccessorDecl->ownedScope = m_astBuilder->create<Scope>();
+ synAccessorDecl->ownedScope->containerDecl = synAccessorDecl;
+ synAccessorDecl->ownedScope->parent = getScope(context->parentDecl);
// Whatever the required accessor returns, that is what our synthesized accessor will return.
//
@@ -3359,6 +3777,7 @@ namespace Slang
// a `this` expression.
//
ThisExpr* synThis = m_astBuilder->create<ThisExpr>();
+ synThis->scope = synAccessorDecl->ownedScope;
// The type of `this` in our accessor will be the type for
// which we are synthesizing a conformance.
@@ -5029,7 +5448,7 @@ namespace Slang
// the min/max tag values, or the total number of tags, so that people don't
// have to declare these as additional cases.
- enumConformanceDecl->setCheckState(DeclCheckState::Checked);
+ enumConformanceDecl->setCheckState(DeclCheckState::DefinitionChecked);
}
}
@@ -5055,7 +5474,7 @@ namespace Slang
// doing its own header checking, rather than rely on this...
caseDecl->type.type = enumType;
- ensureDecl(caseDecl, DeclCheckState::Checked);
+ ensureDecl(caseDecl, DeclCheckState::DefinitionChecked);
}
// For any enum case that didn't provide an explicit
@@ -7569,9 +7988,13 @@ namespace Slang
SemanticsDeclAttributesVisitor(shared).dispatch(decl);
break;
- case DeclCheckState::Checked:
+ case DeclCheckState::DefinitionChecked:
SemanticsDeclBodyVisitor(shared).dispatch(decl);
break;
+
+ case DeclCheckState::CapabilityChecked:
+ SemanticsDeclCapabilityVisitor(shared).dispatch(decl);
+ break;
}
}
@@ -8144,4 +8567,493 @@ namespace Slang
checkDerivativeAttribute(this, decl, primalAttr);
}
}
+
+ static void _propagateRequirement(SemanticsVisitor* visitor, CapabilitySet& resultCaps, SyntaxNode* userNode, SyntaxNode* referencedNode, const CapabilitySet& nodeCaps, SourceLoc referenceLoc)
+ {
+ auto referencedDecl = as<Decl>(referencedNode);
+
+ // Ignore cyclic references.
+ if (referencedDecl)
+ {
+ if (referencedDecl->checkState.isBeingChecked())
+ return;
+
+ ensureDecl(visitor, referencedDecl, DeclCheckState::CapabilityChecked);
+ }
+
+ if (resultCaps.implies(nodeCaps))
+ return;
+
+ auto oldCaps = resultCaps;
+ bool isAnyInvalid = resultCaps.isInvalid() || nodeCaps.isInvalid();
+ resultCaps.join(nodeCaps);
+
+ auto decl = as<Decl>(userNode);
+
+ if (!isAnyInvalid && resultCaps.isInvalid())
+ {
+ // If joining the referenced decl's requirements results an invalid capability set,
+ // then the decl is using things that require conflicting set of capabilities, and we should diagnose an error.
+ if (referencedDecl && decl)
+ {
+ visitor->getSink()->diagnose(
+ referenceLoc,
+ Diagnostics::conflictingCapabilityDueToUseOfDecl,
+ referencedDecl,
+ nodeCaps,
+ decl,
+ oldCaps);
+ }
+ else if (decl)
+ {
+ visitor->getSink()->diagnose(
+ referenceLoc,
+ Diagnostics::conflictingCapabilityDueToStatement,
+ nodeCaps,
+ decl,
+ oldCaps);
+ }
+ else
+ {
+ visitor->getSink()->diagnose(
+ referenceLoc,
+ Diagnostics::conflictingCapabilityDueToStatementEnclosingFunc,
+ nodeCaps,
+ oldCaps);
+ }
+ }
+ if (referencedDecl && decl)
+ {
+ for (auto& capSet : nodeCaps.getExpandedAtoms())
+ {
+ for (auto atom : capSet.getExpandedAtoms())
+ {
+ decl->capabilityRequirementProvenance.addIfNotExists(atom, DeclReferenceWithLoc{ referencedDecl, referenceLoc });
+ }
+ }
+ }
+ };
+
+ CapabilitySet getStatementCapabilityUsage(SemanticsVisitor* visitor, Stmt* stmt);
+
+ template<typename ProcessFunc>
+ struct CapabilityDeclReferenceVisitor
+ : public SemanticsDeclReferenceVisitor<CapabilityDeclReferenceVisitor<ProcessFunc>>
+ {
+ typedef SemanticsDeclReferenceVisitor<CapabilityDeclReferenceVisitor<ProcessFunc>> Base;
+
+ const ProcessFunc& handleReferenceFunc;
+ CapabilityDeclReferenceVisitor(const ProcessFunc& processFunc, SemanticsContext const& outer)
+ : handleReferenceFunc(processFunc)
+ , SemanticsDeclReferenceVisitor<CapabilityDeclReferenceVisitor<ProcessFunc>>(outer)
+ {
+ }
+ virtual void processReferencedDecl(Decl* decl) override
+ {
+ SourceLoc loc = SourceLoc();
+ if (Base::sourceLocStack.getCount())
+ loc = Base::sourceLocStack.getLast();
+ handleReferenceFunc(decl, decl->inferredCapabilityRequirements, loc);
+ }
+ void visitDiscardStmt(DiscardStmt* stmt)
+ {
+ handleReferenceFunc(stmt, CapabilitySet(CapabilityName::fragment), stmt->loc);
+ }
+ void visitTargetSwitchStmt(TargetSwitchStmt* stmt)
+ {
+ CapabilitySet set;
+ for (auto targetCase : stmt->targetCases)
+ {
+ auto targetCap = CapabilitySet(CapabilityName(targetCase->capability));
+ auto oldCap = targetCap;
+ auto bodyCap = getStatementCapabilityUsage(this, targetCase->body);
+ targetCap.join(bodyCap);
+ if (targetCap.isInvalid())
+ {
+ Base::getSink()->diagnose(targetCase->body->loc, Diagnostics::conflictingCapabilityDueToStatement, bodyCap, "target_switch", oldCap);
+ }
+ for (auto& conjunction : targetCap.getExpandedAtoms())
+ set.unionWith(conjunction);
+ }
+ set.canonicalize();
+ handleReferenceFunc(stmt, set, stmt->loc);
+ }
+ };
+
+ template<typename ProcessFunc>
+ void visitReferencedDecls(SemanticsContext& context, NodeBase* node, SourceLoc initialLoc, const ProcessFunc& func)
+ {
+ CapabilityDeclReferenceVisitor<ProcessFunc> visitor(func, context);
+ visitor.sourceLocStack.add(initialLoc);
+
+ if (auto val = as<Val>(node))
+ visitor.dispatchIfNotNull(val);
+ if (auto stmt = as<Stmt>(node))
+ visitor.dispatchIfNotNull(stmt);
+ if (auto expr = as<Expr>(node))
+ visitor.dispatchIfNotNull(expr);
+ if (auto decl = as<Decl>(node))
+ visitor.dispatchIfNotNull(decl);
+ }
+
+ CapabilitySet getStatementCapabilityUsage(SemanticsVisitor* visitor, Stmt* stmt)
+ {
+ if (stmt == nullptr)
+ return CapabilitySet();
+
+ CapabilitySet inferredRequirements;
+ visitReferencedDecls(*visitor, stmt, stmt->loc, [&](SyntaxNode* node, const CapabilitySet& nodeCaps, SourceLoc refLoc)
+ {
+ _propagateRequirement(visitor, inferredRequirements, stmt, node, nodeCaps, refLoc);
+ });
+ return inferredRequirements;
+ }
+
+ void SemanticsDeclCapabilityVisitor::checkVarDeclCommon(VarDeclBase* varDecl)
+ {
+ visitReferencedDecls(*this, varDecl->type.type, varDecl->loc, [this, varDecl](SyntaxNode* node, const CapabilitySet& nodeCaps, SourceLoc refLoc)
+ {
+ _propagateRequirement(this, varDecl->inferredCapabilityRequirements, varDecl, node, nodeCaps, refLoc);
+ });
+ visitReferencedDecls(*this, varDecl->initExpr, varDecl->loc, [this, varDecl](SyntaxNode* node, const CapabilitySet& nodeCaps, SourceLoc refLoc)
+ {
+ _propagateRequirement(this, varDecl->inferredCapabilityRequirements, varDecl, node, nodeCaps, refLoc);
+ });
+ }
+
+ void SemanticsDeclCapabilityVisitor::visitFunctionDeclBase(FunctionDeclBase* funcDecl)
+ {
+ for (auto member : funcDecl->members)
+ {
+ ensureDecl(member, DeclCheckState::CapabilityChecked);
+ _propagateRequirement(this, funcDecl->inferredCapabilityRequirements, funcDecl, member, member->inferredCapabilityRequirements, member->loc);
+ }
+ visitReferencedDecls(*this, funcDecl->body, funcDecl->loc, [this, funcDecl](SyntaxNode* node, const CapabilitySet& nodeCaps, SourceLoc refLoc)
+ {
+ _propagateRequirement(this, funcDecl->inferredCapabilityRequirements, funcDecl, node, nodeCaps, refLoc);
+ });
+
+ // A decls's declared capability set is a transitive join of all parent declarations.
+ CapabilitySet declaredCaps;
+ for (Decl* parent = funcDecl; parent; parent = getParentDecl(parent))
+ {
+ CapabilitySet localDeclaredCaps;
+
+ for (auto decoration : parent->getModifiersOfType<RequireCapabilityAttribute>())
+ {
+ for (auto& set : decoration->capabilitySet.getExpandedAtoms())
+ localDeclaredCaps.unionWith(set);
+ }
+ declaredCaps.join(localDeclaredCaps);
+ }
+
+ if (!declaredCaps.isEmpty())
+ {
+ // If the function is an entrypoint, add the stage to declaredCaps.
+ if (auto entryPointAttr = funcDecl->findModifier<EntryPointAttribute>())
+ {
+ auto stageCaps = CapabilitySet(Profile(entryPointAttr->stage).getCapabilityName());
+ if (declaredCaps.isIncompatibleWith(stageCaps))
+ {
+ getSink()->diagnose(funcDecl->loc, Diagnostics::stageIsInCompatibleWithCapabilityDefinition, funcDecl, stageCaps, declaredCaps);
+ }
+ else
+ {
+ declaredCaps.join(stageCaps);
+ }
+ }
+ }
+
+ auto vis = getDeclVisibility(funcDecl);
+ if (declaredCaps.isEmpty())
+ {
+ // If the user has not declared any capabilities,
+ // we should diagnose an error if this is a public symbol.
+ if (vis == DeclVisibility::Public && !funcDecl->inferredCapabilityRequirements.isEmpty())
+ {
+ if (!getModuleDecl(funcDecl)->isInLegacyLanguage)
+ {
+ getSink()->diagnose(funcDecl->loc, Diagnostics::missingCapabilityRequirementOnPublicDecl, funcDecl);
+ }
+ }
+ }
+ else
+ {
+ if (vis == DeclVisibility::Public)
+ {
+ // For public decls, we need to enforce that the function
+ // only uses capabilities that it declares.
+ const CapabilityConjunctionSet* failedAvailableCapabilityConjunction = nullptr;
+ if (!CapabilitySet::checkCapabilityRequirement(
+ declaredCaps,
+ funcDecl->inferredCapabilityRequirements,
+ failedAvailableCapabilityConjunction))
+ {
+ diagnoseUndeclaredCapability(funcDecl, Diagnostics::useOfUndeclaredCapability, failedAvailableCapabilityConjunction);
+ funcDecl->inferredCapabilityRequirements = declaredCaps;
+ }
+ }
+ else
+ {
+ // For internal decls, their inferred capability should be joined
+ // with the declared capabilities.
+ funcDecl->inferredCapabilityRequirements.join(declaredCaps);
+ }
+ }
+ }
+
+ void SemanticsDeclCapabilityVisitor::visitInheritanceDecl(InheritanceDecl* inheritanceDecl)
+ {
+ // Check that the implementation of an interface requirement is not using more capabilities
+ // than what's declared on the interface method.
+ if (inheritanceDecl->witnessTable)
+ {
+ for (auto& kv : inheritanceDecl->witnessTable->m_requirementDictionary)
+ {
+ if (kv.value.getFlavor() != RequirementWitness::Flavor::declRef)
+ continue;
+ auto requirementDecl = kv.key;
+ auto implDecl = kv.value.getDeclRef();
+ if (!implDecl)
+ continue;
+
+ if (getModuleDecl(implDecl.getDecl())->isInLegacyLanguage)
+ break;
+
+ ensureDecl(requirementDecl, DeclCheckState::CapabilityChecked);
+ ensureDecl(implDecl.declRefBase, DeclCheckState::CapabilityChecked);
+
+ const CapabilityConjunctionSet* failedAvailableCapabilityConjunction = nullptr;
+ if (!CapabilitySet::checkCapabilityRequirement(
+ requirementDecl->inferredCapabilityRequirements,
+ implDecl.getDecl()->inferredCapabilityRequirements,
+ failedAvailableCapabilityConjunction))
+ {
+ diagnoseUndeclaredCapability(implDecl.getDecl(), Diagnostics::useOfUndeclaredCapabilityOfInterfaceRequirement, failedAvailableCapabilityConjunction);
+ }
+ }
+ }
+ }
+
+ DeclVisibility getDeclVisibility(Decl* decl)
+ {
+ if (as<GenericTypeParamDecl>(decl) || as<GenericValueParamDecl>(decl) || as<GenericTypeConstraintDecl>(decl))
+ {
+ auto genericDecl = as<GenericDecl>(decl->parentDecl);
+ if (!genericDecl)
+ return DeclVisibility::Default;
+ if (genericDecl->inner)
+ return getDeclVisibility(genericDecl->inner);
+ return DeclVisibility::Default;
+ }
+ if (auto genericDecl = as<GenericDecl>(decl))
+ decl = genericDecl->inner;
+ for (; decl; decl = getParentDecl(decl))
+ {
+ if (as<AccessorDecl>(decl))
+ continue;
+ if (as<EnumCaseDecl>(decl))
+ continue;
+ break;
+ }
+ if (!decl)
+ return DeclVisibility::Public;
+
+ for (auto modifier : decl->modifiers)
+ {
+ if (as<PublicModifier>(modifier))
+ return DeclVisibility::Public;
+ else if (as<InternalModifier>(modifier))
+ return DeclVisibility::Internal;
+ else if (as<PrivateModifier>(modifier))
+ return DeclVisibility::Private;
+ }
+
+ // Interface members will always have the same visibility as the interface itself.
+ if (auto interfaceDecl = findParentInterfaceDecl(decl))
+ {
+ return getDeclVisibility(interfaceDecl);
+ }
+ else if (as<NamespaceDecl>(decl))
+ {
+ return DeclVisibility::Public;
+ }
+ if (auto parentModule = getModuleDecl(decl))
+ return parentModule->isInLegacyLanguage ? DeclVisibility::Public : DeclVisibility::Internal;
+
+ return DeclVisibility::Default;
+ }
+
+ void diagnoseCapabilityProvenance(DiagnosticSink* sink, Decl* decl, CapabilityAtom missingAtom)
+ {
+ HashSet<Decl*> printedDecls;
+ auto thisModule = getModuleDecl(decl);
+ Decl* declToPrint = decl;
+ while (declToPrint)
+ {
+ printedDecls.add(declToPrint);
+ if (auto provenance = declToPrint->capabilityRequirementProvenance.tryGetValue(missingAtom))
+ {
+ sink->diagnose(provenance->referenceLoc, Diagnostics::seeUsingOf, provenance->referencedDecl);
+ declToPrint = provenance->referencedDecl;
+ if (printedDecls.contains(declToPrint))
+ break;
+ if (declToPrint->findModifier<RequireCapabilityAttribute>())
+ break;
+ auto moduleDecl = getModuleDecl(declToPrint);
+ if (thisModule != moduleDecl)
+ break;
+ if (moduleDecl && moduleDecl->isInLegacyLanguage)
+ continue;
+ if (getDeclVisibility(declToPrint) == DeclVisibility::Public)
+ break;
+ }
+ else
+ {
+ break;
+ }
+ }
+ if (declToPrint)
+ {
+ sink->diagnose(declToPrint->loc, Diagnostics::seeDefinitionOf, declToPrint);
+ }
+ }
+
+ // Print diagnostics tracing which referenced decls are not compatible with the given atom.
+ void diagnoseIncompatibleAtomProvenance(SemanticsVisitor* visitor, DiagnosticSink* sink, Decl* decl, CapabilityAtom incompatibleAtom, int traceLevels = 10)
+ {
+ Decl* refDecl = nullptr;
+ SourceLoc loc;
+ while (traceLevels > 0)
+ {
+ refDecl = nullptr;
+ visitReferencedDecls(*visitor, decl, decl->loc, [&](SyntaxNode* node, const CapabilitySet& nodeCaps, SourceLoc refLoc)
+ {
+ if (nodeCaps.isIncompatibleWith(incompatibleAtom))
+ {
+ if (auto referencedDecl = as<Decl>(node))
+ {
+ refDecl = referencedDecl;
+ loc = refLoc;
+ }
+ else
+ sink->diagnose(refLoc, Diagnostics::seeDefinitionOf, "statement");
+ }
+ });
+ if (refDecl)
+ {
+ sink->diagnose(loc, Diagnostics::seeUsingOf, refDecl);
+ decl = refDecl;
+ }
+ else
+ {
+ break;
+ }
+ traceLevels--;
+ }
+ }
+
+ void SemanticsDeclCapabilityVisitor::diagnoseUndeclaredCapability(Decl* decl, const DiagnosticInfo& diagnosticInfo, const CapabilityConjunctionSet* failedAvailableSet)
+ {
+ if (decl->inferredCapabilityRequirements.getExpandedAtoms().getCount() == 0)
+ return;
+
+ // There are two causes for why type checking failed on failedAvailableSet.
+ // The first scenario is that failedAvailableSet defines a set of capabilities on a
+ // compilation target (e.g. hlsl) that isn't defined by some callees, for example, if we have
+ // a function:
+ // [require(hlsl)] // <-- failedAvailableSet
+ // [require(cpp)]
+ // void caller()
+ // {
+ // printf(); // assume this is defined for (cpp | cuda).
+ // }
+ // In this case we should diagnose error reporting printf isn't defined on a required target.
+ //
+ // The second scenario is when the callee is using a capability that is not provided by the requirement.
+ // For example:
+ // [require(hlsl,b,c)]
+ // void caller()
+ // {
+ // useD(); // require capability (hlsl,d)
+ // }
+ // In this case we should report that useD() is using a capability that is not declared by caller.
+ //
+
+ // Now, we detect if we are case 1.
+ if (decl->inferredCapabilityRequirements.isIncompatibleWith(*failedAvailableSet))
+ {
+ // Find the most derived atom that is leading to the incompatiblity.
+ for (Index i = failedAvailableSet->getExpandedAtoms().getCount() - 1; i >= 0; i--)
+ {
+ auto atom = failedAvailableSet->getExpandedAtoms()[i];
+ if (!isDirectChildOfAbstractAtom(atom))
+ continue;
+ if (decl->inferredCapabilityRequirements.isIncompatibleWith(atom))
+ {
+ getSink()->diagnose(decl->loc, Diagnostics::declHasDependenciesNotDefinedOnTarget, decl, atom);
+ diagnoseIncompatibleAtomProvenance(this, getSink(), decl, atom);
+ return;
+ }
+ }
+ return;
+ }
+
+ // If we reach here, we are case 2.
+
+ CapabilityConjunctionSet* matchingRequirement = &decl->inferredCapabilityRequirements.getExpandedAtoms().getFirst();
+ CapabilityAtom missingAtom = matchingRequirement->getExpandedAtoms().getFirst();
+ if (missingAtom == CapabilityAtom::Invalid)
+ return;
+
+ if (failedAvailableSet)
+ {
+ Int maxIntersectionCount = 0;
+ for (auto& usedSet : decl->inferredCapabilityRequirements.getExpandedAtoms())
+ {
+ auto intersection = usedSet.countIntersectionWith(*failedAvailableSet);
+ if (intersection > maxIntersectionCount)
+ {
+ matchingRequirement = &usedSet;
+ maxIntersectionCount = intersection;
+ }
+ }
+ Index pos = 0;
+ for (Index i = 0; i < matchingRequirement->getExpandedAtoms().getCount(); i++)
+ {
+ auto atom = matchingRequirement->getExpandedAtoms()[i];
+ while (pos < failedAvailableSet->getExpandedAtoms().getCount())
+ {
+ if (failedAvailableSet->getExpandedAtoms()[pos] < atom)
+ pos++;
+ else
+ break;
+ }
+
+ if (pos >= failedAvailableSet->getExpandedAtoms().getCount() ||
+ failedAvailableSet->getExpandedAtoms()[pos] != atom)
+ {
+ missingAtom = atom;
+ break;
+ }
+ }
+
+ // Select the most derived atom of `missingAtom`.
+ for (Index i = matchingRequirement->getExpandedAtoms().getCount() - 1; i >= 0 ; i--)
+ {
+ auto atom = matchingRequirement->getExpandedAtoms()[i];
+ if (CapabilityConjunctionSet(atom).implies(missingAtom))
+ {
+ missingAtom = atom;
+ break;
+ }
+ }
+ }
+
+ getSink()->diagnose(decl->loc, diagnosticInfo, decl, missingAtom);
+
+ // Print provenances.
+ diagnoseCapabilityProvenance(getSink(), decl, missingAtom);
+ }
+
}
diff --git a/source/slang/slang-check-expr.cpp b/source/slang/slang-check-expr.cpp
index e5e990fe5..f9adcc91a 100644
--- a/source/slang/slang-check-expr.cpp
+++ b/source/slang/slang-check-expr.cpp
@@ -62,7 +62,7 @@ namespace Slang
{
VarDecl* varDecl = m_astBuilder->create<VarDecl>();
varDecl->parentDecl = nullptr; // TODO: need to fill this in somehow!
- varDecl->checkState = DeclCheckState::Checked;
+ varDecl->checkState = DeclCheckState::DefinitionChecked;
varDecl->nameAndLoc.loc = expr->loc;
varDecl->initExpr = expr;
varDecl->type.type = expr->type.type;
@@ -827,55 +827,6 @@ namespace Slang
}
}
- DeclVisibility SemanticsVisitor::getDeclVisibility(Decl* decl)
- {
- if (as<GenericTypeParamDecl>(decl) || as<GenericValueParamDecl>(decl) || as<GenericTypeConstraintDecl>(decl))
- {
- auto genericDecl = as<GenericDecl>(decl->parentDecl);
- if (!genericDecl)
- return DeclVisibility::Default;
- if (genericDecl->inner)
- return getDeclVisibility(genericDecl->inner);
- return DeclVisibility::Default;
- }
- if (auto genericDecl = as<GenericDecl>(decl))
- decl = genericDecl->inner;
- for (; decl; decl = getParentDecl(decl))
- {
- if (as<AccessorDecl>(decl))
- continue;
- if (as<EnumCaseDecl>(decl))
- continue;
- break;
- }
- if (!decl)
- return DeclVisibility::Public;
-
- for (auto modifier : decl->modifiers)
- {
- if (as<PublicModifier>(modifier))
- return DeclVisibility::Public;
- else if (as<InternalModifier>(modifier))
- return DeclVisibility::Internal;
- else if (as<PrivateModifier>(modifier))
- return DeclVisibility::Private;
- }
-
- // Interface members will always have the same visibility as the interface itself.
- if (auto interfaceDecl = findParentInterfaceDecl(decl))
- {
- return getDeclVisibility(interfaceDecl);
- }
- else if (as<NamespaceDecl>(decl))
- {
- return DeclVisibility::Public;
- }
- if (auto parentModule = getModuleDecl(decl))
- return parentModule->isInLegacyLanguage ? DeclVisibility::Public : DeclVisibility::Internal;
-
- return DeclVisibility::Default;
- }
-
DeclVisibility SemanticsVisitor::getTypeVisibility(Type* type)
{
if (auto declRefType = as<DeclRefType>(type))
@@ -1721,7 +1672,7 @@ namespace Slang
if (!getInitExpr(m_astBuilder, declRef))
return nullptr;
- ensureDecl(declRef.getDecl(), DeclCheckState::Checked);
+ ensureDecl(declRef.getDecl(), DeclCheckState::DefinitionChecked);
ConstantFoldingCircularityInfo newCircularityInfo(decl, circularityInfo);
return tryConstantFoldExpr(getInitExpr(m_astBuilder, declRef), &newCircularityInfo);
}
diff --git a/source/slang/slang-check-impl.h b/source/slang/slang-check-impl.h
index 4102b3eba..1808274f3 100644
--- a/source/slang/slang-check-impl.h
+++ b/source/slang/slang-check-impl.h
@@ -1099,7 +1099,6 @@ namespace Slang
SourceLoc loc,
Expr* originalExpr);
- DeclVisibility getDeclVisibility(Decl* decl);
DeclVisibility getTypeVisibility(Type* type);
bool isDeclVisibleFromScope(DeclRef<Decl> declRef, Scope* scope);
LookupResult filterLookupResultByVisibility(const LookupResult& lookupResult);
@@ -1472,6 +1471,8 @@ namespace Slang
Expr* expr,
String* outVal);
+ bool checkCapabilityName(Expr* expr, CapabilityName& outCapabilityName);
+
void visitModifier(Modifier*);
AttributeDecl* lookUpAttributeDecl(Name* attributeName, Scope* scope);
@@ -2646,4 +2647,8 @@ namespace Slang
};
bool isUnsizedArrayType(Type* type);
+
+ DeclVisibility getDeclVisibility(Decl* decl);
+
+ void diagnoseCapabilityProvenance(DiagnosticSink* sink, Decl* decl, CapabilityAtom missingAtom);
}
diff --git a/source/slang/slang-check-modifier.cpp b/source/slang/slang-check-modifier.cpp
index 1a7f64944..51cb5346a 100644
--- a/source/slang/slang-check-modifier.cpp
+++ b/source/slang/slang-check-modifier.cpp
@@ -74,6 +74,30 @@ namespace Slang
return false;
}
+ bool SemanticsVisitor::checkCapabilityName(Expr* expr, CapabilityName& outCapabilityName)
+ {
+ if (auto varExpr = as<VarExpr>(expr))
+ {
+ if (!varExpr->name)
+ return false;
+ if (varExpr->name == getSession()->getCompletionRequestTokenName())
+ {
+ auto& suggestions = getLinkage()->contentAssistInfo.completionSuggestions;
+ suggestions.clear();
+ suggestions.scopeKind = CompletionSuggestions::ScopeKind::Capabilities;
+ }
+ outCapabilityName = findCapabilityName(varExpr->name->text.getUnownedSlice());
+ if (outCapabilityName == CapabilityName::Invalid)
+ {
+ getSink()->diagnose(expr, Diagnostics::unknownCapability, varExpr->name);
+ return false;
+ }
+ return true;
+ }
+ getSink()->diagnose(expr, Diagnostics::expectCapability);
+ return false;
+ }
+
void SemanticsVisitor::visitModifier(Modifier*)
{
// Do nothing with modifiers for now
@@ -209,7 +233,7 @@ namespace Slang
paramDecl->nameAndLoc = member->nameAndLoc;
paramDecl->type = varMember->type;
paramDecl->loc = member->loc;
- paramDecl->setCheckState(DeclCheckState::Checked);
+ paramDecl->setCheckState(DeclCheckState::DefinitionChecked);
paramDecl->parentDecl = attrDecl;
attrDecl->members.add(paramDecl);
@@ -233,7 +257,7 @@ namespace Slang
//
// TODO: what check state is relevant here?
//
- ensureDecl(attrDecl, DeclCheckState::Checked);
+ ensureDecl(attrDecl, DeclCheckState::DefinitionChecked);
return attrDecl;
}
@@ -783,6 +807,19 @@ namespace Slang
pyExportAttr->name = name;
}
+ else if (auto requireCapAttr = as<RequireCapabilityAttribute>(attr))
+ {
+ List<CapabilityName> capabilityNames;
+ for (auto& arg : attr->args)
+ {
+ CapabilityName capName;
+ if (checkCapabilityName(arg, capName))
+ {
+ capabilityNames.add(capName);
+ }
+ }
+ requireCapAttr->capabilitySet = CapabilitySet(capabilityNames);
+ }
else
{
if(attr->args.getCount() == 0)
diff --git a/source/slang/slang-check-overload.cpp b/source/slang/slang-check-overload.cpp
index 37c92a317..c6e6677b3 100644
--- a/source/slang/slang-check-overload.cpp
+++ b/source/slang/slang-check-overload.cpp
@@ -351,7 +351,7 @@ namespace Slang
{
// Otherwise, the generic decl had better provide a default value
// or this reference is ill-formed.
- ensureDecl(valParamRef, DeclCheckState::Checked);
+ ensureDecl(valParamRef, DeclCheckState::DefinitionChecked);
ConstantFoldingCircularityInfo newCircularityInfo(valParamRef.getDecl(), nullptr);
auto defaultVal = tryConstantFoldExpr(valParamRef.substitute(m_astBuilder, valParamRef.getDecl()->initExpr), &newCircularityInfo);
if (!defaultVal)
diff --git a/source/slang/slang-check-shader.cpp b/source/slang/slang-check-shader.cpp
index b45107640..2e854554e 100644
--- a/source/slang/slang-check-shader.cpp
+++ b/source/slang/slang-check-shader.cpp
@@ -396,7 +396,6 @@ namespace Slang
auto module = getModule(entryPointFuncDecl);
auto linkage = module->getLinkage();
-
// Every entry point needs to have a stage specified either via
// command-line/API options, or via an explicit `[shader("...")]` attribute.
//
@@ -506,6 +505,38 @@ namespace Slang
}
}
}
+
+ for (auto target : linkage->targets)
+ {
+ auto targetCaps = target->getTargetCaps();
+ auto stageCapabilitySet = CapabilitySet(entryPoint->getProfile().getCapabilityName());
+ targetCaps.join(stageCapabilitySet);
+ if (targetCaps.isIncompatibleWith(entryPointFuncDecl->inferredCapabilityRequirements))
+ {
+ sink->diagnose(entryPointFuncDecl, Diagnostics::entryPointUsesUnavailableCapability, entryPointFuncDecl, entryPointFuncDecl->inferredCapabilityRequirements, targetCaps);
+ auto& interredCapConjunctions = entryPointFuncDecl->inferredCapabilityRequirements.getExpandedAtoms();
+
+ // Find out what exactly is incompatible and print out a trace of provenance to
+ // help user diagnose their code.
+ auto& conjunctions = targetCaps.getExpandedAtoms();
+ if (conjunctions.getCount() == 1 && interredCapConjunctions.getCount() == 1)
+ {
+ for (auto atom : conjunctions[0].getExpandedAtoms())
+ {
+ for (auto inferredAtom : interredCapConjunctions[0].getExpandedAtoms())
+ {
+ if (CapabilityConjunctionSet(inferredAtom).isIncompatibleWith(atom))
+ {
+ diagnoseCapabilityProvenance(sink, entryPointFuncDecl, inferredAtom);
+ goto breakLabel;
+ }
+ }
+ }
+ }
+ }
+ }
+ breakLabel:;
+
}
// Given an entry point specified via API or command line options,
@@ -533,7 +564,7 @@ namespace Slang
auto entryPointName = entryPointReq->getName();
FuncDecl* entryPointFuncDecl = findFunctionDeclByName(translationUnit->getModule(), entryPointName, sink);
-
+
// Did we find a function declaration in our search?
if(!entryPointFuncDecl)
{
diff --git a/source/slang/slang-check-stmt.cpp b/source/slang/slang-check-stmt.cpp
index b2c58ccc7..1090655e5 100644
--- a/source/slang/slang-check-stmt.cpp
+++ b/source/slang/slang-check-stmt.cpp
@@ -52,7 +52,7 @@ namespace Slang
// local `struct` declaration, where it would have members
// that need to be recursively checked.
//
- ensureDeclBase(stmt->decl, DeclCheckState::Checked, this);
+ ensureDeclBase(stmt->decl, DeclCheckState::DefinitionChecked, this);
}
void SemanticsStmtVisitor::visitBlockStmt(BlockStmt* stmt)
@@ -207,7 +207,7 @@ namespace Slang
stmt->varDecl->type.type = m_astBuilder->getIntType();
addModifier(stmt->varDecl, m_astBuilder->create<ConstModifier>());
- stmt->varDecl->setCheckState(DeclCheckState::Checked);
+ stmt->varDecl->setCheckState(DeclCheckState::DefinitionChecked);
IntVal* rangeBeginVal = nullptr;
IntVal* rangeEndVal = nullptr;
@@ -280,7 +280,20 @@ namespace Slang
void SemanticsStmtVisitor::visitTargetCaseStmt(TargetCaseStmt* stmt)
{
auto switchStmt = FindOuterStmt<TargetSwitchStmt>();
+ CapabilitySet set((CapabilityName)stmt->capability);
+ if (getShared()->isInLanguageServer() && getShared()->getSession()->getCompletionRequestTokenName() == stmt->capabilityToken.getName())
+ {
+ getShared()->getLinkage()->contentAssistInfo.completionSuggestions.scopeKind = CompletionSuggestions::ScopeKind::Capabilities;
+ }
+ if (stmt->capabilityToken.getContentLength() != 0 &&
+ (set.getExpandedAtoms().getCount() != 1 || set.isInvalid() || set.isEmpty()))
+ {
+ getSink()->diagnose(
+ stmt->capabilityToken.loc,
+ Diagnostics::invalidTargetSwitchCase,
+ capabilityNameToString((CapabilityName)stmt->capability));
+ }
if (!switchStmt)
{
getSink()->diagnose(stmt, Diagnostics::caseOutsideSwitch);
@@ -648,7 +661,7 @@ namespace Slang
{
stmt->device = CheckExpr(stmt->device);
stmt->gridDims = CheckExpr(stmt->gridDims);
- ensureDeclBase(stmt->dispatchThreadID, DeclCheckState::Checked, this);
+ ensureDeclBase(stmt->dispatchThreadID, DeclCheckState::DefinitionChecked, this);
WithOuterStmt subContext(this, stmt);
stmt->kernelCall = subContext.CheckExpr(stmt->kernelCall);
return;
diff --git a/source/slang/slang-check.cpp b/source/slang/slang-check.cpp
index de86a333f..3f79b7f41 100644
--- a/source/slang/slang-check.cpp
+++ b/source/slang/slang-check.cpp
@@ -218,4 +218,5 @@ namespace Slang
{
return sv->getASTBuilder();
}
+
}
diff --git a/source/slang/slang-compiler.cpp b/source/slang/slang-compiler.cpp
index fbdb91ac7..f8ad95108 100644
--- a/source/slang/slang-compiler.cpp
+++ b/source/slang/slang-compiler.cpp
@@ -359,6 +359,26 @@ namespace Slang
return lookUp(UnownedTerminatedStringSlice(name));
}
+ List<CapabilityName> Profile::getCapabilityName()
+ {
+ List<CapabilityName> result;
+ switch (getVersion())
+ {
+ #define PROFILE_VERSION(TAG, NAME) case ProfileVersion::TAG: result.add(CapabilityName::TAG); break;
+ #include "slang-profile-defs.h"
+ default:
+ break;
+ }
+ switch (getStage())
+ {
+#define PROFILE_STAGE(TAG, NAME, VAL) case Stage::TAG: result.add(CapabilityName::NAME); break;
+#include "slang-profile-defs.h"
+ default:
+ break;
+ }
+ return result;
+ }
+
char const* Profile::getName()
{
switch( raw )
@@ -711,7 +731,7 @@ namespace Slang
// to clobber it with something to get a default.
//
// TODO: This is a huge hack...
- profile.setVersion(ProfileVersion::DX_5_0);
+ profile.setVersion(ProfileVersion::DX_5_1);
break;
}
@@ -755,9 +775,6 @@ namespace Slang
{
#define CASE(TAG, SUFFIX) case ProfileVersion::TAG: versionSuffix = #SUFFIX; break
CASE(DX_4_0, _4_0);
- CASE(DX_4_0_Level_9_0, _4_0_level_9_0);
- CASE(DX_4_0_Level_9_1, _4_0_level_9_1);
- CASE(DX_4_0_Level_9_3, _4_0_level_9_3);
CASE(DX_4_1, _4_1);
CASE(DX_5_0, _5_0);
CASE(DX_5_1, _5_1);
diff --git a/source/slang/slang-content-assist-info.h b/source/slang/slang-content-assist-info.h
index bd2c4b7d9..8f4105184 100644
--- a/source/slang/slang-content-assist-info.h
+++ b/source/slang/slang-content-assist-info.h
@@ -19,7 +19,8 @@ struct CompletionSuggestions
Stmt,
Expr,
Attribute,
- HLSLSemantics
+ HLSLSemantics,
+ Capabilities
};
ScopeKind scopeKind = ScopeKind::Invalid;
List<LookupResultItem> candidateItems;
diff --git a/source/slang/slang-diagnostic-defs.h b/source/slang/slang-diagnostic-defs.h
index d0e188292..786dbda35 100644
--- a/source/slang/slang-diagnostic-defs.h
+++ b/source/slang/slang-diagnostic-defs.h
@@ -297,8 +297,8 @@ DIAGNOSTIC(30041, Error, bitOperationNonIntegral, "bit operation: operand must b
DIAGNOSTIC(30043, Error, getStringHashRequiresStringLiteral, "getStringHash parameter can only accept a string literal")
DIAGNOSTIC(30047, Error, argumentExpectedLValue, "argument passed to parameter '$0' must be l-value.")
-DIAGNOSTIC(30049, Note, thisIsImmutableByDefault, "a 'this' parameter is an immutable parameter by default in Slang; apply the `[mutating]` attribute to the function declaration to opt in to a mutable `this`")
-DIAGNOSTIC(30050, Error, mutatingMethodOnImmutableValue, "mutating method '$0' cannot be called on an immutable value")
+DIAGNOSTIC(30049, Note, thisIsImmutableByDefault, "a 'this' parameter is an immutable parameter by default in Slang; apply the `[mutating]` attribute to the function declaration to opt in to a mutable `this`")
+DIAGNOSTIC(30050, Error, mutatingMethodOnImmutableValue, "mutating method '$0' cannot be called on an immutable value")
DIAGNOSTIC(30051, Error, invalidValueForArgument, "invalid value for argument '$0'")
DIAGNOSTIC(30052, Error, invalidSwizzleExpr, "invalid swizzle pattern '$0' on type '$1'")
@@ -312,9 +312,9 @@ DIAGNOSTIC(30058, Warning, danglingEqualityExpr, "result of '==' not used, did y
DIAGNOSTIC(30060, Error, expectedAType, "expected a type, got a '$0'")
DIAGNOSTIC(30061, Error, expectedANamespace, "expected a namespace, got a '$0'")
-DIAGNOSTIC(30062, Note, implicitCastUsedAsLValueRef, "argument was implicitly cast from '$0' to '$1', and Slang does not support using an implicit cast as an l-value with a reference")
-DIAGNOSTIC(30063, Note, implicitCastUsedAsLValueType, "argument was implicitly cast from '$0' to '$1', and Slang does not support using an implicit cast as an l-value with this type")
-DIAGNOSTIC(30064, Note, implicitCastUsedAsLValue, "argument was implicitly cast from '$0' to '$1', and Slang does not support using an implicit cast as an l-value for this usage")
+DIAGNOSTIC(30062, Note, implicitCastUsedAsLValueRef, "argument was implicitly cast from '$0' to '$1', and Slang does not support using an implicit cast as an l-value with a reference")
+DIAGNOSTIC(30063, Note, implicitCastUsedAsLValueType, "argument was implicitly cast from '$0' to '$1', and Slang does not support using an implicit cast as an l-value with this type")
+DIAGNOSTIC(30064, Note, implicitCastUsedAsLValue, "argument was implicitly cast from '$0' to '$1', and Slang does not support using an implicit cast as an l-value for this usage")
DIAGNOSTIC(30065, Error, newCanOnlyBeUsedToInitializeAClass, "`new` can only be used to initialize a class")
DIAGNOSTIC(30066, Error, classCanOnlyBeInitializedWithNew, "a class can only be initialized by a `new` clause")
@@ -333,7 +333,7 @@ DIAGNOSTIC(30300, Error, isOperatorValueMustBeInterfaceType, "'is'/'as' operator
DIAGNOSTIC(33070, Error, expectedFunction, "expected a function, got '$0'")
DIAGNOSTIC(33071, Error, expectedAStringLiteral, "expected a string literal")
-DIAGNOSTIC( -1, Note, noteExplicitConversionPossible, "explicit conversion from '$0' to '$1' is possible")
+DIAGNOSTIC(-1, Note, noteExplicitConversionPossible, "explicit conversion from '$0' to '$1' is possible")
DIAGNOSTIC(30080, Error, ambiguousConversion, "more than one implicit conversion exists from '$0' to '$1'")
DIAGNOSTIC(30081, Warning, unrecommendedImplicitConversion, "implicit conversion from '$0' to '$1' is not recommended")
DIAGNOSTIC(30082, Warning, implicitConversionToDouble, " implicit float-to-double conversion may cause unexpected performance issues, use explicit cast if intended.")
@@ -366,6 +366,20 @@ DIAGNOSTIC(30603, Error, invalidUseOfPrivateVisibility, "'$0' cannot have privat
DIAGNOSTIC(30604, Error, useOfLessVisibleType, "'$0' references less visible type '$1'.")
DIAGNOSTIC(36005, Error, invalidVisibilityModifierOnTypeOfDecl, "visibility modifier is not allowed on '$0'.")
+// Capability
+DIAGNOSTIC(36100, Error, conflictingCapabilityDueToUseOfDecl, "'$0' requires capability '$1' that is conflicting with the '$2''s current capability requirement '$3'.")
+DIAGNOSTIC(36101, Error, conflictingCapabilityDueToStatement, "statement requires capability '$0' that is conflicting with the '$1''s current capability requirement '$2'.")
+DIAGNOSTIC(36102, Error, conflictingCapabilityDueToStatementEnclosingFunc, "statement requires capability '$0' that is conflicting with the current function's capability requirement '$1'.")
+DIAGNOSTIC(36103, Error, missingCapabilityRequirementOnPublicDecl, "public symbol '$0' is missing capability requirement declaration.")
+DIAGNOSTIC(36104, Error, useOfUndeclaredCapability, "'$0' uses undeclared capability '$1'.")
+DIAGNOSTIC(36104, Error, useOfUndeclaredCapabilityOfInterfaceRequirement, "'$0' uses capability '$1' that is missing from the interface requirement.")
+DIAGNOSTIC(36105, Error, unknownCapability, "unknown capability name '$0'.")
+DIAGNOSTIC(36106, Error, expectCapability, "expect a capability name.")
+DIAGNOSTIC(36107, Error, entryPointUsesUnavailableCapability, "entrypoint '$0' requires capability '$1', which is incompatible with the current compilation target '$2'.")
+DIAGNOSTIC(36108, Error, declHasDependenciesNotDefinedOnTarget, "'$0' has dependencies that are not defined on the required target '$1'.")
+DIAGNOSTIC(36109, Error, invalidTargetSwitchCase, "'$0' cannot be used as a target_switch case.")
+DIAGNOSTIC(36110, Error, stageIsInCompatibleWithCapabilityDefinition, "'$0' is defined for stage '$1', which is incompatible with the declared capability set '$2'.")
+
// Attributes
DIAGNOSTIC(31000, Error, unknownAttributeName, "unknown attribute '$0'")
DIAGNOSTIC(31001, Error, attributeArgumentCountMismatch, "attribute '$0' expects $1 arguments ($2 provided)")
diff --git a/source/slang/slang-emit-glsl.cpp b/source/slang/slang-emit-glsl.cpp
index a0151133d..248b803f4 100644
--- a/source/slang/slang-emit-glsl.cpp
+++ b/source/slang/slang-emit-glsl.cpp
@@ -64,43 +64,14 @@ SlangResult GLSLSourceEmitter::init()
void GLSLSourceEmitter::_requireRayTracing()
{
- // There is more than one extension that provides ray-tracing capabilities,
- // and we need to pick which one to enable.
- //
- // By default, we will use the `GL_EXT_ray_tracing` extension, but if
- // the user has explicitly opted in to the `GL_NV_ray_tracing` extension
- // we will use that one instead.
- //
- if( getTargetCaps().implies(CapabilityAtom::_GL_NV_ray_tracing) )
- {
- m_glslExtensionTracker->requireExtension(UnownedStringSlice::fromLiteral("GL_NV_ray_tracing"));
- }
- else
- {
- m_glslExtensionTracker->requireExtension(UnownedStringSlice::fromLiteral("GL_EXT_ray_tracing"));
- m_glslExtensionTracker->requireSPIRVVersion(SemanticVersion(1, 4));
- }
-
+ m_glslExtensionTracker->requireExtension(UnownedStringSlice::fromLiteral("GL_EXT_ray_tracing"));
+ m_glslExtensionTracker->requireSPIRVVersion(SemanticVersion(1, 4));
m_glslExtensionTracker->requireVersion(ProfileVersion::GLSL_460);
}
void GLSLSourceEmitter::_requireFragmentShaderBarycentric()
{
- // There is more than one extension that provides barycentric coords in fragment shaders,
- // and we need to pick which one to enable.
- //
- // By default, we will use the `GL_EXT_fragment_shader_barycentric` extension, but if
- // the user has explicitly opted in to the `GL_NV_fragment_shader_barycentric` extension
- // we will use that one instead.
-
- if( getTargetCaps().implies(CapabilityAtom::_GL_NV_fragment_shader_barycentric) )
- {
- m_glslExtensionTracker->requireExtension(UnownedStringSlice::fromLiteral("GL_NV_fragment_shader_barycentric"));
- }
- else
- {
- m_glslExtensionTracker->requireExtension(UnownedStringSlice::fromLiteral("GL_EXT_fragment_shader_barycentric"));
- }
+ m_glslExtensionTracker->requireExtension(UnownedStringSlice::fromLiteral("GL_EXT_fragment_shader_barycentric"));
m_glslExtensionTracker->requireVersion(ProfileVersion::GLSL_450);
}
@@ -129,11 +100,6 @@ void GLSLSourceEmitter::_requireGLSLVersion(int version)
{
#define CASE(NUMBER) \
case NUMBER: _requireGLSLVersion(ProfileVersion::GLSL_##NUMBER); break
-
- CASE(110);
- CASE(120);
- CASE(130);
- CASE(140);
CASE(150);
CASE(330);
CASE(400);
@@ -684,14 +650,7 @@ bool GLSLSourceEmitter::_emitGLSLLayoutQualifierWithBindingKinds(LayoutResourceK
m_writer->emit("layout(push_constant)\n");
break;
case LayoutResourceKind::ShaderRecord:
- if (getTargetCaps().implies(CapabilityAtom::_GL_NV_ray_tracing))
- {
- m_writer->emit("layout(shaderRecordNV)\n");
- }
- else
- {
- m_writer->emit("layout(shaderRecordEXT)\n");
- }
+ m_writer->emit("layout(shaderRecordEXT)\n");
break;
}
@@ -1430,40 +1389,19 @@ void GLSLSourceEmitter::emitLayoutQualifiersImpl(IRVarLayout* layout)
case LayoutResourceKind::RayPayload:
{
- if( getTargetCaps().implies(CapabilityAtom::_GL_NV_ray_tracing) )
- {
- m_writer->emit("rayPayloadInNV ");
- }
- else
- {
- m_writer->emit("rayPayloadInEXT ");
- }
+ m_writer->emit("rayPayloadInEXT ");
}
break;
case LayoutResourceKind::CallablePayload:
{
- if( getTargetCaps().implies(CapabilityAtom::_GL_NV_ray_tracing) )
- {
- m_writer->emit("callableDataInNV ");
- }
- else
- {
- m_writer->emit("callableDataInEXT ");
- }
+ m_writer->emit("callableDataInEXT ");
}
break;
case LayoutResourceKind::HitAttributes:
{
- if( getTargetCaps().implies(CapabilityAtom::_GL_NV_ray_tracing) )
- {
- m_writer->emit("hitAttributeNV ");
- }
- else
- {
- m_writer->emit("hitAttributeEXT ");
- }
+ m_writer->emit("hitAttributeEXT ");
}
break;
@@ -2136,10 +2074,6 @@ static Index _getGLSLVersion(ProfileVersion profile)
switch (profile)
{
#define CASE(TAG, VALUE) case ProfileVersion::TAG: return VALUE;
- CASE(GLSL_110, 110);
- CASE(GLSL_120, 120);
- CASE(GLSL_130, 130);
- CASE(GLSL_140, 140);
CASE(GLSL_150, 150);
CASE(GLSL_330, 330);
CASE(GLSL_400, 400);
@@ -2479,45 +2413,8 @@ void GLSLSourceEmitter::emitSimpleTypeImpl(IRType* type)
{
case kIROp_RaytracingAccelerationStructureType:
{
- // Note: We have the problem here that we want to do `_requireRayTracing()`,
- // but just based on the use of a ray-tracing acceleration structure we
- // cannot know which extension the user means to use. The current options are:
- //
- // * GL_NV_ray_tracing
- // * GL_EXT_ray_tracing
- // * GL_EXT_ray_query
- //
- // The first two options there are basically equivalent extensions with
- // different GLSL syntax. We end up requiring the user to opt in to
- // `GL_NV_ray_tracing` using target capabilities, and will always default
- // to `GL_EXT_ray_tracing` otherwise.
- //
- if( getTargetCaps().implies(CapabilityAtom::_GL_NV_ray_tracing) )
- {
- // If the user has explicitly opted in to `GL_NV_ray_tracing`,
- // then we don't need to explicitly request the extentsion again.
- // We know that the acceleration structure type will translate
- // to the one from that extension:
- //
- _requireRayTracing();
- m_writer->emit("accelerationStructureNV");
- }
- else
- {
- // If the user does *not* opt into a specific extension, then we
- // have the problem that either `GL_EXT_ray_tracing` or `GL_EXT_ray-query`
- // could provide the `accelerationSturctureEXT` type, but there
- // can be drivers that provide only one and not the other.
- //
- // For now we will just kludge this by assuming that any driver
- // that supports one of these extensions supports the other.
- //
- // TODO: Revisit that decision once the driver landscape is more stable/clear.
- //
- _requireRayTracing();
-
- m_writer->emit("accelerationStructureEXT");
- }
+ _requireRayTracing();
+ m_writer->emit("accelerationStructureEXT");
break;
}
@@ -2580,14 +2477,7 @@ bool GLSLSourceEmitter::_maybeEmitInterpolationModifierText(IRInterpolationMode
if( stage == Stage::Fragment && isInput)
{
_requireFragmentShaderBarycentric();
- if (getTargetCaps().implies(CapabilityAtom::_GL_NV_fragment_shader_barycentric))
- {
- m_writer->emit("pervertexNV ");
- }
- else
- {
- m_writer->emit("pervertexEXT ");
- }
+ m_writer->emit("pervertexEXT ");
}
else
{
@@ -2694,6 +2584,7 @@ void GLSLSourceEmitter::emitVarDecorationsImpl(IRInst* varDecl)
for (auto decoration : varDecl->getDecorations())
{
UnownedStringSlice prefix;
+ UnownedStringSlice postfix = toSlice("EXT");
if (as<IRVulkanHitAttributesDecoration>(decoration))
{
prefix = toSlice("hitAttribute");
@@ -2713,6 +2604,7 @@ void GLSLSourceEmitter::emitVarDecorationsImpl(IRInst* varDecl)
break;
case kIROp_VulkanHitObjectAttributesDecoration:
prefix = toSlice("hitObjectAttribute");
+ postfix = toSlice("NV");
locationValue = getIntVal(decoration->getOperand(0));
break;
default:
@@ -2725,17 +2617,7 @@ void GLSLSourceEmitter::emitVarDecorationsImpl(IRInst* varDecl)
SLANG_ASSERT(prefix.getLength());
m_writer->emit(prefix);
-
- // Special case hitObjectAttribute as is only NV currently
- if (decoration->getOp() == kIROp_VulkanHitObjectAttributesDecoration ||
- getTargetCaps().implies(CapabilityAtom::_GL_NV_ray_tracing))
- {
- m_writer->emit(toSlice("NV"));
- }
- else
- {
- m_writer->emit(toSlice("EXT"));
- }
+ m_writer->emit(postfix);
m_writer->emit(toSlice("\n"));
// If we emit a location we are done.
diff --git a/source/slang/slang-emit-spirv.cpp b/source/slang/slang-emit-spirv.cpp
index 192a40f54..7ff6ac7d6 100644
--- a/source/slang/slang-emit-spirv.cpp
+++ b/source/slang/slang-emit-spirv.cpp
@@ -3359,18 +3359,9 @@ struct SPIRVEmitContext
}
else if (semanticName == "sv_barycentrics")
{
- if (m_targetRequest->getTargetCaps().implies(CapabilityAtom::_GL_NV_fragment_shader_barycentric))
- {
- requireSPIRVCapability(SpvCapabilityFragmentBarycentricNV);
- ensureExtensionDeclaration(UnownedStringSlice("SPV_NV_fragment_shader_barycentric"));
- return getBuiltinGlobalVar(inst->getFullType(), SpvBuiltInBaryCoordNV);
- }
- else
- {
- requireSPIRVCapability(SpvCapabilityFragmentBarycentricKHR);
- ensureExtensionDeclaration(UnownedStringSlice("SPV_KHR_fragment_shader_barycentric"));
- return getBuiltinGlobalVar(inst->getFullType(), SpvBuiltInBaryCoordKHR);
- }
+ requireSPIRVCapability(SpvCapabilityFragmentBarycentricKHR);
+ ensureExtensionDeclaration(UnownedStringSlice("SPV_KHR_fragment_shader_barycentric"));
+ return getBuiltinGlobalVar(inst->getFullType(), SpvBuiltInBaryCoordKHR);
// TODO: There is also the `gl_BaryCoordNoPerspNV` builtin, which
// we ought to use if the `noperspective` modifier has been
diff --git a/source/slang/slang-glsl-extension-tracker.h b/source/slang/slang-glsl-extension-tracker.h
index cee11cad5..966e1e927 100644
--- a/source/slang/slang-glsl-extension-tracker.h
+++ b/source/slang/slang-glsl-extension-tracker.h
@@ -39,7 +39,7 @@ protected:
uint32_t m_hasBaseTypeFlags = _getFlag(BaseType::Float) | _getFlag(BaseType::Int) | _getFlag(BaseType::UInt) | _getFlag(BaseType::Void) | _getFlag(BaseType::Bool);
- ProfileVersion m_profileVersion = ProfileVersion::GLSL_110;
+ ProfileVersion m_profileVersion = ProfileVersion::GLSL_150;
StringSlicePool m_extensionPool;
diff --git a/source/slang/slang-ir-glsl-legalize.cpp b/source/slang/slang-ir-glsl-legalize.cpp
index 0d279549c..dd165769c 100644
--- a/source/slang/slang-ir-glsl-legalize.cpp
+++ b/source/slang/slang-ir-glsl-legalize.cpp
@@ -518,6 +518,8 @@ GLSLSystemValueInfo* getGLSLSystemValueInfo(
GlobalVaryingDeclarator* declarator,
GLSLSystemValueInfo* inStorage)
{
+ SLANG_UNUSED(codeGenContext);
+
if(auto indicesSemantic = getMeshOutputIndicesSystemValueInfo(
context,
kind,
@@ -914,16 +916,8 @@ GLSLSystemValueInfo* getGLSLSystemValueInfo(
else if (semanticName == "sv_barycentrics")
{
context->requireGLSLVersion(ProfileVersion::GLSL_450);
- if (codeGenContext->getTargetCaps().implies(CapabilityAtom::_GL_NV_fragment_shader_barycentric))
- {
- context->requireGLSLExtension(UnownedStringSlice::fromLiteral("GL_NV_fragment_shader_barycentric"));
- name = "gl_BaryCoordNV";
- }
- else
- {
- context->requireGLSLExtension(UnownedStringSlice::fromLiteral("GL_EXT_fragment_shader_barycentric"));
- name = "gl_BaryCoordEXT";
- }
+ context->requireGLSLExtension(UnownedStringSlice::fromLiteral("GL_EXT_fragment_shader_barycentric"));
+ name = "gl_BaryCoordEXT";
// TODO: There is also the `gl_BaryCoordNoPerspNV` builtin, which
// we ought to use if the `noperspective` modifier has been
diff --git a/source/slang/slang-language-server-completion.cpp b/source/slang/slang-language-server-completion.cpp
index 6038a432a..b723e14b8 100644
--- a/source/slang/slang-language-server-completion.cpp
+++ b/source/slang/slang-language-server-completion.cpp
@@ -443,6 +443,11 @@ List<LanguageServerProtocol::CompletionItem> CompletionContext::collectMembersAn
linkage->contentAssistInfo.completionSuggestions.swizzleBaseType,
linkage->contentAssistInfo.completionSuggestions.elementCount);
}
+ else if (linkage->contentAssistInfo.completionSuggestions.scopeKind ==
+ CompletionSuggestions::ScopeKind::Capabilities)
+ {
+ return createCapabilityCandidates();
+ }
List<LanguageServerProtocol::CompletionItem> result;
bool useCommitChars = true;
bool addKeywords = false;
@@ -595,6 +600,24 @@ List<LanguageServerProtocol::CompletionItem> CompletionContext::collectMembersAn
return result;
}
+List<LanguageServerProtocol::CompletionItem> CompletionContext::createCapabilityCandidates()
+{
+ List<LanguageServerProtocol::CompletionItem> result;
+ List<UnownedStringSlice> names;
+ getCapabilityNames(names);
+ for (auto name : names.getArrayView(1, names.getCount()-1))
+ {
+ if (name.startsWith("_"))
+ continue;
+ LanguageServerProtocol::CompletionItem item;
+ item.data = 0;
+ item.kind = LanguageServerProtocol::kCompletionItemKindEnumMember;
+ item.label = name;
+ result.add(item);
+ }
+ return result;
+}
+
List<LanguageServerProtocol::CompletionItem> CompletionContext::createSwizzleCandidates(
Type* type, IntegerLiteralValue elementCount[2])
{
diff --git a/source/slang/slang-language-server-completion.h b/source/slang/slang-language-server-completion.h
index 5a09ba371..d3910bcfd 100644
--- a/source/slang/slang-language-server-completion.h
+++ b/source/slang/slang-language-server-completion.h
@@ -39,6 +39,7 @@ struct CompletionContext
List<LanguageServerProtocol::CompletionItem> collectMembersAndSymbols();
List<LanguageServerProtocol::CompletionItem> createSwizzleCandidates(
Type* baseType, IntegerLiteralValue elementCount[2]);
+ List<LanguageServerProtocol::CompletionItem> createCapabilityCandidates();
List<LanguageServerProtocol::CompletionItem> collectAttributes();
LanguageServerProtocol::CompletionItem generateGUIDCompletionItem();
List<LanguageServerProtocol::TextEditCompletionItem> gatherFileAndModuleCompletionItems(
diff --git a/source/slang/slang-language-server-inlay-hints.cpp b/source/slang/slang-language-server-inlay-hints.cpp
index 35b603b20..0eee347d1 100644
--- a/source/slang/slang-language-server-inlay-hints.cpp
+++ b/source/slang/slang-language-server-inlay-hints.cpp
@@ -18,7 +18,7 @@ List<LanguageServerProtocol::InlayHint> getInlayHints(
List<LanguageServerProtocol::InlayHint> result;
auto manager = linkage->getSourceManager();
auto docText = doc->getText().getUnownedSlice();
- iterateAST(fileName, manager, module->getModuleDecl(), [&](SyntaxNode* node)
+ iterateASTWithLanguageServerFilter(fileName, manager, module->getModuleDecl(), [&](SyntaxNode* node)
{
if (auto invokeExpr = as<InvokeExpr>(node))
{
diff --git a/source/slang/slang-language-server-semantic-tokens.cpp b/source/slang/slang-language-server-semantic-tokens.cpp
index ae10d62e8..3a40c8e92 100644
--- a/source/slang/slang-language-server-semantic-tokens.cpp
+++ b/source/slang/slang-language-server-semantic-tokens.cpp
@@ -126,7 +126,7 @@ List<SemanticToken> getSemanticTokens(Linkage* linkage, Module* module, UnownedS
}
maybeInsertToken(token);
};
- iterateAST(
+ iterateASTWithLanguageServerFilter(
fileName,
manager,
module->getModuleDecl(),
@@ -240,8 +240,34 @@ List<SemanticToken> getSemanticTokens(Linkage* linkage, Module* module, UnownedS
token.length = (int)attr->originalIdentifierToken.getContentLength();
token.type = SemanticTokenType::Type;
maybeInsertToken(token);
+
+ // Insert capability names as enum cases.
+ if (as<RequireCapabilityAttribute>(attr))
+ {
+ for (auto arg : attr->args)
+ {
+ if (auto varExpr = as<VarExpr>(arg))
+ {
+ if (varExpr->name)
+ {
+ SemanticToken capToken = _createSemanticToken(
+ manager, varExpr->loc, nullptr);
+ capToken.length = (int)varExpr->name->text.getLength();
+ capToken.type = SemanticTokenType::EnumMember;
+ maybeInsertToken(capToken);
+ }
+ }
+ }
+ }
}
}
+ else if (auto targetCase = as<TargetCaseStmt>(node))
+ {
+ SemanticToken token = _createSemanticToken(
+ manager, targetCase->capabilityToken.loc, targetCase->capabilityToken.getName());
+ token.type = SemanticTokenType::EnumMember;
+ maybeInsertToken(token);
+ }
else if (auto spirvAsmExpr = as<SPIRVAsmExpr>(node))
{
// Highlight opcodes and enums.
diff --git a/source/slang/slang-language-server.cpp b/source/slang/slang-language-server.cpp
index 76e21c2ca..0256464c9 100644
--- a/source/slang/slang-language-server.cpp
+++ b/source/slang/slang-language-server.cpp
@@ -127,6 +127,8 @@ SlangResult LanguageServer::parseNextMessage()
caps.completionProvider.triggerCharacters.add(".");
caps.completionProvider.triggerCharacters.add(":");
caps.completionProvider.triggerCharacters.add("[");
+ caps.completionProvider.triggerCharacters.add(" ");
+ caps.completionProvider.triggerCharacters.add("(");
caps.completionProvider.triggerCharacters.add("\"");
caps.completionProvider.triggerCharacters.add("/");
caps.completionProvider.resolveProvider = true;
@@ -985,6 +987,12 @@ SlangResult LanguageServer::completion(
context.line = utf8Line;
context.col = utf8Col;
context.commitCharacterBehavior = m_commitCharacterBehavior;
+ if (args.context.triggerKind == kCompletionTriggerKindTriggerCharacter &&
+ (args.context.triggerCharacter == " " || args.context.triggerCharacter == "[" || args.context.triggerCharacter == "("))
+ {
+ // Never use commit character if completion request is triggerred by these characters to prevent annoyance.
+ context.commitCharacterBehavior = CommitCharacterBehavior::Disabled;
+ }
if (SLANG_SUCCEEDED(context.tryCompleteInclude()))
{
diff --git a/source/slang/slang-options.cpp b/source/slang/slang-options.cpp
index 3edc7dfb5..ca8b60d31 100644
--- a/source/slang/slang-options.cpp
+++ b/source/slang/slang-options.cpp
@@ -3185,10 +3185,10 @@ SlangResult OptionsParser::parse(
m_sink = nullptr;
- if (requestSink->getErrorCount() > 0)
+ if (m_parseSink.getErrorCount() > 0)
{
// Put the errors in the diagnostic
- m_requestImpl->m_diagnosticOutput = requestSink->outputBuffer.produceString();
+ m_requestImpl->m_diagnosticOutput = m_parseSink.outputBuffer.produceString();
}
return res;
diff --git a/source/slang/slang-parser.cpp b/source/slang/slang-parser.cpp
index c5007569e..3940e59cf 100644
--- a/source/slang/slang-parser.cpp
+++ b/source/slang/slang-parser.cpp
@@ -4900,6 +4900,7 @@ namespace Slang
parser->sink->diagnose(caseName.loc, Diagnostics::unknownTargetName, caseName.getContent());
}
targetCase->capability = int32_t(cap);
+ targetCase->capabilityToken = caseName;
targetCase->loc = caseName.loc;
targetCase->body = bodyStmt;
stmt->targetCases.add(targetCase);
diff --git a/source/slang/slang-profile-defs.h b/source/slang/slang-profile-defs.h
index 2eaf6f897..9a9c128a5 100644
--- a/source/slang/slang-profile-defs.h
+++ b/source/slang/slang-profile-defs.h
@@ -88,12 +88,7 @@ PROFILE_FAMILY(DX)
PROFILE_FAMILY(GLSL)
// Profile versions
-
-
PROFILE_VERSION(DX_4_0, DX)
-PROFILE_VERSION(DX_4_0_Level_9_0, DX)
-PROFILE_VERSION(DX_4_0_Level_9_1, DX)
-PROFILE_VERSION(DX_4_0_Level_9_3, DX)
PROFILE_VERSION(DX_4_1, DX)
PROFILE_VERSION(DX_5_0, DX)
PROFILE_VERSION(DX_5_1, DX)
@@ -106,10 +101,6 @@ PROFILE_VERSION(DX_6_5, DX)
PROFILE_VERSION(DX_6_6, DX)
PROFILE_VERSION(DX_6_7, DX)
-PROFILE_VERSION(GLSL_110, GLSL)
-PROFILE_VERSION(GLSL_120, GLSL)
-PROFILE_VERSION(GLSL_130, GLSL)
-PROFILE_VERSION(GLSL_140, GLSL)
PROFILE_VERSION(GLSL_150, GLSL)
PROFILE_VERSION(GLSL_330, GLSL)
PROFILE_VERSION(GLSL_400, GLSL)
@@ -122,7 +113,6 @@ PROFILE_VERSION(GLSL_460, GLSL)
// Specific profiles
-
PROFILE(DX_Compute_4_0, cs_4_0, Compute, DX_4_0)
PROFILE(DX_Compute_4_1, cs_4_1, Compute, DX_4_1)
PROFILE(DX_Compute_5_0, cs_5_0, Compute, DX_5_0)
@@ -160,7 +150,6 @@ PROFILE(DX_Geometry_6_5, gs_6_5, Geometry, DX_6_5)
PROFILE(DX_Geometry_6_6, gs_6_6, Geometry, DX_6_6)
PROFILE(DX_Geometry_6_7, gs_6_7, Geometry, DX_6_7)
-
PROFILE(DX_Hull_5_0, hs_5_0, Hull, DX_5_0)
PROFILE(DX_Hull_5_1, hs_5_1, Hull, DX_5_1)
PROFILE(DX_Hull_6_0, hs_6_0, Hull, DX_6_0)
@@ -172,13 +161,9 @@ PROFILE(DX_Hull_6_5, hs_6_5, Hull, DX_6_5)
PROFILE(DX_Hull_6_6, hs_6_6, Hull, DX_6_6)
PROFILE(DX_Hull_6_7, hs_6_7, Hull, DX_6_7)
-
-PROFILE(DX_Fragment_4_0, ps_4_0, Fragment, DX_4_0)
-PROFILE(DX_Fragment_4_0_Level_9_0, ps_4_0_level_9_0, Fragment, DX_4_0_Level_9_0)
-PROFILE(DX_Fragment_4_0_Level_9_1, ps_4_0_level_9_1, Fragment, DX_4_0_Level_9_1)
-PROFILE(DX_Fragment_4_0_Level_9_3, ps_4_0_level_9_3, Fragment, DX_4_0_Level_9_3)
-PROFILE(DX_Fragment_4_1, ps_4_1, Fragment, DX_4_1)
-PROFILE(DX_Fragment_5_0, ps_5_0, Fragment, DX_5_0)
+PROFILE(DX_Fragment_4_0, ps_4_0, Fragment, DX_4_0)
+PROFILE(DX_Fragment_4_1, ps_4_1, Fragment, DX_4_1)
+PROFILE(DX_Fragment_5_0, ps_5_0, Fragment, DX_5_0)
PROFILE(DX_Fragment_5_1, ps_5_1, Fragment, DX_5_1)
PROFILE(DX_Fragment_6_0, ps_6_0, Fragment, DX_6_0)
PROFILE(DX_Fragment_6_1, ps_6_1, Fragment, DX_6_1)
@@ -189,11 +174,7 @@ PROFILE(DX_Fragment_6_5, ps_6_5, Fragment, DX_6_5)
PROFILE(DX_Fragment_6_6, ps_6_6, Fragment, DX_6_6)
PROFILE(DX_Fragment_6_7, ps_6_7, Fragment, DX_6_7)
-
PROFILE(DX_Vertex_4_0, vs_4_0, Vertex, DX_4_0)
-PROFILE(DX_Vertex_4_0_Level_9_0, vs_4_0_level_9_0, Vertex, DX_4_0_Level_9_0)
-PROFILE(DX_Vertex_4_0_Level_9_1, vs_4_0_level_9_1, Vertex, DX_4_0_Level_9_1)
-PROFILE(DX_Vertex_4_0_Level_9_3, vs_4_0_level_9_3, Vertex, DX_4_0_Level_9_3)
PROFILE(DX_Vertex_4_1, vs_4_1, Vertex, DX_4_1)
PROFILE(DX_Vertex_5_0, vs_5_0, Vertex, DX_5_0)
PROFILE(DX_Vertex_5_1, vs_5_1, Vertex, DX_5_1)
@@ -216,9 +197,6 @@ PROFILE(DX_Amplification_6_7, as_6_7, Amplification, DX_6_7)
// TODO: consider making `lib_*_*` alias these...
PROFILE(DX_None_4_0, sm_4_0, Unknown, DX_4_0)
-PROFILE(DX_None_4_0_Level_9_0, sm_4_0_level_9_0, Unknown, DX_4_0_Level_9_0)
-PROFILE(DX_None_4_0_Level_9_1, sm_4_0_level_9_1, Unknown, DX_4_0_Level_9_1)
-PROFILE(DX_None_4_0_Level_9_3, sm_4_0_level_9_3, Unknown, DX_4_0_Level_9_3)
PROFILE(DX_None_4_1, sm_4_1, Unknown, DX_4_1)
PROFILE(DX_None_5_0, sm_5_0, Unknown, DX_5_0)
PROFILE(DX_None_5_1, sm_5_1, Unknown, DX_5_1)
@@ -254,10 +232,6 @@ PROFILE_ALIAS(DX_None_6_7, DX_Lib_6_7, sm_6_7)
// Define all the GLSL profiles
-PROFILE(GLSL_None_110, glsl_110, Unknown, GLSL_110)
-PROFILE(GLSL_None_120, glsl_120, Unknown, GLSL_120)
-PROFILE(GLSL_None_130, glsl_130, Unknown, GLSL_130)
-PROFILE(GLSL_None_140, glsl_140, Unknown, GLSL_140)
PROFILE(GLSL_None_150, glsl_150, Unknown, GLSL_150)
PROFILE(GLSL_None_330, glsl_330, Unknown, GLSL_330)
PROFILE(GLSL_None_400, glsl_400, Unknown, GLSL_400)
@@ -271,10 +245,6 @@ PROFILE(GLSL_None_460, glsl_460, Unknown, GLSL_460)
#define P(UPPER, LOWER, VERSION) \
PROFILE(GLSL_##UPPER##_##VERSION, glsl_##LOWER##_##VERSION, UPPER, GLSL_##VERSION)
-P(Vertex, vertex, 110)
-P(Vertex, vertex, 120)
-P(Vertex, vertex, 130)
-P(Vertex, vertex, 140)
P(Vertex, vertex, 150)
P(Vertex, vertex, 330)
P(Vertex, vertex, 400)
@@ -284,10 +254,6 @@ P(Vertex, vertex, 430)
P(Vertex, vertex, 440)
P(Vertex, vertex, 450)
-P(Fragment, fragment, 110)
-P(Fragment, fragment, 120)
-P(Fragment, fragment, 130)
-P(Fragment, fragment, 140)
P(Fragment, fragment, 150)
P(Fragment, fragment, 330)
P(Fragment, fragment, 400)
diff --git a/source/slang/slang-profile.h b/source/slang/slang-profile.h
index a0284215f..a1c08fe6a 100644
--- a/source/slang/slang-profile.h
+++ b/source/slang/slang-profile.h
@@ -3,6 +3,7 @@
#include "../core/slang-basic.h"
#include "../../slang.h"
+#include "slang-capability.h"
namespace Slang
{
@@ -109,6 +110,8 @@ namespace Slang
static Profile lookUp(char const* name);
char const* getName();
+ List<CapabilityName> getCapabilityName();
+
RawVal raw = Unknown;
};
diff --git a/source/slang/slang-serialize-ast-type-info.h b/source/slang/slang-serialize-ast-type-info.h
index f5d636b01..f7d8cab08 100644
--- a/source/slang/slang-serialize-ast-type-info.h
+++ b/source/slang/slang-serialize-ast-type-info.h
@@ -78,6 +78,50 @@ struct PtrSerialTypeInfo<T, std::enable_if_t<std::is_base_of_v<Val, T>>>
template <typename T>
struct SerialTypeInfo<DeclRef<T>> : public SerialTypeInfo<DeclRefBase*> {};
+template<>
+struct SerialTypeInfo<CapabilitySet>
+{
+ typedef CapabilitySet NativeType;
+ typedef SerialIndex SerialType;
+ enum { SerialAlignment = SLANG_ALIGN_OF(SerialType) };
+ static void toSerial(SerialWriter* writer, const void* native, void* serial)
+ {
+ auto& src = *(const NativeType*)native;
+ auto& dst = *(SerialType*)serial;
+
+ dst = writer->addArray(src.getExpandedAtoms().getBuffer(), src.getExpandedAtoms().getCount());
+ }
+ static void toNative(SerialReader* reader, const void* serial, void* native)
+ {
+ auto& dst = *(NativeType*)native;
+ auto& src = *(const SerialType*)serial;
+
+ reader->getArray(src, dst.getExpandedAtoms());
+ }
+};
+
+template<>
+struct SerialTypeInfo<CapabilityConjunctionSet>
+{
+ typedef CapabilityConjunctionSet NativeType;
+ typedef SerialIndex SerialType;
+ enum { SerialAlignment = SLANG_ALIGN_OF(SerialType) };
+ static void toSerial(SerialWriter* writer, const void* native, void* serial)
+ {
+ auto& src = *(const NativeType*)native;
+ auto& dst = *(SerialType*)serial;
+
+ dst = writer->addArray(src.getExpandedAtoms().getBuffer(), src.getExpandedAtoms().getCount());
+ }
+ static void toNative(SerialReader* reader, const void* serial, void* native)
+ {
+ auto& dst = *(NativeType*)native;
+ auto& src = *(const SerialType*)serial;
+
+ reader->getArray(src, dst.getExpandedAtoms());
+ }
+};
+
// ValNodeOperand
template <>
struct SerialTypeInfo<ValNodeOperand>
diff --git a/source/slang/slang.cpp b/source/slang/slang.cpp
index b95b21bb5..e99f94484 100644
--- a/source/slang/slang.cpp
+++ b/source/slang/slang.cpp
@@ -858,7 +858,7 @@ Profile getEffectiveProfile(EntryPoint* entryPoint, TargetRequest* target)
case CodeGenTarget::SPIRVAssembly:
if(targetProfile.getFamily() != ProfileFamily::GLSL)
{
- targetProfile.setVersion(ProfileVersion::GLSL_110);
+ targetProfile.setVersion(ProfileVersion::GLSL_150);
}
break;
@@ -869,7 +869,7 @@ Profile getEffectiveProfile(EntryPoint* entryPoint, TargetRequest* target)
case CodeGenTarget::DXILAssembly:
if(targetProfile.getFamily() != ProfileFamily::DX)
{
- targetProfile.setVersion(ProfileVersion::DX_4_0);
+ targetProfile.setVersion(ProfileVersion::DX_5_1);
}
break;
}
@@ -1608,6 +1608,8 @@ CapabilitySet TargetRequest::getTargetCaps()
break;
}
+ CapabilitySet targetCap = CapabilitySet(atoms);
+
CapabilitySet latestSpirvCapSet = CapabilitySet(CapabilityName::spirv_latest);
CapabilityName latestSpirvAtom = (CapabilityName)latestSpirvCapSet.getExpandedAtoms()[0].getExpandedAtoms().getLast();
for (auto atom : rawCapabilities)
@@ -1623,7 +1625,11 @@ CapabilitySet TargetRequest::getTargetCaps()
atom = (CapabilityName)((Int)CapabilityName::glsl_spirv_1_0 + ((Int)atom - (Int)CapabilityName::spirv_1_0));
}
}
- atoms.add(atom);
+ if (!targetCap.isIncompatibleWith(atom))
+ {
+ // Only add atoms that are compatible with the current target.
+ atoms.add(atom);
+ }
}
cookedCapabilities = CapabilitySet(atoms);
diff --git a/tests/bugs/vk-structured-buffer-load.hlsl.glsl b/tests/bugs/vk-structured-buffer-load.hlsl.glsl
index 35fad779b..93181d0a3 100644
--- a/tests/bugs/vk-structured-buffer-load.hlsl.glsl
+++ b/tests/bugs/vk-structured-buffer-load.hlsl.glsl
@@ -1,12 +1,10 @@
#version 460
-#extension GL_NV_ray_tracing : require
+#extension GL_EXT_ray_tracing : require
layout(row_major) uniform;
layout(row_major) buffer;
-
layout(std430, binding = 1) readonly buffer StructuredBuffer_float_t_0 {
float _data[];
} gParamBlock_sbuf_0;
-
float rcp_0(float x_0)
{
return 1.0 / x_0;
@@ -17,24 +15,22 @@ struct RayHitInfoPacked_0
vec4 PackedHitInfoA_0;
};
-rayPayloadInNV RayHitInfoPacked_0 _S1;
+rayPayloadInEXT RayHitInfoPacked_0 _S1;
struct BuiltInTriangleIntersectionAttributes_0
{
vec2 barycentrics_0;
};
-hitAttributeNV BuiltInTriangleIntersectionAttributes_0 _S2;
+hitAttributeEXT BuiltInTriangleIntersectionAttributes_0 _S2;
void main()
{
- float HitT_0 = ((gl_RayTmaxNV));
+ float HitT_0 = ((gl_RayTmaxEXT));
_S1.PackedHitInfoA_0[0] = HitT_0;
float offsfloat_0 = gParamBlock_sbuf_0._data[0];
-
uint use_rcp_0 = 0U | uint(HitT_0 > 0.0);
-
if(use_rcp_0 != 0U)
{
_S1.PackedHitInfoA_0[1] = rcp_0(offsfloat_0);
diff --git a/tests/cross-compile/barycentrics-nv.slang b/tests/cross-compile/barycentrics-nv.slang
index fb0272679..60070f913 100644
--- a/tests/cross-compile/barycentrics-nv.slang
+++ b/tests/cross-compile/barycentrics-nv.slang
@@ -1,4 +1,9 @@
-//TEST:CROSS_COMPILE: -target spirv-assembly -capability GL_NV_fragment_shader_barycentric -entry main -stage fragment
+//TEST:SIMPLE(filecheck=CHECK): -target spirv-assembly -capability GL_NV_fragment_shader_barycentric -entry main -stage fragment
+
+// CHECK: OpCapability FragmentBarycentricKHR
+// CHECK: OpDecorate [[NAME:%[A-Za-z0-9_]+]] BuiltIn BaryCoordKHR
+// CHECK: [[NAME]] = OpVariable {{.*}} Input
+// CHECK: {{.*}} = OpLoad %v3float [[NAME]]
float4 main(float3 bary : SV_Barycentrics) : SV_Target
{
diff --git a/tests/cross-compile/barycentrics-nv.slang.glsl b/tests/cross-compile/barycentrics-nv.slang.glsl
deleted file mode 100644
index 583310125..000000000
--- a/tests/cross-compile/barycentrics-nv.slang.glsl
+++ /dev/null
@@ -1,12 +0,0 @@
-#version 450
-
-#extension GL_NV_fragment_shader_barycentric : enable
-
-layout(location = 0)
-out vec4 main_0;
-
-void main()
-{
- main_0 = vec4(gl_BaryCoordNV, float(0));
- return;
-}
diff --git a/tests/diagnostics/discard-in-compute.slang b/tests/diagnostics/discard-in-compute.slang
new file mode 100644
index 000000000..e530881bd
--- /dev/null
+++ b/tests/diagnostics/discard-in-compute.slang
@@ -0,0 +1,13 @@
+//TEST:SIMPLE(filecheck=CHECK): -target hlsl -entry main -profile cs_6_1
+void test()
+{
+ discard; // This should lead to `test` having `fragment` capability requirement.
+}
+
+[shader("compute")]
+[numthreads(1,1,1)]
+void main()
+{
+ // CHECK: error 36107
+ test(); // compute shader cannot call `test` that require capabiltiy `fragment`.
+}
diff --git a/tests/language-feature/capability/capability1.slang b/tests/language-feature/capability/capability1.slang
new file mode 100644
index 000000000..bccccb964
--- /dev/null
+++ b/tests/language-feature/capability/capability1.slang
@@ -0,0 +1,28 @@
+//TEST:SIMPLE(filecheck=CHECK): -target spirv -emit-spirv-directly -entry main2 -stage compute
+
+[require(spvShaderClockKHR)]
+void leafFunc1() {}
+
+[require(spvShaderNonUniform)]
+void leafFunc2() {}
+
+void caller()
+{
+ leafFunc1();
+ leafFunc2();
+}
+
+[require(spirv, shaderclock)]
+// CHECK: ([[# @LINE+1]]): error 36104:
+void main1()
+{
+ caller(); // Error, shaderclock does not imply spvShaderNonUniform.
+}
+
+
+[require(spirv, shaderclock)]
+void main2()
+{
+ // CHECK-NOT: error
+ leafFunc1(); // OK, shaderclock implies spvShaderClockKHR.
+}
diff --git a/tests/language-feature/capability/capability2.slang b/tests/language-feature/capability/capability2.slang
new file mode 100644
index 000000000..743f998cf
--- /dev/null
+++ b/tests/language-feature/capability/capability2.slang
@@ -0,0 +1,61 @@
+//TEST:SIMPLE(filecheck=CHECK): -target spirv -emit-spirv-directly -entry main -stage compute
+module test;
+
+[require(spvAtomicFloat16AddEXT)]
+interface IFoo
+{
+ [require(spvRayQueryKHR)]
+ void method1();
+
+ void method2();
+}
+
+[require(spvGroupNonUniformArithmetic)]
+void useNonUniformArithmetic()
+{}
+
+[require(spvRayQueryKHR)]
+void useRayQueryKHR()
+{}
+
+[require(spvAtomicFloat16AddEXT)]
+void useAtomicFloat16()
+{}
+
+// This should be OK, uses nothing past what is declared in the interface.
+struct Impl1 : IFoo
+{
+ void method1()
+ {
+ useAtomicFloat16();
+ useRayQueryKHR();
+ }
+
+ void method2()
+ {
+ useAtomicFloat16();
+ }
+}
+
+// CHECK-NOT: error 361
+
+struct Impl2 : IFoo
+{
+ // CHECK: ([[# @LINE+1]]): error 36104: {{.*}}spvGroupNonUniformArithmetic
+ void method1()
+ {
+ useRayQueryKHR(); // OK.
+ useNonUniformArithmetic(); // error.
+ }
+ // CHECK-NOT: error 361
+
+ // CHECK: ([[# @LINE+1]]): error 36104: {{.*}}spvGroupNonUniformArithmetic
+ void method2()
+ {
+ useAtomicFloat16();
+ useNonUniformArithmetic(); // error.
+ }
+}
+
+void main()
+{}
diff --git a/tests/pipeline/rasterization/get-attribute-at-vertex-nv.slang.glsl b/tests/pipeline/rasterization/get-attribute-at-vertex-nv.slang.glsl
index 1897c6467..820918d8b 100644
--- a/tests/pipeline/rasterization/get-attribute-at-vertex-nv.slang.glsl
+++ b/tests/pipeline/rasterization/get-attribute-at-vertex-nv.slang.glsl
@@ -2,11 +2,11 @@
//TEST_IGNORE_FILE:
#version 450
-#extension GL_NV_fragment_shader_barycentric : require
+#extension GL_EXT_fragment_shader_barycentric : require
layout(row_major) uniform;
layout(row_major) buffer;
-pervertexNV layout(location = 0)
+pervertexEXT layout(location = 0)
in vec4 color_0[3];
layout(location = 0)
@@ -14,6 +14,7 @@ out vec4 result_0;
void main()
{
- result_0 = gl_BaryCoordNV.x * ((color_0)[(0U)]) + gl_BaryCoordNV.y * ((color_0)[(1U)]) + gl_BaryCoordNV.z * ((color_0)[(2U)]);
+ result_0 = gl_BaryCoordEXT.x * ((color_0)[(0U)]) + gl_BaryCoordEXT.y * ((color_0)[(1U)]) + gl_BaryCoordEXT.z * ((color_0)[(2U)]);
return;
}
+
diff --git a/tests/pipeline/rasterization/get-attribute-at-vertex-nv.slang.hlsl b/tests/pipeline/rasterization/get-attribute-at-vertex-nv.slang.hlsl
index ce23492c9..a6b45eab4 100644
--- a/tests/pipeline/rasterization/get-attribute-at-vertex-nv.slang.hlsl
+++ b/tests/pipeline/rasterization/get-attribute-at-vertex-nv.slang.hlsl
@@ -2,6 +2,8 @@
//TEST_IGNORE_FILE:
+#pragma warning(disable: 3557)
+
[shader("pixel")]
void main(
nointerpolation vector<float,4> color_0 : COLOR,
diff --git a/tests/pipeline/rasterization/get-attribute-at-vertex.slang.hlsl b/tests/pipeline/rasterization/get-attribute-at-vertex.slang.hlsl
index ce23492c9..9322964d5 100644
--- a/tests/pipeline/rasterization/get-attribute-at-vertex.slang.hlsl
+++ b/tests/pipeline/rasterization/get-attribute-at-vertex.slang.hlsl
@@ -11,4 +11,4 @@ void main(
result_0 = bary_0.x * GetAttributeAtVertex(color_0, 0U)
+ bary_0.y * GetAttributeAtVertex(color_0, 1U)
+ bary_0.z * GetAttributeAtVertex(color_0, 2U);
-}
+} \ No newline at end of file
diff --git a/tests/vkray/callable-caller.slang b/tests/vkray/callable-caller.slang
index 64311988a..6a0c85c38 100644
--- a/tests/vkray/callable-caller.slang
+++ b/tests/vkray/callable-caller.slang
@@ -1,6 +1,8 @@
// callable-caller.slang
-//TEST:CROSS_COMPILE: -profile glsl_460 -capability GL_NV_ray_tracing -stage raygeneration -entry main -target spirv-assembly
+//TEST:SIMPLE(filecheck=CHECK): -profile glsl_460 -capability GL_NV_ray_tracing -stage raygeneration -entry main -target spirv-assembly
+//TEST:SIMPLE(filecheck=CHECK): -profile glsl_460 -capability GL_NV_ray_tracing -stage raygeneration -entry main -target spirv-assembly -emit-spirv-directly
+
import callable_shared;
@@ -16,7 +18,7 @@ void main()
MaterialPayload payload;
payload.albedo = 0;
payload.uv = float2(DispatchRaysIndex().xy) / float2(DispatchRaysDimensions().xy);
-
+ // CHECK: OpExecuteCallable
CallShader(shaderIndex, payload);
gImage[DispatchRaysIndex().xy] = payload.albedo;
diff --git a/tests/vkray/callable-caller.slang.glsl b/tests/vkray/callable-caller.slang.glsl
deleted file mode 100644
index a42e6eaf3..000000000
--- a/tests/vkray/callable-caller.slang.glsl
+++ /dev/null
@@ -1,49 +0,0 @@
-#version 460
-#extension GL_NV_ray_tracing : require
-layout(row_major) uniform;
-layout(row_major) buffer;
-struct MaterialPayload_0
-{
- vec4 albedo_0;
- vec2 uv_0;
-};
-
-layout(location = 0)
-callableDataNV
-MaterialPayload_0 p_0;
-
-struct SLANG_ParameterGroup_C_0
-{
- uint shaderIndex_0;
-};
-
-layout(binding = 0)
-layout(std140) uniform _S1
-{
- uint shaderIndex_0;
-} C_0;
-void CallShader_0(uint shaderIndex_1, inout MaterialPayload_0 payload_0)
-{
- p_0 = payload_0;
- executeCallableNV(shaderIndex_1, (0));
- payload_0 = p_0;
- return;
-}
-
-layout(rgba32f)
-layout(binding = 1)
-uniform image2D gImage_0;
-
-void main()
-{
- MaterialPayload_0 payload_1;
- payload_1.albedo_0 = vec4(0.0);
- uvec3 _S2 = ((gl_LaunchIDNV));
- vec2 _S3 = vec2(_S2.xy);
- uvec3 _S4 = ((gl_LaunchSizeNV));
- payload_1.uv_0 = _S3 / vec2(_S4.xy);
- CallShader_0(C_0.shaderIndex_0, payload_1);
- uvec3 _S5 = ((gl_LaunchIDNV));
- imageStore((gImage_0), ivec2((_S5.xy)), payload_1.albedo_0);
- return;
-}
diff --git a/tests/vkray/miss.slang.glsl b/tests/vkray/miss.slang.glsl
index df7647411..1bc6af5b3 100644
--- a/tests/vkray/miss.slang.glsl
+++ b/tests/vkray/miss.slang.glsl
@@ -1,17 +1,7 @@
//TEST_IGNORE_FILE:
#version 460
-#if USE_NV_RT
-#extension GL_NV_ray_tracing : require
-#define callableDataInEXT callableDataInNV
-#define gl_LaunchIDEXT gl_LaunchIDNV
-#define hitAttributeEXT hitAttributeNV
-#define ignoreIntersectionEXT ignoreIntersectionNV
-#define rayPayloadInEXT rayPayloadInNV
-#define terminateRayEXT terminateRayNV
-#else
#extension GL_EXT_ray_tracing : require
-#endif
struct ShadowRay_0
{
diff --git a/tools/slang-test/slang-test-main.cpp b/tools/slang-test/slang-test-main.cpp
index 5de9a8dc6..c93bebf33 100644
--- a/tools/slang-test/slang-test-main.cpp
+++ b/tools/slang-test/slang-test-main.cpp
@@ -2698,6 +2698,8 @@ TestResult generateActualOutput(TestContext* const context, const TestInput& inp
return TestResult::Pass;
}
+ actualOutput = getOutput(actualExeRes);
+
// Always fail if the compilation produced a failure, just
// to catch situations where, e.g., command-line options parsing
// caused the same error in both the Slang and glslang cases.
@@ -2707,7 +2709,6 @@ TestResult generateActualOutput(TestContext* const context, const TestInput& inp
return TestResult::Fail;
}
- actualOutput = getOutput(actualExeRes);
return TestResult::Pass;
}
diff --git a/tools/slang-test/slangc-tool.cpp b/tools/slang-test/slangc-tool.cpp
index a62ea6975..4c2a3244c 100644
--- a/tools/slang-test/slangc-tool.cpp
+++ b/tools/slang-test/slangc-tool.cpp
@@ -49,7 +49,7 @@ SlangResult SlangCTool::innerMain(StdWriters* stdWriters, slang::IGlobalSession*
const SlangResult res = compileRequest->processCommandLineArguments(&argv[1], argc - 1);
if (SLANG_FAILED(res))
{
- // TODO: print usage message
+ StdWriters::getOut().print("%s", compileRequest->getDiagnosticOutput());
return res;
}
}