diff options
| author | Yong He <yonghe@outlook.com> | 2025-02-06 22:02:43 -0800 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2025-02-06 22:02:43 -0800 |
| commit | bae87afb20f95f9f27c64c4955bbc4464c576509 (patch) | |
| tree | 44d079bd76002d69be20efdbd03ac6ff62ef8caf | |
| parent | 075b10e69055acc6536d74c1cb3399e0fe75338d (diff) | |
Support stage_switch. (#6311)
* Support stage_switch.
* Update proposal status.
* Fix gl_InstanceID.
* Fix.
| -rw-r--r-- | docs/proposals/020-stage-switch.md | 93 | ||||
| -rw-r--r-- | source/slang/glsl.meta.slang | 29 | ||||
| -rw-r--r-- | source/slang/slang-ast-stmt.h | 5 | ||||
| -rw-r--r-- | source/slang/slang-capability.cpp | 60 | ||||
| -rw-r--r-- | source/slang/slang-capability.h | 5 | ||||
| -rw-r--r-- | source/slang/slang-check-decl.cpp | 26 | ||||
| -rw-r--r-- | source/slang/slang-check-stmt.cpp | 44 | ||||
| -rw-r--r-- | source/slang/slang-diagnostic-defs.h | 7 | ||||
| -rw-r--r-- | source/slang/slang-emit.cpp | 9 | ||||
| -rw-r--r-- | source/slang/slang-ir-autodiff-fwd.cpp | 1 | ||||
| -rw-r--r-- | source/slang/slang-ir-inst-defs.h | 3 | ||||
| -rw-r--r-- | source/slang/slang-ir-insts.h | 4 | ||||
| -rw-r--r-- | source/slang/slang-ir-specialize-stage-switch.cpp | 198 | ||||
| -rw-r--r-- | source/slang/slang-ir-specialize-stage-switch.h | 14 | ||||
| -rw-r--r-- | source/slang/slang-ir.cpp | 6 | ||||
| -rw-r--r-- | source/slang/slang-lower-to-ir.cpp | 96 | ||||
| -rw-r--r-- | source/slang/slang-parser.cpp | 17 | ||||
| -rw-r--r-- | tests/language-feature/stage-switch.slang | 38 | ||||
| -rw-r--r-- | tests/spirv/primitive-id-2.slang | 13 |
19 files changed, 640 insertions, 28 deletions
diff --git a/docs/proposals/020-stage-switch.md b/docs/proposals/020-stage-switch.md new file mode 100644 index 000000000..d2cd94f75 --- /dev/null +++ b/docs/proposals/020-stage-switch.md @@ -0,0 +1,93 @@ +# SP#020: `stage_switch` + +## Status + +Author: Yong He + +Status: In Experiment + +Implementation: [PR 6311](https://github.com/shader-slang/slang/pull/6311) + +Reviewed by: Jay Kwak + +## Background + +We need to provide a mechanism for authoring stage-specific code that works with the capability system. For example, the user may want to define a function `ddx_or_zero(v)` that returns `ddx(v)` when called from a fragment shader, and return `0` when called from other shader stages. Without a mechanism for writing stage-specific code, there is no way to define a valid function that can be used from both a fragment shader and a compute shader in a single compilation. + +The user can workaround this problem with the preprocessor: + +``` +float ddx_or_zero(float v) +{ +#ifdef FRAGMENT_SHADER + return ddx(v); +#else + return 0.0; +#endif +} + +[shader("compute")] +[numthread(1,1,1)] +void computeMain() { ddx_or_zero(...); } + +[shader("fragment")] +float4 fragMain() { ddx_or_zero(...); } +``` + +However, this require the application to compile the source file twice with different pre-defined macros. It is impossible to use a single compilation to generate one SPIRV module that contains both the entrypoints. + +## Proposed Approach + +We propose to add a new construct, `__stage_switch` that works like `__target_switch` but switches on stages. With `__stage_switch` the above code can be written as: + +``` +float ddx_or_zero(float v) +{ + __stage_switch + { + case fragment: + return ddx(v); + default: + return 0.0; + } +} + +[shader("compute")] +[numthread(1,1,1)] +void computeMain() +{ + ddx_or_zero(...); // returns 0.0 +} + +[shader("fragment")] +float4 fragMain() +{ + ddx_or_zero(...); // returns ddx(...) +} +``` + +With `__stage_switch`, the two entrypoints can be compiled into a single SPIRV in one go, without requiring setting up any preprocessor macros. + +Unlike `switch`, there is no fallthrough between cases in a `__stage_switch`. All cases will implicitly end with a `break` if it is not written by the user. However, one special type of fallthrough is supported, that is when multiple `cases` are defined next to each other with nothing else in between, for example: + +``` +__stage_switch +{ +case fragment: +case vertex: +case geometry: + return 1.0; +case anyhit: + return 2.0; +default: + return 0.0; +} +``` + +## Alternatives Considered + +We considered to reuse the existing `__target_switch` and extend it to allow switching between different stages. However this turns out to be difficult to implement, if ordinary capabilities are mixed together with stages, because specialization to stages needs to happen at a much later time in the compilation pipeline compared to specialization to capabilities. Using a separate switch allows us to easily tell apart the code that requires specialization at different phases of compilation, and also allow us to provide cleaner error messages. + +## Conclusion + +`__stage_switch` adds the missing functionality from `__target_switch` that allows the user to write stage-specific code that gets specialized for each unique entrypoint stage. This works together with the capability system to provide early type-system checks to ensure the correctness of user code, without requiring use of preprocessor to protect calls to stage specific functions. diff --git a/source/slang/glsl.meta.slang b/source/slang/glsl.meta.slang index ef3bfd683..0bad3c681 100644 --- a/source/slang/glsl.meta.slang +++ b/source/slang/glsl.meta.slang @@ -4824,24 +4824,41 @@ public property uint3 gl_LaunchSizeEXT } } +internal in int __gl_PrimitiveID : SV_PrimitiveID; public property int gl_PrimitiveID { - [require(cuda_glsl_hlsl_spirv, raytracing_anyhit_closesthit_intersection)] + [require(cuda_glsl_hlsl_spirv)] get { - setupExtForRayTracingBuiltIn(); - return PrimitiveIndex(); + __stage_switch + { + case anyhit: + case closesthit: + case intersection: + setupExtForRayTracingBuiltIn(); + return PrimitiveIndex(); + default: + return __gl_PrimitiveID; + } } } public property int gl_InstanceID { - [require(cuda_glsl_hlsl_spirv, raytracing_anyhit_closesthit_intersection)] + [require(cuda_glsl_hlsl_spirv)] get { - setupExtForRayTracingBuiltIn(); - return InstanceIndex(); + __stage_switch + { + case anyhit: + case closesthit: + case intersection: + setupExtForRayTracingBuiltIn(); + return InstanceIndex(); + default: + return gl_InstanceIndex; + } } } diff --git a/source/slang/slang-ast-stmt.h b/source/slang/slang-ast-stmt.h index eba7027c2..09c491287 100644 --- a/source/slang/slang-ast-stmt.h +++ b/source/slang/slang-ast-stmt.h @@ -113,6 +113,11 @@ class TargetSwitchStmt : public Stmt List<TargetCaseStmt*> targetCases; }; +class StageSwitchStmt : public TargetSwitchStmt +{ + SLANG_AST_CLASS(StageSwitchStmt) +}; + class IntrinsicAsmStmt : public Stmt { SLANG_AST_CLASS(IntrinsicAsmStmt) diff --git a/source/slang/slang-capability.cpp b/source/slang/slang-capability.cpp index 7c4e34d32..1eb0cae31 100644 --- a/source/slang/slang-capability.cpp +++ b/source/slang/slang-capability.cpp @@ -100,6 +100,45 @@ bool isDirectChildOfAbstractAtom(CapabilityAtom name) return _getInfo(name).abstractBase != CapabilityName::Invalid; } +bool isStageAtom(CapabilityName name, CapabilityName& outCanonicalStage) +{ + auto& info = _getInfo(name); + if (info.abstractBase == CapabilityName::stage) + { + outCanonicalStage = name; + return true; + } + switch (name) + { + case CapabilityName::anyhit: + outCanonicalStage = CapabilityName::_anyhit; + return true; + case CapabilityName::closesthit: + outCanonicalStage = CapabilityName::_closesthit; + return true; + case CapabilityName::miss: + outCanonicalStage = CapabilityName::_miss; + return true; + case CapabilityName::intersection: + outCanonicalStage = CapabilityName::_intersection; + return true; + case CapabilityName::raygen: + outCanonicalStage = CapabilityName::_raygen; + return true; + case CapabilityName::callable: + outCanonicalStage = CapabilityName::_callable; + return true; + case CapabilityName::mesh: + outCanonicalStage = CapabilityName::_mesh; + return true; + case CapabilityName::amplification: + outCanonicalStage = CapabilityName::_amplification; + return true; + default: + return false; + } +} + bool isTargetVersionAtom(CapabilityAtom name) { if (name >= CapabilityAtom::_spirv_1_0 && name <= getLatestSpirvAtom()) @@ -620,7 +659,26 @@ CapabilitySet CapabilitySet::getTargetsThisHasButOtherDoesNot(const CapabilitySe if (other.m_targetSets.tryGetValue(i.first)) continue; - newSet.m_targetSets[i.first] = this->m_targetSets[i.first]; + newSet.m_targetSets[i.first] = i.second; + } + return newSet; +} + +CapabilitySet CapabilitySet::getStagesThisHasButOtherDoesNot(const CapabilitySet& other) +{ + CapabilitySet newSet{}; + for (auto& i : this->m_targetSets) + { + if (auto otherTarget = other.m_targetSets.tryGetValue(i.first)) + { + auto& thisTarget = m_targetSets[i.first]; + for (auto& stage : thisTarget.shaderStageSets) + { + if (otherTarget->shaderStageSets.containsKey(stage.first)) + continue; + newSet.m_targetSets[i.first].shaderStageSets[stage.first] = stage.second; + } + } } return newSet; } diff --git a/source/slang/slang-capability.h b/source/slang/slang-capability.h index 631cd307a..7c429d825 100644 --- a/source/slang/slang-capability.h +++ b/source/slang/slang-capability.h @@ -169,6 +169,9 @@ public: /// Return a capability set of 'target' atoms 'this' has, but 'other' does not. CapabilitySet getTargetsThisHasButOtherDoesNot(const CapabilitySet& other); + /// Return a capability set of 'stage' atoms 'this' has, but 'other' does not. + CapabilitySet getStagesThisHasButOtherDoesNot(const CapabilitySet& other); + /// Are these two capability sets equal? bool operator==(CapabilitySet const& that) const; @@ -359,7 +362,7 @@ void getCapabilityNames(List<UnownedStringSlice>& ioNames); UnownedStringSlice capabilityNameToString(CapabilityName name); bool isDirectChildOfAbstractAtom(CapabilityAtom name); - +bool isStageAtom(CapabilityName name, CapabilityName& outCanonicalStage); /// Return true if `name` represents an atom for a target version, e.g. spirv_1_5. bool isTargetVersionAtom(CapabilityAtom name); diff --git a/source/slang/slang-check-decl.cpp b/source/slang/slang-check-decl.cpp index ab3335bfb..39ab41421 100644 --- a/source/slang/slang-check-decl.cpp +++ b/source/slang/slang-check-decl.cpp @@ -12735,14 +12735,26 @@ struct CapabilityDeclReferenceVisitor std::swap(stmt->targetCases[i], stmt->targetCases[i + 1]); continue; } - - if (!maybeRequireCapability) - targetCap = (CapabilitySet(CapabilityName::any_target) - .getTargetsThisHasButOtherDoesNot(set)); + if (as<StageSwitchStmt>(stmt)) + { + if (!maybeRequireCapability) + targetCap = (CapabilitySet(CapabilityName::any_target) + .getStagesThisHasButOtherDoesNot(set)); + else + targetCap = + (maybeRequireCapability->capabilitySet.getStagesThisHasButOtherDoesNot( + set)); + } else - targetCap = - (maybeRequireCapability->capabilitySet.getTargetsThisHasButOtherDoesNot( - set)); + { + if (!maybeRequireCapability) + targetCap = (CapabilitySet(CapabilityName::any_target) + .getTargetsThisHasButOtherDoesNot(set)); + else + targetCap = + (maybeRequireCapability->capabilitySet.getTargetsThisHasButOtherDoesNot( + set)); + } } else { diff --git a/source/slang/slang-check-stmt.cpp b/source/slang/slang-check-stmt.cpp index 151c9324a..db6f00d23 100644 --- a/source/slang/slang-check-stmt.cpp +++ b/source/slang/slang-check-stmt.cpp @@ -367,6 +367,40 @@ void SemanticsStmtVisitor::visitTargetSwitchStmt(TargetSwitchStmt* stmt) HashSet<Stmt*> checkedStmt; for (auto caseStmt : stmt->targetCases) { + CapabilitySet set((CapabilityName)caseStmt->capability); + + CapabilityName canonicalStage = CapabilityName::Invalid; + bool isStage = isStageAtom((CapabilityName)caseStmt->capability, canonicalStage); + if (as<StageSwitchStmt>(stmt)) + { + if (!isStage && caseStmt->capability != 0) + { + getSink()->diagnose( + caseStmt->capabilityToken.loc, + Diagnostics::unknownStageName, + caseStmt->capabilityToken); + } + caseStmt->capability = (int)canonicalStage; + } + else + { + if (isStage) + { + getSink()->diagnose( + caseStmt->capabilityToken.loc, + Diagnostics::targetSwitchCaseCannotBeAStage); + } + else if ( + caseStmt->capabilityToken.getContentLength() != 0 && + (set.getCapabilityTargetSets().getCount() != 1 || set.isInvalid() || set.isEmpty())) + { + getSink()->diagnose( + caseStmt->capabilityToken.loc, + Diagnostics::invalidTargetSwitchCase, + capabilityNameToString((CapabilityName)caseStmt->capability)); + } + } + if (checkedStmt.contains(caseStmt->body)) continue; subContext.checkStmt(caseStmt); @@ -377,7 +411,6 @@ void SemanticsStmtVisitor::visitTargetSwitchStmt(TargetSwitchStmt* stmt) void SemanticsStmtVisitor::visitTargetCaseStmt(TargetCaseStmt* stmt) { auto switchStmt = FindOuterStmt<TargetSwitchStmt>(); - CapabilitySet set((CapabilityName)stmt->capability); if (getShared()->isInLanguageServer() && getShared()->getSession()->getCompletionRequestTokenName() == stmt->capabilityToken.getName()) @@ -385,15 +418,6 @@ void SemanticsStmtVisitor::visitTargetCaseStmt(TargetCaseStmt* stmt) getShared()->getLinkage()->contentAssistInfo.completionSuggestions.scopeKind = CompletionSuggestions::ScopeKind::Capabilities; } - - if (stmt->capabilityToken.getContentLength() != 0 && - (set.getCapabilityTargetSets().getCount() != 1 || set.isInvalid() || set.isEmpty())) - { - getSink()->diagnose( - stmt->capabilityToken.loc, - Diagnostics::invalidTargetSwitchCase, - capabilityNameToString((CapabilityName)stmt->capability)); - } if (!switchStmt) { getSink()->diagnose(stmt, Diagnostics::caseOutsideSwitch); diff --git a/source/slang/slang-diagnostic-defs.h b/source/slang/slang-diagnostic-defs.h index a1768415e..59a1bbdb6 100644 --- a/source/slang/slang-diagnostic-defs.h +++ b/source/slang/slang-diagnostic-defs.h @@ -561,6 +561,13 @@ DIAGNOSTIC( Error, spirvUndefinedId, "SPIRV id '%$0' is not defined in the current assembly block location") + +DIAGNOSTIC( + 29115, + Error, + targetSwitchCaseCannotBeAStage, + "cannot use a stage name in '__target_switch', use '__stage_switch' for stage-specific code.") + // // 3xxxx - Semantic analysis // diff --git a/source/slang/slang-emit.cpp b/source/slang/slang-emit.cpp index 58376bbc1..f093104bd 100644 --- a/source/slang/slang-emit.cpp +++ b/source/slang/slang-emit.cpp @@ -90,6 +90,7 @@ #include "slang-ir-specialize-buffer-load-arg.h" #include "slang-ir-specialize-matrix-layout.h" #include "slang-ir-specialize-resources.h" +#include "slang-ir-specialize-stage-switch.h" #include "slang-ir-specialize.h" #include "slang-ir-ssa-simplification.h" #include "slang-ir-ssa.h" @@ -322,6 +323,7 @@ struct RequiredLoweringPassSet bool dynamicResource; bool dynamicResourceHeap; bool resolveVaryingInputRef; + bool specializeStageSwitch; }; // Scan the IR module and determine which lowering/legalization passes are needed based @@ -444,6 +446,9 @@ void calcRequiredLoweringPassSet( case kIROp_ResolveVaryingInputRef: result.resolveVaryingInputRef = true; break; + case kIROp_GetCurrentStage: + result.specializeStageSwitch = true; + break; } if (!result.generics || !result.existentialTypeLayout) { @@ -1027,6 +1032,10 @@ Result linkAndOptimizeIR( cleanupGenerics(targetProgram, irModule, sink); dumpIRIfEnabled(codeGenContext, irModule, "AFTER-LOWER-GENERICS"); + // After dynamic dispatch logic is resolved into ordinary function calls, + // we can now run our stage specialization logic. + if (requiredLoweringPassSet.specializeStageSwitch) + specializeStageSwitch(irModule); if (sink->getErrorCount() != 0) return SLANG_FAIL; #if 0 diff --git a/source/slang/slang-ir-autodiff-fwd.cpp b/source/slang/slang-ir-autodiff-fwd.cpp index 19bfa5e72..6fc7d56ad 100644 --- a/source/slang/slang-ir-autodiff-fwd.cpp +++ b/source/slang/slang-ir-autodiff-fwd.cpp @@ -2177,6 +2177,7 @@ InstPair ForwardDiffTranscriber::transcribeInstImpl(IRBuilder* builder, IRInst* case kIROp_Printf: case kIROp_MakeCoopVector: case kIROp_MakeCoopVectorFromValuePack: + case kIROp_GetCurrentStage: return transcribeNonDiffInst(builder, origInst); // A call to createDynamicObject<T>(arbitraryData) cannot provide a diff value, diff --git a/source/slang/slang-ir-inst-defs.h b/source/slang/slang-ir-inst-defs.h index 59a6852b1..4de7457a3 100644 --- a/source/slang/slang-ir-inst-defs.h +++ b/source/slang/slang-ir-inst-defs.h @@ -410,6 +410,9 @@ INST(InterfaceRequirementEntry, interface_req_entry, 2, GLOBAL) // We will materialize this inst during `translateGLSLGlobalVar`. INST(GetWorkGroupSize, GetWorkGroupSize, 0, HOISTABLE) +// An inst that returns the current stage of the calling entry point. +INST(GetCurrentStage, GetCurrentStage, 0, 0) + INST(Param, param, 0, 0) INST(StructField, field, 2, 0) INST(Var, var, 0, 0) diff --git a/source/slang/slang-ir-insts.h b/source/slang/slang-ir-insts.h index fcafe4bc6..3249f34b0 100644 --- a/source/slang/slang-ir-insts.h +++ b/source/slang/slang-ir-insts.h @@ -2424,7 +2424,7 @@ struct IRCall : IRInst IR_LEAF_ISA(Call) IRInst* getCallee() { return getOperand(0); } - + IRUse* getCalleeUse() { return getOperands(); } UInt getArgCount() { return getOperandCount() - 1; } IRUse* getArgs() { return getOperands() + 1; } IROperandList<IRInst> getArgsList() @@ -3881,6 +3881,8 @@ public: // its rate, if any. void setDataType(IRInst* inst, IRType* dataType); + IRInst* emitGetCurrentStage(); + /// Extract the value wrapped inside an existential box. IRInst* emitGetValueFromBoundInterface(IRType* type, IRInst* boundInterfaceValue); diff --git a/source/slang/slang-ir-specialize-stage-switch.cpp b/source/slang/slang-ir-specialize-stage-switch.cpp new file mode 100644 index 000000000..f65aa4d4c --- /dev/null +++ b/source/slang/slang-ir-specialize-stage-switch.cpp @@ -0,0 +1,198 @@ +#include "slang-ir-specialize-stage-switch.h" + +#include "slang-capability.h" +#include "slang-compiler.h" +#include "slang-ir-call-graph.h" +#include "slang-ir-insts.h" +#include "slang-ir-util.h" +#include "slang-ir.h" + +namespace Slang +{ +bool funcHasGetCurrentStageInst(IRGlobalValueWithCode* func) +{ + for (auto block : func->getBlocks()) + { + for (auto inst : block->getChildren()) + { + if (inst->getOp() == kIROp_GetCurrentStage) + { + return true; + } + } + } + return false; +} + +void discoverStageSpecificFunctions(HashSet<IRInst*>& stageSpecificFunctions, IRModule* module) +{ + List<IRInst*> workList; + for (auto inst : module->getGlobalInsts()) + { + if (auto func = as<IRGlobalValueWithCode>(inst)) + { + if (funcHasGetCurrentStageInst(func)) + { + workList.add(inst); + stageSpecificFunctions.add(func); + } + } + } + for (Index i = 0; i < workList.getCount(); i++) + { + auto callee = workList[i]; + traverseUses( + callee, + [&](IRUse* use) + { + if (use->getUser()->getOp() == kIROp_Call) + { + auto parentFunc = getParentFunc(use->getUser()); + if (parentFunc && stageSpecificFunctions.add(parentFunc)) + { + workList.add(parentFunc); + } + } + }); + } +} + +// Given a func, replace all `GetCurrentStage` insts with the given stage, and rewrite all calls to +// stage specific functions to the specialized function for the given stage. +// +void specializeFuncToStage( + Stage stage, + IRGlobalValueWithCode* func, + Dictionary<IRInst*, Dictionary<Stage, IRInst*>>& mapFuncToStageSpecializedFunc) +{ + // Collect all insts that may need to be modified. + List<IRInst*> instsToModify; + for (auto block : func->getBlocks()) + { + for (auto inst : block->getChildren()) + { + switch (inst->getOp()) + { + case kIROp_GetCurrentStage: + case kIROp_Call: + instsToModify.add(inst); + break; + } + } + } + + IRInst* stageVal = nullptr; + IRBuilder builder(func); + for (auto inst : instsToModify) + { + builder.setInsertBefore(inst); + + switch (inst->getOp()) + { + case kIROp_GetCurrentStage: + { + // Replace `GetCurrentStage` with the stage it is specialized to. + if (!stageVal) + { + stageVal = builder.getIntValue((IRIntegerValue)stage); + } + inst->replaceUsesWith(stageVal); + inst->removeAndDeallocate(); + break; + } + case kIROp_Call: + { + // Replace calls to stage specific functions with the specialized function for the + // given stage. + auto callInst = static_cast<IRCall*>(inst); + auto callee = callInst->getCallee(); + auto specializedFuncs = mapFuncToStageSpecializedFunc.tryGetValue(callee); + if (specializedFuncs) + { + auto specializedFunc = specializedFuncs->tryGetValue(stage); + if (specializedFunc) + { + builder.replaceOperand(callInst->getCalleeUse(), *specializedFunc); + } + } + break; + } + } + } +} + +void specializeStageSwitch(IRModule* module) +{ + Dictionary<IRInst*, HashSet<IRFunc*>> mapInstToReferencingEntryPoints; + buildEntryPointReferenceGraph(mapInstToReferencingEntryPoints, module); + + HashSet<IRInst*> stageSpecificFunctions; + discoverStageSpecificFunctions(stageSpecificFunctions, module); + + // Clone all stage specific functions for each stage they are used in. + Dictionary<IRInst*, Dictionary<Stage, IRInst*>> mapFuncToStageSpecializedFunc; + for (auto func : stageSpecificFunctions) + { + auto referencingEntryPoints = mapInstToReferencingEntryPoints.tryGetValue(func); + if (!referencingEntryPoints) + continue; + if (func->findDecoration<IREntryPointDecoration>()) + continue; + Dictionary<Stage, IRInst*> specializedFuncs; + for (auto entryPoint : *referencingEntryPoints) + { + auto entryPointDecor = entryPoint->findDecoration<IREntryPointDecoration>(); + if (!entryPointDecor) + continue; + auto stage = entryPointDecor->getProfile().getStage(); + auto stageSpecializedFunc = specializedFuncs.tryGetValue(stage); + if (stageSpecializedFunc) + continue; + IRCloneEnv cloneEnv; + IRBuilder builder(func); + builder.setInsertBefore(func); + auto clonedFunc = cloneInst(&cloneEnv, &builder, func); + specializedFuncs[stage] = clonedFunc; + } + mapFuncToStageSpecializedFunc.add(func, _Move(specializedFuncs)); + } + + // Rewrite entrypoint and cloned functions to replace `GetCurrentStage` with the stage they are + // specialized to. + for (auto func : stageSpecificFunctions) + { + // Is this an entrypoint? + if (auto entryPointDecor = func->findDecoration<IREntryPointDecoration>()) + { + auto stage = entryPointDecor->getProfile().getStage(); + specializeFuncToStage( + stage, + as<IRGlobalValueWithCode>(func), + mapFuncToStageSpecializedFunc); + } + else + { + // Is this a cloned function? + auto specializedFuncs = mapFuncToStageSpecializedFunc.tryGetValue(func); + if (!specializedFuncs) + continue; + for (auto pair : *specializedFuncs) + { + auto stage = pair.first; + auto specializedFunc = pair.second; + specializeFuncToStage( + stage, + as<IRGlobalValueWithCode>(specializedFunc), + mapFuncToStageSpecializedFunc); + } + } + } + + // Remove all original stage specific functions. + for (auto f : mapFuncToStageSpecializedFunc) + { + f.first->removeAndDeallocate(); + } +} + +} // namespace Slang diff --git a/source/slang/slang-ir-specialize-stage-switch.h b/source/slang/slang-ir-specialize-stage-switch.h new file mode 100644 index 000000000..54d967e5d --- /dev/null +++ b/source/slang/slang-ir-specialize-stage-switch.h @@ -0,0 +1,14 @@ +#ifndef SLANG_IR_SPECIALIZE_STAGE_SWITCH_H +#define SLANG_IR_SPECIALIZE_STAGE_SWITCH_H + +namespace Slang +{ +struct IRModule; + +// Repalce all stage_switch insts with the case that matches current calling entrypoint. +// +void specializeStageSwitch(IRModule* module); + +} // namespace Slang + +#endif diff --git a/source/slang/slang-ir.cpp b/source/slang/slang-ir.cpp index 3314567f1..6cf7a1786 100644 --- a/source/slang/slang-ir.cpp +++ b/source/slang/slang-ir.cpp @@ -3209,6 +3209,11 @@ void IRBuilder::setDataType(IRInst* inst, IRType* dataType) } } +IRInst* IRBuilder::emitGetCurrentStage() +{ + return emitIntrinsicInst(getIntType(), kIROp_GetCurrentStage, 0, nullptr); +} + IRInst* IRBuilder::emitGetValueFromBoundInterface(IRType* type, IRInst* boundInterfaceValue) { auto inst = @@ -8301,6 +8306,7 @@ bool IRInst::mightHaveSideEffects(SideEffectAnalysisOptions options) case kIROp_ResolveVaryingInputRef: case kIROp_GetPerVertexInputArray: case kIROp_MetalCastToDepthTexture: + case kIROp_GetCurrentStage: return false; case kIROp_ForwardDifferentiate: diff --git a/source/slang/slang-lower-to-ir.cpp b/source/slang/slang-lower-to-ir.cpp index 36003dcb2..06cdf430b 100644 --- a/source/slang/slang-lower-to-ir.cpp +++ b/source/slang/slang-lower-to-ir.cpp @@ -6719,6 +6719,102 @@ struct StmtLoweringVisitor : StmtVisitor<StmtLoweringVisitor> } } + void visitStageSwitchStmt(StageSwitchStmt* stmt) + { + if (!stmt->targetCases.getCount()) + return; + + // We will lower stage switch as a normal switch statement, so they can participate in all + // optimizations. + auto builder = getBuilder(); + startBlockIfNeeded(stmt); + + // First emit code to get the current stage to switch on: + auto conditionVal = builder->emitGetCurrentStage(); + + // Remember the initial block so that we can add to it + // after we've collected all the `case`s + auto initialBlock = builder->getBlock(); + + // Next, create a block to use as the target for any `break` statements + auto breakLabel = createBlock(); + + // Register the `break` label so + // that we can find it for nested statements. + context->shared->breakLabels.add(stmt, breakLabel); + + builder->setInsertInto(initialBlock->getParent()); + + // Iterate over the body of the statement, looking + // for `case` or `default` statements: + SwitchStmtInfo info; + info.initialBlock = initialBlock; + info.defaultLabel = nullptr; + + Dictionary<Stmt*, IRBlock*> mapCaseStmtToBlock; + for (auto targetCase : stmt->targetCases) + { + IRBlock* caseBlock = nullptr; + if (!mapCaseStmtToBlock.tryGetValue(targetCase->body, caseBlock)) + { + caseBlock = builder->emitBlock(); + lowerStmt(context, targetCase->body); + mapCaseStmtToBlock.add(targetCase->body, caseBlock); + if (!builder->getBlock()->getTerminator()) + builder->emitBranch(breakLabel); + } + if (targetCase->capability == 0) + { + info.defaultLabel = caseBlock; + } + else + { + auto stage = getStageFromAtom((CapabilityAtom)targetCase->capability); + info.cases.add(builder->getIntValue(builder->getIntType(), (IRIntegerValue)stage)); + info.cases.add(caseBlock); + } + } + + // If the current block (the end of the last + // `case`) is not terminated, then terminate with a + // `break` operation. + // + // Double check that we aren't in the initial + // block, so we don't get tripped up on an + // empty `switch`. + auto curBlock = builder->getBlock(); + if (curBlock != initialBlock) + { + // Is the block already terminated? + if (!curBlock->getTerminator()) + { + // Not terminated, so add one. + builder->emitBreak(breakLabel); + } + } + + // If there was no `default` statement, then the + // default case will just branch directly to the end. + auto defaultLabel = info.defaultLabel ? info.defaultLabel : breakLabel; + + // Now that we've collected the cases, we are + // prepared to emit the `switch` instruction + // itself. + builder->setInsertInto(initialBlock); + builder->emitSwitch( + conditionVal, + breakLabel, + defaultLabel, + info.cases.getCount(), + info.cases.getBuffer()); + + // Finally we insert the label that a `break` will jump to + // (and that control flow will fall through to otherwise). + // This is the block that subsequent code will go into. + insertBlock(breakLabel); + context->shared->breakLabels.remove(stmt); + } + void visitTargetSwitchStmt(TargetSwitchStmt* stmt) { if (!stmt->targetCases.getCount()) diff --git a/source/slang/slang-parser.cpp b/source/slang/slang-parser.cpp index 95eb971ee..38285c41f 100644 --- a/source/slang/slang-parser.cpp +++ b/source/slang/slang-parser.cpp @@ -5380,9 +5380,8 @@ static Stmt* ParseDefaultStmt(Parser* parser) return stmt; } -static Stmt* parseTargetSwitchStmt(Parser* parser) +static Stmt* parseTargetSwitchStmtImpl(Parser* parser, TargetSwitchStmt* stmt) { - TargetSwitchStmt* stmt = parser->astBuilder->create<TargetSwitchStmt>(); parser->FillPosition(stmt); parser->ReadToken(); if (!beginMatch(parser, MatchedTokenType::CurlyBraces)) @@ -5479,6 +5478,18 @@ static Stmt* parseTargetSwitchStmt(Parser* parser) return stmt; } +static Stmt* parseTargetSwitchStmt(Parser* parser) +{ + auto stmt = parser->astBuilder->create<TargetSwitchStmt>(); + return parseTargetSwitchStmtImpl(parser, stmt); +} + +static Stmt* parseStageSwitchStmt(Parser* parser) +{ + auto stmt = parser->astBuilder->create<StageSwitchStmt>(); + return parseTargetSwitchStmtImpl(parser, stmt); +} + static Stmt* parseIntrinsicAsmStmt(Parser* parser) { IntrinsicAsmStmt* stmt = parser->astBuilder->create<IntrinsicAsmStmt>(); @@ -5725,6 +5736,8 @@ Stmt* Parser::ParseStatement(Stmt* parentStmt) statement = ParseSwitchStmt(this); else if (LookAheadToken("__target_switch")) statement = parseTargetSwitchStmt(this); + else if (LookAheadToken("__stage_switch")) + statement = parseStageSwitchStmt(this); else if (LookAheadToken("__intrinsic_asm")) statement = parseIntrinsicAsmStmt(this); else if (LookAheadToken("case")) diff --git a/tests/language-feature/stage-switch.slang b/tests/language-feature/stage-switch.slang new file mode 100644 index 000000000..1bdd1f85a --- /dev/null +++ b/tests/language-feature/stage-switch.slang @@ -0,0 +1,38 @@ + +//TEST:SIMPLE(filecheck=CHECK):-target spirv + +float ddx_or(float val, float defaultVal) +{ + __stage_switch + { + case fragment: + return ddx(val); + default: + return defaultVal; + } +} + +float intermediate(float val) +{ + return ddx_or(val, 1.0); +} + +RWStructuredBuffer<float> output; + +[numthreads(1,1,1)] +void computeMain() +{ + // CHECK-LABEL: %computeMain = OpFunction + // CHECK: OpStore %{{.*}} %float_1 + // CHECK: OpFunctionEnd + output[0] = intermediate(2.0); +} + +[shader("fragment")] +float4 fragmentMain(float vin) : SV_Target +{ + // CHECK-LABEL: %fragmentMain = OpFunction + // CHECK: OpDPdx + // CHECK: OpFunctionEnd + return intermediate(vin); +}
\ No newline at end of file diff --git a/tests/spirv/primitive-id-2.slang b/tests/spirv/primitive-id-2.slang new file mode 100644 index 000000000..b10514b53 --- /dev/null +++ b/tests/spirv/primitive-id-2.slang @@ -0,0 +1,13 @@ +//TEST:SIMPLE(filecheck=CHECK): -target spirv -entry main -stage fragment + +// CHECK: OpCapability Geometry +// CHECK: BuiltIn PrimitiveId + +#version 450 + +out vec4 color; + +void main() +{ + color = float4(gl_PrimitiveID, 0, 0, 1); +}
\ No newline at end of file |
