diff options
| -rw-r--r-- | prelude/slang-cuda-prelude.h | 81 |
1 files changed, 71 insertions, 10 deletions
diff --git a/prelude/slang-cuda-prelude.h b/prelude/slang-cuda-prelude.h index 2c8faf922..9508ea796 100644 --- a/prelude/slang-cuda-prelude.h +++ b/prelude/slang-cuda-prelude.h @@ -3618,6 +3618,7 @@ struct UniformState; // ---------------------- OptiX Ray Payload -------------------------------------- #ifdef SLANG_CUDA_ENABLE_OPTIX + struct RayDesc { float3 Origin; @@ -3679,13 +3680,16 @@ __forceinline__ __device__ void* optixTrace( r1); } +#if (OPTIX_VERSION >= 90000) __forceinline__ __device__ float4 optixGetSpherePositionAndRadius() { float4 data[1]; optixGetSphereData(data); return data[0]; } +#endif +#if (OPTIX_VERSION >= 90000) __forceinline__ __device__ float4 optixHitObjectGetSpherePositionAndRadius(OptixTraversableHandle* Obj) { @@ -3693,14 +3697,18 @@ optixHitObjectGetSpherePositionAndRadius(OptixTraversableHandle* Obj) optixHitObjectGetSphereData(data); return data[0]; } +#endif +#if (OPTIX_VERSION >= 90000) __forceinline__ __device__ Matrix<float, 2, 4> optixGetLssPositionsAndRadii() { float4 data[2]; optixGetLinearCurveVertexData(data); return makeMatrix<float, 2, 4>(data[0], data[1]); } +#endif +#if (OPTIX_VERSION >= 90000) __forceinline__ __device__ Matrix<float, 2, 4> optixHitObjectGetLssPositionsAndRadii( OptixTraversableHandle* Obj) { @@ -3708,26 +3716,35 @@ __forceinline__ __device__ Matrix<float, 2, 4> optixHitObjectGetLssPositionsAndR optixHitObjectGetLinearCurveVertexData(data); return makeMatrix<float, 2, 4>(data[0], data[1]); } +#endif +#if (OPTIX_VERSION >= 90000) __forceinline__ __device__ bool optixIsSphereHit() { return optixGetPrimitiveType() == OPTIX_PRIMITIVE_TYPE_SPHERE; } +#endif +#if (OPTIX_VERSION >= 90000) __forceinline__ __device__ bool optixHitObjectIsSphereHit(OptixTraversableHandle* Obj) { return optixGetPrimitiveType(optixHitObjectGetHitKind()) == OPTIX_PRIMITIVE_TYPE_SPHERE; } +#endif +#if (OPTIX_VERSION >= 90000) __forceinline__ __device__ bool optixIsLSSHit() { return optixGetPrimitiveType() == OPTIX_PRIMITIVE_TYPE_ROUND_LINEAR; } +#endif +#if (OPTIX_VERSION >= 90000) __forceinline__ __device__ bool optixHitObjectIsLSSHit(OptixTraversableHandle* Obj) { return optixGetPrimitiveType(optixHitObjectGetHitKind()) == OPTIX_PRIMITIVE_TYPE_ROUND_LINEAR; } +#endif template<typename T> __forceinline__ __device__ void* optixTraverse( @@ -3790,60 +3807,79 @@ __forceinline__ __device__ void* optixTraverse( r1); } +#if (OPTIX_VERSION >= 80100) static __forceinline__ __device__ bool slangOptixHitObjectIsHit(OptixTraversableHandle* hitObj) { return optixHitObjectIsHit(); } +#endif +#if (OPTIX_VERSION >= 80100) static __forceinline__ __device__ bool slangOptixHitObjectIsMiss(OptixTraversableHandle* hitObj) { return optixHitObjectIsMiss(); } +#endif +#if (OPTIX_VERSION >= 80100) static __forceinline__ __device__ bool slangOptixHitObjectIsNop(OptixTraversableHandle* hitObj) { return optixHitObjectIsNop(); } +#endif +#if (OPTIX_VERSION >= 90000) static __forceinline__ __device__ uint slangOptixHitObjectGetClusterId(OptixTraversableHandle* hitObj) { return optixHitObjectGetClusterId(); } +#endif +#if (OPTIX_VERSION >= 80100) static __forceinline__ __device__ void optixMakeMissHitObject( uint MissShaderIndex, RayDesc Ray, OptixTraversableHandle* missObj) { - optixMakeMissHitObject( MissShaderIndex, Ray.Origin, Ray.Direction, Ray.TMin, Ray.TMax, - 0.f, /* rayTime */ - OPTIX_RAY_FLAG_NONE /* rayFlags*/); + 0.f /* rayTime */ +#if (OPTIX_VERSION >= 90000) + , + OPTIX_RAY_FLAG_NONE /* rayFlags*/ +#endif + ); } +#endif +#if (OPTIX_VERSION >= 80100) static __forceinline__ __device__ void optixMakeMissHitObject( uint MissShaderIndex, RayDesc Ray, float CurrentTime, OptixTraversableHandle* missObj) { - optixMakeMissHitObject( MissShaderIndex, Ray.Origin, Ray.Direction, Ray.TMin, Ray.TMax, - CurrentTime, - OPTIX_RAY_FLAG_NONE /* rayFlags*/); + CurrentTime +#if (OPTIX_VERSION >= 90000) + , + OPTIX_RAY_FLAG_NONE /* rayFlags*/ +#endif + ); } +#endif +#if (OPTIX_VERSION >= 80100) template<typename T> static __forceinline__ __device__ void optixMakeHitObject( OptixTraversableHandle AccelerationStructure, @@ -3857,7 +3893,6 @@ static __forceinline__ __device__ void optixMakeHitObject( T attr, OptixTraversableHandle* handle) { - OptixTraverseData data{}; optixHitObjectGetTraverseData(&data); optixMakeHitObject( @@ -3871,7 +3906,9 @@ static __forceinline__ __device__ void optixMakeHitObject( nullptr, /*OptixTraversableHandle* transforms*/ 0 /*numTransforms */); } +#endif +#if (OPTIX_VERSION >= 80100) template<typename T> static __forceinline__ __device__ void optixMakeHitObject( uint HitGroupRecordIndex, @@ -3884,7 +3921,6 @@ static __forceinline__ __device__ void optixMakeHitObject( T attr, OptixTraversableHandle* handle) { - OptixTraverseData data{}; optixHitObjectGetTraverseData(&data); optixMakeHitObject( @@ -3898,7 +3934,9 @@ static __forceinline__ __device__ void optixMakeHitObject( nullptr, /*OptixTraversableHandle* transforms*/ 0 /*numTransforms */); } +#endif +#if (OPTIX_VERSION >= 80100) template<typename T> static __forceinline__ __device__ void optixMakeHitObject( OptixTraversableHandle AccelerationStructure, @@ -3913,7 +3951,6 @@ static __forceinline__ __device__ void optixMakeHitObject( T attr, OptixTraversableHandle* handle) { - OptixTraverseData data{}; optixHitObjectGetTraverseData(&data); optixMakeHitObject( @@ -3927,7 +3964,9 @@ static __forceinline__ __device__ void optixMakeHitObject( nullptr, /*OptixTraversableHandle* transforms*/ 0 /*numTransforms */); } +#endif +#if (OPTIX_VERSION >= 80100) template<typename T> static __forceinline__ __device__ void optixMakeHitObject( uint HitGroupRecordIndex, @@ -3941,7 +3980,6 @@ static __forceinline__ __device__ void optixMakeHitObject( T attr, OptixTraversableHandle* handle) { - OptixTraverseData data{}; optixHitObjectGetTraverseData(&data); optixMakeHitObject( @@ -3955,12 +3993,16 @@ static __forceinline__ __device__ void optixMakeHitObject( nullptr, /*OptixTraversableHandle* transforms*/ 0 /*numTransforms */); } +#endif +#if (OPTIX_VERSION >= 80100) static __forceinline__ __device__ void slangOptixMakeNopHitObject(OptixTraversableHandle* Obj) { optixMakeNopHitObject(); } +#endif +#if (OPTIX_VERSION >= 80100) template<typename T> static __forceinline__ __device__ void optixInvoke( OptixTraversableHandle AccelerationStructure, @@ -3971,6 +4013,9 @@ static __forceinline__ __device__ void optixInvoke( packOptiXRayPayloadPointer((void*)Payload, r0, r1); optixInvoke(r0, r1); } +#endif + +#if (OPTIX_VERSION >= 80100) static __forceinline__ __device__ RayDesc optixHitObjectGetRayDesc(OptixTraversableHandle* obj) { RayDesc ray = { @@ -3980,30 +4025,40 @@ static __forceinline__ __device__ RayDesc optixHitObjectGetRayDesc(OptixTraversa optixHitObjectGetRayTmax()}; return ray; } +#endif +#if (OPTIX_VERSION >= 80100) static __forceinline__ __device__ uint slangOptixHitObjectGetInstanceIndex(OptixTraversableHandle* Obj) { return optixHitObjectGetInstanceIndex(); } +#endif +#if (OPTIX_VERSION >= 80100) static __forceinline__ __device__ uint slangOptixHitObjectGetInstanceId(OptixTraversableHandle* Obj) { return optixHitObjectGetInstanceId(); } +#endif +#if (OPTIX_VERSION >= 80100) static __forceinline__ __device__ uint slangOptixHitObjectGetSbtGASIndex(OptixTraversableHandle* Obj) { return optixHitObjectGetSbtGASIndex(); } +#endif +#if (OPTIX_VERSION >= 80100) static __forceinline__ __device__ uint slangOptixHitObjectGetPrimitiveIndex(OptixTraversableHandle* Obj) { return optixHitObjectGetPrimitiveIndex(); } +#endif +#if (OPTIX_VERSION >= 80100) template<typename T> static __forceinline__ __device__ T optixHitObjectGetAttribute(OptixTraversableHandle* Obj) { @@ -4037,19 +4092,25 @@ static __forceinline__ __device__ T optixHitObjectGetAttribute(OptixTraversableH memcpy(&result, values, sizeof(T)); return result; } +#endif +#if (OPTIX_VERSION >= 80100) static __forceinline__ __device__ uint slangOptixHitObjectGetSbtRecordIndex(OptixTraversableHandle* Obj) { return optixHitObjectGetSbtRecordIndex(); } +#endif +#if (OPTIX_VERSION >= 90000) static __forceinline__ __device__ uint slangOptixHitObjectSetSbtRecordIndex(OptixTraversableHandle* Obj, uint sbtRecordIndex) { optixHitObjectSetSbtRecordIndex(sbtRecordIndex); // returns void return sbtRecordIndex; } +#endif + #else // Define OptixTraversableHandle even if OptiX is not enabled. // This allows RaytracingAccelerationStructure to be properly reflected in non-OptiX code. |
