summaryrefslogtreecommitdiffstats
path: root/tools/render-test/cuda/cuda-compute-util.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'tools/render-test/cuda/cuda-compute-util.cpp')
-rw-r--r--tools/render-test/cuda/cuda-compute-util.cpp273
1 files changed, 266 insertions, 7 deletions
diff --git a/tools/render-test/cuda/cuda-compute-util.cpp b/tools/render-test/cuda/cuda-compute-util.cpp
index 48d73fa93..b2006a7e8 100644
--- a/tools/render-test/cuda/cuda-compute-util.cpp
+++ b/tools/render-test/cuda/cuda-compute-util.cpp
@@ -13,6 +13,14 @@
#include <cuda_runtime_api.h>
+// TODO: should conditionalize this on OptiX being present
+#ifdef RENDER_TEST_OPTIX
+#include <optix.h>
+#include <optix_function_table_definition.h>
+#define _CRT_SECURE_NO_WARNINGS 1
+#include <optix_stubs.h>
+#endif
+
namespace renderer_test {
using namespace Slang;
@@ -111,6 +119,36 @@ static SlangResult _handleCUDAError(cudaError_t error)
#define SLANG_CUDA_ASSERT_ON_FAIL(x) { auto _res = x; if (_isError(_res)) { SLANG_ASSERT(!"Failed CUDA call"); }; }
+#ifdef RENDER_TEST_OPTIX
+
+static bool _isError(OptixResult result) { return result != OPTIX_SUCCESS; }
+
+#if 1
+static SlangResult _handleOptixError(OptixResult result, char const* file, int line)
+{
+ fprintf(stderr, "%s(%d): optix: %s (%s)\n",
+ file,
+ line,
+ optixGetErrorString(result),
+ optixGetErrorName(result));
+ return SLANG_FAIL;
+}
+#define SLANG_OPTIX_HANDLE_ERROR(RESULT) _handleOptixError(RESULT, __FILE__, __LINE__)
+#else
+#define SLANG_OPTIX_HANDLE_ERROR(RESULT) SLANG_FAIL
+#endif
+
+#define SLANG_OPTIX_RETURN_ON_FAIL(EXPR) do { auto _res = EXPR; if(_isError(_res)) return SLANG_OPTIX_HANDLE_ERROR(_res); } while(0)
+
+void _optixLogCallback(unsigned int level, const char* tag, const char* message, void* userData)
+{
+ fprintf(stderr, "optix: %s (%s)\n",
+ message,
+ tag);
+}
+
+#endif
+
class MemoryCUDAResource : public CUDAResource
{
public:
@@ -1202,17 +1240,238 @@ static SlangResult _compute(CUcontext context, CUmodule module, const ShaderComp
ScopeCUDAContext cudaContext;
SLANG_RETURN_ON_FAIL(cudaContext.init(0));
- const Index index = outputAndLayout.output.findKernelDescIndex(StageType::Compute);
- if (index < 0)
+
+ switch( outputAndLayout.output.desc.pipelineType )
{
+ default:
return SLANG_FAIL;
- }
- const auto& kernel = outputAndLayout.output.kernelDescs[index];
+ case PipelineType::Compute:
+ {
+ const Index index = outputAndLayout.output.findKernelDescIndex(StageType::Compute);
+ if (index < 0)
+ {
+ return SLANG_FAIL;
+ }
+
+ const auto& kernel = outputAndLayout.output.kernelDescs[index];
+
+ ScopeCUDAModule cudaModule;
+ SLANG_RETURN_ON_FAIL(cudaModule.load(kernel.codeBegin));
+ SLANG_RETURN_ON_FAIL(_compute(cudaContext, cudaModule, outputAndLayout, dispatchSize, outContext));
+ }
+ break;
+
+ case PipelineType::RayTracing:
+ {
+#ifdef RENDER_TEST_OPTIX
+ SLANG_OPTIX_RETURN_ON_FAIL(optixInit());
+
+ OptixDeviceContextOptions optixOptions = {};
+
+ // TODO: set log callback
+ optixOptions.logCallbackFunction = &_optixLogCallback;
+ optixOptions.logCallbackLevel = 4;
+
+ OptixDeviceContext optixContext = nullptr;
+ SLANG_OPTIX_RETURN_ON_FAIL(optixDeviceContextCreate(cudaContext, &optixOptions, &optixContext));
+
+ enum { kOptixLogSize = 2*1024 };
+ char log[kOptixLogSize];
+ size_t logSize = sizeof(log);
+
+ OptixPipelineCompileOptions optixPipelineCompileOptions = {};
+
+ // We need to load modules from the PTX code available to us,
+ // and then also create program groups from the kernels
+ // in those modules.
+ //
+ // For now we will only support program groups with a single
+ // kernel in them, and will create one per entry point.
+ //
+ Index entryPointCount = outputAndLayout.output.kernelDescs.getCount();
+ List<OptixProgramGroup> optixProgramGroups;
+ List<String> names;
+
+ OptixShaderBindingTable optixSBT = {};
+
+ for( Index ee = 0; ee < entryPointCount; ++ee )
+ {
+ auto& kernel = outputAndLayout.output.kernelDescs[ee];
- ScopeCUDAModule cudaModule;
- SLANG_RETURN_ON_FAIL(cudaModule.load(kernel.codeBegin));
- SLANG_RETURN_ON_FAIL(_compute(cudaContext, cudaModule, outputAndLayout, dispatchSize, outContext));
+ OptixModuleCompileOptions optixModuleCompileOptions = {};
+
+ OptixModule optixModule;
+ SLANG_OPTIX_RETURN_ON_FAIL(optixModuleCreateFromPTX(
+ optixContext,
+ &optixModuleCompileOptions,
+ &optixPipelineCompileOptions,
+ (char const*) kernel.codeBegin,
+ kernel.getCodeSize(),
+ log,
+ &logSize,
+ &optixModule));
+
+
+ OptixProgramGroupOptions optixProgramGroupOptions = {};
+
+ OptixProgramGroupDesc optixProgramGroupDesc = {};
+ optixProgramGroupDesc.kind = OPTIX_PROGRAM_GROUP_KIND_RAYGEN;
+ optixProgramGroupDesc.raygen.module = optixModule;
+
+ String name = String("__raygen__") + kernel.entryPointName;
+ names.add(name);
+ optixProgramGroupDesc.raygen.entryFunctionName = name.begin();
+
+ OptixProgramGroup optixProgramGroup = nullptr;
+ SLANG_OPTIX_RETURN_ON_FAIL(optixProgramGroupCreate(
+ optixContext,
+ &optixProgramGroupDesc,
+ 1,
+ &optixProgramGroupOptions,
+ log,
+ &logSize,
+ &optixProgramGroup));
+
+ optixProgramGroups.add(optixProgramGroup);
+
+ {
+ CUdeviceptr rayGenRecordPtr;
+ size_t rayGenRecordSize = OPTIX_SBT_RECORD_HEADER_SIZE;
+
+ SLANG_CUDA_RETURN_ON_FAIL(cudaMalloc((void**) &rayGenRecordPtr, rayGenRecordSize));
+
+ struct { char data[OPTIX_SBT_RECORD_HEADER_SIZE]; } rayGenRecordData;
+ SLANG_OPTIX_RETURN_ON_FAIL(optixSbtRecordPackHeader(optixProgramGroup, &rayGenRecordData));
+
+ SLANG_CUDA_RETURN_ON_FAIL(cudaMemcpy(
+ (void*) rayGenRecordPtr,
+ &rayGenRecordData,
+ rayGenRecordSize,
+ cudaMemcpyHostToDevice));
+
+ optixSBT.raygenRecord = rayGenRecordPtr;
+ }
+ }
+
+
+
+ OptixPipeline optixPipeline = nullptr;
+
+ OptixPipelineLinkOptions optixPipelineLinkOptions = {};
+ optixPipelineLinkOptions.maxTraceDepth = 5;
+ optixPipelineLinkOptions.debugLevel = OPTIX_COMPILE_DEBUG_LEVEL_FULL;
+ optixPipelineLinkOptions.overrideUsesMotionBlur = false;
+ SLANG_OPTIX_RETURN_ON_FAIL(optixPipelineCreate(
+ optixContext,
+ &optixPipelineCompileOptions,
+ &optixPipelineLinkOptions,
+ optixProgramGroups.getBuffer(),
+ (unsigned int)optixProgramGroups.getCount(),
+ log,
+ &logSize,
+ &optixPipeline));
+
+
+ {
+ // The OptiX API complains if we don't fill in a miss record
+ // in the SBT, so we will create a dummy one here to represent
+ // the lack of any miss shaders.
+ //
+ OptixProgramGroupOptions optixProgramGroupOptions = {};
+ OptixProgramGroupDesc missGroupDesc = {};
+ missGroupDesc.kind = OPTIX_PROGRAM_GROUP_KIND_MISS;
+ OptixProgramGroup missProgramGroup;
+ SLANG_OPTIX_RETURN_ON_FAIL(optixProgramGroupCreate(
+ optixContext,
+ &missGroupDesc,
+ 1,
+ &optixProgramGroupOptions,
+ log,
+ &logSize,
+ &missProgramGroup));
+
+
+ CUdeviceptr missRecordPtr;
+ size_t missRecordSize = OPTIX_SBT_RECORD_HEADER_SIZE;
+
+ SLANG_CUDA_RETURN_ON_FAIL(cudaMalloc((void**) &missRecordPtr, missRecordSize));
+
+ struct { char data[OPTIX_SBT_RECORD_HEADER_SIZE]; } missRecordData;
+ SLANG_OPTIX_RETURN_ON_FAIL(optixSbtRecordPackHeader(missProgramGroup, &missRecordData));
+
+ SLANG_CUDA_RETURN_ON_FAIL(cudaMemcpy(
+ (void*) missRecordPtr,
+ &missRecordData,
+ missRecordSize,
+ cudaMemcpyHostToDevice));
+
+ optixSBT.missRecordBase = missRecordPtr;
+ optixSBT.missRecordCount = 1;
+ optixSBT.missRecordStrideInBytes = missRecordSize;
+ }
+ {
+ // Okay, we also need a dummy hit group.
+
+ OptixProgramGroupOptions optixProgramGroupOptions = {};
+ OptixProgramGroupDesc hitGroupDesc = {};
+ hitGroupDesc.kind = OPTIX_PROGRAM_GROUP_KIND_HITGROUP;
+ OptixProgramGroup programGroup;
+ SLANG_OPTIX_RETURN_ON_FAIL(optixProgramGroupCreate(
+ optixContext,
+ &hitGroupDesc,
+ 1,
+ &optixProgramGroupOptions,
+ log,
+ &logSize,
+ &programGroup));
+
+
+ CUdeviceptr recordPtr;
+ size_t recordSize = OPTIX_SBT_RECORD_HEADER_SIZE;
+
+ SLANG_CUDA_RETURN_ON_FAIL(cudaMalloc((void**) &recordPtr, recordSize));
+
+ struct { char data[OPTIX_SBT_RECORD_HEADER_SIZE]; } recordData;
+ SLANG_OPTIX_RETURN_ON_FAIL(optixSbtRecordPackHeader(programGroup, &recordData));
+
+ SLANG_CUDA_RETURN_ON_FAIL(cudaMemcpy(
+ (void*) recordPtr,
+ &recordData,
+ recordSize,
+ cudaMemcpyHostToDevice));
+
+ optixSBT.hitgroupRecordBase = recordPtr;
+ optixSBT.hitgroupRecordCount = 1;
+ optixSBT.hitgroupRecordStrideInBytes = recordSize;
+ }
+
+ ScopeCUDAStream cudaStream;
+
+ CUdeviceptr globalParams = 0;
+ size_t globalParamsSize = 0;
+
+ unsigned int gridSizeX = 1;
+ unsigned int gridSizeY = 1;
+ unsigned int gridSizeZ = 1;
+
+
+ SLANG_OPTIX_RETURN_ON_FAIL(optixLaunch(
+ optixPipeline,
+ cudaStream,
+ globalParams,
+ globalParamsSize,
+ &optixSBT,
+ gridSizeX,
+ gridSizeY,
+ gridSizeZ));
+
+
+ SLANG_RETURN_ON_FAIL(cudaStream.sync());
+#endif
+ }
+ break;
+ }
return SLANG_OK;
}