summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--source/slang/slang-emit-spirv.cpp72
-rw-r--r--tests/vkray/raygen.slang4
-rw-r--r--tests/vkray/rayquery-compute.slang19
3 files changed, 93 insertions, 2 deletions
diff --git a/source/slang/slang-emit-spirv.cpp b/source/slang/slang-emit-spirv.cpp
index 8ecbe1bc7..fda7b098d 100644
--- a/source/slang/slang-emit-spirv.cpp
+++ b/source/slang/slang-emit-spirv.cpp
@@ -1382,6 +1382,36 @@ struct SPIRVEmitContext : public SourceEmitterBase, public SPIRVEmitSharedContex
return result;
}
+ List<List<UnownedStringSlice>> m_anyExtension;
+ void ensureAnyExtensionDeclaration(List<UnownedStringSlice> extensions)
+ {
+ if (!m_anyExtension.contains(extensions))
+ {
+ m_anyExtension.add(extensions);
+ }
+ }
+
+ void emitSPIRVAnyExtension()
+ {
+ for (const auto& options : m_anyExtension)
+ {
+ bool found = false;
+ for (UnownedStringSlice option : options)
+ {
+ if (m_extensionInsts.tryGetValue(option))
+ {
+ found = true;
+ break;
+ }
+ }
+
+ if (!found)
+ {
+ ensureExtensionDeclaration(options[0]);
+ }
+ }
+ }
+
SpvInst* ensureExtensionDeclarationBeforeSpv14(UnownedStringSlice name)
{
if (isSpirv14OrLater())
@@ -1710,8 +1740,10 @@ struct SPIRVEmitContext : public SourceEmitterBase, public SPIRVEmitSharedContex
return emitOpTypeSampler(inst);
case kIROp_RaytracingAccelerationStructureType:
- requireSPIRVCapability(SpvCapabilityRayTracingKHR);
- ensureExtensionDeclaration(UnownedStringSlice("SPV_KHR_ray_tracing"));
+ requireSPIRVAnyCapability({SpvCapabilityRayTracingKHR, SpvCapabilityRayQueryKHR});
+ ensureAnyExtensionDeclaration(
+ {UnownedStringSlice("SPV_KHR_ray_tracing"),
+ UnownedStringSlice("SPV_KHR_ray_query")});
return emitOpTypeAccelerationStructure(inst);
case kIROp_RayQueryType:
@@ -8224,6 +8256,36 @@ struct SPIRVEmitContext : public SourceEmitterBase, public SPIRVEmitSharedContex
}
}
+ List<List<SpvCapability>> m_anyCapability;
+ void requireSPIRVAnyCapability(List<SpvCapability> capabilities)
+ {
+ if (!m_anyCapability.contains(capabilities))
+ {
+ m_anyCapability.add(capabilities);
+ }
+ }
+
+ void emitSPIRVAnyCapabilities()
+ {
+ for (const auto& options : m_anyCapability)
+ {
+ bool found = false;
+ for (SpvCapability option : options)
+ {
+ if (m_capabilities.contains(option))
+ {
+ found = true;
+ break;
+ }
+ }
+
+ if (!found)
+ {
+ requireSPIRVCapability(options[0]);
+ }
+ }
+ }
+
void requireVariableBufferCapabilityIfNeeded(IRInst* type)
{
if (auto ptrType = as<IRPtrTypeBase>(type))
@@ -8416,6 +8478,12 @@ SlangResult emitSPIRVFromIR(
}
} while (context.m_forwardDeclaredPointers.getCount() != 0);
+ // Emit extensions and capabilities for which there are multiple options available.
+ // This is delayed to avoid emitting unnecessary extensions and capabilities if
+ // one of the options is already required by some other op.
+ context.emitSPIRVAnyExtension();
+ context.emitSPIRVAnyCapabilities();
+
context.emitFrontMatter();
context.emitPhysicalLayout();
diff --git a/tests/vkray/raygen.slang b/tests/vkray/raygen.slang
index 28bad734a..653435a65 100644
--- a/tests/vkray/raygen.slang
+++ b/tests/vkray/raygen.slang
@@ -117,6 +117,10 @@ void main()
outputImage[int2(gl_LaunchIDNV.xy)] = float4(color, 1.0);
}
+// CHECK_SPV: OpCapability RayTracingKHR
+// CHECK_SPV-NOT: OpCapability RayQueryKHR
+// CHECK_SPV: OpExtension "SPV_KHR_ray_tracing"
+// CHECK_SPV-NOT: OpExtension "SPV_KHR_ray_query"
// CHECK_SPV: %{{.*}} = OpVariable %_ptr_RayPayload{{NV|KHR}}_ReflectionRay{{.*}} RayPayload
// CHECK_SPV: OpTraceRayKHR
// CHECK_SPV: OpTraceRayKHR
diff --git a/tests/vkray/rayquery-compute.slang b/tests/vkray/rayquery-compute.slang
new file mode 100644
index 000000000..2de53cdcc
--- /dev/null
+++ b/tests/vkray/rayquery-compute.slang
@@ -0,0 +1,19 @@
+// rayquery-compute.slang
+//TEST:SIMPLE(filecheck=CHECK): -stage compute -entry main -target spirv-assembly -emit-spirv-directly
+
+RaytracingAccelerationStructure accelerationStructure;
+
+[numthreads(1, 1, 1)]
+void main(uint3 dispatchThreadID : SV_DispatchThreadID)
+{
+ RayQuery<RAY_FLAG_NONE> rayQuery;
+
+ uint instanceInclusionMask = 0x00;
+ RayDesc rayDesc;
+ rayQuery.TraceRayInline(accelerationStructure, RAY_FLAG_NONE, instanceInclusionMask, rayDesc);
+}
+
+// CHECK: OpCapability RayQueryKHR
+// CHECK-NOT: OpCapability RayTracingKHR
+// CHECK: OpExtension "SPV_KHR_ray_query"
+// CHECK-NOT: OpExtension "SPV_KHR_ray_tracing"