summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorjarcherNV <jarcher@nvidia.com>2025-06-10 09:44:08 -0700
committerGitHub <noreply@github.com>2025-06-10 09:44:08 -0700
commit3fa382505271834514d47612efee8e51a06204c5 (patch)
treea76ff3a3969ed229bfbe4452326335d1db62418a
parente37202002276b679c5241b2678af612552b06d2c (diff)
Allow checking capabilities in specific stages (#7375)
This allows checking capabilities in any stage, needed specifically for the hlsl_2018 capability which is defined for sm_5_1 and above. Stage specific capabilities such as cs_5_1 would not find this in any stage other than compute, so we need to restrict the check to only desired stages.
-rw-r--r--source/slang/slang-capability.cpp9
-rw-r--r--source/slang/slang-capability.h2
-rw-r--r--source/slang/slang-compiler.cpp43
-rw-r--r--source/slang/slang-emit-hlsl.cpp27
-rw-r--r--source/slang/slang-emit-hlsl.h12
-rw-r--r--source/slang/slang-profile.h2
-rw-r--r--tests/hlsl/hlsl-capability.slang49
7 files changed, 137 insertions, 7 deletions
diff --git a/source/slang/slang-capability.cpp b/source/slang/slang-capability.cpp
index 1eb0cae31..a2fef9f8a 100644
--- a/source/slang/slang-capability.cpp
+++ b/source/slang/slang-capability.cpp
@@ -718,17 +718,17 @@ bool CapabilityTargetSet::tryJoin(const CapabilityTargetSets& other)
return true;
}
-void CapabilitySet::join(const CapabilitySet& other)
+CapabilitySet& CapabilitySet::join(const CapabilitySet& other)
{
if (this->isEmpty() || other.isInvalid())
{
*this = other;
- return;
+ return *this;
}
if (this->isInvalid())
- return;
+ return *this;
if (other.isEmpty())
- return;
+ return *this;
List<CapabilityAtom> destroySet;
destroySet.reserve(this->m_targetSets.getCount());
@@ -746,6 +746,7 @@ void CapabilitySet::join(const CapabilitySet& other)
// join made a invalid CapabilitySet
if (this->m_targetSets.getCount() == 0)
this->m_targetSets[CapabilityAtom::Invalid].target = CapabilityAtom::Invalid;
+ return *this;
}
static uint32_t _calcAtomListDifferenceScore(
diff --git a/source/slang/slang-capability.h b/source/slang/slang-capability.h
index 7c429d825..4bf0704a0 100644
--- a/source/slang/slang-capability.h
+++ b/source/slang/slang-capability.h
@@ -157,7 +157,7 @@ public:
/// Join two capability sets to form ('this' & 'other').
/// Destroy incompatible targets/sets apart of 'this' between ('this' & 'other').
/// `this` may be made invalid if other is fully disjoint.
- void join(const CapabilitySet& other);
+ CapabilitySet& join(const CapabilitySet& other);
/// Join two capability sets to form ('this' & 'other').
/// If a target/set has an incompatible atom, do not destroy the target/set.
diff --git a/source/slang/slang-compiler.cpp b/source/slang/slang-compiler.cpp
index 8cb50c1e9..4ba937992 100644
--- a/source/slang/slang-compiler.cpp
+++ b/source/slang/slang-compiler.cpp
@@ -476,6 +476,49 @@ Stage getStageFromAtom(CapabilityAtom atom)
}
}
+CapabilityAtom getAtomFromStage(Stage stage)
+{
+ // Convert Slang::Stage to CapabilityAtom.
+ // Note that capabilities do not share the same values as Slang::Stage
+ // and must be explicitly converted.
+ switch (stage)
+ {
+ case Stage::Compute:
+ return CapabilityAtom::compute;
+ case Stage::Vertex:
+ return CapabilityAtom::vertex;
+ case Stage::Fragment:
+ return CapabilityAtom::fragment;
+ case Stage::Geometry:
+ return CapabilityAtom::geometry;
+ case Stage::Hull:
+ return CapabilityAtom::hull;
+ case Stage::Domain:
+ return CapabilityAtom::domain;
+ case Stage::Mesh:
+ return CapabilityAtom::_mesh;
+ case Stage::Amplification:
+ return CapabilityAtom::_amplification;
+ case Stage::RayGeneration:
+ return CapabilityAtom::_raygen;
+ case Stage::AnyHit:
+ return CapabilityAtom::_anyhit;
+ case Stage::ClosestHit:
+ return CapabilityAtom::_closesthit;
+ case Stage::Miss:
+ return CapabilityAtom::_miss;
+ case Stage::Intersection:
+ return CapabilityAtom::_intersection;
+ case Stage::Callable:
+ return CapabilityAtom::_callable;
+ case Stage::Dispatch:
+ return CapabilityAtom::dispatch;
+ default:
+ SLANG_UNEXPECTED("unknown stage");
+ UNREACHABLE_RETURN(CapabilityAtom::Invalid);
+ }
+}
+
SlangResult checkExternalCompilerSupport(Session* session, PassThroughMode passThrough)
{
// Check if the type is supported on this compile
diff --git a/source/slang/slang-emit-hlsl.cpp b/source/slang/slang-emit-hlsl.cpp
index ba167676a..b022a2db9 100644
--- a/source/slang/slang-emit-hlsl.cpp
+++ b/source/slang/slang-emit-hlsl.cpp
@@ -785,6 +785,29 @@ bool HLSLSourceEmitter::tryEmitInstStmtImpl(IRInst* inst)
}
}
+static bool isTargetHLSL2018(HLSLSourceEmitter* emitter, CapabilitySet targetCaps, Stage stage)
+{
+ auto stageAtom = getAtomFromStage(stage);
+
+ // Cache the result of this function for easier lookup.
+ auto result = emitter->getCachedCapability(stageAtom);
+ if (result)
+ return *result;
+
+ // Here we check for presence of the `hlsl_2018` capability for the
+ // current target+stage.
+ auto capabilitySetForStageOfEntryPoint = CapabilitySet(CapabilityName(stageAtom));
+ auto hlsl2018CapabilitySet =
+ CapabilitySet(CapabilityName::hlsl_2018).join(capabilitySetForStageOfEntryPoint);
+ if (targetCaps.join(capabilitySetForStageOfEntryPoint).implies(hlsl2018CapabilitySet))
+ {
+ emitter->addCachedCapability(stageAtom, false);
+ return false;
+ }
+ emitter->addCachedCapability(stageAtom, true);
+ return true;
+}
+
bool HLSLSourceEmitter::tryEmitInstExprImpl(IRInst* inst, const EmitOpInfo& inOuterPrec)
{
switch (inst->getOp())
@@ -827,7 +850,7 @@ bool HLSLSourceEmitter::tryEmitInstExprImpl(IRInst* inst, const EmitOpInfo& inOu
if (targetProfile.getVersion() < ProfileVersion::DX_6_0)
return false;
auto targetCaps = getTargetReq()->getTargetCaps();
- if (targetCaps.implies(CapabilityAtom::hlsl_2018))
+ if (!isTargetHLSL2018(this, targetCaps, m_entryPointStage))
return false;
if (as<IRBasicType>(inst->getDataType()))
@@ -855,7 +878,7 @@ bool HLSLSourceEmitter::tryEmitInstExprImpl(IRInst* inst, const EmitOpInfo& inOu
if (targetProfile.getVersion() < ProfileVersion::DX_6_0)
return false;
auto targetCaps = getTargetReq()->getTargetCaps();
- if (targetCaps.implies(CapabilityAtom::hlsl_2018))
+ if (!isTargetHLSL2018(this, targetCaps, m_entryPointStage))
return false;
if (as<IRBasicType>(inst->getDataType()))
diff --git a/source/slang/slang-emit-hlsl.h b/source/slang/slang-emit-hlsl.h
index 28319f93a..ddd72c8cc 100644
--- a/source/slang/slang-emit-hlsl.h
+++ b/source/slang/slang-emit-hlsl.h
@@ -26,9 +26,21 @@ public:
virtual RefObject* getExtensionTracker() SLANG_OVERRIDE { return m_extensionTracker; }
+ const bool* getCachedCapability(CapabilityAtom stage) const
+ {
+ return m_capabilityCache.tryGetValue(stage);
+ }
+ void addCachedCapability(CapabilityAtom stage, bool value)
+ {
+ m_capabilityCache.addIfNotExists(stage, value);
+ }
+
protected:
RefPtr<HLSLExtensionTracker> m_extensionTracker;
+ // Allow caching of capability results for easier lookup.
+ Dictionary<CapabilityAtom, bool> m_capabilityCache{};
+
virtual void emitLayoutSemanticsImpl(
IRInst* inst,
char const* uniformSemanticSpelling,
diff --git a/source/slang/slang-profile.h b/source/slang/slang-profile.h
index 9bd905c37..ca7b8b2ae 100644
--- a/source/slang/slang-profile.h
+++ b/source/slang/slang-profile.h
@@ -128,6 +128,8 @@ Stage findStageByName(String const& name);
UnownedStringSlice getStageText(Stage stage);
Stage getStageFromAtom(CapabilityAtom atom);
+CapabilityAtom getAtomFromStage(Stage stage);
+
} // namespace Slang
#endif
diff --git a/tests/hlsl/hlsl-capability.slang b/tests/hlsl/hlsl-capability.slang
new file mode 100644
index 000000000..3a950034e
--- /dev/null
+++ b/tests/hlsl/hlsl-capability.slang
@@ -0,0 +1,49 @@
+//TEST:SIMPLE(filecheck=CHECK_CS): -target hlsl -stage compute -entry computeMain -profile cs_6_3
+//TEST:SIMPLE(filecheck=CHECK_SM): -target hlsl -stage compute -entry computeMain -profile sm_6_3
+//TEST:SIMPLE(filecheck=CHECK_CS_CAP): -target hlsl -stage compute -entry computeMain -profile cs_6_3 -capability hlsl_2018
+//TEST:SIMPLE(filecheck=CHECK_SM_CAP): -target hlsl -stage compute -entry computeMain -profile sm_6_3 -capability hlsl_2018
+
+// Test IR code generation for the `?:` "select" operator with the hlsl_2018 capability and cs_6_3 profile.
+
+// Verify that select is emitted for cs_6_3 and sm_6_3.
+// CHECK_CS: select({{.*}})
+// CHECK_SM: select({{.*}})
+// CHECK_CS-NOT: {{.*}}?{{.*}}:{{.*}}
+// CHECK_SM-NOT: {{.*}}?{{.*}}:{{.*}}
+
+// Verify that select is not emitted for cs_6_3 and sm_6_3 with the hlsl_2018 capability.
+// CHECK_CS_CAP-NOT: select({{.*}})
+// CHECK_SM_CAP-NOT: select({{.*}})
+// CHECK_CS_CAP: {{.*}}?{{.*}}:{{.*}}
+// CHECK_SM_CAP: {{.*}}?{{.*}}:{{.*}}
+
+RWStructuredBuffer<int> outputBuffer;
+static int result = 0;
+bool2 assignFunc(int index)
+{
+ result++;
+ return bool2(true);
+}
+
+[numthreads(4, 1, 1)]
+void computeMain(int3 dispatchThreadID : SV_DispatchThreadID)
+{
+ int index = dispatchThreadID.x;
+
+ if (all(bool2(index >= 1) && assignFunc(index)))
+ {
+ result++;
+ }
+
+ if (all(bool2(index >= 2) || !assignFunc(index)))
+ {
+ result++;
+ }
+
+ if (all(bool2(index >= 3) ? assignFunc(index) : bool2(false)))
+ {
+ result++;
+ }
+
+ outputBuffer[index] = result;
+}