summaryrefslogtreecommitdiffstats
path: root/tools
diff options
context:
space:
mode:
Diffstat (limited to 'tools')
-rw-r--r--tools/gfx/render.h15
-rw-r--r--tools/render-test/cuda/cuda-compute-util.cpp273
-rw-r--r--tools/render-test/options.cpp4
-rw-r--r--tools/render-test/options.h3
-rw-r--r--tools/render-test/render-test-main.cpp25
-rw-r--r--tools/render-test/slang-support.cpp267
-rw-r--r--tools/render-test/slang-support.h5
7 files changed, 455 insertions, 137 deletions
diff --git a/tools/gfx/render.h b/tools/gfx/render.h
index a4d042a9a..423820a0b 100644
--- a/tools/gfx/render.h
+++ b/tools/gfx/render.h
@@ -45,6 +45,7 @@ enum class PipelineType
Unknown,
Graphics,
Compute,
+ RayTracing,
CountOf,
};
@@ -57,6 +58,12 @@ enum class StageType
Geometry,
Fragment,
Compute,
+ RayGeneration,
+ Intersection,
+ AnyHit,
+ ClosestHit,
+ Miss,
+ Callable,
CountOf,
};
@@ -102,6 +109,7 @@ public:
StageType stage;
void const* codeBegin;
void const* codeEnd;
+ char const* entryPointName;
UInt getCodeSize() const { return (char const*)codeEnd - (char const*)codeBegin; }
};
@@ -141,13 +149,12 @@ struct ShaderCompileRequest
struct EntryPoint
{
char const* name = nullptr;
- SourceInfo source;
+ SlangStage slangStage;
};
SourceInfo source;
- EntryPoint vertexShader;
- EntryPoint fragmentShader;
- EntryPoint computeShader;
+ Slang::List<EntryPoint> entryPoints;
+
Slang::List<Slang::String> globalSpecializationArgs;
Slang::List<Slang::String> entryPointSpecializationArgs;
diff --git a/tools/render-test/cuda/cuda-compute-util.cpp b/tools/render-test/cuda/cuda-compute-util.cpp
index 48d73fa93..b2006a7e8 100644
--- a/tools/render-test/cuda/cuda-compute-util.cpp
+++ b/tools/render-test/cuda/cuda-compute-util.cpp
@@ -13,6 +13,14 @@
#include <cuda_runtime_api.h>
+// TODO: should conditionalize this on OptiX being present
+#ifdef RENDER_TEST_OPTIX
+#include <optix.h>
+#include <optix_function_table_definition.h>
+#define _CRT_SECURE_NO_WARNINGS 1
+#include <optix_stubs.h>
+#endif
+
namespace renderer_test {
using namespace Slang;
@@ -111,6 +119,36 @@ static SlangResult _handleCUDAError(cudaError_t error)
#define SLANG_CUDA_ASSERT_ON_FAIL(x) { auto _res = x; if (_isError(_res)) { SLANG_ASSERT(!"Failed CUDA call"); }; }
+#ifdef RENDER_TEST_OPTIX
+
+static bool _isError(OptixResult result) { return result != OPTIX_SUCCESS; }
+
+#if 1
+static SlangResult _handleOptixError(OptixResult result, char const* file, int line)
+{
+ fprintf(stderr, "%s(%d): optix: %s (%s)\n",
+ file,
+ line,
+ optixGetErrorString(result),
+ optixGetErrorName(result));
+ return SLANG_FAIL;
+}
+#define SLANG_OPTIX_HANDLE_ERROR(RESULT) _handleOptixError(RESULT, __FILE__, __LINE__)
+#else
+#define SLANG_OPTIX_HANDLE_ERROR(RESULT) SLANG_FAIL
+#endif
+
+#define SLANG_OPTIX_RETURN_ON_FAIL(EXPR) do { auto _res = EXPR; if(_isError(_res)) return SLANG_OPTIX_HANDLE_ERROR(_res); } while(0)
+
+void _optixLogCallback(unsigned int level, const char* tag, const char* message, void* userData)
+{
+ fprintf(stderr, "optix: %s (%s)\n",
+ message,
+ tag);
+}
+
+#endif
+
class MemoryCUDAResource : public CUDAResource
{
public:
@@ -1202,17 +1240,238 @@ static SlangResult _compute(CUcontext context, CUmodule module, const ShaderComp
ScopeCUDAContext cudaContext;
SLANG_RETURN_ON_FAIL(cudaContext.init(0));
- const Index index = outputAndLayout.output.findKernelDescIndex(StageType::Compute);
- if (index < 0)
+
+ switch( outputAndLayout.output.desc.pipelineType )
{
+ default:
return SLANG_FAIL;
- }
- const auto& kernel = outputAndLayout.output.kernelDescs[index];
+ case PipelineType::Compute:
+ {
+ const Index index = outputAndLayout.output.findKernelDescIndex(StageType::Compute);
+ if (index < 0)
+ {
+ return SLANG_FAIL;
+ }
+
+ const auto& kernel = outputAndLayout.output.kernelDescs[index];
+
+ ScopeCUDAModule cudaModule;
+ SLANG_RETURN_ON_FAIL(cudaModule.load(kernel.codeBegin));
+ SLANG_RETURN_ON_FAIL(_compute(cudaContext, cudaModule, outputAndLayout, dispatchSize, outContext));
+ }
+ break;
+
+ case PipelineType::RayTracing:
+ {
+#ifdef RENDER_TEST_OPTIX
+ SLANG_OPTIX_RETURN_ON_FAIL(optixInit());
+
+ OptixDeviceContextOptions optixOptions = {};
+
+ // TODO: set log callback
+ optixOptions.logCallbackFunction = &_optixLogCallback;
+ optixOptions.logCallbackLevel = 4;
+
+ OptixDeviceContext optixContext = nullptr;
+ SLANG_OPTIX_RETURN_ON_FAIL(optixDeviceContextCreate(cudaContext, &optixOptions, &optixContext));
+
+ enum { kOptixLogSize = 2*1024 };
+ char log[kOptixLogSize];
+ size_t logSize = sizeof(log);
+
+ OptixPipelineCompileOptions optixPipelineCompileOptions = {};
+
+ // We need to load modules from the PTX code available to us,
+ // and then also create program groups from the kernels
+ // in those modules.
+ //
+ // For now we will only support program groups with a single
+ // kernel in them, and will create one per entry point.
+ //
+ Index entryPointCount = outputAndLayout.output.kernelDescs.getCount();
+ List<OptixProgramGroup> optixProgramGroups;
+ List<String> names;
+
+ OptixShaderBindingTable optixSBT = {};
+
+ for( Index ee = 0; ee < entryPointCount; ++ee )
+ {
+ auto& kernel = outputAndLayout.output.kernelDescs[ee];
- ScopeCUDAModule cudaModule;
- SLANG_RETURN_ON_FAIL(cudaModule.load(kernel.codeBegin));
- SLANG_RETURN_ON_FAIL(_compute(cudaContext, cudaModule, outputAndLayout, dispatchSize, outContext));
+ OptixModuleCompileOptions optixModuleCompileOptions = {};
+
+ OptixModule optixModule;
+ SLANG_OPTIX_RETURN_ON_FAIL(optixModuleCreateFromPTX(
+ optixContext,
+ &optixModuleCompileOptions,
+ &optixPipelineCompileOptions,
+ (char const*) kernel.codeBegin,
+ kernel.getCodeSize(),
+ log,
+ &logSize,
+ &optixModule));
+
+
+ OptixProgramGroupOptions optixProgramGroupOptions = {};
+
+ OptixProgramGroupDesc optixProgramGroupDesc = {};
+ optixProgramGroupDesc.kind = OPTIX_PROGRAM_GROUP_KIND_RAYGEN;
+ optixProgramGroupDesc.raygen.module = optixModule;
+
+ String name = String("__raygen__") + kernel.entryPointName;
+ names.add(name);
+ optixProgramGroupDesc.raygen.entryFunctionName = name.begin();
+
+ OptixProgramGroup optixProgramGroup = nullptr;
+ SLANG_OPTIX_RETURN_ON_FAIL(optixProgramGroupCreate(
+ optixContext,
+ &optixProgramGroupDesc,
+ 1,
+ &optixProgramGroupOptions,
+ log,
+ &logSize,
+ &optixProgramGroup));
+
+ optixProgramGroups.add(optixProgramGroup);
+
+ {
+ CUdeviceptr rayGenRecordPtr;
+ size_t rayGenRecordSize = OPTIX_SBT_RECORD_HEADER_SIZE;
+
+ SLANG_CUDA_RETURN_ON_FAIL(cudaMalloc((void**) &rayGenRecordPtr, rayGenRecordSize));
+
+ struct { char data[OPTIX_SBT_RECORD_HEADER_SIZE]; } rayGenRecordData;
+ SLANG_OPTIX_RETURN_ON_FAIL(optixSbtRecordPackHeader(optixProgramGroup, &rayGenRecordData));
+
+ SLANG_CUDA_RETURN_ON_FAIL(cudaMemcpy(
+ (void*) rayGenRecordPtr,
+ &rayGenRecordData,
+ rayGenRecordSize,
+ cudaMemcpyHostToDevice));
+
+ optixSBT.raygenRecord = rayGenRecordPtr;
+ }
+ }
+
+
+
+ OptixPipeline optixPipeline = nullptr;
+
+ OptixPipelineLinkOptions optixPipelineLinkOptions = {};
+ optixPipelineLinkOptions.maxTraceDepth = 5;
+ optixPipelineLinkOptions.debugLevel = OPTIX_COMPILE_DEBUG_LEVEL_FULL;
+ optixPipelineLinkOptions.overrideUsesMotionBlur = false;
+ SLANG_OPTIX_RETURN_ON_FAIL(optixPipelineCreate(
+ optixContext,
+ &optixPipelineCompileOptions,
+ &optixPipelineLinkOptions,
+ optixProgramGroups.getBuffer(),
+ (unsigned int)optixProgramGroups.getCount(),
+ log,
+ &logSize,
+ &optixPipeline));
+
+
+ {
+ // The OptiX API complains if we don't fill in a miss record
+ // in the SBT, so we will create a dummy one here to represent
+ // the lack of any miss shaders.
+ //
+ OptixProgramGroupOptions optixProgramGroupOptions = {};
+ OptixProgramGroupDesc missGroupDesc = {};
+ missGroupDesc.kind = OPTIX_PROGRAM_GROUP_KIND_MISS;
+ OptixProgramGroup missProgramGroup;
+ SLANG_OPTIX_RETURN_ON_FAIL(optixProgramGroupCreate(
+ optixContext,
+ &missGroupDesc,
+ 1,
+ &optixProgramGroupOptions,
+ log,
+ &logSize,
+ &missProgramGroup));
+
+
+ CUdeviceptr missRecordPtr;
+ size_t missRecordSize = OPTIX_SBT_RECORD_HEADER_SIZE;
+
+ SLANG_CUDA_RETURN_ON_FAIL(cudaMalloc((void**) &missRecordPtr, missRecordSize));
+
+ struct { char data[OPTIX_SBT_RECORD_HEADER_SIZE]; } missRecordData;
+ SLANG_OPTIX_RETURN_ON_FAIL(optixSbtRecordPackHeader(missProgramGroup, &missRecordData));
+
+ SLANG_CUDA_RETURN_ON_FAIL(cudaMemcpy(
+ (void*) missRecordPtr,
+ &missRecordData,
+ missRecordSize,
+ cudaMemcpyHostToDevice));
+
+ optixSBT.missRecordBase = missRecordPtr;
+ optixSBT.missRecordCount = 1;
+ optixSBT.missRecordStrideInBytes = missRecordSize;
+ }
+ {
+ // Okay, we also need a dummy hit group.
+
+ OptixProgramGroupOptions optixProgramGroupOptions = {};
+ OptixProgramGroupDesc hitGroupDesc = {};
+ hitGroupDesc.kind = OPTIX_PROGRAM_GROUP_KIND_HITGROUP;
+ OptixProgramGroup programGroup;
+ SLANG_OPTIX_RETURN_ON_FAIL(optixProgramGroupCreate(
+ optixContext,
+ &hitGroupDesc,
+ 1,
+ &optixProgramGroupOptions,
+ log,
+ &logSize,
+ &programGroup));
+
+
+ CUdeviceptr recordPtr;
+ size_t recordSize = OPTIX_SBT_RECORD_HEADER_SIZE;
+
+ SLANG_CUDA_RETURN_ON_FAIL(cudaMalloc((void**) &recordPtr, recordSize));
+
+ struct { char data[OPTIX_SBT_RECORD_HEADER_SIZE]; } recordData;
+ SLANG_OPTIX_RETURN_ON_FAIL(optixSbtRecordPackHeader(programGroup, &recordData));
+
+ SLANG_CUDA_RETURN_ON_FAIL(cudaMemcpy(
+ (void*) recordPtr,
+ &recordData,
+ recordSize,
+ cudaMemcpyHostToDevice));
+
+ optixSBT.hitgroupRecordBase = recordPtr;
+ optixSBT.hitgroupRecordCount = 1;
+ optixSBT.hitgroupRecordStrideInBytes = recordSize;
+ }
+
+ ScopeCUDAStream cudaStream;
+
+ CUdeviceptr globalParams = 0;
+ size_t globalParamsSize = 0;
+
+ unsigned int gridSizeX = 1;
+ unsigned int gridSizeY = 1;
+ unsigned int gridSizeZ = 1;
+
+
+ SLANG_OPTIX_RETURN_ON_FAIL(optixLaunch(
+ optixPipeline,
+ cudaStream,
+ globalParams,
+ globalParamsSize,
+ &optixSBT,
+ gridSizeX,
+ gridSizeY,
+ gridSizeZ));
+
+
+ SLANG_RETURN_ON_FAIL(cudaStream.sync());
+#endif
+ }
+ break;
+ }
return SLANG_OK;
}
diff --git a/tools/render-test/options.cpp b/tools/render-test/options.cpp
index 8331de07e..c2afe78ac 100644
--- a/tools/render-test/options.cpp
+++ b/tools/render-test/options.cpp
@@ -152,6 +152,10 @@ SlangResult parseOptions(int argc, const char*const* argv, Slang::WriterHelper s
{
gOptions.shaderType = ShaderProgramType::GraphicsCompute;
}
+ else if (strcmp(arg, "-rt") == 0)
+ {
+ gOptions.shaderType = ShaderProgramType::RayTracing;
+ }
else if( strcmp(arg, "-use-dxil") == 0 )
{
gOptions.useDXIL = true;
diff --git a/tools/render-test/options.h b/tools/render-test/options.h
index a8b7d5884..f2f0a8ab6 100644
--- a/tools/render-test/options.h
+++ b/tools/render-test/options.h
@@ -36,7 +36,8 @@ struct Options
{
Graphics,
Compute,
- GraphicsCompute
+ GraphicsCompute,
+ RayTracing,
};
char const* appName = "render-test";
diff --git a/tools/render-test/render-test-main.cpp b/tools/render-test/render-test-main.cpp
index 1d88ee500..f966d150e 100644
--- a/tools/render-test/render-test-main.cpp
+++ b/tools/render-test/render-test-main.cpp
@@ -212,7 +212,7 @@ SlangResult RenderTestApp::initialize(SlangSession* session, Renderer* renderer,
Result RenderTestApp::_initializeShaders(SlangSession* session, Renderer* renderer, Options::ShaderProgramType shaderType, const ShaderCompilerUtil::Input& input)
{
- SLANG_RETURN_ON_FAIL(ShaderCompilerUtil::compileWithLayout(session, gOptions.sourcePath, gOptions.compileArgs, gOptions.shaderType, input, m_compilationOutput));
+ SLANG_RETURN_ON_FAIL(ShaderCompilerUtil::compileWithLayout(session, gOptions, input, m_compilationOutput));
m_shaderInputLayout = m_compilationOutput.layout;
m_shaderProgram = renderer->createProgram(m_compilationOutput.output.desc);
return m_shaderProgram ? SLANG_OK : SLANG_FAIL;
@@ -500,6 +500,25 @@ static SlangResult _innerMain(Slang::StdWriters* stdWriters, SlangSession* sessi
break;
}
+ switch( gOptions.shaderType )
+ {
+ case Options::ShaderProgramType::Graphics:
+ case Options::ShaderProgramType::GraphicsCompute:
+ input.pipelineType = PipelineType::Graphics;
+ break;
+
+ case Options::ShaderProgramType::Compute:
+ input.pipelineType = PipelineType::Compute;
+ break;
+
+ case Options::ShaderProgramType::RayTracing:
+ input.pipelineType = PipelineType::RayTracing;
+ break;
+
+ default:
+ break;
+ }
+
if (gOptions.sourceLanguage != SLANG_SOURCE_LANGUAGE_UNKNOWN)
{
input.sourceLanguage = gOptions.sourceLanguage;
@@ -554,7 +573,7 @@ static SlangResult _innerMain(Slang::StdWriters* stdWriters, SlangSession* sessi
}
ShaderCompilerUtil::OutputAndLayout compilationAndLayout;
- SLANG_RETURN_ON_FAIL(ShaderCompilerUtil::compileWithLayout(session, gOptions.sourcePath, gOptions.compileArgs, gOptions.shaderType, input, compilationAndLayout));
+ SLANG_RETURN_ON_FAIL(ShaderCompilerUtil::compileWithLayout(session, gOptions, input, compilationAndLayout));
{
// Get the shared library -> it contains the executable code, we need to keep around if we recompile
@@ -575,7 +594,7 @@ static SlangResult _innerMain(Slang::StdWriters* stdWriters, SlangSession* sessi
// We just want CPP, so we get suitable reflection
slangInput.target = SLANG_CPP_SOURCE;
- SLANG_RETURN_ON_FAIL(ShaderCompilerUtil::compileWithLayout(session, gOptions.sourcePath, gOptions.compileArgs, gOptions.shaderType, slangInput, compilationAndLayout));
+ SLANG_RETURN_ON_FAIL(ShaderCompilerUtil::compileWithLayout(session, gOptions, slangInput, compilationAndLayout));
}
// calculate binding
diff --git a/tools/render-test/slang-support.cpp b/tools/render-test/slang-support.cpp
index 5afcc6d24..3e5cc9a1c 100644
--- a/tools/render-test/slang-support.cpp
+++ b/tools/render-test/slang-support.cpp
@@ -18,6 +18,37 @@ using namespace Slang;
static const char vertexEntryPointName[] = "vertexMain";
static const char fragmentEntryPointName[] = "fragmentMain";
static const char computeEntryPointName[] = "computeMain";
+static const char rtEntryPointName[] = "raygenMain";
+
+static gfx::StageType _translateStage(SlangStage slangStage)
+{
+ switch(slangStage)
+ {
+ default:
+ SLANG_ASSERT(!"unhandled case");
+ return gfx::StageType::Unknown;
+
+#define CASE(FROM, TO) \
+ case SLANG_STAGE_##FROM: return gfx::StageType::TO
+
+ CASE(VERTEX, Vertex);
+ CASE(HULL, Hull);
+ CASE(DOMAIN, Domain);
+ CASE(GEOMETRY, Geometry);
+ CASE(FRAGMENT, Fragment);
+
+ CASE(COMPUTE, Compute);
+
+ CASE(RAY_GENERATION, RayGeneration);
+ CASE(INTERSECTION, Intersection);
+ CASE(ANY_HIT, AnyHit);
+ CASE(CLOSEST_HIT, ClosestHit);
+ CASE(MISS, Miss);
+ CASE(CALLABLE, Callable);
+
+#undef CASE
+ }
+}
/* static */ SlangResult ShaderCompilerUtil::compileProgram(SlangSession* session, const Input& input, const ShaderCompileRequest& request, Output& out)
{
@@ -80,46 +111,12 @@ static const char computeEntryPointName[] = "computeMain";
// the `-xslang <arg>` option to `render-test`.
SLANG_RETURN_ON_FAIL(spProcessCommandLineArguments(slangRequest, input.args, input.argCount));
- int computeTranslationUnit = 0;
- int vertexTranslationUnit = 0;
- int fragmentTranslationUnit = 0;
- char const* vertexEntryPointName = request.vertexShader.name;
- char const* fragmentEntryPointName = request.fragmentShader.name;
- char const* computeEntryPointName = request.computeShader.name;
-
const auto sourceLanguage = input.sourceLanguage;
- if (sourceLanguage == SLANG_SOURCE_LANGUAGE_GLSL)
- {
- // GLSL presents unique challenges because, frankly, it got the whole
- // compilation model wrong. One aspect of working around this is that
- // we will compile the same source file multiple times: once per
- // entry point, and we will have different preprocessor definitions
- // active in each case.
-
- vertexTranslationUnit = spAddTranslationUnit(slangRequest, sourceLanguage, nullptr);
- spAddTranslationUnitSourceString(slangRequest, vertexTranslationUnit, request.source.path, request.source.dataBegin);
- spTranslationUnit_addPreprocessorDefine(slangRequest, vertexTranslationUnit, "__GLSL_VERTEX__", "1");
- vertexEntryPointName = "main";
-
- fragmentTranslationUnit = spAddTranslationUnit(slangRequest, sourceLanguage, nullptr);
- spAddTranslationUnitSourceString(slangRequest, fragmentTranslationUnit, request.source.path, request.source.dataBegin);
- spTranslationUnit_addPreprocessorDefine(slangRequest, fragmentTranslationUnit, "__GLSL_FRAGMENT__", "1");
- fragmentEntryPointName = "main";
-
- computeTranslationUnit = spAddTranslationUnit(slangRequest, sourceLanguage, nullptr);
- spAddTranslationUnitSourceString(slangRequest, computeTranslationUnit, request.source.path, request.source.dataBegin);
- spTranslationUnit_addPreprocessorDefine(slangRequest, computeTranslationUnit, "__GLSL_COMPUTE__", "1");
- computeEntryPointName = "main";
- }
- else
+ int translationUnitIndex = 0;
{
- int translationUnit = spAddTranslationUnit(slangRequest, sourceLanguage, nullptr);
- spAddTranslationUnitSourceString(slangRequest, translationUnit, request.source.path, request.source.dataBegin);
-
- vertexTranslationUnit = translationUnit;
- fragmentTranslationUnit = translationUnit;
- computeTranslationUnit = translationUnit;
+ translationUnitIndex = spAddTranslationUnit(slangRequest, sourceLanguage, nullptr);
+ spAddTranslationUnitSourceString(slangRequest, translationUnitIndex, request.source.path, request.source.dataBegin);
}
const int globalSpecializationArgCount = int(request.globalSpecializationArgs.getCount());
@@ -137,105 +134,100 @@ static const char computeEntryPointName[] = "computeMain";
}
};
- if (request.computeShader.name)
- {
- int computeEntryPointIndex = 0;
- if(!gOptions.dontAddDefaultEntryPoints)
+ Index explicitEntryPointCount = request.entryPoints.getCount();
+ for(Index ee = 0; ee < explicitEntryPointCount; ++ee)
+ {
+ if(gOptions.dontAddDefaultEntryPoints)
{
- computeEntryPointIndex = spAddEntryPoint(slangRequest, computeTranslationUnit,
- computeEntryPointName,
- SLANG_STAGE_COMPUTE);
-
- setEntryPointSpecializationArgs(computeEntryPointIndex);
+ // If default entry points are not to be added, then
+ // the `request.entryPoints` array should have been
+ // left empty.
+ //
+ SLANG_ASSERT(false);
}
- spSetLineDirectiveMode(slangRequest, SLANG_LINE_DIRECTIVE_MODE_NONE);
+ auto& entryPointInfo = request.entryPoints[ee];
+ int entryPointIndex = spAddEntryPoint(
+ slangRequest,
+ translationUnitIndex,
+ entryPointInfo.name,
+ entryPointInfo.slangStage);
+ SLANG_ASSERT(entryPointIndex == ee);
- const SlangResult res = spCompile(slangRequest);
+ setEntryPointSpecializationArgs(entryPointIndex);
+ }
- if (auto diagnostics = spGetDiagnosticOutput(slangRequest))
- {
- fprintf(stderr, "%s", diagnostics);
- }
+ spSetLineDirectiveMode(slangRequest, SLANG_LINE_DIRECTIVE_MODE_NONE);
- SLANG_RETURN_ON_FAIL(res);
+ const SlangResult res = spCompile(slangRequest);
- // We are going to get the entry point count... lets check what we have
- if (input.passThrough == SLANG_PASS_THROUGH_NONE)
- {
- auto reflection = spGetReflection(slangRequest);
- // Get the amount of entry points in reflection
- const int entryPointCount = int(spReflection_getEntryPointCount(reflection));
+ if (auto diagnostics = spGetDiagnosticOutput(slangRequest))
+ {
+ fprintf(stderr, "%s", diagnostics);
+ }
- // Above code assumes there is an entry point
- SLANG_ASSERT(entryPointCount && computeEntryPointIndex < entryPointCount);
+ SLANG_RETURN_ON_FAIL(res);
- auto entryPoint = spReflection_getEntryPointByIndex(reflection, computeEntryPointIndex);
+
+ List<ShaderCompileRequest::EntryPoint> actualEntryPoints;
+ if(input.passThrough == SLANG_PASS_THROUGH_NONE)
+ {
+ // In the case where pass-through compilation is not being used,
+ // we can use the Slang reflection information to discover what
+ // the entry points were, and then use those to drive the
+ // loading of code.
+ //
+ auto reflection = slang::ProgramLayout::get(slangRequest);
- // Get the entry point name
- const char* entryPointName = spReflectionEntryPoint_getName(entryPoint);
+ // Get the amount of entry points in reflection
+ Index entryPointCount = Index(reflection->getEntryPointCount());
- SLANG_ASSERT(entryPointName);
- }
+ // We must have at least one entry point (whether explicit or implicit)
+ SLANG_ASSERT(entryPointCount);
+ for(Index ee = 0; ee < entryPointCount; ++ee)
{
- size_t codeSize = 0;
- char const* code = (char const*) spGetEntryPointCode(slangRequest, computeEntryPointIndex, &codeSize);
+ auto entryPoint = reflection->getEntryPointByIndex(ee);
+ const char* entryPointName = entryPoint->getName();
+ SLANG_ASSERT(entryPointName);
+
+ auto slangStage = entryPoint->getStage();
- ShaderProgram::KernelDesc kernelDesc;
- kernelDesc.stage = StageType::Compute;
- kernelDesc.codeBegin = code;
- kernelDesc.codeEnd = code + codeSize;
+ ShaderCompileRequest::EntryPoint entryPointInfo;
+ entryPointInfo.name = entryPointName;
+ entryPointInfo.slangStage = slangStage;
- out.set(PipelineType::Compute, &kernelDesc, 1);
+ actualEntryPoints.add(entryPointInfo);
}
}
else
{
- int vertexEntryPoint = 0;
- int fragmentEntryPoint = 1;
- if( !gOptions.dontAddDefaultEntryPoints )
- {
- vertexEntryPoint = spAddEntryPoint(slangRequest, vertexTranslationUnit, vertexEntryPointName, SLANG_STAGE_VERTEX);
- fragmentEntryPoint = spAddEntryPoint(slangRequest, fragmentTranslationUnit, fragmentEntryPointName, SLANG_STAGE_FRAGMENT);
-
- setEntryPointSpecializationArgs(vertexEntryPoint);
- setEntryPointSpecializationArgs(fragmentEntryPoint);
- }
-
- const SlangResult res = spCompile(slangRequest);
- if (auto diagnostics = spGetDiagnosticOutput(slangRequest))
- {
- // TODO(tfoley): re-enable when I get a logging solution in place
-// OutputDebugStringA(diagnostics);
- fprintf(stderr, "%s", diagnostics);
- }
-
- SLANG_RETURN_ON_FAIL(res);
-
- {
- size_t vertexCodeSize = 0;
- char const* vertexCode = (char const*) spGetEntryPointCode(slangRequest, vertexEntryPoint, &vertexCodeSize);
+ actualEntryPoints = request.entryPoints;
+ }
- size_t fragmentCodeSize = 0;
- char const* fragmentCode = (char const*) spGetEntryPointCode(slangRequest, fragmentEntryPoint, &fragmentCodeSize);
+ List<ShaderProgram::KernelDesc> kernelDescs;
- static const int kDescCount = 2;
+ Index actualEntryPointCount = actualEntryPoints.getCount();
+ for(Index ee = 0; ee < actualEntryPointCount; ++ee)
+ {
+ auto& actualEntryPoint = actualEntryPoints[ee];
- ShaderProgram::KernelDesc kernelDescs[kDescCount];
+ size_t codeSize = 0;
+ char const* code = (char const*) spGetEntryPointCode(slangRequest, int(ee), &codeSize);
- kernelDescs[0].stage = StageType::Vertex;
- kernelDescs[0].codeBegin = vertexCode;
- kernelDescs[0].codeEnd = vertexCode + vertexCodeSize;
+ auto gfxStage = _translateStage(actualEntryPoint.slangStage);
- kernelDescs[1].stage = StageType::Fragment;
- kernelDescs[1].codeBegin = fragmentCode;
- kernelDescs[1].codeEnd = fragmentCode + fragmentCodeSize;
+ ShaderProgram::KernelDesc kernelDesc;
+ kernelDesc.stage = gfxStage;
+ kernelDesc.codeBegin = code;
+ kernelDesc.codeEnd = code + codeSize;
+ kernelDesc.entryPointName = actualEntryPoint.name;
- out.set(PipelineType::Graphics, kernelDescs, kDescCount);
- }
+ kernelDescs.add(kernelDesc);
}
+ out.set(input.pipelineType, kernelDescs.getBuffer(), kernelDescs.getCount());
+
return SLANG_OK;
}
@@ -260,8 +252,12 @@ static const char computeEntryPointName[] = "computeMain";
return SLANG_OK;
}
-/* static */SlangResult ShaderCompilerUtil::compileWithLayout(SlangSession* session, const String& sourcePath, const Slang::List<Slang::CommandLine::Arg>& compileArgs, Options::ShaderProgramType shaderType, const ShaderCompilerUtil::Input& input, OutputAndLayout& output)
+/* static */SlangResult ShaderCompilerUtil::compileWithLayout(SlangSession* session, const Options& options, const ShaderCompilerUtil::Input& input, OutputAndLayout& output)
{
+ String sourcePath = options.sourcePath;
+ auto& compileArgs = options.compileArgs;
+ auto shaderType = options.shaderType;
+
List<char> sourceText;
SLANG_RETURN_ON_FAIL(readSource(sourcePath, sourceText));
@@ -294,6 +290,7 @@ static const char computeEntryPointName[] = "computeMain";
break;
case Options::ShaderProgramType::Compute:
+ case Options::ShaderProgramType::RayTracing:
layout.numRenderTargets = 0;
break;
}
@@ -317,17 +314,47 @@ static const char computeEntryPointName[] = "computeMain";
compileRequest.compileArgs = compileArgs;
compileRequest.source = sourceInfo;
- if (shaderType == Options::ShaderProgramType::Graphics || shaderType == Options::ShaderProgramType::GraphicsCompute)
- {
- compileRequest.vertexShader.source = sourceInfo;
- compileRequest.vertexShader.name = vertexEntryPointName;
- compileRequest.fragmentShader.source = sourceInfo;
- compileRequest.fragmentShader.name = fragmentEntryPointName;
- }
- else
+
+ // Now we will add the "default" entry point names/stages that
+ // are appropriate to the pipeline type being targetted, *unless*
+ // the options specify that we should leave out the default
+ // entry points and instead rely on the Slang compiler's built-in
+ // mechanisms for discovering entry points (e.g., `[shader(...)]`
+ // attributes).
+ //
+ if( !options.dontAddDefaultEntryPoints )
{
- compileRequest.computeShader.source = sourceInfo;
- compileRequest.computeShader.name = computeEntryPointName;
+ if (shaderType == Options::ShaderProgramType::Graphics || shaderType == Options::ShaderProgramType::GraphicsCompute)
+ {
+ ShaderCompileRequest::EntryPoint vertexEntryPoint;
+ vertexEntryPoint.name = vertexEntryPointName;
+ vertexEntryPoint.slangStage = SLANG_STAGE_VERTEX;
+ compileRequest.entryPoints.add(vertexEntryPoint);
+
+ ShaderCompileRequest::EntryPoint fragmentEntryPoint;
+ fragmentEntryPoint.name = fragmentEntryPointName;
+ fragmentEntryPoint.slangStage = SLANG_STAGE_FRAGMENT;
+ compileRequest.entryPoints.add(fragmentEntryPoint);
+ }
+ else if( shaderType == Options::ShaderProgramType::RayTracing )
+ {
+ // Note: Current GPU ray tracing pipelines allow for an
+ // almost arbitrary mix of entry points for different stages
+ // to be used together (e.g., a single "program" might
+ // have multiple any-hit shaders, multiple miss shaders, etc.)
+ //
+ // Rather than try to define a fixed set of entry point
+ // names and stages that the testing will support, we will
+ // instead rely on `[shader(...)]` annotations to tell us
+ // what entry points are present in the input code.
+ }
+ else
+ {
+ ShaderCompileRequest::EntryPoint computeEntryPoint;
+ computeEntryPoint.name = computeEntryPointName;
+ computeEntryPoint.slangStage = SLANG_STAGE_COMPUTE;
+ compileRequest.entryPoints.add(computeEntryPoint);
+ }
}
compileRequest.globalSpecializationArgs = layout.globalSpecializationArgs;
compileRequest.entryPointSpecializationArgs = layout.entryPointSpecializationArgs;
diff --git a/tools/render-test/slang-support.h b/tools/render-test/slang-support.h
index 97b85ff8f..99509914e 100644
--- a/tools/render-test/slang-support.h
+++ b/tools/render-test/slang-support.h
@@ -17,6 +17,7 @@ struct ShaderCompilerUtil
SlangCompileTarget target;
SlangSourceLanguage sourceLanguage;
SlangPassThrough passThrough;
+ PipelineType pipelineType = PipelineType::Unknown;
char const* profile;
const char** args;
int argCount;
@@ -24,7 +25,7 @@ struct ShaderCompilerUtil
struct Output
{
- void set(PipelineType pipelineType, const ShaderProgram::KernelDesc* inKernelDescs, int kernelDescCount)
+ void set(PipelineType pipelineType, const ShaderProgram::KernelDesc* inKernelDescs, Slang::Index kernelDescCount)
{
kernelDescs.clear();
kernelDescs.addRange(inKernelDescs, kernelDescCount);
@@ -82,7 +83,7 @@ struct ShaderCompilerUtil
Slang::String sourcePath;
};
- static SlangResult compileWithLayout(SlangSession* session, const Slang::String& sourcePath, const Slang::List<Slang::CommandLine::Arg>& compileArgs, Options::ShaderProgramType shaderType, const ShaderCompilerUtil::Input& input, OutputAndLayout& output);
+ static SlangResult compileWithLayout(SlangSession* session, const Options& options, const ShaderCompilerUtil::Input& input, OutputAndLayout& output);
static SlangResult readSource(const Slang::String& inSourcePath, List<char>& outSourceText);