summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--prelude/slang-cuda-prelude.h21
-rw-r--r--source/slang/hlsl.meta.slang4
-rw-r--r--tests/cuda/lss-test.slang14
3 files changed, 21 insertions, 18 deletions
diff --git a/prelude/slang-cuda-prelude.h b/prelude/slang-cuda-prelude.h
index d1160cdd3..a1d3da082 100644
--- a/prelude/slang-cuda-prelude.h
+++ b/prelude/slang-cuda-prelude.h
@@ -31,7 +31,6 @@
#ifdef SLANG_CUDA_ENABLE_OPTIX
#include <optix.h>
-#include <optix_device.h>
#endif
// Define slang offsetof implementation
@@ -3226,28 +3225,30 @@ __forceinline__ __device__ float4 optixGetSpherePositionAndRadius()
{
float4 data[1];
optixGetSphereData(data);
- return data;
+ return data[0];
}
-__forceinline__ __device__ float4 optixHitObjectGetSpherePositionAndRadius()
+__forceinline__ __device__ float4
+optixHitObjectGetSpherePositionAndRadius(OptixTraversableHandle* Obj)
{
float4 data[1];
optixHitObjectGetSphereData(data);
- return data;
+ return data[0];
}
-__forceinline__ __device__ Matrix<float, 2, 4> optixGetSpherePositionAndRadius()
+__forceinline__ __device__ Matrix<float, 2, 4> optixGetLssPositionsAndRadii()
{
float4 data[2];
optixGetLinearCurveVertexData(data);
- return Matrix<float, 2, 4>(data[0], data[1]);
+ return makeMatrix<float, 2, 4>(data[0], data[1]);
}
-__forceinline__ __device__ float2x4 optixHitObjectGetSpherePositionAndRadius()
+__forceinline__ __device__ Matrix<float, 2, 4> optixHitObjectGetLssPositionsAndRadii(
+ OptixTraversableHandle* Obj)
{
float4 data[2];
optixHitObjectGetLinearCurveVertexData(data);
- return Matrix<float, 2, 4>(data[0], data[1]);
+ return makeMatrix<float, 2, 4>(data[0], data[1]);
}
__forceinline__ __device__ bool optixIsSphereHit()
@@ -3255,7 +3256,7 @@ __forceinline__ __device__ bool optixIsSphereHit()
return optixGetPrimitiveType() == OPTIX_PRIMITIVE_TYPE_SPHERE;
}
-__forceinline__ __device__ bool optixHitObjectIsSphereHit()
+__forceinline__ __device__ bool optixHitObjectIsSphereHit(OptixTraversableHandle* Obj)
{
return optixGetPrimitiveType(optixHitObjectGetHitKind()) == OPTIX_PRIMITIVE_TYPE_SPHERE;
}
@@ -3265,7 +3266,7 @@ __forceinline__ __device__ bool optixIsLSSHit()
return optixGetPrimitiveType() == OPTIX_PRIMITIVE_TYPE_ROUND_LINEAR;
}
-__forceinline__ __device__ bool optixHitObjectIsLSSHit()
+__forceinline__ __device__ bool optixHitObjectIsLSSHit(OptixTraversableHandle* Obj)
{
return optixGetPrimitiveType(optixHitObjectGetHitKind()) == OPTIX_PRIMITIVE_TYPE_ROUND_LINEAR;
}
diff --git a/source/slang/hlsl.meta.slang b/source/slang/hlsl.meta.slang
index 677b5d7bf..ad1a983dd 100644
--- a/source/slang/hlsl.meta.slang
+++ b/source/slang/hlsl.meta.slang
@@ -17986,7 +17986,7 @@ float2x4 GetLssPositionsAndRadii()
case hlsl: __intrinsic_asm "NvRtLssObjectPositionsAndRadii";
case cuda:
{
- __intrinsic_asm "optixObjectPositionsAndRadii";
+ __intrinsic_asm "optixGetLssPositionsAndRadii";
}
case spirv:
return spirv_asm
@@ -20671,7 +20671,7 @@ struct HitObject
case hlsl: __intrinsic_asm "NvRtLssObjectPositionsAndRadii";
case cuda:
{
- __intrinsic_asm "optixHitObjectGetSpherePositionAndRadius";
+ __intrinsic_asm "optixHitObjectGetLssPositionsAndRadii";
}
case spirv:
return spirv_asm
diff --git a/tests/cuda/lss-test.slang b/tests/cuda/lss-test.slang
index 4b0512cb1..f3052cacb 100644
--- a/tests/cuda/lss-test.slang
+++ b/tests/cuda/lss-test.slang
@@ -1,10 +1,15 @@
//TEST:SIMPLE(filecheck=CHECK): -target cuda
//CHECK_: __global__ void __closesthit__closestHitShaderLss
//CHECK: optixGetSpherePositionAndRadius
-//CHECK: optixObjectPositionsAndRadii
+//CHECK: optixGetLssPositionsAndRadii
//CHECK: optixIsSphereHit
//CHECK: optixIsLSSHit
+//CHECK: optixHitObjectGetSpherePositionAndRadius
+//CHECK: optixHitObjectGetLssPositionsAndRadii
+//CHECK: optixHitObjectIsSphereHit
+//CHECK: optixHitObjectIsLSSHit
+
struct RayPayload
{
float4 color;
@@ -22,13 +27,10 @@ void closestHitShaderLss(inout RayPayload payload, in BuiltInTriangleIntersectio
payload.isSphere = IsSphereHit();
payload.isLss = IsLssHit();
-// TODO: This will be enabled once issue #6647 is completed.
-#if 0
// Test HitObject API functions
HitObject hitObj;
- float4 sphereData = hitObj.GetSphereObjectPositionAndRadius();
- float2x4 lssData = hitObj.GetLssObjectPositionsAndRadii();
+ float4 sphereData = hitObj.GetSpherePositionAndRadius();
+ float2x4 lssData = hitObj.GetLssPositionsAndRadii();
bool isSphereHit = hitObj.IsSphereHit();
bool isLssHit = hitObj.IsLssHit();
-#endif
} \ No newline at end of file