summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorHarsh Aggarwal (NVIDIA) <haaggarwal@nvidia.com>2025-05-26 21:00:38 +0530
committerGitHub <noreply@github.com>2025-05-26 15:30:38 +0000
commit83538e0b4b97425ecdae6f72f9c8fd44cb255aac (patch)
tree8f27c47fb7c1614fa916c2da6ab9996655e29da1
parent8ecb2c70437292ef6fa34f7122df44067de6a4de (diff)
Implement shader execution reordering support for OptiX (#7211)
* Implement shader execution reordering support for OptiX Added OptiX backend support for Shader Execution Reordering (SER) features as outlined in issue #6647. This implementation: 1. Added CUDA target support for HitObject API 2. Implemented core SER functionality (TraceRay, MakeHit/Miss, Invoke) 3. Added OptiX-specific hit object handling functions 4. Added test case for OptiX SER functionality * format code --------- Co-authored-by: slangbot <186143334+slangbot@users.noreply.github.com>
-rw-r--r--prelude/slang-cuda-prelude.h325
-rw-r--r--source/slang/hlsl.meta.slang94
-rw-r--r--source/slang/slang-emit-cuda.cpp5
-rw-r--r--tests/cuda/optix-ser.slang146
-rw-r--r--tests/hlsl-intrinsic/shader-execution-reordering/hit-object-trace-ray.slang4
5 files changed, 546 insertions, 28 deletions
diff --git a/prelude/slang-cuda-prelude.h b/prelude/slang-cuda-prelude.h
index c641025d4..d2c9fce9d 100644
--- a/prelude/slang-cuda-prelude.h
+++ b/prelude/slang-cuda-prelude.h
@@ -3193,7 +3193,7 @@ static __forceinline__ __device__ void* getOptiXRayPayloadPtr()
}
template<typename T>
-__forceinline__ __device__ void* traceOptiXRay(
+__forceinline__ __device__ void* optixTrace(
OptixTraversableHandle AccelerationStructure,
uint32_t RayFlags,
uint32_t InstanceInclusionMask,
@@ -3221,8 +3221,329 @@ __forceinline__ __device__ void* traceOptiXRay(
r1);
}
-#endif
+template<typename T>
+__forceinline__ __device__ void* optixTraverse(
+ OptixTraversableHandle AccelerationStructure,
+ uint32_t RayFlags,
+ uint32_t InstanceInclusionMask,
+ uint32_t RayContributionToHitGroupIndex,
+ uint32_t MultiplierForGeometryContributionToHitGroupIndex,
+ uint32_t MissShaderIndex,
+ RayDesc Ray,
+ T* Payload,
+ OptixTraversableHandle* hitObj)
+{
+ uint32_t r0, r1;
+ packOptiXRayPayloadPointer((void*)Payload, r0, r1);
+ optixTraverse(
+ AccelerationStructure,
+ Ray.Origin,
+ Ray.Direction,
+ Ray.TMin,
+ Ray.TMax,
+ 0.f, /* Time for motion blur, currently unsupported in slang */
+ InstanceInclusionMask,
+ RayFlags,
+ RayContributionToHitGroupIndex,
+ MultiplierForGeometryContributionToHitGroupIndex,
+ MissShaderIndex,
+ r0,
+ r1);
+}
+
+template<typename T>
+__forceinline__ __device__ void* optixTraverse(
+ OptixTraversableHandle AccelerationStructure,
+ uint32_t RayFlags,
+ uint32_t InstanceInclusionMask,
+ uint32_t RayContributionToHitGroupIndex,
+ uint32_t MultiplierForGeometryContributionToHitGroupIndex,
+ uint32_t MissShaderIndex,
+ RayDesc Ray,
+ float RayTime,
+ T* Payload,
+ OptixTraversableHandle* hitObj)
+{
+ uint32_t r0, r1;
+ packOptiXRayPayloadPointer((void*)Payload, r0, r1);
+ optixTraverse(
+ AccelerationStructure,
+ Ray.Origin,
+ Ray.Direction,
+ Ray.TMin,
+ Ray.TMax,
+ RayTime,
+ InstanceInclusionMask,
+ RayFlags,
+ RayContributionToHitGroupIndex,
+ MultiplierForGeometryContributionToHitGroupIndex,
+ MissShaderIndex,
+ r0,
+ r1);
+}
+
+static __forceinline__ __device__ bool optixHitObjectIsHit(OptixTraversableHandle* hitObj)
+{
+ return optixHitObjectIsHit();
+}
+
+static __forceinline__ __device__ bool optixHitObjectIsMiss(OptixTraversableHandle* hitObj)
+{
+ return optixHitObjectIsMiss();
+}
+
+static __forceinline__ __device__ bool optixHitObjectIsNop(OptixTraversableHandle* hitObj)
+{
+ return optixHitObjectIsNop();
+}
+
+static __forceinline__ __device__ uint optixHitObjectGetClusterId(OptixTraversableHandle* hitObj)
+{
+ return optixHitObjectGetClusterId();
+}
+
+static __forceinline__ __device__ void optixMakeMissHitObject(
+ uint MissShaderIndex,
+ RayDesc Ray,
+ OptixTraversableHandle* missObj)
+{
+
+ optixMakeMissHitObject(
+ MissShaderIndex,
+ Ray.Origin,
+ Ray.Direction,
+ Ray.TMin,
+ Ray.TMax,
+ 0.f, /* rayTime */
+ OPTIX_RAY_FLAG_NONE /* rayFlags*/);
+}
+
+static __forceinline__ __device__ void optixMakeMissHitObject(
+ uint MissShaderIndex,
+ RayDesc Ray,
+ float CurrentTime,
+ OptixTraversableHandle* missObj)
+{
+
+ optixMakeMissHitObject(
+ MissShaderIndex,
+ Ray.Origin,
+ Ray.Direction,
+ Ray.TMin,
+ Ray.TMax,
+ CurrentTime,
+ OPTIX_RAY_FLAG_NONE /* rayFlags*/);
+}
+template<typename T>
+static __forceinline__ __device__ void optixMakeHitObject(
+ OptixTraversableHandle AccelerationStructure,
+ uint InstanceIndex,
+ uint GeometryIndex,
+ uint PrimitiveIndex,
+ uint HitKind,
+ uint RayContributionToHitGroupIndex,
+ uint MultiplierForGeometryContributionToHitGroupIndex,
+ RayDesc Ray,
+ T attr,
+ OptixTraversableHandle* handle)
+{
+
+ OptixTraverseData data{};
+ optixHitObjectGetTraverseData(&data);
+ optixMakeHitObject(
+ AccelerationStructure,
+ Ray.Origin,
+ Ray.Direction,
+ Ray.TMin,
+ 0.f,
+ OPTIX_RAY_FLAG_NONE, /* rayFlags*/
+ data,
+ nullptr, /*OptixTraversableHandle* transforms*/
+ 0 /*numTransforms */);
+}
+
+template<typename T>
+static __forceinline__ __device__ void optixMakeHitObject(
+ uint HitGroupRecordIndex,
+ OptixTraversableHandle AccelerationStructure,
+ uint InstanceIndex,
+ uint GeometryIndex,
+ uint PrimitiveIndex,
+ uint HitKind,
+ RayDesc Ray,
+ T attr,
+ OptixTraversableHandle* handle)
+{
+
+ OptixTraverseData data{};
+ optixHitObjectGetTraverseData(&data);
+ optixMakeHitObject(
+ AccelerationStructure,
+ Ray.Origin,
+ Ray.Direction,
+ Ray.TMin,
+ 0.f,
+ OPTIX_RAY_FLAG_NONE, /* rayFlags*/
+ data,
+ nullptr, /*OptixTraversableHandle* transforms*/
+ 0 /*numTransforms */);
+}
+
+template<typename T>
+static __forceinline__ __device__ void optixMakeHitObject(
+ OptixTraversableHandle AccelerationStructure,
+ uint InstanceIndex,
+ uint GeometryIndex,
+ uint PrimitiveIndex,
+ uint HitKind,
+ uint RayContributionToHitGroupIndex,
+ uint MultiplierForGeometryContributionToHitGroupIndex,
+ RayDesc Ray,
+ float CurrentTime,
+ T attr,
+ OptixTraversableHandle* handle)
+{
+
+ OptixTraverseData data{};
+ optixHitObjectGetTraverseData(&data);
+ optixMakeHitObject(
+ AccelerationStructure,
+ Ray.Origin,
+ Ray.Direction,
+ Ray.TMin,
+ CurrentTime,
+ OPTIX_RAY_FLAG_NONE, /* rayFlags*/
+ data,
+ nullptr, /*OptixTraversableHandle* transforms*/
+ 0 /*numTransforms */);
+}
+
+template<typename T>
+static __forceinline__ __device__ void optixMakeHitObject(
+ uint HitGroupRecordIndex,
+ OptixTraversableHandle AccelerationStructure,
+ uint InstanceIndex,
+ uint GeometryIndex,
+ uint PrimitiveIndex,
+ uint HitKind,
+ RayDesc Ray,
+ float CurrentTime,
+ T attr,
+ OptixTraversableHandle* handle)
+{
+
+ OptixTraverseData data{};
+ optixHitObjectGetTraverseData(&data);
+ optixMakeHitObject(
+ AccelerationStructure,
+ Ray.Origin,
+ Ray.Direction,
+ Ray.TMin,
+ CurrentTime,
+ OPTIX_RAY_FLAG_NONE, /* rayFlags*/
+ data,
+ nullptr, /*OptixTraversableHandle* transforms*/
+ 0 /*numTransforms */);
+}
+
+static __forceinline__ __device__ void optixMakeNopHitObject(OptixTraversableHandle* Obj)
+{
+ optixMakeNopHitObject();
+}
+
+template<typename T>
+static __forceinline__ __device__ void optixInvoke(
+ OptixTraversableHandle AccelerationStructure,
+ OptixTraversableHandle* HitOrMiss,
+ T Payload)
+{
+ uint32_t r0, r1;
+ packOptiXRayPayloadPointer((void*)Payload, r0, r1);
+ optixInvoke(r0, r1);
+}
+static __forceinline__ __device__ RayDesc optixHitObjectGetRayDesc(OptixTraversableHandle* obj)
+{
+ RayDesc ray = {
+ optixHitObjectGetWorldRayOrigin(),
+ optixHitObjectGetRayTmin(),
+ optixHitObjectGetWorldRayDirection(),
+ optixHitObjectGetRayTmax()};
+ return ray;
+}
+
+static __forceinline__ __device__ uint optixHitObjectGetInstanceIndex(OptixTraversableHandle* Obj)
+{
+ return optixHitObjectGetInstanceIndex();
+}
+
+static __forceinline__ __device__ uint optixHitObjectGetInstanceId(OptixTraversableHandle* Obj)
+{
+ return optixHitObjectGetInstanceId();
+}
+
+static __forceinline__ __device__ uint optixHitObjectGetSbtGASIndex(OptixTraversableHandle* Obj)
+{
+ return optixHitObjectGetSbtGASIndex();
+}
+
+static __forceinline__ __device__ uint optixHitObjectGetPrimitiveIndex(OptixTraversableHandle* Obj)
+{
+ return optixHitObjectGetPrimitiveIndex();
+}
+
+template<typename T>
+static __forceinline__ __device__ T optixHitObjectGetAttribute(OptixTraversableHandle* Obj)
+{
+ constexpr size_t numInts = (sizeof(T) + sizeof(uint32_t) - 1) /
+ sizeof(uint32_t); // Number of 32-bit values, rounded up
+ static_assert(numInts <= 8, "Attribute type is too large");
+
+ // Create an array to hold the attribute values
+ uint32_t values[numInts == 0 ? 1 : numInts] = {0}; // Ensure we have at least one element
+
+ // Read the appropriate number of attribute registers
+ if constexpr (numInts > 0)
+ values[0] = optixHitObjectGetAttribute_0();
+ if constexpr (numInts > 1)
+ values[1] = optixHitObjectGetAttribute_1();
+ if constexpr (numInts > 2)
+ values[2] = optixHitObjectGetAttribute_2();
+ if constexpr (numInts > 3)
+ values[3] = optixHitObjectGetAttribute_3();
+ if constexpr (numInts > 4)
+ values[4] = optixHitObjectGetAttribute_4();
+ if constexpr (numInts > 5)
+ values[5] = optixHitObjectGetAttribute_5();
+ if constexpr (numInts > 6)
+ values[6] = optixHitObjectGetAttribute_6();
+ if constexpr (numInts > 7)
+ values[7] = optixHitObjectGetAttribute_7();
+
+ // Reinterpret the array as the desired type
+ T result;
+ memcpy(&result, values, sizeof(T));
+ return result;
+}
+
+static __forceinline__ __device__ uint optixHitObjectGetSbtRecordIndex(OptixTraversableHandle* Obj)
+{
+ return optixHitObjectGetSbtRecordIndex();
+}
+
+static __forceinline__ __device__ uint
+optixHitObjectSetSbtRecordIndex(OptixTraversableHandle* Obj, uint sbtRecordIndex)
+{
+ optixHitObjectSetSbtRecordIndex(sbtRecordIndex); // returns void
+ return 0;
+}
+static __forceinline__ __device__ uint
+optixHitObjectGetSbtDataPointer(OptixTraversableHandle* Obj, uint sbtRecordIndex)
+{
+ optixHitObjectGetSbtDataPointer(); // returns void
+ return 0;
+}
+#endif
static const int kSlangTorchTensorMaxDim = 5;
// TensorView
diff --git a/source/slang/hlsl.meta.slang b/source/slang/hlsl.meta.slang
index cb050dd51..fd7c7cfc7 100644
--- a/source/slang/hlsl.meta.slang
+++ b/source/slang/hlsl.meta.slang
@@ -17250,7 +17250,7 @@ void TraceRay(
Ray,
__forceVarIntoRayPayloadStructTemporarily(Payload));
return;
- case cuda: __intrinsic_asm "traceOptiXRay";
+ case cuda: __intrinsic_asm "optixTrace";
case glsl:
{
[__vulkanRayPayload]
@@ -19576,7 +19576,7 @@ struct HitObject
/// Executes ray traversal (including anyhit and intersection shaders) like TraceRay, but returns the
/// resulting hit information as a HitObject and does not trigger closesthit or miss shaders.
[ForceInline]
- [require(glsl_hlsl_spirv, ser_raygen_closesthit_miss)]
+ [require(cuda_glsl_hlsl_spirv, ser_raygen_closesthit_miss)]
static HitObject TraceRay<payload_t>(
RaytracingAccelerationStructure AccelerationStructure,
uint RayFlags,
@@ -19629,6 +19629,7 @@ struct HitObject
// Write the payload out
Payload = p;
}
+ case cuda: __intrinsic_asm "optixTraverse";
case spirv:
{
[__vulkanRayPayload]
@@ -19669,7 +19670,7 @@ struct HitObject
/// Executes motion ray traversal (including anyhit and intersection shaders) like TraceRay, but returns the
/// resulting hit information as a HitObject and does not trigger closesthit or miss shaders.
[ForceInline]
- [require(glsl_hlsl_spirv, ser_motion_raygen_closesthit_miss)]
+ [require(cuda_glsl_hlsl_spirv, ser_motion_raygen_closesthit_miss)]
static HitObject TraceMotionRay<payload_t>(
RaytracingAccelerationStructure AccelerationStructure,
uint RayFlags,
@@ -19720,6 +19721,7 @@ struct HitObject
// Write the payload out
Payload = p;
}
+ case cuda: __intrinsic_asm "optixTraverse";
case spirv:
{
[__vulkanRayPayload]
@@ -19768,7 +19770,7 @@ struct HitObject
/// Attributes parameter must either be an attribute struct, such as
/// BuiltInTriangleIntersectionAttributes, or another HitObject to copy the attributes from.
[ForceInline]
- [require(glsl_hlsl_spirv, ser_raygen_closesthit_miss)]
+ [require(cuda_glsl_hlsl_spirv, ser_raygen_closesthit_miss)]
static HitObject MakeHit<attr_t>(
RaytracingAccelerationStructure AccelerationStructure,
uint InstanceIndex,
@@ -19816,6 +19818,7 @@ struct HitObject
Ray.TMax,
__hitObjectAttributesLocation(__hitObjectAttributes<attr_t>()));
}
+ case cuda: __intrinsic_asm "optixMakeHitObject";
case spirv:
{
// Save the attributes
@@ -19853,7 +19856,7 @@ struct HitObject
/// See MakeHit but handles Motion
/// Currently only supported on VK
[ForceInline]
- [require(glsl_hlsl_spirv, ser_motion_raygen_closesthit_miss)]
+ [require(cuda_glsl_hlsl_spirv, ser_motion_raygen_closesthit_miss)]
static HitObject MakeMotionHit<attr_t>(
RaytracingAccelerationStructure AccelerationStructure,
uint InstanceIndex,
@@ -19890,6 +19893,7 @@ struct HitObject
CurrentTime,
__hitObjectAttributesLocation(__hitObjectAttributes<attr_t>()));
}
+ case cuda: __intrinsic_asm "optixMakeHitObject";
case spirv:
{
// Save the attributes
@@ -19935,7 +19939,7 @@ struct HitObject
/// attribute struct, such as BuiltInTriangleIntersectionAttributes, or another HitObject to copy the
/// attributes from.
[ForceInline]
- [require(glsl_hlsl_spirv, ser_raygen_closesthit_miss)]
+ [require(cuda_glsl_hlsl_spirv, ser_raygen_closesthit_miss)]
static HitObject MakeHit<attr_t>(
uint HitGroupRecordIndex,
RaytracingAccelerationStructure AccelerationStructure,
@@ -19980,6 +19984,7 @@ struct HitObject
Ray.TMax,
__hitObjectAttributesLocation(__hitObjectAttributes<attr_t>()));
}
+ case cuda: __intrinsic_asm "optixMakeHitObject";
case spirv:
{
// Save the attributes
@@ -20013,7 +20018,7 @@ struct HitObject
/// See MakeHit but handles Motion
/// Currently only supported on VK
[ForceInline]
- [require(glsl_spirv, ser_motion_raygen_closesthit_miss)]
+ [require(cuda_glsl_spirv, ser_motion_raygen_closesthit_miss)]
static HitObject MakeMotionHit<attr_t>(
uint HitGroupRecordIndex,
RaytracingAccelerationStructure AccelerationStructure,
@@ -20047,6 +20052,7 @@ struct HitObject
CurrentTime,
__hitObjectAttributesLocation(__hitObjectAttributes<attr_t>()));
}
+ case cuda: __intrinsic_asm "optixMakeHitObject";
case spirv:
{
// Save the attributes
@@ -20084,7 +20090,7 @@ struct HitObject
/// table.
[__requiresNVAPI]
[ForceInline]
- [require(glsl_hlsl_spirv, ser_raygen_closesthit_miss)]
+ [require(cuda_glsl_hlsl_spirv, ser_raygen_closesthit_miss)]
static HitObject MakeMiss(
uint MissShaderIndex,
RayDesc Ray)
@@ -20094,6 +20100,7 @@ struct HitObject
case hlsl: __intrinsic_asm "($2=NvMakeMiss($0,$1))";
case glsl:
__glslMakeMiss(__return_val, MissShaderIndex, Ray.Origin, Ray.TMin, Ray.Direction, Ray.TMax);
+ case cuda: __intrinsic_asm "optixMakeMissHitObject";
case spirv:
{
let origin = Ray.Origin;
@@ -20119,7 +20126,7 @@ struct HitObject
/// See MakeMiss but handles Motion
/// Currently only supported on VK
[ForceInline]
- [require(glsl_hlsl_spirv, ser_motion_raygen_closesthit_miss)]
+ [require(cuda_glsl_hlsl_spirv, ser_motion_raygen_closesthit_miss)]
static HitObject MakeMotionMiss(
uint MissShaderIndex,
RayDesc Ray,
@@ -20130,6 +20137,7 @@ struct HitObject
case hlsl: __intrinsic_asm "($3=NvMakeMotionMiss($0,$1,$2))";
case glsl:
__glslMakeMotionMiss(__return_val, MissShaderIndex, Ray.Origin, Ray.TMin, Ray.Direction, Ray.TMax, CurrentTime);
+ case cuda: __intrinsic_asm "optixMakeMissHitObject";
case spirv:
{
let origin = Ray.Origin;
@@ -20162,7 +20170,7 @@ struct HitObject
/// miss.
[__requiresNVAPI]
[ForceInline]
- [require(glsl_hlsl_spirv, ser_raygen_closesthit_miss)]
+ [require(cuda_glsl_hlsl_spirv, ser_raygen_closesthit_miss)]
static HitObject MakeNop()
{
__target_switch
@@ -20171,6 +20179,7 @@ struct HitObject
__intrinsic_asm "($0 = NvMakeNop())";
case glsl:
__glslMakeNop(__return_val);
+ case cuda: __intrinsic_asm "optixMakeNopHitObject";
case spirv:
spirv_asm
{
@@ -20199,7 +20208,7 @@ struct HitObject
/// shader is invoked.
[__requiresNVAPI]
[ForceInline]
- [require(glsl_hlsl_spirv, ser_raygen_closesthit_miss)]
+ [require(cuda_glsl_hlsl_spirv, ser_raygen_closesthit_miss)]
static void Invoke<payload_t>(
RaytracingAccelerationStructure AccelerationStructure,
HitObject HitOrMiss,
@@ -20225,6 +20234,7 @@ struct HitObject
// Write payload result
Payload = p;
}
+ case cuda: __intrinsic_asm "optixInvoke";
case spirv:
{
[__vulkanRayPayload]
@@ -20251,13 +20261,14 @@ struct HitObject
/// Returns true if the HitObject encodes a miss, otherwise returns false.
[__requiresNVAPI]
[ForceInline]
- [require(glsl_hlsl_spirv, ser_raygen_closesthit_miss)]
+ [require(cuda_glsl_hlsl_spirv, ser_raygen_closesthit_miss)]
bool IsMiss()
{
__target_switch
{
case hlsl: __intrinsic_asm ".IsMiss";
case glsl: __intrinsic_asm "hitObjectIsMissNV($0)";
+ case cuda: __intrinsic_asm "optixHitObjectIsMiss";
case spirv:
return spirv_asm
{
@@ -20271,13 +20282,14 @@ struct HitObject
/// Returns true if the HitObject encodes a hit, otherwise returns false.
[__requiresNVAPI]
[ForceInline]
- [require(glsl_hlsl_spirv, ser_raygen_closesthit_miss)]
+ [require(cuda_glsl_hlsl_spirv, ser_raygen_closesthit_miss)]
bool IsHit()
{
__target_switch
{
case hlsl: __intrinsic_asm ".IsHit";
case glsl: __intrinsic_asm "hitObjectIsHitNV($0)";
+ case cuda: __intrinsic_asm "optixHitObjectIsHit";
case spirv:
return spirv_asm
{
@@ -20291,13 +20303,14 @@ struct HitObject
/// Returns true if the HitObject encodes a nop, otherwise returns false.
[__requiresNVAPI]
[ForceInline]
- [require(glsl_hlsl_spirv, ser_raygen_closesthit_miss)]
+ [require(cuda_glsl_hlsl_spirv, ser_raygen_closesthit_miss)]
bool IsNop()
{
__target_switch
{
case hlsl: __intrinsic_asm ".IsNop";
case glsl: __intrinsic_asm "hitObjectIsEmptyNV($0)";
+ case cuda: __intrinsic_asm "optixHitObjectIsNop";
case spirv:
return spirv_asm
{
@@ -20311,7 +20324,7 @@ struct HitObject
/// Queries ray properties from HitObject. Valid if the hit object represents a hit or a miss.
[__requiresNVAPI]
[ForceInline]
- [require(glsl_hlsl_spirv, ser_raygen_closesthit_miss)]
+ [require(cuda_glsl_hlsl_spirv, ser_raygen_closesthit_miss)]
RayDesc GetRayDesc()
{
__target_switch
@@ -20323,6 +20336,7 @@ struct HitObject
RayDesc ray = { __glslGetRayWorldOrigin(), __glslGetTMin(), __glslGetRayWorldDirection(), __glslGetTMax() };
return ray;
}
+ case cuda: __intrinsic_asm "optixHitObjectGetRayDesc";
case spirv:
return spirv_asm
{
@@ -20341,13 +20355,14 @@ struct HitObject
[__requiresNVAPI]
__glsl_extension(GL_EXT_ray_tracing)
[ForceInline]
- [require(glsl_hlsl_spirv, ser_raygen_closesthit_miss)]
+ [require(cuda_glsl_hlsl_spirv, ser_raygen_closesthit_miss)]
uint GetShaderTableIndex()
{
__target_switch
{
case hlsl: __intrinsic_asm ".GetShaderTableIndex";
case glsl: __intrinsic_asm "hitObjectGetShaderBindingTableRecordIndexNV($0)";
+ case cuda: __intrinsic_asm "optixHitObjectGetSbtRecordIndex";
case spirv:
return spirv_asm
{
@@ -20358,17 +20373,41 @@ struct HitObject
}
}
+ [__requiresNVAPI]
+ __glsl_extension(GL_EXT_ray_tracing)
+ [ForceInline]
+ [require(cuda_hlsl, ser_raygen_closesthit_miss)]
+ uint SetShaderTableIndex(uint RecordIndex)
+ {
+ __target_switch
+ {
+ case hlsl: __intrinsic_asm ".SetShaderTableIndex";
+ case cuda: __intrinsic_asm "optixHitObjectSetSbtRecordIndex";
+ }
+ }
+
+ // TODO - Add other targets [__requiresNVAPI] __glsl_extension(GL_EXT_ray_tracing)
+ [ForceInline]
+ [require(cuda, ser_raygen_closesthit_miss)]
+ uint LoadLocalRootArgumentsConstant(uint RootConstantOffsetInBytes)
+ {
+ __target_switch
+ {
+ case cuda: __intrinsic_asm "optixHitObjectGetSbtDataPointer";
+ }
+ }
/// Returns the instance index of a hit. Valid if the hit object represents a hit.
[__requiresNVAPI]
__glsl_extension(GL_EXT_ray_tracing)
[ForceInline]
- [require(glsl_hlsl_spirv, ser_raygen_closesthit_miss)]
+ [require(cuda_glsl_hlsl_spirv, ser_raygen_closesthit_miss)]
uint GetInstanceIndex()
{
__target_switch
{
case hlsl: __intrinsic_asm ".GetInstanceIndex";
case glsl: __intrinsic_asm "hitObjectGetInstanceIdNV($0)";
+ case cuda: __intrinsic_asm "optixHitObjectGetInstanceIndex";
case spirv:
return spirv_asm
{
@@ -20383,13 +20422,14 @@ struct HitObject
[__requiresNVAPI]
__glsl_extension(GL_EXT_ray_tracing)
[ForceInline]
- [require(glsl_hlsl_spirv, ser_raygen_closesthit_miss)]
+ [require(cuda_glsl_hlsl_spirv, ser_raygen_closesthit_miss)]
uint GetInstanceID()
{
__target_switch
{
case hlsl: __intrinsic_asm ".GetInstanceID";
case glsl: __intrinsic_asm "hitObjectGetInstanceCustomIndexNV($0)";
+ case cuda: __intrinsic_asm "optixHitObjectGetInstanceId";
case spirv:
return spirv_asm
{
@@ -20404,13 +20444,14 @@ struct HitObject
[__requiresNVAPI]
__glsl_extension(GL_EXT_ray_tracing)
[ForceInline]
- [require(glsl_hlsl_spirv, ser_raygen_closesthit_miss)]
+ [require(cuda_glsl_hlsl_spirv, ser_raygen_closesthit_miss)]
uint GetGeometryIndex()
{
__target_switch
{
case hlsl: __intrinsic_asm ".GetGeometryIndex";
case glsl: __intrinsic_asm "hitObjectGetGeometryIndexNV($0)";
+ case cuda: __intrinsic_asm "optixHitObjectGetSbtGASIndex";
case spirv:
return spirv_asm
{
@@ -20425,13 +20466,14 @@ struct HitObject
[__requiresNVAPI]
__glsl_extension(GL_EXT_ray_tracing)
[ForceInline]
- [require(glsl_hlsl_spirv, ser_raygen_closesthit_miss)]
+ [require(cuda_glsl_hlsl_spirv, ser_raygen_closesthit_miss)]
uint GetPrimitiveIndex()
{
__target_switch
{
case hlsl: __intrinsic_asm ".GetPrimitiveIndex";
case glsl: __intrinsic_asm "hitObjectGetPrimitiveIndexNV($0)";
+ case cuda: __intrinsic_asm "optixHitObjectGetPrimitiveIndex";
case spirv:
return spirv_asm
{
@@ -20596,7 +20638,7 @@ struct HitObject
/// Returns the attributes of a hit. Valid if the hit object represents a hit or a miss.
[ForceInline]
- [require(glsl_hlsl_spirv, ser_raygen_closesthit_miss)]
+ [require(cuda_glsl_hlsl_spirv, ser_raygen_closesthit_miss)]
attr_t GetAttributes<attr_t>()
{
__target_switch
@@ -20618,6 +20660,7 @@ struct HitObject
// Return the attributes
return __hitObjectAttributes<attr_t>();
}
+ case cuda: __intrinsic_asm "optixHitObjectGetAttribute<$TR>($0)";
case spirv:
{
__Addr<attr_t> attr = __allocHitObjectAttributes<attr_t>();
@@ -21008,13 +21051,14 @@ struct HitObject
__glsl_extension(GL_EXT_ray_tracing)
__glsl_extension(GL_NV_shader_invocation_reorder)
[ForceInline]
-[require(glsl_hlsl_spirv, ser_raygen)]
+[require(cuda_glsl_hlsl_spirv, ser_raygen)]
void ReorderThread( uint CoherenceHint, uint NumCoherenceHintBitsFromLSB )
{
__target_switch
{
case hlsl: __intrinsic_asm "NvReorderThread";
case glsl: __intrinsic_asm "reorderThreadNV";
+ case cuda: __intrinsic_asm "optixReorder";
case spirv:
spirv_asm
{
@@ -21045,13 +21089,14 @@ void ReorderThread( uint CoherenceHint, uint NumCoherenceHintBitsFromLSB )
__glsl_extension(GL_EXT_ray_tracing)
__glsl_extension(GL_NV_shader_invocation_reorder)
[ForceInline]
-[require(glsl_hlsl_spirv, ser_raygen)]
+[require(cuda_glsl_hlsl_spirv, ser_raygen)]
void ReorderThread( HitObject HitOrMiss, uint CoherenceHint, uint NumCoherenceHintBitsFromLSB )
{
__target_switch
{
case hlsl: __intrinsic_asm "NvReorderThread";
case glsl: __intrinsic_asm "reorderThreadNV";
+ case cuda: __intrinsic_asm "optixReorder($1, $2)";
case spirv:
spirv_asm
{
@@ -21072,13 +21117,14 @@ void ReorderThread( HitObject HitOrMiss, uint CoherenceHint, uint NumCoherenceHi
__glsl_extension(GL_EXT_ray_tracing)
__glsl_extension(GL_NV_shader_invocation_reorder)
[ForceInline]
-[require(glsl_hlsl_spirv, ser_raygen)]
+[require(cuda_glsl_hlsl_spirv, ser_raygen)]
void ReorderThread( HitObject HitOrMiss )
{
__target_switch
{
case hlsl: __intrinsic_asm "NvReorderThread";
case glsl: __intrinsic_asm "reorderThreadNV";
+ case cuda: __intrinsic_asm "optixReorder()";
case spirv:
spirv_asm
{
diff --git a/source/slang/slang-emit-cuda.cpp b/source/slang/slang-emit-cuda.cpp
index 74133fcf0..e5169ba38 100644
--- a/source/slang/slang-emit-cuda.cpp
+++ b/source/slang/slang-emit-cuda.cpp
@@ -214,6 +214,11 @@ SlangResult CUDASourceEmitter::calcTypeName(IRType* type, CodeGenTarget target,
out << "TensorView";
return SLANG_OK;
}
+ case kIROp_HitObjectType:
+ {
+ out << "OptixTraversableHandle";
+ return SLANG_OK;
+ }
default:
{
if (isNominalOp(type->getOp()))
diff --git a/tests/cuda/optix-ser.slang b/tests/cuda/optix-ser.slang
new file mode 100644
index 000000000..54f300706
--- /dev/null
+++ b/tests/cuda/optix-ser.slang
@@ -0,0 +1,146 @@
+// optix-ser.slang
+
+
+//TEST:SIMPLE(filecheck=CHECK): -target cuda -entry rayGenerationMain -stage raygeneration
+
+//TEST_INPUT: set scene = AccelerationStructure
+uniform RaytracingAccelerationStructure scene;
+
+//TEST_INPUT:set outputBuffer = out ubuffer(data=[0, 0, 0, 0], stride=4)
+RWStructuredBuffer<uint> outputBuffer;
+
+struct SomeValues
+{
+ int a;
+ float b;
+};
+
+uint calcValue(HitObject hit)
+{
+ uint r = 0;
+
+ if (hit.IsHit())
+ {
+ uint instanceIndex = hit.GetInstanceIndex();
+ uint instanceID = hit.GetInstanceID();
+ uint geometryIndex = hit.GetGeometryIndex();
+ uint primitiveIndex = hit.GetPrimitiveIndex();
+ int clusterID = hit.GetClusterID();
+ uint shaderTableIndex = hit.GetShaderTableIndex();
+ // spriv and glsl lack these methods
+ uint setShaderTableIndex = hit.SetShaderTableIndex(0);
+ uint ialbedo = hit.LoadLocalRootTableConstant(0);
+ SomeValues objSomeValues = hit.GetAttributes<SomeValues>();
+
+ r += instanceIndex;
+ r += instanceID;
+ r += geometryIndex;
+ r += primitiveIndex;
+ r += objSomeValues.a;
+ r += clusterID;
+ r += shaderTableIndex;
+ r += setShaderTableIndex;
+ r += ialbedo;
+ }
+
+ return r;
+}
+
+void rayGenerationMain()
+{
+ int2 launchID = int2(DispatchRaysIndex().xy);
+ int2 launchSize = int2(DispatchRaysDimensions().xy);
+
+ int idx = launchID.x;
+
+ SomeValues someValues = { idx, idx * 2.0f };
+
+ RayDesc ray;
+ ray.Origin = float3(idx, 0, 0);
+ ray.TMin = 0.01f;
+ ray.Direction = float3(0, 1, 0);
+ ray.TMax = 1e4f;
+
+ RAY_FLAG rayFlags = RAY_FLAG_ACCEPT_FIRST_HIT_AND_END_SEARCH | RAY_FLAG_CULL_BACK_FACING_TRIANGLES;
+ uint instanceInclusionMask = 0xff;
+ uint rayContributionToHitGroupIndex = 0;
+ uint multiplierForGeometryContributionToHitGroupIndex = 4;
+ uint missShaderIndex = 0;
+ // SPIRV: OpHitObjectTraceRayNV
+ // CHECK: optixTraverse
+ HitObject hit = HitObject::TraceRay(scene,
+ rayFlags,
+ instanceInclusionMask,
+ rayContributionToHitGroupIndex,
+ multiplierForGeometryContributionToHitGroupIndex,
+ missShaderIndex,
+ ray,
+ someValues);
+
+ ReorderThread( hit );
+ ReorderThread(hit, uint(idx & 3), 2);
+ ReorderThread(uint(idx & 1), 1);
+
+ outputBuffer[idx] = calcValue(hit);
+ HitObject miss[2];
+ miss[0] = HitObject::MakeMiss(0u, ray);
+ miss[1] = HitObject::MakeMotionMiss(0u, ray, 1.f);
+
+ uint hitGroupRecordIndex = 0;
+ uint instanceIndex = 0xff;
+ uint geometryIndex = 0;
+ uint primitiveIndex = 0;
+ uint hitKind = 0;
+ BuiltInTriangleIntersectionAttributes attr = {0.01f, 0.2f};
+
+ HitObject hitObj = HitObject::MakeHit(hitGroupRecordIndex, scene,
+ instanceIndex,
+ geometryIndex,
+ primitiveIndex,
+ hitKind,
+ ray,
+ attr);
+ HitObject nopObj = HitObject::MakeNop();
+ outputBuffer[idx] = uint(nopObj.IsNop());
+
+ outputBuffer[idx] += calcValue(hit);
+ outputBuffer[idx] += calcValue(miss[0]);
+ outputBuffer[idx] += calcValue(miss[1]);
+ outputBuffer[idx] += calcValue(hitObj);
+ outputBuffer[idx] += calcValue(nopObj);
+
+ // Change the payload
+ SomeValues otherValues = { idx * -2, idx * 8.0f };
+
+ HitObject::Invoke( scene, hit, otherValues );
+ HitObject motionHitObj[2];
+ motionHitObj[0] = HitObject::MakeMotionHit(
+ scene,
+ instanceIndex,
+ geometryIndex,
+ primitiveIndex,
+ hitKind,
+ rayContributionToHitGroupIndex,
+ multiplierForGeometryContributionToHitGroupIndex,
+ ray,
+ 0.f,
+ attr);
+ motionHitObj[1] = HitObject::MakeMotionHit(
+ hitGroupRecordIndex,
+ scene,
+ instanceIndex,
+ geometryIndex,
+ primitiveIndex,
+ hitKind,
+ ray,
+ 0.f,
+ attr);
+ outputBuffer[idx] += calcValue(motionHitObj[0]);
+ outputBuffer[idx] += calcValue(motionHitObj[1]);
+
+ RayDesc rayD = hit.GetRayDesc();
+
+ outputBuffer[idx] += uint(rayD.TMin > 0);
+ outputBuffer[idx] += uint(rayD.TMax < ray.TMin);
+
+}
diff --git a/tests/hlsl-intrinsic/shader-execution-reordering/hit-object-trace-ray.slang b/tests/hlsl-intrinsic/shader-execution-reordering/hit-object-trace-ray.slang
index 877e41977..71c113934 100644
--- a/tests/hlsl-intrinsic/shader-execution-reordering/hit-object-trace-ray.slang
+++ b/tests/hlsl-intrinsic/shader-execution-reordering/hit-object-trace-ray.slang
@@ -4,8 +4,7 @@
//TEST:SIMPLE(filecheck=SPIRV): -target spirv -entry rayGenerationMain -stage raygeneration -profile sm_6_5 -line-directive-mode none
//TEST:SIMPLE(filecheck=SPIRV): -target spirv -entry rayGenerationMain -stage raygeneration -emit-spirv-directly
-// Note: HitObject::TraceRay is not supported in raygen stage for cuda target
-//DISABLE_TEST:SIMPLE: -target cuda -entry rayGenerationMain -stage raygeneration
+//TEST:SIMPLE(filecheck=CHECK): -target cuda -entry rayGenerationMain -stage raygeneration
//DISABLE_TEST(compute):COMPARE_COMPUTE:-d3d12 -output-using-type -use-dxil -profile sm_6_6 -render-feature ray-query
//DISABLE_TEST(compute):COMPARE_COMPUTE:-vk -output-using-type -render-feature ray-query
@@ -68,6 +67,7 @@ void rayGenerationMain()
uint multiplierForGeometryContributionToHitGroupIndex = 4;
uint missShaderIndex = 0;
// SPIRV: OpHitObjectTraceRayNV
+ // CHECK: optixTraverse
HitObject hit = HitObject::TraceRay(scene,
rayFlags,
instanceInclusionMask,