summaryrefslogtreecommitdiffstats
path: root/source
diff options
context:
space:
mode:
authorArielG-NV <159081215+ArielG-NV@users.noreply.github.com>2024-08-28 15:06:23 -0400
committerGitHub <noreply@github.com>2024-08-28 12:06:23 -0700
commit65240d074b4ddec55e56962ebf8de46207bcf5fa (patch)
treefa887d3de8ab55c7498eae2d5bf61966818135a1 /source
parent638e5fb000d4e242a91e8b653da4a72daec0efda (diff)
Allow capabilities to be used with `[shader("...")]` (#4928)
* Allow capabilities to be used with `[shader("...")]` Fixes: #4917 Changes: 1. Allow using capabilities instead of `Stage`s with `EntryPointAttribute`. 2. When resolving capabilities for an entrypoint+profile (per entrypoint) in `resolveStageOfProfileWithEntryPoint` add our `EntryPointAttribute` and resolved capability 3. Added tests and some capabilities related clean-up * fix a warning made by a mistake in syntax * change fineStageByName to assume it is passed a stage without a '_' * test with and without prefix '_' * cleanup some profiles and reprisentation to work better with 'Stage' and 'Profile' This use case is why we need to clean all profile-usage into `CapabilityName`s directly. * change how we compare * only change profiles * let all capabilities be resolved by 'shader' profile for now * fix warning checks I accidently broke * meshshading_internal to _meshshading --------- Co-authored-by: Yong He <yonghe@outlook.com>
Diffstat (limited to 'source')
-rw-r--r--source/slang/slang-ast-modifier.h16
-rw-r--r--source/slang/slang-capabilities.capdef98
-rw-r--r--source/slang/slang-check-decl.cpp19
-rw-r--r--source/slang/slang-check-modifier.cpp42
-rw-r--r--source/slang/slang-check-shader.cpp17
-rw-r--r--source/slang/slang-compiler.cpp24
-rw-r--r--source/slang/slang-diagnostic-defs.h7
-rw-r--r--source/slang/slang-profile.h4
-rw-r--r--source/slang/slang.cpp2
9 files changed, 134 insertions, 95 deletions
diff --git a/source/slang/slang-ast-modifier.h b/source/slang/slang-ast-modifier.h
index 14e945e25..8c9cb484f 100644
--- a/source/slang/slang-ast-modifier.h
+++ b/source/slang/slang-ast-modifier.h
@@ -924,17 +924,15 @@ class InstanceAttribute : public Attribute
int32_t value;
};
-// A `[shader("stageName")]` attribute, which marks an entry point
-// to be compiled, and specifies the stage for that entry point
-class EntryPointAttribute : public Attribute
+// A `[shader("stageName")]`/`[shader("capability")]` attribute which
+// marks an entry point for compiling. This attribute also specifies
+// the 'capabilities' implicitly supported by an entry point
+class EntryPointAttribute : public Attribute
{
SLANG_AST_CLASS(EntryPointAttribute)
-
- // The resolved stage that the entry point is targetting.
- //
- // TODO: This should be an accessor that uses the
- // ordinary `args` list, rather than side data.
- Stage stage;
+
+ // The resolved capailities for our entry point.
+ CapabilitySet capabilitySet;
};
// A `[__vulkanRayPayload(location)]` attribute, which is used in the
diff --git a/source/slang/slang-capabilities.capdef b/source/slang/slang-capabilities.capdef
index 220e4a424..de44c98a5 100644
--- a/source/slang/slang-capabilities.capdef
+++ b/source/slang/slang-capabilities.capdef
@@ -172,40 +172,14 @@ def compute : stage;
def hull : stage;
def domain : stage;
def geometry : stage;
-def raygen : stage;
-def intersection : stage;
-def anyhit : stage;
-def closesthit: stage;
-def miss : stage;
-def mesh : stage;
-def amplification : stage;
-def callable : stage;
-alias any_stage = vertex | fragment | compute | hull | domain | geometry
- | raygen | intersection | anyhit | closesthit | miss | mesh
- | amplification | callable
- ;
-
-// shader stage alias's
-alias pixel = fragment;
-alias raygeneration = raygen;
-alias tesscontrol = hull;
-alias tesseval = domain;
-alias amplification_mesh = amplification | mesh;
-alias raytracing_stages = raygen | intersection | anyhit | closesthit | miss | callable;
-alias anyhit_closesthit = anyhit | closesthit;
-alias raygen_closesthit_miss = raygen | closesthit | miss;
-alias anyhit_closesthit_intersection = anyhit | closesthit | intersection;
-alias anyhit_closesthit_intersection_miss = anyhit | closesthit | intersection | miss;
-alias raygen_closesthit_miss_callable = raygen | closesthit | miss | callable;
-alias compute_tesscontrol_tesseval = compute | tesscontrol | tesseval;
-alias compute_fragment = compute | fragment;
-alias compute_fragment_geometry_vertex = compute | fragment | geometry | vertex;
-alias domain_hull = domain | hull;
-alias raytracingstages_fragment = raytracing_stages | fragment;
-alias raytracingstages_compute = raytracing_stages | compute;
-alias raytracingstages_compute_amplification_mesh = raytracingstages_compute | amplification_mesh;
-alias raytracingstages_compute_fragment = raytracing_stages | compute_fragment;
-alias raytracingstages_compute_fragment_geometry_vertex = raytracing_stages | compute_fragment_geometry_vertex;
+def _raygen : stage;
+def _intersection : stage;
+def _anyhit : stage;
+def _closesthit: stage;
+def _callable : stage;
+def _miss : stage;
+def _mesh : stage;
+def _amplification : stage;
// SPIRV extensions.
@@ -403,7 +377,7 @@ alias GL_NV_shader_invocation_reorder = _GL_NV_shader_invocation_reorder + _GL_E
alias GL_NV_shader_subgroup_partitioned = _GL_NV_shader_subgroup_partitioned | spvGroupNonUniformPartitionedNV;
alias GL_NV_shader_texture_footprint = _GL_NV_shader_texture_footprint | spvImageFootprintNV;
-// Define feature names
+// Define feature names not reliant on shader stages
alias nvapi = hlsl_nvapi;
alias raytracing = GL_EXT_ray_tracing | _sm_6_3 | cuda;
@@ -413,20 +387,66 @@ alias rayquery = GL_EXT_ray_query | _sm_6_3;
alias raytracing_motionblur = raytracing + motionblur | cuda;
alias ser_motion = ser + motionblur;
alias shaderclock = GL_EXT_shader_realtime_clock | hlsl_nvapi | cpp | cuda;
-alias meshshading_internal = GL_EXT_mesh_shader | _sm_6_5 | metal;
-alias meshshading = amplification_mesh + meshshading_internal;
+alias _meshshading = GL_EXT_mesh_shader | _sm_6_5 | metal;
alias fragmentshaderinterlock = _GL_ARB_fragment_shader_interlock | hlsl_nvapi | spvFragmentShaderPixelInterlockEXT;
alias atomic64 = GL_EXT_shader_atomic_int64 | _sm_6_6 | cpp | cuda;
alias atomicfloat = GL_EXT_shader_atomic_float | _sm_6_0 + hlsl_nvapi | cpp | cuda;
alias atomicfloat2 = GL_EXT_shader_atomic_float2 | _sm_6_6 + hlsl_nvapi | cpp | cuda;
alias fragmentshaderbarycentric = GL_EXT_fragment_shader_barycentric | _sm_6_1;
alias shadermemorycontrol = glsl | _spirv_1_0 | _sm_5_0;
-alias shadermemorycontrol_compute = raytracingstages_compute + shadermemorycontrol;
-alias subpass = fragment + _sm_6_0 | fragment + any_gfx_target;
alias waveprefix = _sm_6_5 | _cuda_sm_7_0 | GL_KHR_shader_subgroup_arithmetic;
alias bufferreference = GL_EXT_buffer_reference;
alias bufferreference_int64 = bufferreference + GL_EXT_shader_explicit_arithmetic_types_int64;
+// non-internal shader stages
+
+alias pixel = fragment;
+
+alias tesscontrol = hull;
+alias tesseval = domain;
+
+alias _raygeneration = _raygen;
+alias raygen = _raygen + raytracing;
+alias raygeneration = _raygeneration + raytracing;
+alias intersection = _intersection + raytracing;
+alias anyhit = _anyhit + raytracing;
+alias closesthit = _closesthit + raytracing;
+alias callable = _callable + raytracing;
+alias miss = _miss + raytracing;
+
+alias mesh = _mesh + _meshshading;
+alias amplification = _amplification + _meshshading;
+
+// shader stage groups
+
+alias any_stage = vertex | fragment | compute | hull | domain | geometry
+ | raygen | intersection | anyhit | closesthit | miss | mesh
+ | amplification | callable
+ ;
+alias amplification_mesh = amplification | mesh;
+alias raytracing_stages = raygen | intersection | anyhit | closesthit | miss | callable;
+alias anyhit_closesthit = anyhit | closesthit;
+alias raygen_closesthit_miss = raygen | closesthit | miss;
+alias anyhit_closesthit_intersection = anyhit | closesthit | intersection;
+alias anyhit_closesthit_intersection_miss = anyhit | closesthit | intersection | miss;
+alias raygen_closesthit_miss_callable = raygen | closesthit | miss | callable;
+alias compute_tesscontrol_tesseval = compute | tesscontrol | tesseval;
+alias compute_fragment = compute | fragment;
+alias compute_fragment_geometry_vertex = compute | fragment | geometry | vertex;
+alias domain_hull = domain | hull;
+alias raytracingstages_fragment = raytracing_stages | fragment;
+alias raytracingstages_compute = raytracing_stages | compute;
+alias raytracingstages_compute_amplification_mesh = raytracingstages_compute | amplification_mesh;
+alias raytracingstages_compute_fragment = raytracing_stages | compute_fragment;
+alias raytracingstages_compute_fragment_geometry_vertex = raytracing_stages | compute_fragment_geometry_vertex;
+
+// Define feature names reliant on shader stages
+
+alias meshshading = amplification_mesh + _meshshading;
+
+alias shadermemorycontrol_compute = raytracingstages_compute + shadermemorycontrol;
+alias subpass = fragment + _sm_6_0 | fragment + any_gfx_target;
+
// Define what each shader model means on different targets.
// spirv profile
diff --git a/source/slang/slang-check-decl.cpp b/source/slang/slang-check-decl.cpp
index 3bd6bd327..190433e2f 100644
--- a/source/slang/slang-check-decl.cpp
+++ b/source/slang/slang-check-decl.cpp
@@ -10644,6 +10644,7 @@ namespace Slang
void SemanticsDeclCapabilityVisitor::visitFunctionDeclBase(FunctionDeclBase* funcDecl)
{
+ // If the function is an entrypoint and specifies a target stage, add the capabilities to our function capabilities.
_dispatchCapabilitiesVisitorOfFunctionDecl(this, funcDecl,
[this, funcDecl](SyntaxNode* node, const CapabilitySet& nodeCaps, SourceLoc refLoc)
{
@@ -10657,30 +10658,12 @@ namespace Slang
auto declaredCaps = getDeclaredCapabilitySet(funcDecl);
- if (!declaredCaps.isEmpty())
- {
- // If the function is an entrypoint, add the stage to declaredCaps.
- if (auto entryPointAttr = funcDecl->findModifier<EntryPointAttribute>())
- {
- auto stageCaps = CapabilitySet(Profile(entryPointAttr->stage).getCapabilityName());
- if (declaredCaps.isIncompatibleWith(stageCaps))
- {
- maybeDiagnose(getSink(), this->getOptionSet(), DiagnosticCategory::Capability, funcDecl->loc, Diagnostics::stageIsIncompatibleWithCapabilityDefinition, funcDecl, stageCaps, declaredCaps);
- }
- else
- {
- declaredCaps.join(stageCaps);
- }
- }
- }
-
auto vis = getDeclVisibility(funcDecl);
// If 0 capabilities were annotated on a function, capabilities are inferred from the function body
if (declaredCaps.isEmpty())
{
declaredCaps = funcDecl->inferredCapabilityRequirements;
- return;
}
else
{
diff --git a/source/slang/slang-check-modifier.cpp b/source/slang/slang-check-modifier.cpp
index 705d0bb3b..d7f879c51 100644
--- a/source/slang/slang-check-modifier.cpp
+++ b/source/slang/slang-check-modifier.cpp
@@ -548,19 +548,47 @@ namespace Slang
{
SLANG_ASSERT(attr->args.getCount() == 1);
- String stageName;
- if (!checkLiteralStringVal(attr->args[0], &stageName))
+ String capNameString;
+ if (!checkLiteralStringVal(attr->args[0], &capNameString))
{
return false;
}
- auto stage = findStageByName(stageName);
- if (stage == Stage::Unknown)
+ CapabilityName capName = findCapabilityName(capNameString.getUnownedSlice());
+ if (capName != CapabilityName::Invalid)
{
- getSink()->diagnose(attr->args[0], Diagnostics::unknownStageName, stageName);
- }
+ if (isInternalCapabilityName(capName))
+ maybeDiagnose(getSink(), this->getOptionSet(), DiagnosticCategory::Capability, attr, Diagnostics::usingInternalCapabilityName, attr, capName);
+
+ // Ensure this capability only defines 1 stage per target, else diagnose an error.
+ // This is a fatal error, do not allow toggling this error off.
+ entryPointAttr->capabilitySet = CapabilitySet(capName);
+ HashSet<CapabilityAtom> stageToBeUsed;
+ for (auto& targetSet : entryPointAttr->capabilitySet.getCapabilityTargetSets())
+ {
+ for(auto& stageSet : targetSet.second.shaderStageSets)
+ stageToBeUsed.add(stageSet.first);
+ }
- entryPointAttr->stage = stage;
+ // TODO: Once profiles are removed in favor for `CapabilitySet`s we will beable to use more complex relationships,
+ // Until then we have an artificial limit that any capabilites used inside '[shader(...)]' must only specify 1 stage type
+ // uniformly across targets.
+ if (stageToBeUsed.getCount() > 1)
+ {
+ List<CapabilityAtom> atomsToPrint;
+ atomsToPrint.reserve(stageToBeUsed.getCount());
+ for (auto i : stageToBeUsed)
+ atomsToPrint.add(i);
+ getSink()->diagnose(attr, Diagnostics::capabilityHasMultipleStages, capNameString, atomsToPrint);
+ }
+ return entryPointAttr;
+ }
+ else
+ {
+ // always diagnose this error since nothing can compile with an invalid capability
+ getSink()->diagnose(attr, Diagnostics::unknownCapability, capNameString);
+ return false;
+ }
}
else if ((as<DomainAttribute>(attr)) ||
(as<MaxTessFactorAttribute>(attr)) ||
diff --git a/source/slang/slang-check-shader.cpp b/source/slang/slang-check-shader.cpp
index 99205e522..3a1e4c7f6 100644
--- a/source/slang/slang-check-shader.cpp
+++ b/source/slang/slang-check-shader.cpp
@@ -557,7 +557,7 @@ namespace Slang
for (auto target : linkage->targets)
{
auto targetCaps = target->getTargetCaps();
- auto stageCapabilitySet = CapabilitySet(entryPoint->getProfile().getCapabilityName());
+ auto stageCapabilitySet = entryPoint->getProfile().getCapabilityName();
targetCaps.join(stageCapabilitySet);
if (targetCaps.isIncompatibleWith(entryPointFuncDecl->inferredCapabilityRequirements))
{
@@ -613,20 +613,23 @@ namespace Slang
if (auto entryPointAttr = entryPointFuncDecl->findModifier<EntryPointAttribute>())
{
auto entryPointProfileStage = entryPointProfile.getStage();
- // Ensure every target is specifying the same stage as an entry` point
+ auto entryPointStage = getStageFromAtom(entryPointAttr->capabilitySet.getTargetStage());
+
+ // Ensure every target is specifying the same stage as an entry-point
// if a profile+stage was set, else user will not be aware that their
// code is requiring `fragment` on a `vertex` shader
for (auto target : targets)
{
auto targetProfile = target->getOptionSet().getProfile();
auto profileStage = targetProfile.getStage();
- if (profileStage != Stage::Unknown && profileStage != entryPointAttr->stage)
- maybeDiagnose(sink, optionSet, DiagnosticCategory::Capability, entryPointAttr, Diagnostics::entryPointAndProfileAreIncompatible, entryPointFuncDecl, entryPointAttr->stage, targetProfile.getName());
+ if (profileStage != Stage::Unknown && profileStage != entryPointStage)
+ maybeDiagnose(sink, optionSet, DiagnosticCategory::Capability, entryPointAttr, Diagnostics::entryPointAndProfileAreIncompatible, entryPointFuncDecl, entryPointStage, targetProfile.getName());
}
if (entryPointProfileStage == Stage::Unknown)
- entryPointProfile.setStage(entryPointAttr->stage);
- else if (entryPointProfileStage != Stage::Unknown && entryPointProfileStage != entryPointAttr->stage)
- maybeDiagnose(sink, optionSet, DiagnosticCategory::Capability, entryPointFuncDecl, Diagnostics::specifiedStageDoesntMatchAttribute, entryPointFuncDecl->getName(), entryPointProfileStage, entryPointAttr->stage);
+ entryPointProfile = Profile(entryPointStage);
+ else if (entryPointProfileStage != Stage::Unknown && entryPointProfileStage != entryPointStage)
+ maybeDiagnose(sink, optionSet, DiagnosticCategory::Capability, entryPointFuncDecl, Diagnostics::specifiedStageDoesntMatchAttribute, entryPointFuncDecl->getName(), entryPointProfileStage, entryPointStage);
+ entryPointProfile.additionalCapabilities.add(entryPointAttr->capabilitySet);
return true;
}
return false;
diff --git a/source/slang/slang-compiler.cpp b/source/slang/slang-compiler.cpp
index f5b7ff428..428532658 100644
--- a/source/slang/slang-compiler.cpp
+++ b/source/slang/slang-compiler.cpp
@@ -361,7 +361,7 @@ namespace Slang
return lookUp(UnownedTerminatedStringSlice(name));
}
- List<CapabilityName> Profile::getCapabilityName()
+ CapabilitySet Profile::getCapabilityName()
{
List<CapabilityName> result;
switch (getVersion())
@@ -378,7 +378,11 @@ namespace Slang
default:
break;
}
- return result;
+
+ CapabilitySet resultSet = CapabilitySet(result);
+ for(auto i : this->additionalCapabilities)
+ resultSet.join(i);
+ return resultSet;
}
char const* Profile::getName()
@@ -451,21 +455,21 @@ namespace Slang
return Stage::Fragment;
case CapabilityAtom::compute:
return Stage::Compute;
- case CapabilityAtom::mesh:
+ case CapabilityAtom::_mesh:
return Stage::Mesh;
- case CapabilityAtom::amplification:
+ case CapabilityAtom::_amplification:
return Stage::Amplification;
- case CapabilityAtom::anyhit:
+ case CapabilityAtom::_anyhit:
return Stage::AnyHit;
- case CapabilityAtom::closesthit:
+ case CapabilityAtom::_closesthit:
return Stage::ClosestHit;
- case CapabilityAtom::intersection:
+ case CapabilityAtom::_intersection:
return Stage::Intersection;
- case CapabilityAtom::raygen:
+ case CapabilityAtom::_raygen:
return Stage::RayGeneration;
- case CapabilityAtom::miss:
+ case CapabilityAtom::_miss:
return Stage::Miss;
- case CapabilityAtom::callable:
+ case CapabilityAtom::_callable:
return Stage::Callable;
default:
SLANG_UNEXPECTED("unknown stage atom");
diff --git a/source/slang/slang-diagnostic-defs.h b/source/slang/slang-diagnostic-defs.h
index d23ae8a3e..b35acd2c3 100644
--- a/source/slang/slang-diagnostic-defs.h
+++ b/source/slang/slang-diagnostic-defs.h
@@ -390,8 +390,8 @@ DIAGNOSTIC(30604, Error, useOfLessVisibleType, "'$0' references less visible typ
DIAGNOSTIC(36005, Error, invalidVisibilityModifierOnTypeOfDecl, "visibility modifier is not allowed on '$0'.")
// Capability
-DIAGNOSTIC(36100, Error, conflictingCapabilityDueToUseOfDecl, "'$0' requires capability '$1' that is conflicting with the '$2''s current capability requirement '$3'.")
-DIAGNOSTIC(36101, Error, conflictingCapabilityDueToStatement, "statement requires capability '$0' that is conflicting with the '$1''s current capability requirement '$2'.")
+DIAGNOSTIC(36100, Error, conflictingCapabilityDueToUseOfDecl, "'$0' requires capability '$1' that is conflicting with the '$2's current capability requirement '$3'.")
+DIAGNOSTIC(36101, Error, conflictingCapabilityDueToStatement, "statement requires capability '$0' that is conflicting with the '$1's current capability requirement '$2'.")
DIAGNOSTIC(36102, Error, conflictingCapabilityDueToStatementEnclosingFunc, "statement requires capability '$0' that is conflicting with the current function's capability requirement '$1'.")
DIAGNOSTIC(36103, Warning, missingCapabilityRequirementOnPublicDecl, "public symbol '$0' is missing capability requirement declaration, the symbol is assumed to require inferred capabilities '$1'.")
DIAGNOSTIC(36104, Error, useOfUndeclaredCapability, "'$0' uses undeclared capability '$1'.")
@@ -404,9 +404,10 @@ DIAGNOSTIC(36109, Error, invalidTargetSwitchCase, "'$0' cannot be used as a targ
DIAGNOSTIC(36110, Error, stageIsIncompatibleWithCapabilityDefinition, "'$0' is defined for stage '$1', which is incompatible with the declared capability set '$2'.")
DIAGNOSTIC(36111, Error, unexpectedCapability, "'$0' resolves into a disallowed `$1` Capability.")
DIAGNOSTIC(36112, Warning, entryPointAndProfileAreIncompatible, "'$0' is defined for stage '$1', which is incompatible with the declared profile '$2'.")
-DIAGNOSTIC(36113, Warning, usingInternalCapabilityName, "'$0' resolves into a '_Internal' `_$1' Capability, use '$1' instead.")
+DIAGNOSTIC(36113, Warning, usingInternalCapabilityName, "'$0' resolves into a '_Internal' '_$1' Capability, use '$1' instead.")
DIAGNOSTIC(36114, Warning, incompatibleWithPrecompileLib, "Precompiled library requires '$0', has `$1`, implicitly upgrading capabilities.")
DIAGNOSTIC(36115, Error, incompatibleWithPrecompileLibRestrictive, "Precompiled library requires '$0', has `$1`.")
+DIAGNOSTIC(36116, Error, capabilityHasMultipleStages, "Capability '$0' is targeting stages '$1', only allowed to use 1 unique stage here.")
// Attributes
diff --git a/source/slang/slang-profile.h b/source/slang/slang-profile.h
index d07a4555f..04d4f5112 100644
--- a/source/slang/slang-profile.h
+++ b/source/slang/slang-profile.h
@@ -113,7 +113,9 @@ namespace Slang
static Profile lookUp(char const* name);
char const* getName();
- List<CapabilityName> getCapabilityName();
+ CapabilitySet getCapabilityName();
+
+ List<CapabilitySet> additionalCapabilities;
RawVal raw = Unknown;
};
diff --git a/source/slang/slang.cpp b/source/slang/slang.cpp
index aa114e44d..23e25249d 100644
--- a/source/slang/slang.cpp
+++ b/source/slang/slang.cpp
@@ -1749,7 +1749,7 @@ CapabilitySet TargetRequest::getTargetCaps()
// If the user specified a explicit profile, we should pull
// a corresponding atom representing the target version from the profile.
- CapabilitySet profileCaps = CapabilitySet(optionSet.getProfile().getCapabilityName());
+ CapabilitySet profileCaps = optionSet.getProfile().getCapabilityName();
bool isGLSLTarget = false;
switch(getTarget())