diff options
| author | Theresa Foley <tfoleyNV@users.noreply.github.com> | 2021-11-10 10:48:44 -0800 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2021-11-10 10:48:44 -0800 |
| commit | 95e82acc0b32c81a9c6ac39708d18a423d8c7b1e (patch) | |
| tree | 068ad53eafd6d4b15734af0ce51e219b1a51d1e2 /source | |
| parent | 4d4cd569ad7fcc88693c18f848603f18894e24be (diff) | |
Fix a bug with CUDA entry-point params (#2007)
Recent work that added support for translating DXR-style ray tracing shaders to work with OptiX seems to have accidentally applied its transformations even when compute shaders are translated for CUDA. As a result, compute entry points with `uniform` parameters at entry-point scope would be miscompiled to use OptiX calls that are not available for non-OptiX compiles.
This change fixes the relevant pass so that it correctly opts-out on compute entry points, and also unifies some pieces of code that were being shared between a few different IR passes but that had gotten copy-pasted for the OptiX case.
The fix has been confirmed by running relevant CUDA tests locally, but CUDA is still disabled in the default CI builds, so this change is not yet actively being tested to avoid further regression.
Diffstat (limited to 'source')
| -rw-r--r-- | source/slang/slang-ir-entry-point-pass.cpp | 52 | ||||
| -rw-r--r-- | source/slang/slang-ir-entry-point-pass.h | 39 | ||||
| -rw-r--r-- | source/slang/slang-ir-entry-point-uniforms.cpp | 85 | ||||
| -rw-r--r-- | source/slang/slang-ir-optix-entry-point-uniforms.cpp | 92 |
4 files changed, 122 insertions, 146 deletions
diff --git a/source/slang/slang-ir-entry-point-pass.cpp b/source/slang/slang-ir-entry-point-pass.cpp new file mode 100644 index 000000000..e5bf55a44 --- /dev/null +++ b/source/slang/slang-ir-entry-point-pass.cpp @@ -0,0 +1,52 @@ +// slang-ir-entry-point-pass.cpp +#include "slang-ir-entry-point-pass.h" + +namespace Slang +{ + +void PerEntryPointPass::processModule(IRModule* module) +{ + m_module = module; + + SharedIRBuilder sharedBuilder(module); + m_sharedBuilder = &sharedBuilder; + + // Note that we are only looking at true global-scope + // functions and not functions nested inside of + // IR generics. When using generic entry points, this + // pass should be run after the entry point(s) have + // been specialized to their generic type parameters. + + for (auto inst : module->getGlobalInsts()) + { + // We are only interested in entry points. + // + // Every entry point must be a function. + // + auto func = as<IRFunc>(inst); + if (!func) + continue; + + // Entry points will always have the `[entryPoint]` + // decoration to differentiate them from ordinary + // functions. + // + auto entryPointDecoration = func->findDecoration<IREntryPointDecoration>(); + if (!entryPointDecoration) + continue; + + // If we find a candidate entry point, then we + // will process it. + // + processEntryPoint(func, entryPointDecoration); + } +} + +void PerEntryPointPass::processEntryPoint(IRFunc* entryPointFunc, IREntryPointDecoration* entryPointDecoration) +{ + m_entryPoint.func = entryPointFunc; + m_entryPoint.decoration = entryPointDecoration; + processEntryPointImpl(m_entryPoint); +} + +} diff --git a/source/slang/slang-ir-entry-point-pass.h b/source/slang/slang-ir-entry-point-pass.h new file mode 100644 index 000000000..f8a1c9888 --- /dev/null +++ b/source/slang/slang-ir-entry-point-pass.h @@ -0,0 +1,39 @@ +// ir-entry-point-pass.h +#pragma once + +#include "slang-ir.h" +#include "slang-ir-insts.h" + +namespace Slang +{ + +struct PerEntryPointPass +{ +public: + // We will process a whole module by visiting all + // its global functions, looking for entry points. + // + void processModule(IRModule* module); + + struct EntryPointInfo + { + IRFunc* func = nullptr; + IREntryPointDecoration* decoration = nullptr; + }; + +protected: + void processEntryPoint(IRFunc* entryPointFunc, IREntryPointDecoration* entryPointDecoration); + + virtual void processEntryPointImpl(EntryPointInfo const& info) = 0; + + // We'll hang on to the module we are processing, + // so that we can refer to it when setting up `IRBuilder`s. + // + IRModule* m_module = nullptr; + + SharedIRBuilder* m_sharedBuilder = nullptr; + + EntryPointInfo m_entryPoint; +}; + +} diff --git a/source/slang/slang-ir-entry-point-uniforms.cpp b/source/slang/slang-ir-entry-point-uniforms.cpp index 47e361d07..d98f39515 100644 --- a/source/slang/slang-ir-entry-point-uniforms.cpp +++ b/source/slang/slang-ir-entry-point-uniforms.cpp @@ -3,6 +3,7 @@ #include "slang-ir.h" #include "slang-ir-insts.h" +#include "slang-ir-entry-point-pass.h" #include "slang-mangle.h" @@ -163,74 +164,6 @@ bool isVaryingParameter(IRVarLayout* varLayout) return isVaryingParameter(varLayout->getTypeLayout()); } -// Our two passes have a fair amount in common in terms of -// how they traverse the IR, so we will factor out the -// shared logic into a base type. - -struct PerEntryPointPass -{ - // We'll hang on to the module we are processing, - // so that we can refer to it when setting up `IRBuilder`s. - // - IRModule* module; - - - SharedIRBuilder* m_sharedBuilder = nullptr; - - // We will process a whole module by visiting all - // its global functions, looking for entry points. - // - void processModule() - { - SharedIRBuilder sharedBuilder(module); - m_sharedBuilder = &sharedBuilder; - - // Note that we are only looking at true global-scope - // functions and not functions nested inside of - // IR generics. When using generic entry points, this - // pass should be run after the entry point(s) have - // been specialized to their generic type parameters. - - for( auto inst : module->getGlobalInsts() ) - { - // We are only interested in entry points. - // - // Every entry point must be a function. - // - auto func = as<IRFunc>(inst); - if( !func ) - continue; - - // Entry points will always have the `[entryPoint]` - // decoration to differentiate them from ordinary - // functions. - // - // TODO: we could make `IREntryPoint` a subclass of - // `IRFunc` if desired, to avoid having to attach - // an explicit decoration to identify them. - // - if( !func->findDecorationImpl(kIROp_EntryPointDecoration) ) - continue; - - // If we find a candidate entry point, then we - // will process it. - // - processEntryPoint(func); - } - } - - void processEntryPoint(IRFunc* entryPointFunc) - { - m_entryPointFunc = entryPointFunc; - processEntryPointImpl(entryPointFunc); - } - - IRFunc* m_entryPointFunc = nullptr; - - virtual void processEntryPointImpl(IRFunc* entryPointFunc) = 0; -}; - - struct CollectEntryPointUniformParams : PerEntryPointPass { CollectEntryPointUniformParamsOptions m_options; @@ -248,8 +181,10 @@ struct CollectEntryPointUniformParams : PerEntryPointPass IRVarLayout* entryPointParamsLayout = nullptr; bool needConstantBuffer = false; - void processEntryPointImpl(IRFunc* entryPointFunc) SLANG_OVERRIDE + void processEntryPointImpl(EntryPointInfo const& info) SLANG_OVERRIDE { + auto entryPointFunc = info.func; + // This pass object may be used across multiple entry points, // so we need to make sure to reset state that could have been // left over from a previous entry point. @@ -449,7 +384,7 @@ struct CollectEntryPointUniformParams : PerEntryPointPass // First we create the structure to hold the parameters. // - builder.setInsertBefore(m_entryPointFunc); + builder.setInsertBefore(m_entryPoint.func); paramStructType = builder.createStructType(); builder.addNameHintDecoration(paramStructType, UnownedTerminatedStringSlice("EntryPointParams")); @@ -484,8 +419,10 @@ struct CollectEntryPointUniformParams : PerEntryPointPass struct MoveEntryPointUniformParametersToGlobalScope : PerEntryPointPass { - void processEntryPointImpl(IRFunc* entryPointFunc) SLANG_OVERRIDE + void processEntryPointImpl(EntryPointInfo const& info) SLANG_OVERRIDE { + auto entryPointFunc = info.func; + // We will set up an IR builder so that we are ready to generate code. // IRBuilder builderStorage(m_sharedBuilder); @@ -561,17 +498,15 @@ void collectEntryPointUniformParams( CollectEntryPointUniformParamsOptions const& options) { CollectEntryPointUniformParams context; - context.module = module; context.m_options = options; - context.processModule(); + context.processModule(module); } void moveEntryPointUniformParamsToGlobalScope( IRModule* module) { MoveEntryPointUniformParametersToGlobalScope context; - context.module = module; - context.processModule(); + context.processModule(module); } } diff --git a/source/slang/slang-ir-optix-entry-point-uniforms.cpp b/source/slang/slang-ir-optix-entry-point-uniforms.cpp index 0ec8e3adb..91aa30f47 100644 --- a/source/slang/slang-ir-optix-entry-point-uniforms.cpp +++ b/source/slang/slang-ir-optix-entry-point-uniforms.cpp @@ -6,79 +6,13 @@ #include "slang-ir-optix-entry-point-uniforms.h" #include "slang-ir.h" +#include "slang-ir-entry-point-pass.h" #include "slang-ir-insts.h" #include "slang-ir-restructure.h" namespace Slang { -struct PerEntryPointPass -{ - // We'll hang on to the module we are processing, - // so that we can refer to it when setting up `IRBuilder`s. - IRModule* module; - - SharedIRBuilder* m_sharedBuilder = nullptr; - - // We will process a whole module by visiting all - // its global functions, looking for entry points. - void processModule() - { - SharedIRBuilder sharedBuilder(module); - m_sharedBuilder = &sharedBuilder; - - // Note that we are only looking at true global-scope - // functions and not functions nested inside of - // IR generics. When using generic entry points, this - // pass should be run after the entry point(s) have - // been specialized to their generic type parameters. - - for( auto inst : module->getGlobalInsts() ) - { - // We are only interested in entry points. - // - // Every entry point must be a function. - // - auto func = as<IRFunc>(inst); - if( !func ) - continue; - - // Entry points will always have the `[entryPoint]` - // decoration to differentiate them from ordinary - // functions. - // - auto entryPointDecor = func->findDecoration<IREntryPointDecoration>(); - if(!entryPointDecor) - continue; - - // Check the IREntryPointDecoration for raytracing entry points - // (as SBT records are only relevant to raytracing) - if (!( - entryPointDecor->getProfile().getStage() == Stage::RayGeneration || - entryPointDecor->getProfile().getStage() == Stage::Intersection || - entryPointDecor->getProfile().getStage() == Stage::AnyHit || - entryPointDecor->getProfile().getStage() == Stage::ClosestHit || - entryPointDecor->getProfile().getStage() == Stage::Miss || - entryPointDecor->getProfile().getStage() == Stage::Callable - )) continue; - - // If we find a candidate entry point, then we - // will process it. - processEntryPoint(func); - } - } - - void processEntryPoint(IRFunc* entryPointFunc) - { - m_entryPointFunc = entryPointFunc; - processEntryPointImpl(entryPointFunc); - } - - IRFunc* m_entryPointFunc = nullptr; - - virtual void processEntryPointImpl(IRFunc* entryPointFunc) = 0; -}; - struct CollectOptixEntryPointUniformParams : PerEntryPointPass { // *If* the entry point has any uniform parameter then we want to create a @@ -91,8 +25,11 @@ struct CollectOptixEntryPointUniformParams : PerEntryPointPass { IRParam* collectedParam = nullptr; IRVarLayout* entryPointParamsLayout = nullptr; - void processEntryPointImpl(IRFunc* entryPointFunc) SLANG_OVERRIDE + void processEntryPointImpl(EntryPointInfo const& info) SLANG_OVERRIDE { + auto entryPointFunc = info.func; + auto entryPointDecoration = info.decoration; + // This pass object may be used across multiple entry points, // so we need to make sure to reset state that could have been // left over from a previous entry point. @@ -100,6 +37,20 @@ struct CollectOptixEntryPointUniformParams : PerEntryPointPass { paramStructType = nullptr; collectedParam = nullptr; + // We only want to process entry points that are used in OptiX/ray-tracing + // stages, and not ordinary compute entry points (the entry-point `uniform` + // parameters of an ordinary compute entry point will translate to CUDA + // launch parameters). + // + switch( entryPointDecoration->getProfile().getStage() ) + { + default: + break; + + case Stage::Compute: + return; + } + // We expect all entry points to have explicit layout information attached. // // We will assert that we have the information we need, but try to be @@ -287,7 +238,7 @@ struct CollectOptixEntryPointUniformParams : PerEntryPointPass { // First we create the structure to hold the parameters. // - builder.setInsertBefore(m_entryPointFunc); + builder.setInsertBefore(m_entryPoint.func); paramStructType = builder.createStructType(); builder.addNameHintDecoration(paramStructType, UnownedTerminatedStringSlice("ShaderRecordParams")); @@ -318,8 +269,7 @@ void collectOptiXEntryPointUniformParams( // Insts of the module. For any ray tracing entry points, collect all uniform parameters into one // common struct, and replace parameter usage with SBT record accesses. CollectOptixEntryPointUniformParams context; - context.module = module; - context.processModule(); + context.processModule(module); } } |
