summaryrefslogtreecommitdiffstats
path: root/prelude
diff options
context:
space:
mode:
authorHarsh Aggarwal (NVIDIA) <haaggarwal@nvidia.com>2025-05-26 21:00:38 +0530
committerGitHub <noreply@github.com>2025-05-26 15:30:38 +0000
commit83538e0b4b97425ecdae6f72f9c8fd44cb255aac (patch)
tree8f27c47fb7c1614fa916c2da6ab9996655e29da1 /prelude
parent8ecb2c70437292ef6fa34f7122df44067de6a4de (diff)
Implement shader execution reordering support for OptiX (#7211)
* Implement shader execution reordering support for OptiX Added OptiX backend support for Shader Execution Reordering (SER) features as outlined in issue #6647. This implementation: 1. Added CUDA target support for HitObject API 2. Implemented core SER functionality (TraceRay, MakeHit/Miss, Invoke) 3. Added OptiX-specific hit object handling functions 4. Added test case for OptiX SER functionality * format code --------- Co-authored-by: slangbot <186143334+slangbot@users.noreply.github.com>
Diffstat (limited to 'prelude')
-rw-r--r--prelude/slang-cuda-prelude.h325
1 files changed, 323 insertions, 2 deletions
diff --git a/prelude/slang-cuda-prelude.h b/prelude/slang-cuda-prelude.h
index c641025d4..d2c9fce9d 100644
--- a/prelude/slang-cuda-prelude.h
+++ b/prelude/slang-cuda-prelude.h
@@ -3193,7 +3193,7 @@ static __forceinline__ __device__ void* getOptiXRayPayloadPtr()
}
template<typename T>
-__forceinline__ __device__ void* traceOptiXRay(
+__forceinline__ __device__ void* optixTrace(
OptixTraversableHandle AccelerationStructure,
uint32_t RayFlags,
uint32_t InstanceInclusionMask,
@@ -3221,8 +3221,329 @@ __forceinline__ __device__ void* traceOptiXRay(
r1);
}
-#endif
+template<typename T>
+__forceinline__ __device__ void* optixTraverse(
+ OptixTraversableHandle AccelerationStructure,
+ uint32_t RayFlags,
+ uint32_t InstanceInclusionMask,
+ uint32_t RayContributionToHitGroupIndex,
+ uint32_t MultiplierForGeometryContributionToHitGroupIndex,
+ uint32_t MissShaderIndex,
+ RayDesc Ray,
+ T* Payload,
+ OptixTraversableHandle* hitObj)
+{
+ uint32_t r0, r1;
+ packOptiXRayPayloadPointer((void*)Payload, r0, r1);
+ optixTraverse(
+ AccelerationStructure,
+ Ray.Origin,
+ Ray.Direction,
+ Ray.TMin,
+ Ray.TMax,
+ 0.f, /* Time for motion blur, currently unsupported in slang */
+ InstanceInclusionMask,
+ RayFlags,
+ RayContributionToHitGroupIndex,
+ MultiplierForGeometryContributionToHitGroupIndex,
+ MissShaderIndex,
+ r0,
+ r1);
+}
+
+template<typename T>
+__forceinline__ __device__ void* optixTraverse(
+ OptixTraversableHandle AccelerationStructure,
+ uint32_t RayFlags,
+ uint32_t InstanceInclusionMask,
+ uint32_t RayContributionToHitGroupIndex,
+ uint32_t MultiplierForGeometryContributionToHitGroupIndex,
+ uint32_t MissShaderIndex,
+ RayDesc Ray,
+ float RayTime,
+ T* Payload,
+ OptixTraversableHandle* hitObj)
+{
+ uint32_t r0, r1;
+ packOptiXRayPayloadPointer((void*)Payload, r0, r1);
+ optixTraverse(
+ AccelerationStructure,
+ Ray.Origin,
+ Ray.Direction,
+ Ray.TMin,
+ Ray.TMax,
+ RayTime,
+ InstanceInclusionMask,
+ RayFlags,
+ RayContributionToHitGroupIndex,
+ MultiplierForGeometryContributionToHitGroupIndex,
+ MissShaderIndex,
+ r0,
+ r1);
+}
+
+static __forceinline__ __device__ bool optixHitObjectIsHit(OptixTraversableHandle* hitObj)
+{
+ return optixHitObjectIsHit();
+}
+
+static __forceinline__ __device__ bool optixHitObjectIsMiss(OptixTraversableHandle* hitObj)
+{
+ return optixHitObjectIsMiss();
+}
+
+static __forceinline__ __device__ bool optixHitObjectIsNop(OptixTraversableHandle* hitObj)
+{
+ return optixHitObjectIsNop();
+}
+
+static __forceinline__ __device__ uint optixHitObjectGetClusterId(OptixTraversableHandle* hitObj)
+{
+ return optixHitObjectGetClusterId();
+}
+
+static __forceinline__ __device__ void optixMakeMissHitObject(
+ uint MissShaderIndex,
+ RayDesc Ray,
+ OptixTraversableHandle* missObj)
+{
+
+ optixMakeMissHitObject(
+ MissShaderIndex,
+ Ray.Origin,
+ Ray.Direction,
+ Ray.TMin,
+ Ray.TMax,
+ 0.f, /* rayTime */
+ OPTIX_RAY_FLAG_NONE /* rayFlags*/);
+}
+
+static __forceinline__ __device__ void optixMakeMissHitObject(
+ uint MissShaderIndex,
+ RayDesc Ray,
+ float CurrentTime,
+ OptixTraversableHandle* missObj)
+{
+
+ optixMakeMissHitObject(
+ MissShaderIndex,
+ Ray.Origin,
+ Ray.Direction,
+ Ray.TMin,
+ Ray.TMax,
+ CurrentTime,
+ OPTIX_RAY_FLAG_NONE /* rayFlags*/);
+}
+template<typename T>
+static __forceinline__ __device__ void optixMakeHitObject(
+ OptixTraversableHandle AccelerationStructure,
+ uint InstanceIndex,
+ uint GeometryIndex,
+ uint PrimitiveIndex,
+ uint HitKind,
+ uint RayContributionToHitGroupIndex,
+ uint MultiplierForGeometryContributionToHitGroupIndex,
+ RayDesc Ray,
+ T attr,
+ OptixTraversableHandle* handle)
+{
+
+ OptixTraverseData data{};
+ optixHitObjectGetTraverseData(&data);
+ optixMakeHitObject(
+ AccelerationStructure,
+ Ray.Origin,
+ Ray.Direction,
+ Ray.TMin,
+ 0.f,
+ OPTIX_RAY_FLAG_NONE, /* rayFlags*/
+ data,
+ nullptr, /*OptixTraversableHandle* transforms*/
+ 0 /*numTransforms */);
+}
+
+template<typename T>
+static __forceinline__ __device__ void optixMakeHitObject(
+ uint HitGroupRecordIndex,
+ OptixTraversableHandle AccelerationStructure,
+ uint InstanceIndex,
+ uint GeometryIndex,
+ uint PrimitiveIndex,
+ uint HitKind,
+ RayDesc Ray,
+ T attr,
+ OptixTraversableHandle* handle)
+{
+
+ OptixTraverseData data{};
+ optixHitObjectGetTraverseData(&data);
+ optixMakeHitObject(
+ AccelerationStructure,
+ Ray.Origin,
+ Ray.Direction,
+ Ray.TMin,
+ 0.f,
+ OPTIX_RAY_FLAG_NONE, /* rayFlags*/
+ data,
+ nullptr, /*OptixTraversableHandle* transforms*/
+ 0 /*numTransforms */);
+}
+
+template<typename T>
+static __forceinline__ __device__ void optixMakeHitObject(
+ OptixTraversableHandle AccelerationStructure,
+ uint InstanceIndex,
+ uint GeometryIndex,
+ uint PrimitiveIndex,
+ uint HitKind,
+ uint RayContributionToHitGroupIndex,
+ uint MultiplierForGeometryContributionToHitGroupIndex,
+ RayDesc Ray,
+ float CurrentTime,
+ T attr,
+ OptixTraversableHandle* handle)
+{
+
+ OptixTraverseData data{};
+ optixHitObjectGetTraverseData(&data);
+ optixMakeHitObject(
+ AccelerationStructure,
+ Ray.Origin,
+ Ray.Direction,
+ Ray.TMin,
+ CurrentTime,
+ OPTIX_RAY_FLAG_NONE, /* rayFlags*/
+ data,
+ nullptr, /*OptixTraversableHandle* transforms*/
+ 0 /*numTransforms */);
+}
+
+template<typename T>
+static __forceinline__ __device__ void optixMakeHitObject(
+ uint HitGroupRecordIndex,
+ OptixTraversableHandle AccelerationStructure,
+ uint InstanceIndex,
+ uint GeometryIndex,
+ uint PrimitiveIndex,
+ uint HitKind,
+ RayDesc Ray,
+ float CurrentTime,
+ T attr,
+ OptixTraversableHandle* handle)
+{
+
+ OptixTraverseData data{};
+ optixHitObjectGetTraverseData(&data);
+ optixMakeHitObject(
+ AccelerationStructure,
+ Ray.Origin,
+ Ray.Direction,
+ Ray.TMin,
+ CurrentTime,
+ OPTIX_RAY_FLAG_NONE, /* rayFlags*/
+ data,
+ nullptr, /*OptixTraversableHandle* transforms*/
+ 0 /*numTransforms */);
+}
+
+static __forceinline__ __device__ void optixMakeNopHitObject(OptixTraversableHandle* Obj)
+{
+ optixMakeNopHitObject();
+}
+
+template<typename T>
+static __forceinline__ __device__ void optixInvoke(
+ OptixTraversableHandle AccelerationStructure,
+ OptixTraversableHandle* HitOrMiss,
+ T Payload)
+{
+ uint32_t r0, r1;
+ packOptiXRayPayloadPointer((void*)Payload, r0, r1);
+ optixInvoke(r0, r1);
+}
+static __forceinline__ __device__ RayDesc optixHitObjectGetRayDesc(OptixTraversableHandle* obj)
+{
+ RayDesc ray = {
+ optixHitObjectGetWorldRayOrigin(),
+ optixHitObjectGetRayTmin(),
+ optixHitObjectGetWorldRayDirection(),
+ optixHitObjectGetRayTmax()};
+ return ray;
+}
+
+static __forceinline__ __device__ uint optixHitObjectGetInstanceIndex(OptixTraversableHandle* Obj)
+{
+ return optixHitObjectGetInstanceIndex();
+}
+
+static __forceinline__ __device__ uint optixHitObjectGetInstanceId(OptixTraversableHandle* Obj)
+{
+ return optixHitObjectGetInstanceId();
+}
+
+static __forceinline__ __device__ uint optixHitObjectGetSbtGASIndex(OptixTraversableHandle* Obj)
+{
+ return optixHitObjectGetSbtGASIndex();
+}
+
+static __forceinline__ __device__ uint optixHitObjectGetPrimitiveIndex(OptixTraversableHandle* Obj)
+{
+ return optixHitObjectGetPrimitiveIndex();
+}
+
+template<typename T>
+static __forceinline__ __device__ T optixHitObjectGetAttribute(OptixTraversableHandle* Obj)
+{
+ constexpr size_t numInts = (sizeof(T) + sizeof(uint32_t) - 1) /
+ sizeof(uint32_t); // Number of 32-bit values, rounded up
+ static_assert(numInts <= 8, "Attribute type is too large");
+
+ // Create an array to hold the attribute values
+ uint32_t values[numInts == 0 ? 1 : numInts] = {0}; // Ensure we have at least one element
+
+ // Read the appropriate number of attribute registers
+ if constexpr (numInts > 0)
+ values[0] = optixHitObjectGetAttribute_0();
+ if constexpr (numInts > 1)
+ values[1] = optixHitObjectGetAttribute_1();
+ if constexpr (numInts > 2)
+ values[2] = optixHitObjectGetAttribute_2();
+ if constexpr (numInts > 3)
+ values[3] = optixHitObjectGetAttribute_3();
+ if constexpr (numInts > 4)
+ values[4] = optixHitObjectGetAttribute_4();
+ if constexpr (numInts > 5)
+ values[5] = optixHitObjectGetAttribute_5();
+ if constexpr (numInts > 6)
+ values[6] = optixHitObjectGetAttribute_6();
+ if constexpr (numInts > 7)
+ values[7] = optixHitObjectGetAttribute_7();
+
+ // Reinterpret the array as the desired type
+ T result;
+ memcpy(&result, values, sizeof(T));
+ return result;
+}
+
+static __forceinline__ __device__ uint optixHitObjectGetSbtRecordIndex(OptixTraversableHandle* Obj)
+{
+ return optixHitObjectGetSbtRecordIndex();
+}
+
+static __forceinline__ __device__ uint
+optixHitObjectSetSbtRecordIndex(OptixTraversableHandle* Obj, uint sbtRecordIndex)
+{
+ optixHitObjectSetSbtRecordIndex(sbtRecordIndex); // returns void
+ return 0;
+}
+static __forceinline__ __device__ uint
+optixHitObjectGetSbtDataPointer(OptixTraversableHandle* Obj, uint sbtRecordIndex)
+{
+ optixHitObjectGetSbtDataPointer(); // returns void
+ return 0;
+}
+#endif
static const int kSlangTorchTensorMaxDim = 5;
// TensorView