summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorSimon Kallweit <64953474+skallweitNV@users.noreply.github.com>2025-10-14 17:21:03 +0200
committerGitHub <noreply@github.com>2025-10-14 15:21:03 +0000
commit5978f934ee9a8a3e710dc743a4af92191639b718 (patch)
tree5ae88284c70b8af40ef2b08298b39e7fe8b37392
parent96df31a9fa53e3d897a2b7c4eef021f37f421c91 (diff)
Add support targeting older OptiX versions (#8700)
Currently, the emitted CUDA code does only compile with latest OptiX 9.0. This change allows code to be compiled with OptiX 8.0 upwards by not emitting OptiX calls that are not available. In a later step we should add proper capabilities for the various OptiX versions.
-rw-r--r--prelude/slang-cuda-prelude.h81
1 files changed, 71 insertions, 10 deletions
diff --git a/prelude/slang-cuda-prelude.h b/prelude/slang-cuda-prelude.h
index 2c8faf922..9508ea796 100644
--- a/prelude/slang-cuda-prelude.h
+++ b/prelude/slang-cuda-prelude.h
@@ -3618,6 +3618,7 @@ struct UniformState;
// ---------------------- OptiX Ray Payload --------------------------------------
#ifdef SLANG_CUDA_ENABLE_OPTIX
+
struct RayDesc
{
float3 Origin;
@@ -3679,13 +3680,16 @@ __forceinline__ __device__ void* optixTrace(
r1);
}
+#if (OPTIX_VERSION >= 90000)
__forceinline__ __device__ float4 optixGetSpherePositionAndRadius()
{
float4 data[1];
optixGetSphereData(data);
return data[0];
}
+#endif
+#if (OPTIX_VERSION >= 90000)
__forceinline__ __device__ float4
optixHitObjectGetSpherePositionAndRadius(OptixTraversableHandle* Obj)
{
@@ -3693,14 +3697,18 @@ optixHitObjectGetSpherePositionAndRadius(OptixTraversableHandle* Obj)
optixHitObjectGetSphereData(data);
return data[0];
}
+#endif
+#if (OPTIX_VERSION >= 90000)
__forceinline__ __device__ Matrix<float, 2, 4> optixGetLssPositionsAndRadii()
{
float4 data[2];
optixGetLinearCurveVertexData(data);
return makeMatrix<float, 2, 4>(data[0], data[1]);
}
+#endif
+#if (OPTIX_VERSION >= 90000)
__forceinline__ __device__ Matrix<float, 2, 4> optixHitObjectGetLssPositionsAndRadii(
OptixTraversableHandle* Obj)
{
@@ -3708,26 +3716,35 @@ __forceinline__ __device__ Matrix<float, 2, 4> optixHitObjectGetLssPositionsAndR
optixHitObjectGetLinearCurveVertexData(data);
return makeMatrix<float, 2, 4>(data[0], data[1]);
}
+#endif
+#if (OPTIX_VERSION >= 90000)
__forceinline__ __device__ bool optixIsSphereHit()
{
return optixGetPrimitiveType() == OPTIX_PRIMITIVE_TYPE_SPHERE;
}
+#endif
+#if (OPTIX_VERSION >= 90000)
__forceinline__ __device__ bool optixHitObjectIsSphereHit(OptixTraversableHandle* Obj)
{
return optixGetPrimitiveType(optixHitObjectGetHitKind()) == OPTIX_PRIMITIVE_TYPE_SPHERE;
}
+#endif
+#if (OPTIX_VERSION >= 90000)
__forceinline__ __device__ bool optixIsLSSHit()
{
return optixGetPrimitiveType() == OPTIX_PRIMITIVE_TYPE_ROUND_LINEAR;
}
+#endif
+#if (OPTIX_VERSION >= 90000)
__forceinline__ __device__ bool optixHitObjectIsLSSHit(OptixTraversableHandle* Obj)
{
return optixGetPrimitiveType(optixHitObjectGetHitKind()) == OPTIX_PRIMITIVE_TYPE_ROUND_LINEAR;
}
+#endif
template<typename T>
__forceinline__ __device__ void* optixTraverse(
@@ -3790,60 +3807,79 @@ __forceinline__ __device__ void* optixTraverse(
r1);
}
+#if (OPTIX_VERSION >= 80100)
static __forceinline__ __device__ bool slangOptixHitObjectIsHit(OptixTraversableHandle* hitObj)
{
return optixHitObjectIsHit();
}
+#endif
+#if (OPTIX_VERSION >= 80100)
static __forceinline__ __device__ bool slangOptixHitObjectIsMiss(OptixTraversableHandle* hitObj)
{
return optixHitObjectIsMiss();
}
+#endif
+#if (OPTIX_VERSION >= 80100)
static __forceinline__ __device__ bool slangOptixHitObjectIsNop(OptixTraversableHandle* hitObj)
{
return optixHitObjectIsNop();
}
+#endif
+#if (OPTIX_VERSION >= 90000)
static __forceinline__ __device__ uint
slangOptixHitObjectGetClusterId(OptixTraversableHandle* hitObj)
{
return optixHitObjectGetClusterId();
}
+#endif
+#if (OPTIX_VERSION >= 80100)
static __forceinline__ __device__ void optixMakeMissHitObject(
uint MissShaderIndex,
RayDesc Ray,
OptixTraversableHandle* missObj)
{
-
optixMakeMissHitObject(
MissShaderIndex,
Ray.Origin,
Ray.Direction,
Ray.TMin,
Ray.TMax,
- 0.f, /* rayTime */
- OPTIX_RAY_FLAG_NONE /* rayFlags*/);
+ 0.f /* rayTime */
+#if (OPTIX_VERSION >= 90000)
+ ,
+ OPTIX_RAY_FLAG_NONE /* rayFlags*/
+#endif
+ );
}
+#endif
+#if (OPTIX_VERSION >= 80100)
static __forceinline__ __device__ void optixMakeMissHitObject(
uint MissShaderIndex,
RayDesc Ray,
float CurrentTime,
OptixTraversableHandle* missObj)
{
-
optixMakeMissHitObject(
MissShaderIndex,
Ray.Origin,
Ray.Direction,
Ray.TMin,
Ray.TMax,
- CurrentTime,
- OPTIX_RAY_FLAG_NONE /* rayFlags*/);
+ CurrentTime
+#if (OPTIX_VERSION >= 90000)
+ ,
+ OPTIX_RAY_FLAG_NONE /* rayFlags*/
+#endif
+ );
}
+#endif
+#if (OPTIX_VERSION >= 80100)
template<typename T>
static __forceinline__ __device__ void optixMakeHitObject(
OptixTraversableHandle AccelerationStructure,
@@ -3857,7 +3893,6 @@ static __forceinline__ __device__ void optixMakeHitObject(
T attr,
OptixTraversableHandle* handle)
{
-
OptixTraverseData data{};
optixHitObjectGetTraverseData(&data);
optixMakeHitObject(
@@ -3871,7 +3906,9 @@ static __forceinline__ __device__ void optixMakeHitObject(
nullptr, /*OptixTraversableHandle* transforms*/
0 /*numTransforms */);
}
+#endif
+#if (OPTIX_VERSION >= 80100)
template<typename T>
static __forceinline__ __device__ void optixMakeHitObject(
uint HitGroupRecordIndex,
@@ -3884,7 +3921,6 @@ static __forceinline__ __device__ void optixMakeHitObject(
T attr,
OptixTraversableHandle* handle)
{
-
OptixTraverseData data{};
optixHitObjectGetTraverseData(&data);
optixMakeHitObject(
@@ -3898,7 +3934,9 @@ static __forceinline__ __device__ void optixMakeHitObject(
nullptr, /*OptixTraversableHandle* transforms*/
0 /*numTransforms */);
}
+#endif
+#if (OPTIX_VERSION >= 80100)
template<typename T>
static __forceinline__ __device__ void optixMakeHitObject(
OptixTraversableHandle AccelerationStructure,
@@ -3913,7 +3951,6 @@ static __forceinline__ __device__ void optixMakeHitObject(
T attr,
OptixTraversableHandle* handle)
{
-
OptixTraverseData data{};
optixHitObjectGetTraverseData(&data);
optixMakeHitObject(
@@ -3927,7 +3964,9 @@ static __forceinline__ __device__ void optixMakeHitObject(
nullptr, /*OptixTraversableHandle* transforms*/
0 /*numTransforms */);
}
+#endif
+#if (OPTIX_VERSION >= 80100)
template<typename T>
static __forceinline__ __device__ void optixMakeHitObject(
uint HitGroupRecordIndex,
@@ -3941,7 +3980,6 @@ static __forceinline__ __device__ void optixMakeHitObject(
T attr,
OptixTraversableHandle* handle)
{
-
OptixTraverseData data{};
optixHitObjectGetTraverseData(&data);
optixMakeHitObject(
@@ -3955,12 +3993,16 @@ static __forceinline__ __device__ void optixMakeHitObject(
nullptr, /*OptixTraversableHandle* transforms*/
0 /*numTransforms */);
}
+#endif
+#if (OPTIX_VERSION >= 80100)
static __forceinline__ __device__ void slangOptixMakeNopHitObject(OptixTraversableHandle* Obj)
{
optixMakeNopHitObject();
}
+#endif
+#if (OPTIX_VERSION >= 80100)
template<typename T>
static __forceinline__ __device__ void optixInvoke(
OptixTraversableHandle AccelerationStructure,
@@ -3971,6 +4013,9 @@ static __forceinline__ __device__ void optixInvoke(
packOptiXRayPayloadPointer((void*)Payload, r0, r1);
optixInvoke(r0, r1);
}
+#endif
+
+#if (OPTIX_VERSION >= 80100)
static __forceinline__ __device__ RayDesc optixHitObjectGetRayDesc(OptixTraversableHandle* obj)
{
RayDesc ray = {
@@ -3980,30 +4025,40 @@ static __forceinline__ __device__ RayDesc optixHitObjectGetRayDesc(OptixTraversa
optixHitObjectGetRayTmax()};
return ray;
}
+#endif
+#if (OPTIX_VERSION >= 80100)
static __forceinline__ __device__ uint
slangOptixHitObjectGetInstanceIndex(OptixTraversableHandle* Obj)
{
return optixHitObjectGetInstanceIndex();
}
+#endif
+#if (OPTIX_VERSION >= 80100)
static __forceinline__ __device__ uint slangOptixHitObjectGetInstanceId(OptixTraversableHandle* Obj)
{
return optixHitObjectGetInstanceId();
}
+#endif
+#if (OPTIX_VERSION >= 80100)
static __forceinline__ __device__ uint
slangOptixHitObjectGetSbtGASIndex(OptixTraversableHandle* Obj)
{
return optixHitObjectGetSbtGASIndex();
}
+#endif
+#if (OPTIX_VERSION >= 80100)
static __forceinline__ __device__ uint
slangOptixHitObjectGetPrimitiveIndex(OptixTraversableHandle* Obj)
{
return optixHitObjectGetPrimitiveIndex();
}
+#endif
+#if (OPTIX_VERSION >= 80100)
template<typename T>
static __forceinline__ __device__ T optixHitObjectGetAttribute(OptixTraversableHandle* Obj)
{
@@ -4037,19 +4092,25 @@ static __forceinline__ __device__ T optixHitObjectGetAttribute(OptixTraversableH
memcpy(&result, values, sizeof(T));
return result;
}
+#endif
+#if (OPTIX_VERSION >= 80100)
static __forceinline__ __device__ uint
slangOptixHitObjectGetSbtRecordIndex(OptixTraversableHandle* Obj)
{
return optixHitObjectGetSbtRecordIndex();
}
+#endif
+#if (OPTIX_VERSION >= 90000)
static __forceinline__ __device__ uint
slangOptixHitObjectSetSbtRecordIndex(OptixTraversableHandle* Obj, uint sbtRecordIndex)
{
optixHitObjectSetSbtRecordIndex(sbtRecordIndex); // returns void
return sbtRecordIndex;
}
+#endif
+
#else
// Define OptixTraversableHandle even if OptiX is not enabled.
// This allows RaytracingAccelerationStructure to be properly reflected in non-OptiX code.