diff options
| author | Simon Kallweit <64953474+skallweitNV@users.noreply.github.com> | 2025-10-14 17:21:03 +0200 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2025-10-14 15:21:03 +0000 |
| commit | 5978f934ee9a8a3e710dc743a4af92191639b718 (patch) | |
| tree | 5ae88284c70b8af40ef2b08298b39e7fe8b37392 | |
| parent | 96df31a9fa53e3d897a2b7c4eef021f37f421c91 (diff) | |
Add support targeting older OptiX versions (#8700)
Currently, the emitted CUDA code does only compile with latest OptiX
9.0. This change allows code to be compiled with OptiX 8.0 upwards by
not emitting OptiX calls that are not available. In a later step we
should add proper capabilities for the various OptiX versions.
| -rw-r--r-- | prelude/slang-cuda-prelude.h | 81 |
1 files changed, 71 insertions, 10 deletions
diff --git a/prelude/slang-cuda-prelude.h b/prelude/slang-cuda-prelude.h index 2c8faf922..9508ea796 100644 --- a/prelude/slang-cuda-prelude.h +++ b/prelude/slang-cuda-prelude.h @@ -3618,6 +3618,7 @@ struct UniformState; // ---------------------- OptiX Ray Payload -------------------------------------- #ifdef SLANG_CUDA_ENABLE_OPTIX + struct RayDesc { float3 Origin; @@ -3679,13 +3680,16 @@ __forceinline__ __device__ void* optixTrace( r1); } +#if (OPTIX_VERSION >= 90000) __forceinline__ __device__ float4 optixGetSpherePositionAndRadius() { float4 data[1]; optixGetSphereData(data); return data[0]; } +#endif +#if (OPTIX_VERSION >= 90000) __forceinline__ __device__ float4 optixHitObjectGetSpherePositionAndRadius(OptixTraversableHandle* Obj) { @@ -3693,14 +3697,18 @@ optixHitObjectGetSpherePositionAndRadius(OptixTraversableHandle* Obj) optixHitObjectGetSphereData(data); return data[0]; } +#endif +#if (OPTIX_VERSION >= 90000) __forceinline__ __device__ Matrix<float, 2, 4> optixGetLssPositionsAndRadii() { float4 data[2]; optixGetLinearCurveVertexData(data); return makeMatrix<float, 2, 4>(data[0], data[1]); } +#endif +#if (OPTIX_VERSION >= 90000) __forceinline__ __device__ Matrix<float, 2, 4> optixHitObjectGetLssPositionsAndRadii( OptixTraversableHandle* Obj) { @@ -3708,26 +3716,35 @@ __forceinline__ __device__ Matrix<float, 2, 4> optixHitObjectGetLssPositionsAndR optixHitObjectGetLinearCurveVertexData(data); return makeMatrix<float, 2, 4>(data[0], data[1]); } +#endif +#if (OPTIX_VERSION >= 90000) __forceinline__ __device__ bool optixIsSphereHit() { return optixGetPrimitiveType() == OPTIX_PRIMITIVE_TYPE_SPHERE; } +#endif +#if (OPTIX_VERSION >= 90000) __forceinline__ __device__ bool optixHitObjectIsSphereHit(OptixTraversableHandle* Obj) { return optixGetPrimitiveType(optixHitObjectGetHitKind()) == OPTIX_PRIMITIVE_TYPE_SPHERE; } +#endif +#if (OPTIX_VERSION >= 90000) __forceinline__ __device__ bool optixIsLSSHit() { return optixGetPrimitiveType() == OPTIX_PRIMITIVE_TYPE_ROUND_LINEAR; } +#endif +#if (OPTIX_VERSION >= 90000) __forceinline__ __device__ bool optixHitObjectIsLSSHit(OptixTraversableHandle* Obj) { return optixGetPrimitiveType(optixHitObjectGetHitKind()) == OPTIX_PRIMITIVE_TYPE_ROUND_LINEAR; } +#endif template<typename T> __forceinline__ __device__ void* optixTraverse( @@ -3790,60 +3807,79 @@ __forceinline__ __device__ void* optixTraverse( r1); } +#if (OPTIX_VERSION >= 80100) static __forceinline__ __device__ bool slangOptixHitObjectIsHit(OptixTraversableHandle* hitObj) { return optixHitObjectIsHit(); } +#endif +#if (OPTIX_VERSION >= 80100) static __forceinline__ __device__ bool slangOptixHitObjectIsMiss(OptixTraversableHandle* hitObj) { return optixHitObjectIsMiss(); } +#endif +#if (OPTIX_VERSION >= 80100) static __forceinline__ __device__ bool slangOptixHitObjectIsNop(OptixTraversableHandle* hitObj) { return optixHitObjectIsNop(); } +#endif +#if (OPTIX_VERSION >= 90000) static __forceinline__ __device__ uint slangOptixHitObjectGetClusterId(OptixTraversableHandle* hitObj) { return optixHitObjectGetClusterId(); } +#endif +#if (OPTIX_VERSION >= 80100) static __forceinline__ __device__ void optixMakeMissHitObject( uint MissShaderIndex, RayDesc Ray, OptixTraversableHandle* missObj) { - optixMakeMissHitObject( MissShaderIndex, Ray.Origin, Ray.Direction, Ray.TMin, Ray.TMax, - 0.f, /* rayTime */ - OPTIX_RAY_FLAG_NONE /* rayFlags*/); + 0.f /* rayTime */ +#if (OPTIX_VERSION >= 90000) + , + OPTIX_RAY_FLAG_NONE /* rayFlags*/ +#endif + ); } +#endif +#if (OPTIX_VERSION >= 80100) static __forceinline__ __device__ void optixMakeMissHitObject( uint MissShaderIndex, RayDesc Ray, float CurrentTime, OptixTraversableHandle* missObj) { - optixMakeMissHitObject( MissShaderIndex, Ray.Origin, Ray.Direction, Ray.TMin, Ray.TMax, - CurrentTime, - OPTIX_RAY_FLAG_NONE /* rayFlags*/); + CurrentTime +#if (OPTIX_VERSION >= 90000) + , + OPTIX_RAY_FLAG_NONE /* rayFlags*/ +#endif + ); } +#endif +#if (OPTIX_VERSION >= 80100) template<typename T> static __forceinline__ __device__ void optixMakeHitObject( OptixTraversableHandle AccelerationStructure, @@ -3857,7 +3893,6 @@ static __forceinline__ __device__ void optixMakeHitObject( T attr, OptixTraversableHandle* handle) { - OptixTraverseData data{}; optixHitObjectGetTraverseData(&data); optixMakeHitObject( @@ -3871,7 +3906,9 @@ static __forceinline__ __device__ void optixMakeHitObject( nullptr, /*OptixTraversableHandle* transforms*/ 0 /*numTransforms */); } +#endif +#if (OPTIX_VERSION >= 80100) template<typename T> static __forceinline__ __device__ void optixMakeHitObject( uint HitGroupRecordIndex, @@ -3884,7 +3921,6 @@ static __forceinline__ __device__ void optixMakeHitObject( T attr, OptixTraversableHandle* handle) { - OptixTraverseData data{}; optixHitObjectGetTraverseData(&data); optixMakeHitObject( @@ -3898,7 +3934,9 @@ static __forceinline__ __device__ void optixMakeHitObject( nullptr, /*OptixTraversableHandle* transforms*/ 0 /*numTransforms */); } +#endif +#if (OPTIX_VERSION >= 80100) template<typename T> static __forceinline__ __device__ void optixMakeHitObject( OptixTraversableHandle AccelerationStructure, @@ -3913,7 +3951,6 @@ static __forceinline__ __device__ void optixMakeHitObject( T attr, OptixTraversableHandle* handle) { - OptixTraverseData data{}; optixHitObjectGetTraverseData(&data); optixMakeHitObject( @@ -3927,7 +3964,9 @@ static __forceinline__ __device__ void optixMakeHitObject( nullptr, /*OptixTraversableHandle* transforms*/ 0 /*numTransforms */); } +#endif +#if (OPTIX_VERSION >= 80100) template<typename T> static __forceinline__ __device__ void optixMakeHitObject( uint HitGroupRecordIndex, @@ -3941,7 +3980,6 @@ static __forceinline__ __device__ void optixMakeHitObject( T attr, OptixTraversableHandle* handle) { - OptixTraverseData data{}; optixHitObjectGetTraverseData(&data); optixMakeHitObject( @@ -3955,12 +3993,16 @@ static __forceinline__ __device__ void optixMakeHitObject( nullptr, /*OptixTraversableHandle* transforms*/ 0 /*numTransforms */); } +#endif +#if (OPTIX_VERSION >= 80100) static __forceinline__ __device__ void slangOptixMakeNopHitObject(OptixTraversableHandle* Obj) { optixMakeNopHitObject(); } +#endif +#if (OPTIX_VERSION >= 80100) template<typename T> static __forceinline__ __device__ void optixInvoke( OptixTraversableHandle AccelerationStructure, @@ -3971,6 +4013,9 @@ static __forceinline__ __device__ void optixInvoke( packOptiXRayPayloadPointer((void*)Payload, r0, r1); optixInvoke(r0, r1); } +#endif + +#if (OPTIX_VERSION >= 80100) static __forceinline__ __device__ RayDesc optixHitObjectGetRayDesc(OptixTraversableHandle* obj) { RayDesc ray = { @@ -3980,30 +4025,40 @@ static __forceinline__ __device__ RayDesc optixHitObjectGetRayDesc(OptixTraversa optixHitObjectGetRayTmax()}; return ray; } +#endif +#if (OPTIX_VERSION >= 80100) static __forceinline__ __device__ uint slangOptixHitObjectGetInstanceIndex(OptixTraversableHandle* Obj) { return optixHitObjectGetInstanceIndex(); } +#endif +#if (OPTIX_VERSION >= 80100) static __forceinline__ __device__ uint slangOptixHitObjectGetInstanceId(OptixTraversableHandle* Obj) { return optixHitObjectGetInstanceId(); } +#endif +#if (OPTIX_VERSION >= 80100) static __forceinline__ __device__ uint slangOptixHitObjectGetSbtGASIndex(OptixTraversableHandle* Obj) { return optixHitObjectGetSbtGASIndex(); } +#endif +#if (OPTIX_VERSION >= 80100) static __forceinline__ __device__ uint slangOptixHitObjectGetPrimitiveIndex(OptixTraversableHandle* Obj) { return optixHitObjectGetPrimitiveIndex(); } +#endif +#if (OPTIX_VERSION >= 80100) template<typename T> static __forceinline__ __device__ T optixHitObjectGetAttribute(OptixTraversableHandle* Obj) { @@ -4037,19 +4092,25 @@ static __forceinline__ __device__ T optixHitObjectGetAttribute(OptixTraversableH memcpy(&result, values, sizeof(T)); return result; } +#endif +#if (OPTIX_VERSION >= 80100) static __forceinline__ __device__ uint slangOptixHitObjectGetSbtRecordIndex(OptixTraversableHandle* Obj) { return optixHitObjectGetSbtRecordIndex(); } +#endif +#if (OPTIX_VERSION >= 90000) static __forceinline__ __device__ uint slangOptixHitObjectSetSbtRecordIndex(OptixTraversableHandle* Obj, uint sbtRecordIndex) { optixHitObjectSetSbtRecordIndex(sbtRecordIndex); // returns void return sbtRecordIndex; } +#endif + #else // Define OptixTraversableHandle even if OptiX is not enabled. // This allows RaytracingAccelerationStructure to be properly reflected in non-OptiX code. |
