diff options
Diffstat (limited to 'source/slang/slang-emit-spirv.cpp')
| -rw-r--r-- | source/slang/slang-emit-spirv.cpp | 44 |
1 files changed, 44 insertions, 0 deletions
diff --git a/source/slang/slang-emit-spirv.cpp b/source/slang/slang-emit-spirv.cpp index 06e5f0766..106248ef8 100644 --- a/source/slang/slang-emit-spirv.cpp +++ b/source/slang/slang-emit-spirv.cpp @@ -4,6 +4,7 @@ #include "slang-emit-base.h" #include "slang-ir-util.h" +#include "slang-ir-call-graph.h" #include "slang-ir.h" #include "slang-ir-insts.h" #include "slang-ir-layout.h" @@ -437,6 +438,7 @@ constexpr bool isPlural<IRUse*> = true; template<typename T> constexpr bool isSingular = !isPlural<T>; + // Now that we've defined the intermediate data structures we will // use to represent SPIR-V code during emission, we will move on // to defining the main context type that will drive SPIR-V @@ -1278,6 +1280,11 @@ struct SPIRVEmitContext return result; } + bool hasExtensionDeclaration(const UnownedStringSlice& name) + { + return m_extensionInsts.containsKey(name); + } + struct SpvTypeInstKey { List<SpvWord> words; @@ -2732,6 +2739,43 @@ struct SPIRVEmitContext result = inner; break; } + case kIROp_RequireComputeDerivative: + { + auto parentFunc = getParentFunc(inst); + + HashSet<IRFunc*>* entryPointsUsingInst = getReferencingEntryPoints(m_referencingEntryPoints, parentFunc); + for (IRFunc* entryPoint : *entryPointsUsingInst) + { + bool isQuad = true; + IREntryPointDecoration* entryPointDecor = nullptr; + for(auto dec : entryPoint->getDecorations()) + { + if(auto maybeEntryPointDecor = as<IREntryPointDecoration>(dec)) + entryPointDecor = maybeEntryPointDecor; + if(as<IRDerivativeGroupLinearDecoration>(dec)) + isQuad = false; + } + if (!entryPointDecor || entryPointDecor->getProfile().getStage() != Stage::Compute) + continue; + + ensureExtensionDeclaration(UnownedStringSlice("SPV_NV_compute_shader_derivatives")); + auto numThreadsDecor = entryPointDecor->findDecoration<IRNumThreadsDecoration>(); + if (isQuad) + { + verifyComputeDerivativeGroupModifiers(this->m_sink, inst->sourceLoc, true, false, numThreadsDecor); + emitOpExecutionMode(getSection(SpvLogicalSectionID::ExecutionModes), nullptr, entryPoint, SpvExecutionModeDerivativeGroupQuadsNV); + emitOpCapability(getSection(SpvLogicalSectionID::Capabilities), nullptr, SpvCapabilityComputeDerivativeGroupQuadsNV); + } + else + { + verifyComputeDerivativeGroupModifiers(this->m_sink, inst->sourceLoc, false, true, numThreadsDecor); + emitOpExecutionMode(getSection(SpvLogicalSectionID::ExecutionModes), nullptr, entryPoint, SpvExecutionModeDerivativeGroupLinearNV); + emitOpCapability(getSection(SpvLogicalSectionID::Capabilities), nullptr, SpvCapabilityComputeDerivativeGroupLinearNV); + } + } + + break; + } case kIROp_Return: if (as<IRReturn>(inst)->getVal()->getOp() == kIROp_VoidLit) result = emitOpReturn(parent, inst); |
