summaryrefslogtreecommitdiffstats
path: root/source
diff options
context:
space:
mode:
authorYong He <yonghe@outlook.com>2023-09-28 18:08:06 -0700
committerGitHub <noreply@github.com>2023-09-28 18:08:06 -0700
commitaf61737e7ba107e9e92164bf39ce6ab34e05ce82 (patch)
tree37747b26fb951548256669124adb83c0246a8f32 /source
parentb7d318f48db2cb83a41d665f1727ae93fc555124 (diff)
[Direct SPIRV]: ray tracing pipeline intrinsics. (#3244)
* Use a dedicated inst opcode to retrieve ray payload locations. * [Direct SPIRV]: ray tracing pipeline intrinsics. * Fix. --------- Co-authored-by: Yong He <yhe@nvidia.com>
Diffstat (limited to 'source')
-rw-r--r--source/slang/hlsl.meta.slang577
-rw-r--r--source/slang/slang-emit-c-like.cpp53
-rw-r--r--source/slang/slang-emit-c-like.h48
-rw-r--r--source/slang/slang-emit-glsl.cpp38
-rw-r--r--source/slang/slang-emit-spirv.cpp42
-rw-r--r--source/slang/slang-intrinsic-expand.cpp52
-rw-r--r--source/slang/slang-ir-glsl-legalize.cpp57
-rw-r--r--source/slang/slang-ir-inst-defs.h2
-rw-r--r--source/slang/slang-ir-lower-buffer-element-type.cpp2
-rw-r--r--source/slang/slang-ir-specialize-target-switch.cpp2
-rw-r--r--source/slang/slang-ir-spirv-legalize.cpp38
11 files changed, 575 insertions, 336 deletions
diff --git a/source/slang/hlsl.meta.slang b/source/slang/hlsl.meta.slang
index f26418855..49fcbe346 100644
--- a/source/slang/hlsl.meta.slang
+++ b/source/slang/hlsl.meta.slang
@@ -7243,9 +7243,6 @@ struct BuiltInTriangleIntersectionAttributes
// 10.3.1
-__target_intrinsic(hlsl)
-void CallShader<Payload>(uint shaderIndex, inout Payload payload);
-
// `executeCallableNV` is the GLSL intrinsic that will be used to implement
// `CallShader()` for GLSL-based targets.
//
@@ -7257,9 +7254,8 @@ void __executeCallable(uint shaderIndex, int payloadLocation);
// for a type being used in a `CallShader()` call for GLSL-based targets.
//
__generic<Payload>
-__target_intrinsic(__glslRayTracing, "$XC")
[__readNone]
-[__AlwaysFoldIntoUseSiteAttribute]
+__intrinsic_op($(kIROp_GetVulkanRayTracingPayloadLocation))
int __callablePayloadLocation(__ref Payload payload);
// Now we provide a hard-coded definition of `CallShader()` for GLSL-based
@@ -7267,29 +7263,35 @@ int __callablePayloadLocation(__ref Payload payload);
// GLSL equivalent.
//
__generic<Payload>
-__specialized_for_target(glsl)
void CallShader(uint shaderIndex, inout Payload payload)
{
- [__vulkanCallablePayload]
- static Payload p;
+ __target_switch
+ {
+ case hlsl: __intrinsic_asm "CallShader";
+ case glsl:
+ {
+ [__vulkanCallablePayload]
+ static Payload p;
- p = payload;
- __executeCallable(shaderIndex, __callablePayloadLocation(p));
- payload = p;
+ p = payload;
+ __executeCallable(shaderIndex, __callablePayloadLocation(p));
+ payload = p;
+ }
+ case spirv:
+ {
+ [__vulkanCallablePayload]
+ static Payload p;
+
+ p = payload;
+ spirv_asm {
+ OpExecuteCallableKHR $shaderIndex &p
+ };
+ payload = p;
+ }
+ }
}
// 10.3.2
-__target_intrinsic(hlsl)
-__target_intrinsic(cuda, "traceOptiXRay")
-void TraceRay<payload_t>(
- RaytracingAccelerationStructure AccelerationStructure,
- uint RayFlags,
- uint InstanceInclusionMask,
- uint RayContributionToHitGroupIndex,
- uint MultiplierForGeometryContributionToHitGroupIndex,
- uint MissShaderIndex,
- RayDesc Ray,
- inout payload_t Payload);
__target_intrinsic(GL_NV_ray_tracing, "traceNV")
__target_intrinsic(GL_EXT_ray_tracing, "traceRayEXT")
@@ -7313,13 +7315,11 @@ void __traceRay(
// syntax works in a pinch.
//
__generic<Payload>
-__target_intrinsic(__glslRayTracing, "$XP")
[__readNone]
-[__AlwaysFoldIntoUseSiteAttribute]
+__intrinsic_op($(kIROp_GetVulkanRayTracingPayloadLocation))
int __rayPayloadLocation(__ref Payload payload);
__generic<payload_t>
-__specialized_for_target(glsl)
void TraceRay(
RaytracingAccelerationStructure AccelerationStructure,
uint RayFlags,
@@ -7330,25 +7330,58 @@ void TraceRay(
RayDesc Ray,
inout payload_t Payload)
{
- [__vulkanRayPayload]
- static payload_t p;
+ __target_switch
+ {
+ case hlsl: __intrinsic_asm "TraceRay";
+ case cuda: __intrinsic_asm "traceOptiXRay";
+ case glsl:
+ {
+ [__vulkanRayPayload]
+ static payload_t p;
- p = Payload;
- __traceRay(
- AccelerationStructure,
- RayFlags,
- InstanceInclusionMask,
- RayContributionToHitGroupIndex,
- MultiplierForGeometryContributionToHitGroupIndex,
- MissShaderIndex,
- Ray.Origin,
- Ray.TMin,
- Ray.Direction,
- Ray.TMax,
- __rayPayloadLocation(p));
- Payload = p;
-}
+ p = Payload;
+ __traceRay(
+ AccelerationStructure,
+ RayFlags,
+ InstanceInclusionMask,
+ RayContributionToHitGroupIndex,
+ MultiplierForGeometryContributionToHitGroupIndex,
+ MissShaderIndex,
+ Ray.Origin,
+ Ray.TMin,
+ Ray.Direction,
+ Ray.TMax,
+ __rayPayloadLocation(p));
+ Payload = p;
+ }
+ case spirv:
+ {
+ [__vulkanRayPayload]
+ static payload_t p;
+ p = Payload;
+ let origin = Ray.Origin;
+ let direction = Ray.Direction;
+ let tmin = Ray.TMin;
+ let tmax = Ray.TMax;
+ spirv_asm {
+ OpTraceRayKHR
+ /**/ $AccelerationStructure
+ /**/ $RayFlags
+ /**/ $InstanceInclusionMask
+ /**/ $RayContributionToHitGroupIndex
+ /**/ $MultiplierForGeometryContributionToHitGroupIndex
+ /**/ $MissShaderIndex
+ /**/ $origin
+ /**/ $tmin
+ /**/ $direction
+ /**/ $tmax
+ /**/ &p;
+ };
+ Payload = p;
+ }
+ }
+}
// NOTE!
// The name of the following functions may change when DXR supports
@@ -7356,17 +7389,6 @@ void TraceRay(
//
// https://github.com/KhronosGroup/GLSL/blob/master/extensions/nv/GLSL_NV_ray_tracing_motion_blur.txt
-void TraceMotionRay<payload_t>(
- RaytracingAccelerationStructure AccelerationStructure,
- uint RayFlags,
- uint InstanceInclusionMask,
- uint RayContributionToHitGroupIndex,
- uint MultiplierForGeometryContributionToHitGroupIndex,
- uint MissShaderIndex,
- RayDesc Ray,
- float CurrentTime,
- inout payload_t Payload);
-
__target_intrinsic(glsl, "traceRayMotionNV")
__glsl_version(460)
__glsl_extension(GL_NV_ray_tracing_motion_blur)
@@ -7386,7 +7408,6 @@ void __traceMotionRay(
int PayloadLocation);
__generic<payload_t>
-__specialized_for_target(glsl)
void TraceMotionRay(
RaytracingAccelerationStructure AccelerationStructure,
uint RayFlags,
@@ -7398,36 +7419,84 @@ void TraceMotionRay(
float CurrentTime,
inout payload_t Payload)
{
- [__vulkanRayPayload]
- static payload_t p;
-
- p = Payload;
- __traceMotionRay(
- AccelerationStructure,
- RayFlags,
- InstanceInclusionMask,
- RayContributionToHitGroupIndex,
- MultiplierForGeometryContributionToHitGroupIndex,
- MissShaderIndex,
- Ray.Origin,
- Ray.TMin,
- Ray.Direction,
- Ray.TMax,
- CurrentTime,
- __rayPayloadLocation(p));
- Payload = p;
+ __target_switch
+ {
+ case hlsl: __intrinsic_asm "TraceMotionRay";
+ case glsl:
+ {
+ [__vulkanRayPayload]
+ static payload_t p;
+
+ p = Payload;
+ __traceMotionRay(
+ AccelerationStructure,
+ RayFlags,
+ InstanceInclusionMask,
+ RayContributionToHitGroupIndex,
+ MultiplierForGeometryContributionToHitGroupIndex,
+ MissShaderIndex,
+ Ray.Origin,
+ Ray.TMin,
+ Ray.Direction,
+ Ray.TMax,
+ CurrentTime,
+ __rayPayloadLocation(p));
+ Payload = p;
+ }
+ case spirv:
+ {
+ [__vulkanRayPayload]
+ static payload_t p;
+
+ let origin = Ray.Origin;
+ let direction = Ray.Direction;
+ let tmin = Ray.TMin;
+ let tmax = Ray.TMax;
+
+ p = Payload;
+ spirv_asm {
+ OpCapability RayTracingMotionBlurNV;
+ OpExtension "SPV_NV_ray_tracing_motion_blur";
+
+ OpTraceRayMotionNV
+ /**/ $AccelerationStructure
+ /**/ $RayFlags
+ /**/ $InstanceInclusionMask
+ /**/ $RayContributionToHitGroupIndex
+ /**/ $MultiplierForGeometryContributionToHitGroupIndex
+ /**/ $MissShaderIndex
+ /**/ $origin
+ /**/ $tmin
+ /**/ $direction
+ /**/ $tmax
+ /**/ $CurrentTime
+ /**/ &p;
+ };
+ Payload = p;
+ }
+ }
}
// 10.3.3
__target_intrinsic(hlsl)
bool ReportHit<A>(float tHit, uint hitKind, A attributes);
-__target_intrinsic(GL_NV_ray_tracing, "reportIntersectionNV")
-__target_intrinsic(GL_EXT_ray_tracing, "reportIntersectionEXT")
-bool __reportIntersection(float tHit, uint hitKind);
+bool __reportIntersection(float tHit, uint hitKind)
+{
+ __target_switch
+ {
+ case GL_EXT_ray_tracing: __intrinsic_asm "reportIntersectionEXT";
+ case GL_NV_ray_tracing: __intrinsic_asm "reportIntersectionNV";
+ case spirv:
+ return spirv_asm {
+ result:$$bool = OpReportIntersectionKHR $tHit $hitKind;
+ };
+ }
+}
__generic<A>
__specialized_for_target(glsl)
+__specialized_for_target(spirv)
bool ReportHit(float tHit, uint hitKind, A attributes)
{
[__vulkanHitAttributes]
@@ -7438,18 +7507,30 @@ bool ReportHit(float tHit, uint hitKind, A attributes)
}
// 10.3.4
-__target_intrinsic(hlsl)
-__target_intrinsic(GL_NV_ray_tracing, ignoreIntersectionNV)
-__target_intrinsic(GL_EXT_ray_tracing, "ignoreIntersectionEXT;")
-__target_intrinsic(cuda, "optixIgnoreIntersection")
-void IgnoreHit();
+void IgnoreHit()
+{
+ __target_switch
+ {
+ case hlsl: __intrinsic_asm "IgnoreHit";
+ case GL_EXT_ray_tracing: __intrinsic_asm "ignoreIntersectionEXT;";
+ case GL_NV_ray_tracing: __intrinsic_asm "ignoreIntersectionNV";
+ case cuda: __intrinsic_asm "optixIgnoreIntersection";
+ case spirv: spirv_asm { OpIgnoreIntersectionKHR; %_ = OpLabel };
+ }
+}
// 10.3.5
-__target_intrinsic(hlsl)
-__target_intrinsic(GL_NV_ray_tracing, terminateRayNV)
-__target_intrinsic(GL_EXT_ray_tracing, "terminateRayEXT;")
-__target_intrinsic(cuda, "optixTerminateRay")
-void AcceptHitAndEndSearch();
+void AcceptHitAndEndSearch()
+{
+ __target_switch
+ {
+ case hlsl: __intrinsic_asm "AcceptHitAndEndSearch";
+ case GL_EXT_ray_tracing: __intrinsic_asm "terminateRayEXT;";
+ case GL_NV_ray_tracing: __intrinsic_asm "terminateRayNV";
+ case cuda: __intrinsic_asm "optixTerminateRay";
+ case spirv: spirv_asm { OpTerminateRayKHR; %_ = OpLabel };
+ }
+}
// 10.4 - System Values and Special Semantics
@@ -7458,32 +7539,82 @@ void AcceptHitAndEndSearch();
// 10.4.1 - Ray Dispatch System Values
-__target_intrinsic(GL_NV_ray_tracing, "(gl_LaunchIDNV)")
-__target_intrinsic(GL_EXT_ray_tracing, "(gl_LaunchIDEXT)")
-__target_intrinsic(cuda, "optixGetLaunchIndex")
-uint3 DispatchRaysIndex();
+uint3 DispatchRaysIndex()
+{
+ __target_switch
+ {
+ case hlsl: __intrinsic_asm "DispatchRaysIndex";
+ case GL_EXT_ray_tracing: __intrinsic_asm "(gl_LaunchIDEXT)";
+ case GL_NV_ray_tracing: __intrinsic_asm "(gl_LaunchIDNV)";
+ case cuda: __intrinsic_asm "optixGetLaunchIndex";
+ case spirv:
+ return spirv_asm {
+ result:$$uint3 = OpLoad builtin(LaunchIdKHR:uint3);
+ };
+ }
+}
-__target_intrinsic(GL_NV_ray_tracing, "(gl_LaunchSizeNV)")
-__target_intrinsic(GL_EXT_ray_tracing, "(gl_LaunchSizeEXT)")
-__target_intrinsic(cuda, "optixGetLaunchDimensions")
-uint3 DispatchRaysDimensions();
+uint3 DispatchRaysDimensions()
+{
+ __target_switch
+ {
+ case hlsl: __intrinsic_asm "DispatchRaysDimensions";
+ case GL_EXT_ray_tracing: __intrinsic_asm "(gl_LaunchSizeEXT)";
+ case GL_NV_ray_tracing: __intrinsic_asm "(gl_LaunchSizeNV)";
+ case cuda: __intrinsic_asm "optixGetLaunchDimensions";
+ case spirv:
+ return spirv_asm {
+ result:$$uint3 = OpLoad builtin(LaunchSizeKHR:uint3);
+ };
+ }
+}
// 10.4.2 - Ray System Values
-__target_intrinsic(GL_NV_ray_tracing, "(gl_WorldRayOriginNV)")
-__target_intrinsic(GL_EXT_ray_tracing, "(gl_WorldRayOriginEXT)")
-__target_intrinsic(cuda, "optixGetWorldRayOrigin")
-float3 WorldRayOrigin();
+float3 WorldRayOrigin()
+{
+ __target_switch
+ {
+ case hlsl: __intrinsic_asm "WorldRayOrigin";
+ case GL_EXT_ray_tracing: __intrinsic_asm "(gl_WorldRayOriginEXT)";
+ case GL_NV_ray_tracing: __intrinsic_asm "(gl_WorldRayOriginNV)";
+ case cuda: __intrinsic_asm "optixGetWorldRayOrigin";
+ case spirv:
+ return spirv_asm {
+ result:$$float3 = OpLoad builtin(WorldRayOriginKHR:float3);
+ };
+ }
+}
-__target_intrinsic(GL_NV_ray_tracing, "(gl_WorldRayDirectionNV)")
-__target_intrinsic(GL_EXT_ray_tracing, "(gl_WorldRayDirectionEXT)")
-__target_intrinsic(cuda, "optixGetWorldRayDirection")
-float3 WorldRayDirection();
+float3 WorldRayDirection()
+{
+ __target_switch
+ {
+ case hlsl: __intrinsic_asm "WorldRayDirection";
+ case GL_EXT_ray_tracing: __intrinsic_asm "(gl_WorldRayDirectionEXT)";
+ case GL_NV_ray_tracing: __intrinsic_asm "(gl_WorldRayDirectionNV)";
+ case cuda: __intrinsic_asm "optixGetWorldRayDirection";
+ case spirv:
+ return spirv_asm {
+ result:$$float3 = OpLoad builtin(WorldRayDirectionKHR:float3);
+ };
+ }
+}
-__target_intrinsic(GL_NV_ray_tracing, "(gl_RayTminNV)")
-__target_intrinsic(GL_EXT_ray_tracing, "(gl_RayTminEXT)")
-__target_intrinsic(cuda, "optixGetRayTmin")
-float RayTMin();
+float RayTMin()
+{
+ __target_switch
+ {
+ case hlsl: __intrinsic_asm "RayTMin";
+ case GL_EXT_ray_tracing: __intrinsic_asm "(gl_RayTminEXT)";
+ case GL_NV_ray_tracing: __intrinsic_asm "(gl_RayTminNV)";
+ case cuda: __intrinsic_asm "optixGetRayTmin";
+ case spirv:
+ return spirv_asm {
+ result:$$float = OpLoad builtin(RayTminKHR:float);
+ };
+ }
+}
// Note: The `RayTCurrent()` intrinsic should translate to
// either `gl_HitTNV` (for hit shaders) or `gl_RayTmaxNV`
@@ -7495,68 +7626,190 @@ float RayTMin();
// we should simply provide two overloads here, specialized
// to the appropriate Vulkan stages.
//
-__target_intrinsic(GL_NV_ray_tracing, "(gl_RayTmaxNV)")
-__target_intrinsic(GL_EXT_ray_tracing, "(gl_RayTmaxEXT)")
-__target_intrinsic(cuda, "optixGetRayTmax")
-float RayTCurrent();
+float RayTCurrent()
+{
+ __target_switch
+ {
+ case hlsl: __intrinsic_asm "RayTCurrent";
+ case GL_EXT_ray_tracing: __intrinsic_asm "(gl_RayTmaxEXT)";
+ case GL_NV_ray_tracing: __intrinsic_asm "(gl_RayTmaxNV)";
+ case cuda: __intrinsic_asm "optixGetRayTmax";
+ case spirv:
+ return spirv_asm {
+ result:$$float = OpLoad builtin(RayTmaxKHR:float);
+ };
+ }
+}
-__target_intrinsic(GL_NV_ray_tracing, "(gl_IncomingRayFlagsNV)")
-__target_intrinsic(GL_EXT_ray_tracing, "(gl_IncomingRayFlagsEXT)")
-__target_intrinsic(cuda, "optixGetRayFlags")
-uint RayFlags();
+uint RayFlags()
+{
+ __target_switch
+ {
+ case hlsl: __intrinsic_asm "RayFlags";
+ case GL_EXT_ray_tracing: __intrinsic_asm "(gl_IncomingRayFlagsEXT)";
+ case GL_NV_ray_tracing: __intrinsic_asm "(gl_IncomingRayFlagsNV)";
+ case cuda: __intrinsic_asm "optixGetRayFlags";
+ case spirv:
+ return spirv_asm {
+ result:$$uint = OpLoad builtin(IncomingRayFlagsKHR:uint);
+ };
+ }
+}
// 10.4.3 - Primitive/Object Space System Values
-__target_intrinsic(__glslRayTracing, "(gl_InstanceID)")
-__target_intrinsic(cuda, "optixGetInstanceIndex")
-uint InstanceIndex();
+uint InstanceIndex()
+{
+ __target_switch
+ {
+ case hlsl: __intrinsic_asm "InstanceIndex";
+ case __glslRayTracing: __intrinsic_asm "(gl_InstanceID)";
+ case cuda: __intrinsic_asm "optixGetInstanceIndex";
+ case spirv:
+ return spirv_asm {
+ result:$$uint = OpLoad builtin(InstanceId:uint);
+ };
+ }
+}
-__target_intrinsic(GL_NV_ray_tracing, "(gl_InstanceCustomIndexNV)")
-__target_intrinsic(GL_EXT_ray_tracing, "(gl_InstanceCustomIndexEXT)")
-__target_intrinsic(cuda, "optixGetInstanceId")
-uint InstanceID();
+uint InstanceID()
+{
+ __target_switch
+ {
+ case hlsl: __intrinsic_asm "InstanceID";
+ case GL_EXT_ray_tracing: __intrinsic_asm "(gl_InstanceCustomIndexEXT)";
+ case GL_NV_ray_tracing: __intrinsic_asm "(gl_InstanceCustomIndexNV)";
+ case cuda: __intrinsic_asm "optixGetInstanceId";
+ case spirv:
+ return spirv_asm {
+ result:$$uint = OpLoad builtin(InstanceCustomIndexKHR:uint);
+ };
+ }
+}
-__target_intrinsic(__glslRayTracing, "(gl_PrimitiveID)")
-__target_intrinsic(cuda, "optixGetPrimitiveIndex")
-uint PrimitiveIndex();
+uint PrimitiveIndex()
+{
+ __target_switch
+ {
+ case hlsl: __intrinsic_asm "PrimitiveIndex";
+ case __glslRayTracing: __intrinsic_asm "(gl_PrimitiveID)";
+ case cuda: __intrinsic_asm "optixGetPrimitiveIndex";
+ case spirv:
+ return spirv_asm {
+ result:$$uint = OpLoad builtin(PrimitiveId:uint);
+ };
+ }
+}
-__target_intrinsic(GL_NV_ray_tracing, "(gl_ObjectRayOriginNV)")
-__target_intrinsic(GL_EXT_ray_tracing, "(gl_ObjectRayOriginEXT)")
-__target_intrinsic(cuda, "optixGetObjectRayOrigin")
-float3 ObjectRayOrigin();
+float3 ObjectRayOrigin()
+{
+ __target_switch
+ {
+ case hlsl: __intrinsic_asm "ObjectRayOrigin";
+ case GL_EXT_ray_tracing: __intrinsic_asm "(gl_ObjectRayOriginEXT)";
+ case GL_NV_ray_tracing: __intrinsic_asm "(gl_ObjectRayOriginNV)";
+ case cuda: __intrinsic_asm "optixGetObjectRayOrigin";
+ case spirv:
+ return spirv_asm {
+ result:$$float3 = OpLoad builtin(ObjectRayOriginKHR:float3);
+ };
+ }
+}
-__target_intrinsic(GL_NV_ray_tracing, "(gl_ObjectRayDirectionNV)")
-__target_intrinsic(GL_EXT_ray_tracing, "(gl_ObjectRayDirectionEXT)")
-__target_intrinsic(cuda, "optixGetObjectRayDirection")
-float3 ObjectRayDirection();
+float3 ObjectRayDirection()
+{
+ __target_switch
+ {
+ case hlsl: __intrinsic_asm "ObjectRayDirection";
+ case GL_EXT_ray_tracing: __intrinsic_asm "(gl_ObjectRayDirectionEXT)";
+ case GL_NV_ray_tracing: __intrinsic_asm "(gl_ObjectRayDirectionNV)";
+ case cuda: __intrinsic_asm "optixGetObjectRayDirection";
+ case spirv:
+ return spirv_asm {
+ result:$$float3 = OpLoad builtin(ObjectRayDirectionKHR:float3);
+ };
+ }
+}
// TODO: optix has an optixGetObjectToWorldTransformMatrix function that returns 12
// floats by reference.
-__target_intrinsic(GL_NV_ray_tracing, "transpose(gl_ObjectToWorldNV)")
-__target_intrinsic(GL_EXT_ray_tracing, "transpose(gl_ObjectToWorldEXT)")
-float3x4 ObjectToWorld3x4();
+float3x4 ObjectToWorld3x4()
+{
+ __target_switch
+ {
+ case hlsl: __intrinsic_asm "ObjectToWorld3x4";
+ case GL_EXT_ray_tracing: __intrinsic_asm "transpose(gl_ObjectToWorldEXT)";
+ case GL_NV_ray_tracing: __intrinsic_asm "transpose(gl_ObjectToWorldNV)";
+ case spirv:
+ return spirv_asm {
+ %mat = OpLoad builtin(ObjectToWorldKHR:float4x3);
+ result:$$float3x4 = OpTranspose %mat;
+ };
+ }
+}
-__target_intrinsic(GL_NV_ray_tracing, "transpose(gl_WorldToObjectNV)")
-__target_intrinsic(GL_EXT_ray_tracing, "transpose(gl_WorldToObjectEXT)")
-float3x4 WorldToObject3x4();
+float3x4 WorldToObject3x4()
+{
+ __target_switch
+ {
+ case hlsl: __intrinsic_asm "WorldToObject3x4";
+ case GL_EXT_ray_tracing: __intrinsic_asm "transpose(gl_WorldToObjectEXT)";
+ case GL_NV_ray_tracing: __intrinsic_asm "transpose(gl_WorldToObjectNV)";
+ case spirv:
+ return spirv_asm {
+ %mat = OpLoad builtin(WorldToObjectKHR:float4x3);
+ result:$$float3x4 = OpTranspose %mat;
+ };
+ }
+}
-__target_intrinsic(GL_NV_ray_tracing, "transpose(gl_ObjectToWorldNV)")
-__target_intrinsic(GL_EXT_ray_tracing, "transpose(gl_ObjectToWorld3x4EXT)")
-float4x3 ObjectToWorld4x3();
+float4x3 ObjectToWorld4x3()
+{
+ __target_switch
+ {
+ case hlsl: __intrinsic_asm "ObjectToWorld4x3";
+ case GL_EXT_ray_tracing: __intrinsic_asm "(gl_ObjectToWorldEXT)";
+ case GL_NV_ray_tracing: __intrinsic_asm "(gl_ObjectToWorldNV)";
+ case spirv:
+ return spirv_asm {
+ result:$$float4x3 = OpLoad builtin(ObjectToWorldKHR:float4x3);
+ };
+ }
+}
-__target_intrinsic(GL_NV_ray_tracing, "transpose(gl_WorldToObjectNV)")
-__target_intrinsic(GL_EXT_ray_tracing, "transpose(gl_WorldToObject3x4EXT)")
-float4x3 WorldToObject4x3();
+float4x3 WorldToObject4x3()
+{
+ __target_switch
+ {
+ case hlsl: __intrinsic_asm "WorldToObject4x3";
+ case GL_EXT_ray_tracing: __intrinsic_asm "(gl_WorldToObject3x4EXT)";
+ case GL_NV_ray_tracing: __intrinsic_asm "(gl_WorldToObjectNV)";
+ case spirv:
+ return spirv_asm {
+ result:$$float4x3 = OpLoad builtin(WorldToObjectKHR:float4x3);
+ };
+ }
+}
// NOTE!
// The name of the following functions may change when DXR supports
// a feature similar to the `GL_NV_ray_tracing_motion_blur` extension
-__target_intrinsic(glsl, "(gl_CurrentRayTimeNV)")
__glsl_version(460)
__glsl_extension(GL_NV_ray_tracing_motion_blur)
__glsl_extension(GL_EXT_ray_tracing)
-float RayCurrentTime();
+float RayCurrentTime()
+{
+ __target_switch
+ {
+ case hlsl: __intrinsic_asm "RayCurrentTime";
+ case glsl: __intrinsic_asm "(gl_CurrentRayTimeNV)";
+ case spirv:
+ return spirv_asm {
+ result:$$float = OpLoad builtin(CurrentRayTimeNV:float);
+ };
+ }
+}
// Note: The provisional DXR spec included these unadorned
// `ObjectToWorld()` and `WorldToObject()` functions, so
@@ -7571,10 +7824,20 @@ float3x4 ObjectToWorld() { return ObjectToWorld3x4(); }
float3x4 WorldToObject() { return WorldToObject3x4(); }
// 10.4.4 - Hit Specific System values
-__target_intrinsic(GL_NV_ray_tracing, "(gl_HitKindNV)")
-__target_intrinsic(GL_EXT_ray_tracing, "(gl_HitKindEXT)")
-__target_intrinsic(cuda, "optixGetHitKind")
-uint HitKind();
+uint HitKind()
+{
+ __target_switch
+ {
+ case hlsl: __intrinsic_asm "HitKind";
+ case GL_EXT_ray_tracing: __intrinsic_asm "(gl_HitKindEXT)";
+ case GL_NV_ray_tracing: __intrinsic_asm "(gl_HitKindNV)";
+ case cuda: __intrinsic_asm "optixGetHitKind";
+ case spirv:
+ return spirv_asm {
+ result:$$uint = OpLoad builtin(HitKindKHR:uint);
+ };
+ }
+}
// Pre-defined hit kinds (not documented explicitly)
static const uint HIT_KIND_TRIANGLE_FRONT_FACE = 254;
@@ -7794,7 +8057,17 @@ struct FeedbackTexture2DArray<T : __BuiltinSamplerFeedbackType>
// Get the index of the geometry that was hit in an intersection, any-hit, or closest-hit shader
__target_intrinsic(GL_EXT_ray_tracing, "(gl_GeometryIndexEXT)")
-uint GeometryIndex();
+uint GeometryIndex()
+{
+ __target_switch
+ {
+ case hlsl: __intrinsic_asm "GeometryIndex";
+ case glsl: __intrinsic_asm "(gl_GeometryIndexEXT)";
+ case spirv: return spirv_asm {
+ result:$$uint = OpLoad builtin(RayGeometryIndexKHR:uint);
+ };
+ }
+}
// Status of whether a (closest) hit has been committed in a `RayQuery`.
typedef uint COMMITTED_STATUS;
@@ -8327,9 +8600,7 @@ Ref<T> __hitObjectAttributes<T>()
// for GLSL-based targets.
//
__generic<Attributes>
-__target_intrinsic(__glslRayTracing, "$XH")
-[__readNone]
-[__AlwaysFoldIntoUseSiteAttribute]
+__intrinsic_op($(kIROp_GetVulkanRayTracingPayloadLocation))
int __hitObjectAttributesLocation(__ref Attributes attributes);
/// Immutable data type representing a ray hit or a miss. Can be used to invoke hit or miss shading,
diff --git a/source/slang/slang-emit-c-like.cpp b/source/slang/slang-emit-c-like.cpp
index 2559269d4..fb216ae01 100644
--- a/source/slang/slang-emit-c-like.cpp
+++ b/source/slang/slang-emit-c-like.cpp
@@ -41,56 +41,6 @@ struct CLikeSourceEmitter::ComputeEmitActionsContext
List<EmitAction>* actions;
};
-/* !!!!!!!!!!!!!!!!!!!!!!!!!!!! LocationTracker !!!!!!!!!!!!!!!!!!!!!!!!!! */
-
-/* static */LocationTracker::Kind LocationTracker::getKindFromDecoration(IRDecoration* decoration)
-{
- switch (decoration->getOp())
- {
- case kIROp_VulkanRayPayloadDecoration: return Kind::RayPayload;
- case kIROp_VulkanCallablePayloadDecoration: return Kind::CallablePayload;
- case kIROp_VulkanHitObjectAttributesDecoration: return Kind::HitObjectAttribute;
- default: break;
- }
- return Kind::Invalid;
-}
-
-Index LocationTracker::getValue(IRInst* inst, IRDecoration* decoration)
-{
- const Kind kind = getKindFromDecoration(decoration);
- SLANG_RELEASE_ASSERT(kind != Kind::Invalid);
- if (kind == Kind::Invalid)
- {
- return -1;
- }
-
- return getValue(kind, inst, decoration);
-}
-
-Index LocationTracker::getValue(Kind kind, IRInst* inst, IRDecoration* decoration)
-{
- if (decoration->getOperandCount() > 0)
- {
- // TODO(JS):
- // There could be a clash with the auto generated location, and the user set value/
- // Perhaps the implication in practice is that either all are marked or none.
- const int explicitLocation = int(getIntVal(decoration->getOperand(0)));
- if (explicitLocation >= 0)
- return UInt(explicitLocation);
- }
-
- auto& nextValue = m_nextValueForKind[Index(kind)];
-
- const Location defaultLocation{kind, nextValue};
- const Location foundLocation = m_mapIRToLocations.getOrAddValue(inst, defaultLocation);
-
- // Increase if it was the default
- nextValue += Index(defaultLocation == foundLocation);
-
- // Has to match the kind
- return (foundLocation.kind == kind) ? foundLocation.value : -1;
-}
-
/* !!!!!!!!!!!!!!!!!!!!!!!!!!!! CLikeSourceEmitter !!!!!!!!!!!!!!!!!!!!!!!!!! */
/* static */SourceLanguage CLikeSourceEmitter::getSourceLanguage(CodeGenTarget target)
@@ -1242,6 +1192,9 @@ bool CLikeSourceEmitter::shouldFoldInstIntoUseSites(IRInst* inst)
case kIROp_LookupWitness:
case kIROp_GetValueFromBoundInterface:
return true;
+
+ case kIROp_GetVulkanRayTracingPayloadLocation:
+ return true;
}
// Layouts and attributes are only present to annotate other
diff --git a/source/slang/slang-emit-c-like.h b/source/slang/slang-emit-c-like.h
index 976ac8e19..ffb4c0244 100644
--- a/source/slang/slang-emit-c-like.h
+++ b/source/slang/slang-emit-c-like.h
@@ -17,48 +17,6 @@
namespace Slang
{
-class LocationTracker
-{
-public:
- enum class Kind
- {
- Invalid = -1,
- RayPayload, ///< GLSL rayPayload
- CallablePayload, ///< GLSL callableData
- HitObjectAttribute, ///< GLSL hitObjectAttribute
- CountOf,
- };
-
- /// Given a decoration returns the Kind, or Kind::Invalid if that is not appropriate
- static Kind getKindFromDecoration(IRDecoration* decoration);
-
- /// Get the location value associated with inst (and decoration).
- /// Will return -1, if no location is associated
- Index getValue(IRInst* inst, IRDecoration* decoration);
-
- /// Get the location value associated with inst (and decoration).
- /// The kind must match that for the decoration.
- /// Will return -1, if no location is associated
- Index getValue(Kind kind, IRInst* inst, IRDecoration* decoration);
-
-protected:
- struct Location
- {
- typedef Location ThisType;
-
- bool operator==(const ThisType& rhs) const { return kind == rhs.kind && value == rhs.value; }
- bool operator!=(const ThisType& rhs) const { return !(*this == rhs); }
-
- Kind kind; ///< The kind of location
- Index value; ///< The value of the location. Must be >= 0
- };
-
- Index m_nextValueForKind[Count(Kind::CountOf)] = { 0, };
-
- Dictionary<IRInst*, Location> m_mapIRToLocations;
-};
-
-
class CLikeSourceEmitter: public SourceEmitterBase
{
public:
@@ -278,8 +236,6 @@ public:
ComponentType* getProgram() { return m_codeGenContext->getProgram(); }
TargetProgram* getTargetProgram() { return m_codeGenContext->getTargetProgram(); }
- LocationTracker& getLocationTracker() { return m_locationTracker; }
-
//
// Types
//
@@ -614,10 +570,6 @@ public:
// Map an IR instruction to the name that we've decided
// to use for it when emitting code.
Dictionary<IRInst*, String> m_mapInstToName;
-
- // Maps instructions to locations. Used for GLSL output for locations, but could potentially
- // be used for other kinds of location.
- LocationTracker m_locationTracker;
};
}
diff --git a/source/slang/slang-emit-glsl.cpp b/source/slang/slang-emit-glsl.cpp
index ebcad1873..e9472bc96 100644
--- a/source/slang/slang-emit-glsl.cpp
+++ b/source/slang/slang-emit-glsl.cpp
@@ -2594,9 +2594,6 @@ void GLSLSourceEmitter::emitVarDecorationsImpl(IRInst* varDecl)
for (auto decoration : varDecl->getDecorations())
{
- typedef LocationTracker::Kind LocationKind;
-
- LocationKind locationKind = LocationKind::Invalid;
UnownedStringSlice prefix;
if (as<IRVulkanHitAttributesDecoration>(decoration))
{
@@ -2604,37 +2601,34 @@ void GLSLSourceEmitter::emitVarDecorationsImpl(IRInst* varDecl)
}
else
{
- // Handle attributes that have location
- const LocationKind decorationLocationKind = LocationTracker::getKindFromDecoration(decoration);
- if (decorationLocationKind == LocationKind::Invalid)
+ IRIntegerValue locationValue = -1;
+ switch (decoration->getOp())
{
- // Next decoration
+ case kIROp_VulkanCallablePayloadDecoration:
+ prefix = toSlice("callableData");
+ locationValue = getIntVal(decoration->getOperand(0));
+ break;
+ case kIROp_VulkanRayPayloadDecoration:
+ prefix = toSlice("rayPayload");
+ locationValue = getIntVal(decoration->getOperand(0));
+ break;
+ case kIROp_VulkanHitObjectAttributesDecoration:
+ prefix = toSlice("hitObjectAttribute");
+ locationValue = getIntVal(decoration->getOperand(0));
+ break;
+ default:
continue;
}
-
- locationKind = decorationLocationKind;
-
- // Get the location value
- const auto locationValue = m_locationTracker.getValue(locationKind, varDecl, decoration);
-
m_writer->emit(toSlice("layout(location = "));
m_writer->emit(locationValue);
m_writer->emit(toSlice(")\n"));
-
- switch (locationKind)
- {
- case LocationKind::CallablePayload: prefix = toSlice("callableData"); break;
- case LocationKind::HitObjectAttribute: prefix = toSlice("hitObjectAttribute"); break;
- case LocationKind::RayPayload: prefix = toSlice("rayPayload"); break;
- default: break;
- }
}
SLANG_ASSERT(prefix.getLength());
m_writer->emit(prefix);
// Special case hitObjectAttribute as is only NV currently
- if (locationKind == LocationKind::HitObjectAttribute ||
+ if (decoration->getOp() == kIROp_VulkanHitObjectAttributesDecoration ||
getTargetCaps().implies(CapabilityAtom::GL_NV_ray_tracing))
{
m_writer->emit(toSlice("NV"));
diff --git a/source/slang/slang-emit-spirv.cpp b/source/slang/slang-emit-spirv.cpp
index 7cfd9ffad..6bd781e62 100644
--- a/source/slang/slang-emit-spirv.cpp
+++ b/source/slang/slang-emit-spirv.cpp
@@ -1443,6 +1443,8 @@ struct SPIRVEmitContext
return emitGlobalParam(as<IRGlobalParam>(inst));
case kIROp_GlobalVar:
return emitGlobalVar(as<IRGlobalVar>(inst));
+ case kIROp_SPIRVAsmOperandBuiltinVar:
+ return emitBuiltinVar(inst);
case kIROp_Var:
return emitVar(getSection(SpvLogicalSectionID::GlobalVariables), inst);
// ...
@@ -1934,6 +1936,16 @@ struct SPIRVEmitContext
return varInst;
}
+ SpvInst* emitBuiltinVar(IRInst* spvAsmBuiltinVar)
+ {
+ const auto kind = (SpvBuiltIn)(getIntVal(spvAsmBuiltinVar->getOperand(0)));
+ IRBuilder builder(spvAsmBuiltinVar);
+ builder.setInsertBefore(spvAsmBuiltinVar);
+ auto varInst = getBuiltinGlobalVar(builder.getPtrType(kIROp_PtrType, spvAsmBuiltinVar->getDataType(), SpvStorageClassInput), kind);
+ registerInst(spvAsmBuiltinVar, varInst);
+ return varInst;
+ }
+
/// Emit the given `irFunc` to SPIR-V
SpvInst* emitFunc(IRFunc* irFunc)
{
@@ -2575,6 +2587,14 @@ struct SPIRVEmitContext
case Stage::Geometry:
requireSPIRVCapability(SpvCapabilityGeometry);
break;
+ case Stage::Miss:
+ case Stage::AnyHit:
+ case Stage::ClosestHit:
+ case Stage::Intersection:
+ case Stage::RayGeneration:
+ case Stage::Callable:
+ requireSPIRVCapability(SpvCapabilityRayTracingKHR);
+ ensureExtensionDeclaration(UnownedStringSlice("SPV_KHR_ray_tracing"));
default:
break;
}
@@ -2721,6 +2741,15 @@ struct SPIRVEmitContext
);
}
break;
+
+ case kIROp_VulkanCallablePayloadDecoration:
+ case kIROp_VulkanHitObjectAttributesDecoration:
+ case kIROp_VulkanRayPayloadDecoration:
+ emitOpDecorateLocation(getSection(SpvLogicalSectionID::Annotations),
+ decoration,
+ dstID,
+ SpvLiteralInteger::from32(int32_t(getIntVal(decoration->getOperand(0)))));
+ break;
// ...
}
}
@@ -2862,7 +2891,12 @@ struct SPIRVEmitContext
CASE(Compute, GLCompute);
CASE(Mesh, MeshEXT);
CASE(Amplification, TaskEXT);
-
+ CASE(ClosestHit, ClosestHitKHR);
+ CASE(AnyHit, AnyHitKHR);
+ CASE(Callable, CallableKHR);
+ CASE(Miss, MissKHR);
+ CASE(Intersection, IntersectionKHR);
+ CASE(RayGeneration, RayGenerationKHR);
// TODO: Extended execution models for ray tracing, etc.
#undef CASE
@@ -4437,11 +4471,7 @@ struct SPIRVEmitContext
}
case kIROp_SPIRVAsmOperandBuiltinVar:
{
- const auto kind = (SpvBuiltIn)(getIntVal(operand->getOperand(0)));
- IRBuilder builder(operand);
- builder.setInsertBefore(operand);
- auto varInst = getBuiltinGlobalVar(builder.getPtrType(kIROp_PtrType, operand->getDataType(), SpvStorageClassInput), kind);
- emitOperand(varInst);
+ emitOperand(ensureInst(operand));
break;
}
case kIROp_SPIRVAsmOperandGLSL450Set:
diff --git a/source/slang/slang-intrinsic-expand.cpp b/source/slang/slang-intrinsic-expand.cpp
index 8f74b1591..f17f564c8 100644
--- a/source/slang/slang-intrinsic-expand.cpp
+++ b/source/slang/slang-intrinsic-expand.cpp
@@ -739,58 +739,6 @@ const char* IntrinsicExpandContext::_emitSpecial(const char* cursor)
}
break;
- // We will use the `$X` case as a prefix for
- // special logic needed when cross-compiling ray-tracing
- // shaders.
- case 'X':
- {
- typedef LocationTracker::Kind LocationKind;
-
- SLANG_RELEASE_ASSERT(*cursor);
- const auto kindChar = *cursor++;
-
- LocationKind kind = LocationKind::Invalid;
-
- // The `$XP`/`$XC`/`$XH` case handles looking up
- // the associated `location` for a variable
- // used as the argument.
- switch (kindChar)
- {
- case 'P': kind = LocationKind::RayPayload; break;
- case 'C': kind = LocationKind::CallablePayload; break;
- case 'H': kind = LocationKind::HitObjectAttribute; break;
- default: break;
- }
-
- SLANG_ASSERT(kind != LocationKind::Invalid);
-
- if (kind != LocationKind::Invalid)
- {
- Index argIndex = 0;
- SLANG_RELEASE_ASSERT(m_argCount > argIndex);
- auto arg = m_args[argIndex].get();
-
- // Find the associated decoration
- IRDecoration* foundDecoration = nullptr;
- for (auto decoration : arg->getDecorations())
- {
- const auto curKind = LocationTracker::getKindFromDecoration(decoration);
- if (curKind == kind)
- {
- foundDecoration = decoration;
- break;
- }
- }
-
- // Must have found the decoration
- SLANG_ASSERT(foundDecoration);
-
- const auto location = m_emitter->getLocationTracker().getValue(kind, arg, foundDecoration);
- m_writer->emit(location);
- }
- }
- break;
-
case 'P':
// Type-based prefix as used for CUDA and C++ targets
{
diff --git a/source/slang/slang-ir-glsl-legalize.cpp b/source/slang/slang-ir-glsl-legalize.cpp
index 2d24a0577..c4a71a3e9 100644
--- a/source/slang/slang-ir-glsl-legalize.cpp
+++ b/source/slang/slang-ir-glsl-legalize.cpp
@@ -2653,6 +2653,61 @@ bool shouldUseOriginalEntryPointName(CodeGenContext* codeGenContext)
return false;
}
+void assignRayPayloadHitObjectAttributeLocations(IRModule* module)
+{
+ IRIntegerValue rayPayloadCounter = 0;
+ IRIntegerValue callablePayloadCounter = 0;
+ IRIntegerValue hitObjectAttributeCounter = 0;
+
+ IRBuilder builder(module);
+ for (auto inst : module->getGlobalInsts())
+ {
+ auto globalVar = as<IRGlobalVar>(inst);
+ if (!globalVar)
+ continue;
+ IRInst* location = nullptr;
+ for (auto decor : globalVar->getDecorations())
+ {
+ switch (decor->getOp())
+ {
+ case kIROp_VulkanRayPayloadDecoration:
+ builder.setInsertBefore(inst);
+ location = builder.getIntValue(builder.getIntType(), rayPayloadCounter);
+ decor->setOperand(0, location);
+ rayPayloadCounter++;
+ goto end;
+ case kIROp_VulkanCallablePayloadDecoration:
+ builder.setInsertBefore(inst);
+ location = builder.getIntValue(builder.getIntType(), callablePayloadCounter);
+ decor->setOperand(0, location);
+ callablePayloadCounter++;
+ goto end;
+ case kIROp_VulkanHitObjectAttributesDecoration:
+ builder.setInsertBefore(inst);
+ location = builder.getIntValue(builder.getIntType(), hitObjectAttributeCounter);
+ decor->setOperand(0, location);
+ hitObjectAttributeCounter++;
+ goto end;
+ default:
+ break;
+ }
+ }
+ end:;
+ if (location)
+ {
+ traverseUses(globalVar, [&](IRUse* use)
+ {
+ auto user = use->getUser();
+ if (user->getOp() == kIROp_GetVulkanRayTracingPayloadLocation)
+ {
+ user->replaceUsesWith(location);
+ user->removeAndDeallocate();
+ }
+ });
+ }
+ }
+}
+
void legalizeEntryPointForGLSL(
Session* session,
IRModule* module,
@@ -2891,6 +2946,8 @@ void legalizeEntryPointsForGLSL(
{
legalizeEntryPointForGLSL(session, module, func, context, glslExtensionTracker);
}
+
+ assignRayPayloadHitObjectAttributeLocations(module);
}
void legalizeConstantBufferLoadForGLSL(IRModule* module)
diff --git a/source/slang/slang-ir-inst-defs.h b/source/slang/slang-ir-inst-defs.h
index 0406f224e..e64dfdf47 100644
--- a/source/slang/slang-ir-inst-defs.h
+++ b/source/slang/slang-ir-inst-defs.h
@@ -633,6 +633,8 @@ INST(GetOptiXHitAttribute, getOptiXHitAttribute, 2, 0)
// using a pointer.
INST(GetOptiXSbtDataPtr, getOptiXSbtDataPointer, 0, 0)
+INST(GetVulkanRayTracingPayloadLocation, GetVulkanRayTracingPayloadLocation, 1, 0)
+
INST(MakeArrayList, makeArrayList, 0, 0)
INST(MakeTensorView, makeTensorView, 0, 0)
INST(AllocateTorchTensor, allocTorchTensor, 0, 0)
diff --git a/source/slang/slang-ir-lower-buffer-element-type.cpp b/source/slang/slang-ir-lower-buffer-element-type.cpp
index a5227ed68..ffdd4584a 100644
--- a/source/slang/slang-ir-lower-buffer-element-type.cpp
+++ b/source/slang/slang-ir-lower-buffer-element-type.cpp
@@ -560,6 +560,8 @@ namespace Slang
traverseUses(ptrVal, [&](IRUse* use)
{
auto user = use->getUser();
+ if (as<IRDecoration>(user))
+ return;
switch (user->getOp())
{
case kIROp_Load:
diff --git a/source/slang/slang-ir-specialize-target-switch.cpp b/source/slang/slang-ir-specialize-target-switch.cpp
index 2be7c8194..fcd5ca10a 100644
--- a/source/slang/slang-ir-specialize-target-switch.cpp
+++ b/source/slang/slang-ir-specialize-target-switch.cpp
@@ -19,6 +19,8 @@ namespace Slang
for (UInt i = 0; i < targetSwitch->getCaseCount(); i++)
{
auto cap = (CapabilityAtom)getIntVal(targetSwitch->getCaseValue(i));
+ if (target->getTargetCaps().isIncompatibleWith(cap))
+ continue;
CapabilitySet capSet;
if (cap == CapabilityAtom::Invalid)
capSet = CapabilitySet::makeEmpty();
diff --git a/source/slang/slang-ir-spirv-legalize.cpp b/source/slang/slang-ir-spirv-legalize.cpp
index f0517ed98..121452533 100644
--- a/source/slang/slang-ir-spirv-legalize.cpp
+++ b/source/slang/slang-ir-spirv-legalize.cpp
@@ -199,8 +199,11 @@ struct SPIRVLegalizationContext : public SourceEmitterBase
}
else
{
- auto val = builder.emitLoad(addr);
- builder.replaceOperand(use, val);
+ if (!as<IRDecoration>(use->getUser()))
+ {
+ auto val = builder.emitLoad(addr);
+ builder.replaceOperand(use, val);
+ }
}
}
@@ -546,10 +549,10 @@ struct SPIRVLegalizationContext : public SourceEmitterBase
storageClass = SpvStorageClassPushConstant;
break;
case LayoutResourceKind::RayPayload:
- storageClass = SpvStorageClassRayPayloadKHR;
+ storageClass = SpvStorageClassIncomingRayPayloadKHR;
break;
case LayoutResourceKind::CallablePayload:
- storageClass = SpvStorageClassCallableDataKHR;
+ storageClass = SpvStorageClassIncomingCallableDataKHR;
break;
case LayoutResourceKind::HitAttributes:
storageClass = SpvStorageClassHitAttributeKHR;
@@ -565,6 +568,12 @@ struct SPIRVLegalizationContext : public SourceEmitterBase
SpvStorageClass getGlobalParamStorageClass(IRVarLayout* varLayout)
{
+ auto typeLayout = varLayout->getTypeLayout()->unwrapArray();
+ if (auto parameterGroupTypeLayout = as<IRParameterGroupTypeLayout>(typeLayout))
+ {
+ varLayout = parameterGroupTypeLayout->getContainerVarLayout();
+ }
+
SpvStorageClass result = SpvStorageClassMax;
for (auto rr : varLayout->getOffsetAttrs())
{
@@ -612,7 +621,24 @@ struct SPIRVLegalizationContext : public SourceEmitterBase
if (cls != SpvStorageClassMax)
storageClass = cls;
}
-
+ for (auto decor : inst->getDecorations())
+ {
+ switch (decor->getOp())
+ {
+ case kIROp_VulkanRayPayloadDecoration:
+ storageClass = SpvStorageClassRayPayloadKHR;
+ break;
+ case kIROp_VulkanCallablePayloadDecoration:
+ storageClass = SpvStorageClassCallableDataKHR;
+ break;
+ case kIROp_VulkanHitObjectAttributesDecoration:
+ storageClass = SpvStorageClassHitObjectAttributeNV;
+ break;
+ case kIROp_VulkanHitAttributesDecoration:
+ storageClass = SpvStorageClassHitAttributeKHR;
+ break;
+ }
+ }
IRBuilder builder(m_sharedContext->m_irModule);
builder.setInsertBefore(inst);
auto newPtrType =
@@ -1364,6 +1390,7 @@ void buildEntryPointReferenceGraph(SPIRVEmitSharedContext* context, IRModule* mo
switch (inst->getOp())
{
case kIROp_GlobalParam:
+ case kIROp_SPIRVAsmOperandBuiltinVar:
registerEntryPointReference(entryPoint, inst);
break;
case kIROp_Block:
@@ -1393,6 +1420,7 @@ void buildEntryPointReferenceGraph(SPIRVEmitSharedContext* context, IRModule* mo
{
case kIROp_GlobalParam:
case kIROp_GlobalVar:
+ case kIROp_SPIRVAsmOperandBuiltinVar:
addToWorkList({ entryPoint, operand });
break;
}