diff options
| author | Harsh Aggarwal (NVIDIA) <haaggarwal@nvidia.com> | 2025-05-26 21:00:38 +0530 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2025-05-26 15:30:38 +0000 |
| commit | 83538e0b4b97425ecdae6f72f9c8fd44cb255aac (patch) | |
| tree | 8f27c47fb7c1614fa916c2da6ab9996655e29da1 /prelude | |
| parent | 8ecb2c70437292ef6fa34f7122df44067de6a4de (diff) | |
Implement shader execution reordering support for OptiX (#7211)
* Implement shader execution reordering support for OptiX
Added OptiX backend support for Shader Execution Reordering (SER) features as outlined in issue #6647. This implementation:
1. Added CUDA target support for HitObject API
2. Implemented core SER functionality (TraceRay, MakeHit/Miss, Invoke)
3. Added OptiX-specific hit object handling functions
4. Added test case for OptiX SER functionality
* format code
---------
Co-authored-by: slangbot <186143334+slangbot@users.noreply.github.com>
Diffstat (limited to 'prelude')
| -rw-r--r-- | prelude/slang-cuda-prelude.h | 325 |
1 files changed, 323 insertions, 2 deletions
diff --git a/prelude/slang-cuda-prelude.h b/prelude/slang-cuda-prelude.h index c641025d4..d2c9fce9d 100644 --- a/prelude/slang-cuda-prelude.h +++ b/prelude/slang-cuda-prelude.h @@ -3193,7 +3193,7 @@ static __forceinline__ __device__ void* getOptiXRayPayloadPtr() } template<typename T> -__forceinline__ __device__ void* traceOptiXRay( +__forceinline__ __device__ void* optixTrace( OptixTraversableHandle AccelerationStructure, uint32_t RayFlags, uint32_t InstanceInclusionMask, @@ -3221,8 +3221,329 @@ __forceinline__ __device__ void* traceOptiXRay( r1); } -#endif +template<typename T> +__forceinline__ __device__ void* optixTraverse( + OptixTraversableHandle AccelerationStructure, + uint32_t RayFlags, + uint32_t InstanceInclusionMask, + uint32_t RayContributionToHitGroupIndex, + uint32_t MultiplierForGeometryContributionToHitGroupIndex, + uint32_t MissShaderIndex, + RayDesc Ray, + T* Payload, + OptixTraversableHandle* hitObj) +{ + uint32_t r0, r1; + packOptiXRayPayloadPointer((void*)Payload, r0, r1); + optixTraverse( + AccelerationStructure, + Ray.Origin, + Ray.Direction, + Ray.TMin, + Ray.TMax, + 0.f, /* Time for motion blur, currently unsupported in slang */ + InstanceInclusionMask, + RayFlags, + RayContributionToHitGroupIndex, + MultiplierForGeometryContributionToHitGroupIndex, + MissShaderIndex, + r0, + r1); +} + +template<typename T> +__forceinline__ __device__ void* optixTraverse( + OptixTraversableHandle AccelerationStructure, + uint32_t RayFlags, + uint32_t InstanceInclusionMask, + uint32_t RayContributionToHitGroupIndex, + uint32_t MultiplierForGeometryContributionToHitGroupIndex, + uint32_t MissShaderIndex, + RayDesc Ray, + float RayTime, + T* Payload, + OptixTraversableHandle* hitObj) +{ + uint32_t r0, r1; + packOptiXRayPayloadPointer((void*)Payload, r0, r1); + optixTraverse( + AccelerationStructure, + Ray.Origin, + Ray.Direction, + Ray.TMin, + Ray.TMax, + RayTime, + InstanceInclusionMask, + RayFlags, + RayContributionToHitGroupIndex, + MultiplierForGeometryContributionToHitGroupIndex, + MissShaderIndex, + r0, + r1); +} + +static __forceinline__ __device__ bool optixHitObjectIsHit(OptixTraversableHandle* hitObj) +{ + return optixHitObjectIsHit(); +} + +static __forceinline__ __device__ bool optixHitObjectIsMiss(OptixTraversableHandle* hitObj) +{ + return optixHitObjectIsMiss(); +} + +static __forceinline__ __device__ bool optixHitObjectIsNop(OptixTraversableHandle* hitObj) +{ + return optixHitObjectIsNop(); +} + +static __forceinline__ __device__ uint optixHitObjectGetClusterId(OptixTraversableHandle* hitObj) +{ + return optixHitObjectGetClusterId(); +} + +static __forceinline__ __device__ void optixMakeMissHitObject( + uint MissShaderIndex, + RayDesc Ray, + OptixTraversableHandle* missObj) +{ + + optixMakeMissHitObject( + MissShaderIndex, + Ray.Origin, + Ray.Direction, + Ray.TMin, + Ray.TMax, + 0.f, /* rayTime */ + OPTIX_RAY_FLAG_NONE /* rayFlags*/); +} + +static __forceinline__ __device__ void optixMakeMissHitObject( + uint MissShaderIndex, + RayDesc Ray, + float CurrentTime, + OptixTraversableHandle* missObj) +{ + + optixMakeMissHitObject( + MissShaderIndex, + Ray.Origin, + Ray.Direction, + Ray.TMin, + Ray.TMax, + CurrentTime, + OPTIX_RAY_FLAG_NONE /* rayFlags*/); +} +template<typename T> +static __forceinline__ __device__ void optixMakeHitObject( + OptixTraversableHandle AccelerationStructure, + uint InstanceIndex, + uint GeometryIndex, + uint PrimitiveIndex, + uint HitKind, + uint RayContributionToHitGroupIndex, + uint MultiplierForGeometryContributionToHitGroupIndex, + RayDesc Ray, + T attr, + OptixTraversableHandle* handle) +{ + + OptixTraverseData data{}; + optixHitObjectGetTraverseData(&data); + optixMakeHitObject( + AccelerationStructure, + Ray.Origin, + Ray.Direction, + Ray.TMin, + 0.f, + OPTIX_RAY_FLAG_NONE, /* rayFlags*/ + data, + nullptr, /*OptixTraversableHandle* transforms*/ + 0 /*numTransforms */); +} + +template<typename T> +static __forceinline__ __device__ void optixMakeHitObject( + uint HitGroupRecordIndex, + OptixTraversableHandle AccelerationStructure, + uint InstanceIndex, + uint GeometryIndex, + uint PrimitiveIndex, + uint HitKind, + RayDesc Ray, + T attr, + OptixTraversableHandle* handle) +{ + + OptixTraverseData data{}; + optixHitObjectGetTraverseData(&data); + optixMakeHitObject( + AccelerationStructure, + Ray.Origin, + Ray.Direction, + Ray.TMin, + 0.f, + OPTIX_RAY_FLAG_NONE, /* rayFlags*/ + data, + nullptr, /*OptixTraversableHandle* transforms*/ + 0 /*numTransforms */); +} + +template<typename T> +static __forceinline__ __device__ void optixMakeHitObject( + OptixTraversableHandle AccelerationStructure, + uint InstanceIndex, + uint GeometryIndex, + uint PrimitiveIndex, + uint HitKind, + uint RayContributionToHitGroupIndex, + uint MultiplierForGeometryContributionToHitGroupIndex, + RayDesc Ray, + float CurrentTime, + T attr, + OptixTraversableHandle* handle) +{ + + OptixTraverseData data{}; + optixHitObjectGetTraverseData(&data); + optixMakeHitObject( + AccelerationStructure, + Ray.Origin, + Ray.Direction, + Ray.TMin, + CurrentTime, + OPTIX_RAY_FLAG_NONE, /* rayFlags*/ + data, + nullptr, /*OptixTraversableHandle* transforms*/ + 0 /*numTransforms */); +} + +template<typename T> +static __forceinline__ __device__ void optixMakeHitObject( + uint HitGroupRecordIndex, + OptixTraversableHandle AccelerationStructure, + uint InstanceIndex, + uint GeometryIndex, + uint PrimitiveIndex, + uint HitKind, + RayDesc Ray, + float CurrentTime, + T attr, + OptixTraversableHandle* handle) +{ + + OptixTraverseData data{}; + optixHitObjectGetTraverseData(&data); + optixMakeHitObject( + AccelerationStructure, + Ray.Origin, + Ray.Direction, + Ray.TMin, + CurrentTime, + OPTIX_RAY_FLAG_NONE, /* rayFlags*/ + data, + nullptr, /*OptixTraversableHandle* transforms*/ + 0 /*numTransforms */); +} + +static __forceinline__ __device__ void optixMakeNopHitObject(OptixTraversableHandle* Obj) +{ + optixMakeNopHitObject(); +} + +template<typename T> +static __forceinline__ __device__ void optixInvoke( + OptixTraversableHandle AccelerationStructure, + OptixTraversableHandle* HitOrMiss, + T Payload) +{ + uint32_t r0, r1; + packOptiXRayPayloadPointer((void*)Payload, r0, r1); + optixInvoke(r0, r1); +} +static __forceinline__ __device__ RayDesc optixHitObjectGetRayDesc(OptixTraversableHandle* obj) +{ + RayDesc ray = { + optixHitObjectGetWorldRayOrigin(), + optixHitObjectGetRayTmin(), + optixHitObjectGetWorldRayDirection(), + optixHitObjectGetRayTmax()}; + return ray; +} + +static __forceinline__ __device__ uint optixHitObjectGetInstanceIndex(OptixTraversableHandle* Obj) +{ + return optixHitObjectGetInstanceIndex(); +} + +static __forceinline__ __device__ uint optixHitObjectGetInstanceId(OptixTraversableHandle* Obj) +{ + return optixHitObjectGetInstanceId(); +} + +static __forceinline__ __device__ uint optixHitObjectGetSbtGASIndex(OptixTraversableHandle* Obj) +{ + return optixHitObjectGetSbtGASIndex(); +} + +static __forceinline__ __device__ uint optixHitObjectGetPrimitiveIndex(OptixTraversableHandle* Obj) +{ + return optixHitObjectGetPrimitiveIndex(); +} + +template<typename T> +static __forceinline__ __device__ T optixHitObjectGetAttribute(OptixTraversableHandle* Obj) +{ + constexpr size_t numInts = (sizeof(T) + sizeof(uint32_t) - 1) / + sizeof(uint32_t); // Number of 32-bit values, rounded up + static_assert(numInts <= 8, "Attribute type is too large"); + + // Create an array to hold the attribute values + uint32_t values[numInts == 0 ? 1 : numInts] = {0}; // Ensure we have at least one element + + // Read the appropriate number of attribute registers + if constexpr (numInts > 0) + values[0] = optixHitObjectGetAttribute_0(); + if constexpr (numInts > 1) + values[1] = optixHitObjectGetAttribute_1(); + if constexpr (numInts > 2) + values[2] = optixHitObjectGetAttribute_2(); + if constexpr (numInts > 3) + values[3] = optixHitObjectGetAttribute_3(); + if constexpr (numInts > 4) + values[4] = optixHitObjectGetAttribute_4(); + if constexpr (numInts > 5) + values[5] = optixHitObjectGetAttribute_5(); + if constexpr (numInts > 6) + values[6] = optixHitObjectGetAttribute_6(); + if constexpr (numInts > 7) + values[7] = optixHitObjectGetAttribute_7(); + + // Reinterpret the array as the desired type + T result; + memcpy(&result, values, sizeof(T)); + return result; +} + +static __forceinline__ __device__ uint optixHitObjectGetSbtRecordIndex(OptixTraversableHandle* Obj) +{ + return optixHitObjectGetSbtRecordIndex(); +} + +static __forceinline__ __device__ uint +optixHitObjectSetSbtRecordIndex(OptixTraversableHandle* Obj, uint sbtRecordIndex) +{ + optixHitObjectSetSbtRecordIndex(sbtRecordIndex); // returns void + return 0; +} +static __forceinline__ __device__ uint +optixHitObjectGetSbtDataPointer(OptixTraversableHandle* Obj, uint sbtRecordIndex) +{ + optixHitObjectGetSbtDataPointer(); // returns void + return 0; +} +#endif static const int kSlangTorchTensorMaxDim = 5; // TensorView |
