diff options
| author | Harsh Aggarwal (NVIDIA) <haaggarwal@nvidia.com> | 2025-05-26 21:00:38 +0530 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2025-05-26 15:30:38 +0000 |
| commit | 83538e0b4b97425ecdae6f72f9c8fd44cb255aac (patch) | |
| tree | 8f27c47fb7c1614fa916c2da6ab9996655e29da1 | |
| parent | 8ecb2c70437292ef6fa34f7122df44067de6a4de (diff) | |
Implement shader execution reordering support for OptiX (#7211)
* Implement shader execution reordering support for OptiX
Added OptiX backend support for Shader Execution Reordering (SER) features as outlined in issue #6647. This implementation:
1. Added CUDA target support for HitObject API
2. Implemented core SER functionality (TraceRay, MakeHit/Miss, Invoke)
3. Added OptiX-specific hit object handling functions
4. Added test case for OptiX SER functionality
* format code
---------
Co-authored-by: slangbot <186143334+slangbot@users.noreply.github.com>
| -rw-r--r-- | prelude/slang-cuda-prelude.h | 325 | ||||
| -rw-r--r-- | source/slang/hlsl.meta.slang | 94 | ||||
| -rw-r--r-- | source/slang/slang-emit-cuda.cpp | 5 | ||||
| -rw-r--r-- | tests/cuda/optix-ser.slang | 146 | ||||
| -rw-r--r-- | tests/hlsl-intrinsic/shader-execution-reordering/hit-object-trace-ray.slang | 4 |
5 files changed, 546 insertions, 28 deletions
diff --git a/prelude/slang-cuda-prelude.h b/prelude/slang-cuda-prelude.h index c641025d4..d2c9fce9d 100644 --- a/prelude/slang-cuda-prelude.h +++ b/prelude/slang-cuda-prelude.h @@ -3193,7 +3193,7 @@ static __forceinline__ __device__ void* getOptiXRayPayloadPtr() } template<typename T> -__forceinline__ __device__ void* traceOptiXRay( +__forceinline__ __device__ void* optixTrace( OptixTraversableHandle AccelerationStructure, uint32_t RayFlags, uint32_t InstanceInclusionMask, @@ -3221,8 +3221,329 @@ __forceinline__ __device__ void* traceOptiXRay( r1); } -#endif +template<typename T> +__forceinline__ __device__ void* optixTraverse( + OptixTraversableHandle AccelerationStructure, + uint32_t RayFlags, + uint32_t InstanceInclusionMask, + uint32_t RayContributionToHitGroupIndex, + uint32_t MultiplierForGeometryContributionToHitGroupIndex, + uint32_t MissShaderIndex, + RayDesc Ray, + T* Payload, + OptixTraversableHandle* hitObj) +{ + uint32_t r0, r1; + packOptiXRayPayloadPointer((void*)Payload, r0, r1); + optixTraverse( + AccelerationStructure, + Ray.Origin, + Ray.Direction, + Ray.TMin, + Ray.TMax, + 0.f, /* Time for motion blur, currently unsupported in slang */ + InstanceInclusionMask, + RayFlags, + RayContributionToHitGroupIndex, + MultiplierForGeometryContributionToHitGroupIndex, + MissShaderIndex, + r0, + r1); +} + +template<typename T> +__forceinline__ __device__ void* optixTraverse( + OptixTraversableHandle AccelerationStructure, + uint32_t RayFlags, + uint32_t InstanceInclusionMask, + uint32_t RayContributionToHitGroupIndex, + uint32_t MultiplierForGeometryContributionToHitGroupIndex, + uint32_t MissShaderIndex, + RayDesc Ray, + float RayTime, + T* Payload, + OptixTraversableHandle* hitObj) +{ + uint32_t r0, r1; + packOptiXRayPayloadPointer((void*)Payload, r0, r1); + optixTraverse( + AccelerationStructure, + Ray.Origin, + Ray.Direction, + Ray.TMin, + Ray.TMax, + RayTime, + InstanceInclusionMask, + RayFlags, + RayContributionToHitGroupIndex, + MultiplierForGeometryContributionToHitGroupIndex, + MissShaderIndex, + r0, + r1); +} + +static __forceinline__ __device__ bool optixHitObjectIsHit(OptixTraversableHandle* hitObj) +{ + return optixHitObjectIsHit(); +} + +static __forceinline__ __device__ bool optixHitObjectIsMiss(OptixTraversableHandle* hitObj) +{ + return optixHitObjectIsMiss(); +} + +static __forceinline__ __device__ bool optixHitObjectIsNop(OptixTraversableHandle* hitObj) +{ + return optixHitObjectIsNop(); +} + +static __forceinline__ __device__ uint optixHitObjectGetClusterId(OptixTraversableHandle* hitObj) +{ + return optixHitObjectGetClusterId(); +} + +static __forceinline__ __device__ void optixMakeMissHitObject( + uint MissShaderIndex, + RayDesc Ray, + OptixTraversableHandle* missObj) +{ + + optixMakeMissHitObject( + MissShaderIndex, + Ray.Origin, + Ray.Direction, + Ray.TMin, + Ray.TMax, + 0.f, /* rayTime */ + OPTIX_RAY_FLAG_NONE /* rayFlags*/); +} + +static __forceinline__ __device__ void optixMakeMissHitObject( + uint MissShaderIndex, + RayDesc Ray, + float CurrentTime, + OptixTraversableHandle* missObj) +{ + + optixMakeMissHitObject( + MissShaderIndex, + Ray.Origin, + Ray.Direction, + Ray.TMin, + Ray.TMax, + CurrentTime, + OPTIX_RAY_FLAG_NONE /* rayFlags*/); +} +template<typename T> +static __forceinline__ __device__ void optixMakeHitObject( + OptixTraversableHandle AccelerationStructure, + uint InstanceIndex, + uint GeometryIndex, + uint PrimitiveIndex, + uint HitKind, + uint RayContributionToHitGroupIndex, + uint MultiplierForGeometryContributionToHitGroupIndex, + RayDesc Ray, + T attr, + OptixTraversableHandle* handle) +{ + + OptixTraverseData data{}; + optixHitObjectGetTraverseData(&data); + optixMakeHitObject( + AccelerationStructure, + Ray.Origin, + Ray.Direction, + Ray.TMin, + 0.f, + OPTIX_RAY_FLAG_NONE, /* rayFlags*/ + data, + nullptr, /*OptixTraversableHandle* transforms*/ + 0 /*numTransforms */); +} + +template<typename T> +static __forceinline__ __device__ void optixMakeHitObject( + uint HitGroupRecordIndex, + OptixTraversableHandle AccelerationStructure, + uint InstanceIndex, + uint GeometryIndex, + uint PrimitiveIndex, + uint HitKind, + RayDesc Ray, + T attr, + OptixTraversableHandle* handle) +{ + + OptixTraverseData data{}; + optixHitObjectGetTraverseData(&data); + optixMakeHitObject( + AccelerationStructure, + Ray.Origin, + Ray.Direction, + Ray.TMin, + 0.f, + OPTIX_RAY_FLAG_NONE, /* rayFlags*/ + data, + nullptr, /*OptixTraversableHandle* transforms*/ + 0 /*numTransforms */); +} + +template<typename T> +static __forceinline__ __device__ void optixMakeHitObject( + OptixTraversableHandle AccelerationStructure, + uint InstanceIndex, + uint GeometryIndex, + uint PrimitiveIndex, + uint HitKind, + uint RayContributionToHitGroupIndex, + uint MultiplierForGeometryContributionToHitGroupIndex, + RayDesc Ray, + float CurrentTime, + T attr, + OptixTraversableHandle* handle) +{ + + OptixTraverseData data{}; + optixHitObjectGetTraverseData(&data); + optixMakeHitObject( + AccelerationStructure, + Ray.Origin, + Ray.Direction, + Ray.TMin, + CurrentTime, + OPTIX_RAY_FLAG_NONE, /* rayFlags*/ + data, + nullptr, /*OptixTraversableHandle* transforms*/ + 0 /*numTransforms */); +} + +template<typename T> +static __forceinline__ __device__ void optixMakeHitObject( + uint HitGroupRecordIndex, + OptixTraversableHandle AccelerationStructure, + uint InstanceIndex, + uint GeometryIndex, + uint PrimitiveIndex, + uint HitKind, + RayDesc Ray, + float CurrentTime, + T attr, + OptixTraversableHandle* handle) +{ + + OptixTraverseData data{}; + optixHitObjectGetTraverseData(&data); + optixMakeHitObject( + AccelerationStructure, + Ray.Origin, + Ray.Direction, + Ray.TMin, + CurrentTime, + OPTIX_RAY_FLAG_NONE, /* rayFlags*/ + data, + nullptr, /*OptixTraversableHandle* transforms*/ + 0 /*numTransforms */); +} + +static __forceinline__ __device__ void optixMakeNopHitObject(OptixTraversableHandle* Obj) +{ + optixMakeNopHitObject(); +} + +template<typename T> +static __forceinline__ __device__ void optixInvoke( + OptixTraversableHandle AccelerationStructure, + OptixTraversableHandle* HitOrMiss, + T Payload) +{ + uint32_t r0, r1; + packOptiXRayPayloadPointer((void*)Payload, r0, r1); + optixInvoke(r0, r1); +} +static __forceinline__ __device__ RayDesc optixHitObjectGetRayDesc(OptixTraversableHandle* obj) +{ + RayDesc ray = { + optixHitObjectGetWorldRayOrigin(), + optixHitObjectGetRayTmin(), + optixHitObjectGetWorldRayDirection(), + optixHitObjectGetRayTmax()}; + return ray; +} + +static __forceinline__ __device__ uint optixHitObjectGetInstanceIndex(OptixTraversableHandle* Obj) +{ + return optixHitObjectGetInstanceIndex(); +} + +static __forceinline__ __device__ uint optixHitObjectGetInstanceId(OptixTraversableHandle* Obj) +{ + return optixHitObjectGetInstanceId(); +} + +static __forceinline__ __device__ uint optixHitObjectGetSbtGASIndex(OptixTraversableHandle* Obj) +{ + return optixHitObjectGetSbtGASIndex(); +} + +static __forceinline__ __device__ uint optixHitObjectGetPrimitiveIndex(OptixTraversableHandle* Obj) +{ + return optixHitObjectGetPrimitiveIndex(); +} + +template<typename T> +static __forceinline__ __device__ T optixHitObjectGetAttribute(OptixTraversableHandle* Obj) +{ + constexpr size_t numInts = (sizeof(T) + sizeof(uint32_t) - 1) / + sizeof(uint32_t); // Number of 32-bit values, rounded up + static_assert(numInts <= 8, "Attribute type is too large"); + + // Create an array to hold the attribute values + uint32_t values[numInts == 0 ? 1 : numInts] = {0}; // Ensure we have at least one element + + // Read the appropriate number of attribute registers + if constexpr (numInts > 0) + values[0] = optixHitObjectGetAttribute_0(); + if constexpr (numInts > 1) + values[1] = optixHitObjectGetAttribute_1(); + if constexpr (numInts > 2) + values[2] = optixHitObjectGetAttribute_2(); + if constexpr (numInts > 3) + values[3] = optixHitObjectGetAttribute_3(); + if constexpr (numInts > 4) + values[4] = optixHitObjectGetAttribute_4(); + if constexpr (numInts > 5) + values[5] = optixHitObjectGetAttribute_5(); + if constexpr (numInts > 6) + values[6] = optixHitObjectGetAttribute_6(); + if constexpr (numInts > 7) + values[7] = optixHitObjectGetAttribute_7(); + + // Reinterpret the array as the desired type + T result; + memcpy(&result, values, sizeof(T)); + return result; +} + +static __forceinline__ __device__ uint optixHitObjectGetSbtRecordIndex(OptixTraversableHandle* Obj) +{ + return optixHitObjectGetSbtRecordIndex(); +} + +static __forceinline__ __device__ uint +optixHitObjectSetSbtRecordIndex(OptixTraversableHandle* Obj, uint sbtRecordIndex) +{ + optixHitObjectSetSbtRecordIndex(sbtRecordIndex); // returns void + return 0; +} +static __forceinline__ __device__ uint +optixHitObjectGetSbtDataPointer(OptixTraversableHandle* Obj, uint sbtRecordIndex) +{ + optixHitObjectGetSbtDataPointer(); // returns void + return 0; +} +#endif static const int kSlangTorchTensorMaxDim = 5; // TensorView diff --git a/source/slang/hlsl.meta.slang b/source/slang/hlsl.meta.slang index cb050dd51..fd7c7cfc7 100644 --- a/source/slang/hlsl.meta.slang +++ b/source/slang/hlsl.meta.slang @@ -17250,7 +17250,7 @@ void TraceRay( Ray, __forceVarIntoRayPayloadStructTemporarily(Payload)); return; - case cuda: __intrinsic_asm "traceOptiXRay"; + case cuda: __intrinsic_asm "optixTrace"; case glsl: { [__vulkanRayPayload] @@ -19576,7 +19576,7 @@ struct HitObject /// Executes ray traversal (including anyhit and intersection shaders) like TraceRay, but returns the /// resulting hit information as a HitObject and does not trigger closesthit or miss shaders. [ForceInline] - [require(glsl_hlsl_spirv, ser_raygen_closesthit_miss)] + [require(cuda_glsl_hlsl_spirv, ser_raygen_closesthit_miss)] static HitObject TraceRay<payload_t>( RaytracingAccelerationStructure AccelerationStructure, uint RayFlags, @@ -19629,6 +19629,7 @@ struct HitObject // Write the payload out Payload = p; } + case cuda: __intrinsic_asm "optixTraverse"; case spirv: { [__vulkanRayPayload] @@ -19669,7 +19670,7 @@ struct HitObject /// Executes motion ray traversal (including anyhit and intersection shaders) like TraceRay, but returns the /// resulting hit information as a HitObject and does not trigger closesthit or miss shaders. [ForceInline] - [require(glsl_hlsl_spirv, ser_motion_raygen_closesthit_miss)] + [require(cuda_glsl_hlsl_spirv, ser_motion_raygen_closesthit_miss)] static HitObject TraceMotionRay<payload_t>( RaytracingAccelerationStructure AccelerationStructure, uint RayFlags, @@ -19720,6 +19721,7 @@ struct HitObject // Write the payload out Payload = p; } + case cuda: __intrinsic_asm "optixTraverse"; case spirv: { [__vulkanRayPayload] @@ -19768,7 +19770,7 @@ struct HitObject /// Attributes parameter must either be an attribute struct, such as /// BuiltInTriangleIntersectionAttributes, or another HitObject to copy the attributes from. [ForceInline] - [require(glsl_hlsl_spirv, ser_raygen_closesthit_miss)] + [require(cuda_glsl_hlsl_spirv, ser_raygen_closesthit_miss)] static HitObject MakeHit<attr_t>( RaytracingAccelerationStructure AccelerationStructure, uint InstanceIndex, @@ -19816,6 +19818,7 @@ struct HitObject Ray.TMax, __hitObjectAttributesLocation(__hitObjectAttributes<attr_t>())); } + case cuda: __intrinsic_asm "optixMakeHitObject"; case spirv: { // Save the attributes @@ -19853,7 +19856,7 @@ struct HitObject /// See MakeHit but handles Motion /// Currently only supported on VK [ForceInline] - [require(glsl_hlsl_spirv, ser_motion_raygen_closesthit_miss)] + [require(cuda_glsl_hlsl_spirv, ser_motion_raygen_closesthit_miss)] static HitObject MakeMotionHit<attr_t>( RaytracingAccelerationStructure AccelerationStructure, uint InstanceIndex, @@ -19890,6 +19893,7 @@ struct HitObject CurrentTime, __hitObjectAttributesLocation(__hitObjectAttributes<attr_t>())); } + case cuda: __intrinsic_asm "optixMakeHitObject"; case spirv: { // Save the attributes @@ -19935,7 +19939,7 @@ struct HitObject /// attribute struct, such as BuiltInTriangleIntersectionAttributes, or another HitObject to copy the /// attributes from. [ForceInline] - [require(glsl_hlsl_spirv, ser_raygen_closesthit_miss)] + [require(cuda_glsl_hlsl_spirv, ser_raygen_closesthit_miss)] static HitObject MakeHit<attr_t>( uint HitGroupRecordIndex, RaytracingAccelerationStructure AccelerationStructure, @@ -19980,6 +19984,7 @@ struct HitObject Ray.TMax, __hitObjectAttributesLocation(__hitObjectAttributes<attr_t>())); } + case cuda: __intrinsic_asm "optixMakeHitObject"; case spirv: { // Save the attributes @@ -20013,7 +20018,7 @@ struct HitObject /// See MakeHit but handles Motion /// Currently only supported on VK [ForceInline] - [require(glsl_spirv, ser_motion_raygen_closesthit_miss)] + [require(cuda_glsl_spirv, ser_motion_raygen_closesthit_miss)] static HitObject MakeMotionHit<attr_t>( uint HitGroupRecordIndex, RaytracingAccelerationStructure AccelerationStructure, @@ -20047,6 +20052,7 @@ struct HitObject CurrentTime, __hitObjectAttributesLocation(__hitObjectAttributes<attr_t>())); } + case cuda: __intrinsic_asm "optixMakeHitObject"; case spirv: { // Save the attributes @@ -20084,7 +20090,7 @@ struct HitObject /// table. [__requiresNVAPI] [ForceInline] - [require(glsl_hlsl_spirv, ser_raygen_closesthit_miss)] + [require(cuda_glsl_hlsl_spirv, ser_raygen_closesthit_miss)] static HitObject MakeMiss( uint MissShaderIndex, RayDesc Ray) @@ -20094,6 +20100,7 @@ struct HitObject case hlsl: __intrinsic_asm "($2=NvMakeMiss($0,$1))"; case glsl: __glslMakeMiss(__return_val, MissShaderIndex, Ray.Origin, Ray.TMin, Ray.Direction, Ray.TMax); + case cuda: __intrinsic_asm "optixMakeMissHitObject"; case spirv: { let origin = Ray.Origin; @@ -20119,7 +20126,7 @@ struct HitObject /// See MakeMiss but handles Motion /// Currently only supported on VK [ForceInline] - [require(glsl_hlsl_spirv, ser_motion_raygen_closesthit_miss)] + [require(cuda_glsl_hlsl_spirv, ser_motion_raygen_closesthit_miss)] static HitObject MakeMotionMiss( uint MissShaderIndex, RayDesc Ray, @@ -20130,6 +20137,7 @@ struct HitObject case hlsl: __intrinsic_asm "($3=NvMakeMotionMiss($0,$1,$2))"; case glsl: __glslMakeMotionMiss(__return_val, MissShaderIndex, Ray.Origin, Ray.TMin, Ray.Direction, Ray.TMax, CurrentTime); + case cuda: __intrinsic_asm "optixMakeMissHitObject"; case spirv: { let origin = Ray.Origin; @@ -20162,7 +20170,7 @@ struct HitObject /// miss. [__requiresNVAPI] [ForceInline] - [require(glsl_hlsl_spirv, ser_raygen_closesthit_miss)] + [require(cuda_glsl_hlsl_spirv, ser_raygen_closesthit_miss)] static HitObject MakeNop() { __target_switch @@ -20171,6 +20179,7 @@ struct HitObject __intrinsic_asm "($0 = NvMakeNop())"; case glsl: __glslMakeNop(__return_val); + case cuda: __intrinsic_asm "optixMakeNopHitObject"; case spirv: spirv_asm { @@ -20199,7 +20208,7 @@ struct HitObject /// shader is invoked. [__requiresNVAPI] [ForceInline] - [require(glsl_hlsl_spirv, ser_raygen_closesthit_miss)] + [require(cuda_glsl_hlsl_spirv, ser_raygen_closesthit_miss)] static void Invoke<payload_t>( RaytracingAccelerationStructure AccelerationStructure, HitObject HitOrMiss, @@ -20225,6 +20234,7 @@ struct HitObject // Write payload result Payload = p; } + case cuda: __intrinsic_asm "optixInvoke"; case spirv: { [__vulkanRayPayload] @@ -20251,13 +20261,14 @@ struct HitObject /// Returns true if the HitObject encodes a miss, otherwise returns false. [__requiresNVAPI] [ForceInline] - [require(glsl_hlsl_spirv, ser_raygen_closesthit_miss)] + [require(cuda_glsl_hlsl_spirv, ser_raygen_closesthit_miss)] bool IsMiss() { __target_switch { case hlsl: __intrinsic_asm ".IsMiss"; case glsl: __intrinsic_asm "hitObjectIsMissNV($0)"; + case cuda: __intrinsic_asm "optixHitObjectIsMiss"; case spirv: return spirv_asm { @@ -20271,13 +20282,14 @@ struct HitObject /// Returns true if the HitObject encodes a hit, otherwise returns false. [__requiresNVAPI] [ForceInline] - [require(glsl_hlsl_spirv, ser_raygen_closesthit_miss)] + [require(cuda_glsl_hlsl_spirv, ser_raygen_closesthit_miss)] bool IsHit() { __target_switch { case hlsl: __intrinsic_asm ".IsHit"; case glsl: __intrinsic_asm "hitObjectIsHitNV($0)"; + case cuda: __intrinsic_asm "optixHitObjectIsHit"; case spirv: return spirv_asm { @@ -20291,13 +20303,14 @@ struct HitObject /// Returns true if the HitObject encodes a nop, otherwise returns false. [__requiresNVAPI] [ForceInline] - [require(glsl_hlsl_spirv, ser_raygen_closesthit_miss)] + [require(cuda_glsl_hlsl_spirv, ser_raygen_closesthit_miss)] bool IsNop() { __target_switch { case hlsl: __intrinsic_asm ".IsNop"; case glsl: __intrinsic_asm "hitObjectIsEmptyNV($0)"; + case cuda: __intrinsic_asm "optixHitObjectIsNop"; case spirv: return spirv_asm { @@ -20311,7 +20324,7 @@ struct HitObject /// Queries ray properties from HitObject. Valid if the hit object represents a hit or a miss. [__requiresNVAPI] [ForceInline] - [require(glsl_hlsl_spirv, ser_raygen_closesthit_miss)] + [require(cuda_glsl_hlsl_spirv, ser_raygen_closesthit_miss)] RayDesc GetRayDesc() { __target_switch @@ -20323,6 +20336,7 @@ struct HitObject RayDesc ray = { __glslGetRayWorldOrigin(), __glslGetTMin(), __glslGetRayWorldDirection(), __glslGetTMax() }; return ray; } + case cuda: __intrinsic_asm "optixHitObjectGetRayDesc"; case spirv: return spirv_asm { @@ -20341,13 +20355,14 @@ struct HitObject [__requiresNVAPI] __glsl_extension(GL_EXT_ray_tracing) [ForceInline] - [require(glsl_hlsl_spirv, ser_raygen_closesthit_miss)] + [require(cuda_glsl_hlsl_spirv, ser_raygen_closesthit_miss)] uint GetShaderTableIndex() { __target_switch { case hlsl: __intrinsic_asm ".GetShaderTableIndex"; case glsl: __intrinsic_asm "hitObjectGetShaderBindingTableRecordIndexNV($0)"; + case cuda: __intrinsic_asm "optixHitObjectGetSbtRecordIndex"; case spirv: return spirv_asm { @@ -20358,17 +20373,41 @@ struct HitObject } } + [__requiresNVAPI] + __glsl_extension(GL_EXT_ray_tracing) + [ForceInline] + [require(cuda_hlsl, ser_raygen_closesthit_miss)] + uint SetShaderTableIndex(uint RecordIndex) + { + __target_switch + { + case hlsl: __intrinsic_asm ".SetShaderTableIndex"; + case cuda: __intrinsic_asm "optixHitObjectSetSbtRecordIndex"; + } + } + + // TODO - Add other targets [__requiresNVAPI] __glsl_extension(GL_EXT_ray_tracing) + [ForceInline] + [require(cuda, ser_raygen_closesthit_miss)] + uint LoadLocalRootArgumentsConstant(uint RootConstantOffsetInBytes) + { + __target_switch + { + case cuda: __intrinsic_asm "optixHitObjectGetSbtDataPointer"; + } + } /// Returns the instance index of a hit. Valid if the hit object represents a hit. [__requiresNVAPI] __glsl_extension(GL_EXT_ray_tracing) [ForceInline] - [require(glsl_hlsl_spirv, ser_raygen_closesthit_miss)] + [require(cuda_glsl_hlsl_spirv, ser_raygen_closesthit_miss)] uint GetInstanceIndex() { __target_switch { case hlsl: __intrinsic_asm ".GetInstanceIndex"; case glsl: __intrinsic_asm "hitObjectGetInstanceIdNV($0)"; + case cuda: __intrinsic_asm "optixHitObjectGetInstanceIndex"; case spirv: return spirv_asm { @@ -20383,13 +20422,14 @@ struct HitObject [__requiresNVAPI] __glsl_extension(GL_EXT_ray_tracing) [ForceInline] - [require(glsl_hlsl_spirv, ser_raygen_closesthit_miss)] + [require(cuda_glsl_hlsl_spirv, ser_raygen_closesthit_miss)] uint GetInstanceID() { __target_switch { case hlsl: __intrinsic_asm ".GetInstanceID"; case glsl: __intrinsic_asm "hitObjectGetInstanceCustomIndexNV($0)"; + case cuda: __intrinsic_asm "optixHitObjectGetInstanceId"; case spirv: return spirv_asm { @@ -20404,13 +20444,14 @@ struct HitObject [__requiresNVAPI] __glsl_extension(GL_EXT_ray_tracing) [ForceInline] - [require(glsl_hlsl_spirv, ser_raygen_closesthit_miss)] + [require(cuda_glsl_hlsl_spirv, ser_raygen_closesthit_miss)] uint GetGeometryIndex() { __target_switch { case hlsl: __intrinsic_asm ".GetGeometryIndex"; case glsl: __intrinsic_asm "hitObjectGetGeometryIndexNV($0)"; + case cuda: __intrinsic_asm "optixHitObjectGetSbtGASIndex"; case spirv: return spirv_asm { @@ -20425,13 +20466,14 @@ struct HitObject [__requiresNVAPI] __glsl_extension(GL_EXT_ray_tracing) [ForceInline] - [require(glsl_hlsl_spirv, ser_raygen_closesthit_miss)] + [require(cuda_glsl_hlsl_spirv, ser_raygen_closesthit_miss)] uint GetPrimitiveIndex() { __target_switch { case hlsl: __intrinsic_asm ".GetPrimitiveIndex"; case glsl: __intrinsic_asm "hitObjectGetPrimitiveIndexNV($0)"; + case cuda: __intrinsic_asm "optixHitObjectGetPrimitiveIndex"; case spirv: return spirv_asm { @@ -20596,7 +20638,7 @@ struct HitObject /// Returns the attributes of a hit. Valid if the hit object represents a hit or a miss. [ForceInline] - [require(glsl_hlsl_spirv, ser_raygen_closesthit_miss)] + [require(cuda_glsl_hlsl_spirv, ser_raygen_closesthit_miss)] attr_t GetAttributes<attr_t>() { __target_switch @@ -20618,6 +20660,7 @@ struct HitObject // Return the attributes return __hitObjectAttributes<attr_t>(); } + case cuda: __intrinsic_asm "optixHitObjectGetAttribute<$TR>($0)"; case spirv: { __Addr<attr_t> attr = __allocHitObjectAttributes<attr_t>(); @@ -21008,13 +21051,14 @@ struct HitObject __glsl_extension(GL_EXT_ray_tracing) __glsl_extension(GL_NV_shader_invocation_reorder) [ForceInline] -[require(glsl_hlsl_spirv, ser_raygen)] +[require(cuda_glsl_hlsl_spirv, ser_raygen)] void ReorderThread( uint CoherenceHint, uint NumCoherenceHintBitsFromLSB ) { __target_switch { case hlsl: __intrinsic_asm "NvReorderThread"; case glsl: __intrinsic_asm "reorderThreadNV"; + case cuda: __intrinsic_asm "optixReorder"; case spirv: spirv_asm { @@ -21045,13 +21089,14 @@ void ReorderThread( uint CoherenceHint, uint NumCoherenceHintBitsFromLSB ) __glsl_extension(GL_EXT_ray_tracing) __glsl_extension(GL_NV_shader_invocation_reorder) [ForceInline] -[require(glsl_hlsl_spirv, ser_raygen)] +[require(cuda_glsl_hlsl_spirv, ser_raygen)] void ReorderThread( HitObject HitOrMiss, uint CoherenceHint, uint NumCoherenceHintBitsFromLSB ) { __target_switch { case hlsl: __intrinsic_asm "NvReorderThread"; case glsl: __intrinsic_asm "reorderThreadNV"; + case cuda: __intrinsic_asm "optixReorder($1, $2)"; case spirv: spirv_asm { @@ -21072,13 +21117,14 @@ void ReorderThread( HitObject HitOrMiss, uint CoherenceHint, uint NumCoherenceHi __glsl_extension(GL_EXT_ray_tracing) __glsl_extension(GL_NV_shader_invocation_reorder) [ForceInline] -[require(glsl_hlsl_spirv, ser_raygen)] +[require(cuda_glsl_hlsl_spirv, ser_raygen)] void ReorderThread( HitObject HitOrMiss ) { __target_switch { case hlsl: __intrinsic_asm "NvReorderThread"; case glsl: __intrinsic_asm "reorderThreadNV"; + case cuda: __intrinsic_asm "optixReorder()"; case spirv: spirv_asm { diff --git a/source/slang/slang-emit-cuda.cpp b/source/slang/slang-emit-cuda.cpp index 74133fcf0..e5169ba38 100644 --- a/source/slang/slang-emit-cuda.cpp +++ b/source/slang/slang-emit-cuda.cpp @@ -214,6 +214,11 @@ SlangResult CUDASourceEmitter::calcTypeName(IRType* type, CodeGenTarget target, out << "TensorView"; return SLANG_OK; } + case kIROp_HitObjectType: + { + out << "OptixTraversableHandle"; + return SLANG_OK; + } default: { if (isNominalOp(type->getOp())) diff --git a/tests/cuda/optix-ser.slang b/tests/cuda/optix-ser.slang new file mode 100644 index 000000000..54f300706 --- /dev/null +++ b/tests/cuda/optix-ser.slang @@ -0,0 +1,146 @@ +// optix-ser.slang + + +//TEST:SIMPLE(filecheck=CHECK): -target cuda -entry rayGenerationMain -stage raygeneration + +//TEST_INPUT: set scene = AccelerationStructure +uniform RaytracingAccelerationStructure scene; + +//TEST_INPUT:set outputBuffer = out ubuffer(data=[0, 0, 0, 0], stride=4) +RWStructuredBuffer<uint> outputBuffer; + +struct SomeValues +{ + int a; + float b; +}; + +uint calcValue(HitObject hit) +{ + uint r = 0; + + if (hit.IsHit()) + { + uint instanceIndex = hit.GetInstanceIndex(); + uint instanceID = hit.GetInstanceID(); + uint geometryIndex = hit.GetGeometryIndex(); + uint primitiveIndex = hit.GetPrimitiveIndex(); + int clusterID = hit.GetClusterID(); + uint shaderTableIndex = hit.GetShaderTableIndex(); + // spriv and glsl lack these methods + uint setShaderTableIndex = hit.SetShaderTableIndex(0); + uint ialbedo = hit.LoadLocalRootTableConstant(0); + SomeValues objSomeValues = hit.GetAttributes<SomeValues>(); + + r += instanceIndex; + r += instanceID; + r += geometryIndex; + r += primitiveIndex; + r += objSomeValues.a; + r += clusterID; + r += shaderTableIndex; + r += setShaderTableIndex; + r += ialbedo; + } + + return r; +} + +void rayGenerationMain() +{ + int2 launchID = int2(DispatchRaysIndex().xy); + int2 launchSize = int2(DispatchRaysDimensions().xy); + + int idx = launchID.x; + + SomeValues someValues = { idx, idx * 2.0f }; + + RayDesc ray; + ray.Origin = float3(idx, 0, 0); + ray.TMin = 0.01f; + ray.Direction = float3(0, 1, 0); + ray.TMax = 1e4f; + + RAY_FLAG rayFlags = RAY_FLAG_ACCEPT_FIRST_HIT_AND_END_SEARCH | RAY_FLAG_CULL_BACK_FACING_TRIANGLES; + uint instanceInclusionMask = 0xff; + uint rayContributionToHitGroupIndex = 0; + uint multiplierForGeometryContributionToHitGroupIndex = 4; + uint missShaderIndex = 0; + // SPIRV: OpHitObjectTraceRayNV + // CHECK: optixTraverse + HitObject hit = HitObject::TraceRay(scene, + rayFlags, + instanceInclusionMask, + rayContributionToHitGroupIndex, + multiplierForGeometryContributionToHitGroupIndex, + missShaderIndex, + ray, + someValues); + + ReorderThread( hit ); + ReorderThread(hit, uint(idx & 3), 2); + ReorderThread(uint(idx & 1), 1); + + outputBuffer[idx] = calcValue(hit); + HitObject miss[2]; + miss[0] = HitObject::MakeMiss(0u, ray); + miss[1] = HitObject::MakeMotionMiss(0u, ray, 1.f); + + uint hitGroupRecordIndex = 0; + uint instanceIndex = 0xff; + uint geometryIndex = 0; + uint primitiveIndex = 0; + uint hitKind = 0; + BuiltInTriangleIntersectionAttributes attr = {0.01f, 0.2f}; + + HitObject hitObj = HitObject::MakeHit(hitGroupRecordIndex, scene, + instanceIndex, + geometryIndex, + primitiveIndex, + hitKind, + ray, + attr); + HitObject nopObj = HitObject::MakeNop(); + outputBuffer[idx] = uint(nopObj.IsNop()); + + outputBuffer[idx] += calcValue(hit); + outputBuffer[idx] += calcValue(miss[0]); + outputBuffer[idx] += calcValue(miss[1]); + outputBuffer[idx] += calcValue(hitObj); + outputBuffer[idx] += calcValue(nopObj); + + // Change the payload + SomeValues otherValues = { idx * -2, idx * 8.0f }; + + HitObject::Invoke( scene, hit, otherValues ); + HitObject motionHitObj[2]; + motionHitObj[0] = HitObject::MakeMotionHit( + scene, + instanceIndex, + geometryIndex, + primitiveIndex, + hitKind, + rayContributionToHitGroupIndex, + multiplierForGeometryContributionToHitGroupIndex, + ray, + 0.f, + attr); + motionHitObj[1] = HitObject::MakeMotionHit( + hitGroupRecordIndex, + scene, + instanceIndex, + geometryIndex, + primitiveIndex, + hitKind, + ray, + 0.f, + attr); + outputBuffer[idx] += calcValue(motionHitObj[0]); + outputBuffer[idx] += calcValue(motionHitObj[1]); + + RayDesc rayD = hit.GetRayDesc(); + + outputBuffer[idx] += uint(rayD.TMin > 0); + outputBuffer[idx] += uint(rayD.TMax < ray.TMin); + +} diff --git a/tests/hlsl-intrinsic/shader-execution-reordering/hit-object-trace-ray.slang b/tests/hlsl-intrinsic/shader-execution-reordering/hit-object-trace-ray.slang index 877e41977..71c113934 100644 --- a/tests/hlsl-intrinsic/shader-execution-reordering/hit-object-trace-ray.slang +++ b/tests/hlsl-intrinsic/shader-execution-reordering/hit-object-trace-ray.slang @@ -4,8 +4,7 @@ //TEST:SIMPLE(filecheck=SPIRV): -target spirv -entry rayGenerationMain -stage raygeneration -profile sm_6_5 -line-directive-mode none //TEST:SIMPLE(filecheck=SPIRV): -target spirv -entry rayGenerationMain -stage raygeneration -emit-spirv-directly -// Note: HitObject::TraceRay is not supported in raygen stage for cuda target -//DISABLE_TEST:SIMPLE: -target cuda -entry rayGenerationMain -stage raygeneration +//TEST:SIMPLE(filecheck=CHECK): -target cuda -entry rayGenerationMain -stage raygeneration //DISABLE_TEST(compute):COMPARE_COMPUTE:-d3d12 -output-using-type -use-dxil -profile sm_6_6 -render-feature ray-query //DISABLE_TEST(compute):COMPARE_COMPUTE:-vk -output-using-type -render-feature ray-query @@ -68,6 +67,7 @@ void rayGenerationMain() uint multiplierForGeometryContributionToHitGroupIndex = 4; uint missShaderIndex = 0; // SPIRV: OpHitObjectTraceRayNV + // CHECK: optixTraverse HitObject hit = HitObject::TraceRay(scene, rayFlags, instanceInclusionMask, |
