summaryrefslogtreecommitdiffstats
path: root/source/slang
diff options
context:
space:
mode:
authorYong He <yonghe@outlook.com>2025-02-06 22:02:43 -0800
committerGitHub <noreply@github.com>2025-02-06 22:02:43 -0800
commitbae87afb20f95f9f27c64c4955bbc4464c576509 (patch)
tree44d079bd76002d69be20efdbd03ac6ff62ef8caf /source/slang
parent075b10e69055acc6536d74c1cb3399e0fe75338d (diff)
Support stage_switch. (#6311)
* Support stage_switch. * Update proposal status. * Fix gl_InstanceID. * Fix.
Diffstat (limited to 'source/slang')
-rw-r--r--source/slang/glsl.meta.slang29
-rw-r--r--source/slang/slang-ast-stmt.h5
-rw-r--r--source/slang/slang-capability.cpp60
-rw-r--r--source/slang/slang-capability.h5
-rw-r--r--source/slang/slang-check-decl.cpp26
-rw-r--r--source/slang/slang-check-stmt.cpp44
-rw-r--r--source/slang/slang-diagnostic-defs.h7
-rw-r--r--source/slang/slang-emit.cpp9
-rw-r--r--source/slang/slang-ir-autodiff-fwd.cpp1
-rw-r--r--source/slang/slang-ir-inst-defs.h3
-rw-r--r--source/slang/slang-ir-insts.h4
-rw-r--r--source/slang/slang-ir-specialize-stage-switch.cpp198
-rw-r--r--source/slang/slang-ir-specialize-stage-switch.h14
-rw-r--r--source/slang/slang-ir.cpp6
-rw-r--r--source/slang/slang-lower-to-ir.cpp96
-rw-r--r--source/slang/slang-parser.cpp17
16 files changed, 496 insertions, 28 deletions
diff --git a/source/slang/glsl.meta.slang b/source/slang/glsl.meta.slang
index ef3bfd683..0bad3c681 100644
--- a/source/slang/glsl.meta.slang
+++ b/source/slang/glsl.meta.slang
@@ -4824,24 +4824,41 @@ public property uint3 gl_LaunchSizeEXT
}
}
+internal in int __gl_PrimitiveID : SV_PrimitiveID;
public property int gl_PrimitiveID
{
- [require(cuda_glsl_hlsl_spirv, raytracing_anyhit_closesthit_intersection)]
+ [require(cuda_glsl_hlsl_spirv)]
get
{
- setupExtForRayTracingBuiltIn();
- return PrimitiveIndex();
+ __stage_switch
+ {
+ case anyhit:
+ case closesthit:
+ case intersection:
+ setupExtForRayTracingBuiltIn();
+ return PrimitiveIndex();
+ default:
+ return __gl_PrimitiveID;
+ }
}
}
public property int gl_InstanceID
{
- [require(cuda_glsl_hlsl_spirv, raytracing_anyhit_closesthit_intersection)]
+ [require(cuda_glsl_hlsl_spirv)]
get
{
- setupExtForRayTracingBuiltIn();
- return InstanceIndex();
+ __stage_switch
+ {
+ case anyhit:
+ case closesthit:
+ case intersection:
+ setupExtForRayTracingBuiltIn();
+ return InstanceIndex();
+ default:
+ return gl_InstanceIndex;
+ }
}
}
diff --git a/source/slang/slang-ast-stmt.h b/source/slang/slang-ast-stmt.h
index eba7027c2..09c491287 100644
--- a/source/slang/slang-ast-stmt.h
+++ b/source/slang/slang-ast-stmt.h
@@ -113,6 +113,11 @@ class TargetSwitchStmt : public Stmt
List<TargetCaseStmt*> targetCases;
};
+class StageSwitchStmt : public TargetSwitchStmt
+{
+ SLANG_AST_CLASS(StageSwitchStmt)
+};
+
class IntrinsicAsmStmt : public Stmt
{
SLANG_AST_CLASS(IntrinsicAsmStmt)
diff --git a/source/slang/slang-capability.cpp b/source/slang/slang-capability.cpp
index 7c4e34d32..1eb0cae31 100644
--- a/source/slang/slang-capability.cpp
+++ b/source/slang/slang-capability.cpp
@@ -100,6 +100,45 @@ bool isDirectChildOfAbstractAtom(CapabilityAtom name)
return _getInfo(name).abstractBase != CapabilityName::Invalid;
}
+bool isStageAtom(CapabilityName name, CapabilityName& outCanonicalStage)
+{
+ auto& info = _getInfo(name);
+ if (info.abstractBase == CapabilityName::stage)
+ {
+ outCanonicalStage = name;
+ return true;
+ }
+ switch (name)
+ {
+ case CapabilityName::anyhit:
+ outCanonicalStage = CapabilityName::_anyhit;
+ return true;
+ case CapabilityName::closesthit:
+ outCanonicalStage = CapabilityName::_closesthit;
+ return true;
+ case CapabilityName::miss:
+ outCanonicalStage = CapabilityName::_miss;
+ return true;
+ case CapabilityName::intersection:
+ outCanonicalStage = CapabilityName::_intersection;
+ return true;
+ case CapabilityName::raygen:
+ outCanonicalStage = CapabilityName::_raygen;
+ return true;
+ case CapabilityName::callable:
+ outCanonicalStage = CapabilityName::_callable;
+ return true;
+ case CapabilityName::mesh:
+ outCanonicalStage = CapabilityName::_mesh;
+ return true;
+ case CapabilityName::amplification:
+ outCanonicalStage = CapabilityName::_amplification;
+ return true;
+ default:
+ return false;
+ }
+}
+
bool isTargetVersionAtom(CapabilityAtom name)
{
if (name >= CapabilityAtom::_spirv_1_0 && name <= getLatestSpirvAtom())
@@ -620,7 +659,26 @@ CapabilitySet CapabilitySet::getTargetsThisHasButOtherDoesNot(const CapabilitySe
if (other.m_targetSets.tryGetValue(i.first))
continue;
- newSet.m_targetSets[i.first] = this->m_targetSets[i.first];
+ newSet.m_targetSets[i.first] = i.second;
+ }
+ return newSet;
+}
+
+CapabilitySet CapabilitySet::getStagesThisHasButOtherDoesNot(const CapabilitySet& other)
+{
+ CapabilitySet newSet{};
+ for (auto& i : this->m_targetSets)
+ {
+ if (auto otherTarget = other.m_targetSets.tryGetValue(i.first))
+ {
+ auto& thisTarget = m_targetSets[i.first];
+ for (auto& stage : thisTarget.shaderStageSets)
+ {
+ if (otherTarget->shaderStageSets.containsKey(stage.first))
+ continue;
+ newSet.m_targetSets[i.first].shaderStageSets[stage.first] = stage.second;
+ }
+ }
}
return newSet;
}
diff --git a/source/slang/slang-capability.h b/source/slang/slang-capability.h
index 631cd307a..7c429d825 100644
--- a/source/slang/slang-capability.h
+++ b/source/slang/slang-capability.h
@@ -169,6 +169,9 @@ public:
/// Return a capability set of 'target' atoms 'this' has, but 'other' does not.
CapabilitySet getTargetsThisHasButOtherDoesNot(const CapabilitySet& other);
+ /// Return a capability set of 'stage' atoms 'this' has, but 'other' does not.
+ CapabilitySet getStagesThisHasButOtherDoesNot(const CapabilitySet& other);
+
/// Are these two capability sets equal?
bool operator==(CapabilitySet const& that) const;
@@ -359,7 +362,7 @@ void getCapabilityNames(List<UnownedStringSlice>& ioNames);
UnownedStringSlice capabilityNameToString(CapabilityName name);
bool isDirectChildOfAbstractAtom(CapabilityAtom name);
-
+bool isStageAtom(CapabilityName name, CapabilityName& outCanonicalStage);
/// Return true if `name` represents an atom for a target version, e.g. spirv_1_5.
bool isTargetVersionAtom(CapabilityAtom name);
diff --git a/source/slang/slang-check-decl.cpp b/source/slang/slang-check-decl.cpp
index ab3335bfb..39ab41421 100644
--- a/source/slang/slang-check-decl.cpp
+++ b/source/slang/slang-check-decl.cpp
@@ -12735,14 +12735,26 @@ struct CapabilityDeclReferenceVisitor
std::swap(stmt->targetCases[i], stmt->targetCases[i + 1]);
continue;
}
-
- if (!maybeRequireCapability)
- targetCap = (CapabilitySet(CapabilityName::any_target)
- .getTargetsThisHasButOtherDoesNot(set));
+ if (as<StageSwitchStmt>(stmt))
+ {
+ if (!maybeRequireCapability)
+ targetCap = (CapabilitySet(CapabilityName::any_target)
+ .getStagesThisHasButOtherDoesNot(set));
+ else
+ targetCap =
+ (maybeRequireCapability->capabilitySet.getStagesThisHasButOtherDoesNot(
+ set));
+ }
else
- targetCap =
- (maybeRequireCapability->capabilitySet.getTargetsThisHasButOtherDoesNot(
- set));
+ {
+ if (!maybeRequireCapability)
+ targetCap = (CapabilitySet(CapabilityName::any_target)
+ .getTargetsThisHasButOtherDoesNot(set));
+ else
+ targetCap =
+ (maybeRequireCapability->capabilitySet.getTargetsThisHasButOtherDoesNot(
+ set));
+ }
}
else
{
diff --git a/source/slang/slang-check-stmt.cpp b/source/slang/slang-check-stmt.cpp
index 151c9324a..db6f00d23 100644
--- a/source/slang/slang-check-stmt.cpp
+++ b/source/slang/slang-check-stmt.cpp
@@ -367,6 +367,40 @@ void SemanticsStmtVisitor::visitTargetSwitchStmt(TargetSwitchStmt* stmt)
HashSet<Stmt*> checkedStmt;
for (auto caseStmt : stmt->targetCases)
{
+ CapabilitySet set((CapabilityName)caseStmt->capability);
+
+ CapabilityName canonicalStage = CapabilityName::Invalid;
+ bool isStage = isStageAtom((CapabilityName)caseStmt->capability, canonicalStage);
+ if (as<StageSwitchStmt>(stmt))
+ {
+ if (!isStage && caseStmt->capability != 0)
+ {
+ getSink()->diagnose(
+ caseStmt->capabilityToken.loc,
+ Diagnostics::unknownStageName,
+ caseStmt->capabilityToken);
+ }
+ caseStmt->capability = (int)canonicalStage;
+ }
+ else
+ {
+ if (isStage)
+ {
+ getSink()->diagnose(
+ caseStmt->capabilityToken.loc,
+ Diagnostics::targetSwitchCaseCannotBeAStage);
+ }
+ else if (
+ caseStmt->capabilityToken.getContentLength() != 0 &&
+ (set.getCapabilityTargetSets().getCount() != 1 || set.isInvalid() || set.isEmpty()))
+ {
+ getSink()->diagnose(
+ caseStmt->capabilityToken.loc,
+ Diagnostics::invalidTargetSwitchCase,
+ capabilityNameToString((CapabilityName)caseStmt->capability));
+ }
+ }
+
if (checkedStmt.contains(caseStmt->body))
continue;
subContext.checkStmt(caseStmt);
@@ -377,7 +411,6 @@ void SemanticsStmtVisitor::visitTargetSwitchStmt(TargetSwitchStmt* stmt)
void SemanticsStmtVisitor::visitTargetCaseStmt(TargetCaseStmt* stmt)
{
auto switchStmt = FindOuterStmt<TargetSwitchStmt>();
- CapabilitySet set((CapabilityName)stmt->capability);
if (getShared()->isInLanguageServer() &&
getShared()->getSession()->getCompletionRequestTokenName() ==
stmt->capabilityToken.getName())
@@ -385,15 +418,6 @@ void SemanticsStmtVisitor::visitTargetCaseStmt(TargetCaseStmt* stmt)
getShared()->getLinkage()->contentAssistInfo.completionSuggestions.scopeKind =
CompletionSuggestions::ScopeKind::Capabilities;
}
-
- if (stmt->capabilityToken.getContentLength() != 0 &&
- (set.getCapabilityTargetSets().getCount() != 1 || set.isInvalid() || set.isEmpty()))
- {
- getSink()->diagnose(
- stmt->capabilityToken.loc,
- Diagnostics::invalidTargetSwitchCase,
- capabilityNameToString((CapabilityName)stmt->capability));
- }
if (!switchStmt)
{
getSink()->diagnose(stmt, Diagnostics::caseOutsideSwitch);
diff --git a/source/slang/slang-diagnostic-defs.h b/source/slang/slang-diagnostic-defs.h
index a1768415e..59a1bbdb6 100644
--- a/source/slang/slang-diagnostic-defs.h
+++ b/source/slang/slang-diagnostic-defs.h
@@ -561,6 +561,13 @@ DIAGNOSTIC(
Error,
spirvUndefinedId,
"SPIRV id '%$0' is not defined in the current assembly block location")
+
+DIAGNOSTIC(
+ 29115,
+ Error,
+ targetSwitchCaseCannotBeAStage,
+ "cannot use a stage name in '__target_switch', use '__stage_switch' for stage-specific code.")
+
//
// 3xxxx - Semantic analysis
//
diff --git a/source/slang/slang-emit.cpp b/source/slang/slang-emit.cpp
index 58376bbc1..f093104bd 100644
--- a/source/slang/slang-emit.cpp
+++ b/source/slang/slang-emit.cpp
@@ -90,6 +90,7 @@
#include "slang-ir-specialize-buffer-load-arg.h"
#include "slang-ir-specialize-matrix-layout.h"
#include "slang-ir-specialize-resources.h"
+#include "slang-ir-specialize-stage-switch.h"
#include "slang-ir-specialize.h"
#include "slang-ir-ssa-simplification.h"
#include "slang-ir-ssa.h"
@@ -322,6 +323,7 @@ struct RequiredLoweringPassSet
bool dynamicResource;
bool dynamicResourceHeap;
bool resolveVaryingInputRef;
+ bool specializeStageSwitch;
};
// Scan the IR module and determine which lowering/legalization passes are needed based
@@ -444,6 +446,9 @@ void calcRequiredLoweringPassSet(
case kIROp_ResolveVaryingInputRef:
result.resolveVaryingInputRef = true;
break;
+ case kIROp_GetCurrentStage:
+ result.specializeStageSwitch = true;
+ break;
}
if (!result.generics || !result.existentialTypeLayout)
{
@@ -1027,6 +1032,10 @@ Result linkAndOptimizeIR(
cleanupGenerics(targetProgram, irModule, sink);
dumpIRIfEnabled(codeGenContext, irModule, "AFTER-LOWER-GENERICS");
+ // After dynamic dispatch logic is resolved into ordinary function calls,
+ // we can now run our stage specialization logic.
+ if (requiredLoweringPassSet.specializeStageSwitch)
+ specializeStageSwitch(irModule);
if (sink->getErrorCount() != 0)
return SLANG_FAIL;
#if 0
diff --git a/source/slang/slang-ir-autodiff-fwd.cpp b/source/slang/slang-ir-autodiff-fwd.cpp
index 19bfa5e72..6fc7d56ad 100644
--- a/source/slang/slang-ir-autodiff-fwd.cpp
+++ b/source/slang/slang-ir-autodiff-fwd.cpp
@@ -2177,6 +2177,7 @@ InstPair ForwardDiffTranscriber::transcribeInstImpl(IRBuilder* builder, IRInst*
case kIROp_Printf:
case kIROp_MakeCoopVector:
case kIROp_MakeCoopVectorFromValuePack:
+ case kIROp_GetCurrentStage:
return transcribeNonDiffInst(builder, origInst);
// A call to createDynamicObject<T>(arbitraryData) cannot provide a diff value,
diff --git a/source/slang/slang-ir-inst-defs.h b/source/slang/slang-ir-inst-defs.h
index 59a6852b1..4de7457a3 100644
--- a/source/slang/slang-ir-inst-defs.h
+++ b/source/slang/slang-ir-inst-defs.h
@@ -410,6 +410,9 @@ INST(InterfaceRequirementEntry, interface_req_entry, 2, GLOBAL)
// We will materialize this inst during `translateGLSLGlobalVar`.
INST(GetWorkGroupSize, GetWorkGroupSize, 0, HOISTABLE)
+// An inst that returns the current stage of the calling entry point.
+INST(GetCurrentStage, GetCurrentStage, 0, 0)
+
INST(Param, param, 0, 0)
INST(StructField, field, 2, 0)
INST(Var, var, 0, 0)
diff --git a/source/slang/slang-ir-insts.h b/source/slang/slang-ir-insts.h
index fcafe4bc6..3249f34b0 100644
--- a/source/slang/slang-ir-insts.h
+++ b/source/slang/slang-ir-insts.h
@@ -2424,7 +2424,7 @@ struct IRCall : IRInst
IR_LEAF_ISA(Call)
IRInst* getCallee() { return getOperand(0); }
-
+ IRUse* getCalleeUse() { return getOperands(); }
UInt getArgCount() { return getOperandCount() - 1; }
IRUse* getArgs() { return getOperands() + 1; }
IROperandList<IRInst> getArgsList()
@@ -3881,6 +3881,8 @@ public:
// its rate, if any.
void setDataType(IRInst* inst, IRType* dataType);
+ IRInst* emitGetCurrentStage();
+
/// Extract the value wrapped inside an existential box.
IRInst* emitGetValueFromBoundInterface(IRType* type, IRInst* boundInterfaceValue);
diff --git a/source/slang/slang-ir-specialize-stage-switch.cpp b/source/slang/slang-ir-specialize-stage-switch.cpp
new file mode 100644
index 000000000..f65aa4d4c
--- /dev/null
+++ b/source/slang/slang-ir-specialize-stage-switch.cpp
@@ -0,0 +1,198 @@
+#include "slang-ir-specialize-stage-switch.h"
+
+#include "slang-capability.h"
+#include "slang-compiler.h"
+#include "slang-ir-call-graph.h"
+#include "slang-ir-insts.h"
+#include "slang-ir-util.h"
+#include "slang-ir.h"
+
+namespace Slang
+{
+bool funcHasGetCurrentStageInst(IRGlobalValueWithCode* func)
+{
+ for (auto block : func->getBlocks())
+ {
+ for (auto inst : block->getChildren())
+ {
+ if (inst->getOp() == kIROp_GetCurrentStage)
+ {
+ return true;
+ }
+ }
+ }
+ return false;
+}
+
+void discoverStageSpecificFunctions(HashSet<IRInst*>& stageSpecificFunctions, IRModule* module)
+{
+ List<IRInst*> workList;
+ for (auto inst : module->getGlobalInsts())
+ {
+ if (auto func = as<IRGlobalValueWithCode>(inst))
+ {
+ if (funcHasGetCurrentStageInst(func))
+ {
+ workList.add(inst);
+ stageSpecificFunctions.add(func);
+ }
+ }
+ }
+ for (Index i = 0; i < workList.getCount(); i++)
+ {
+ auto callee = workList[i];
+ traverseUses(
+ callee,
+ [&](IRUse* use)
+ {
+ if (use->getUser()->getOp() == kIROp_Call)
+ {
+ auto parentFunc = getParentFunc(use->getUser());
+ if (parentFunc && stageSpecificFunctions.add(parentFunc))
+ {
+ workList.add(parentFunc);
+ }
+ }
+ });
+ }
+}
+
+// Given a func, replace all `GetCurrentStage` insts with the given stage, and rewrite all calls to
+// stage specific functions to the specialized function for the given stage.
+//
+void specializeFuncToStage(
+ Stage stage,
+ IRGlobalValueWithCode* func,
+ Dictionary<IRInst*, Dictionary<Stage, IRInst*>>& mapFuncToStageSpecializedFunc)
+{
+ // Collect all insts that may need to be modified.
+ List<IRInst*> instsToModify;
+ for (auto block : func->getBlocks())
+ {
+ for (auto inst : block->getChildren())
+ {
+ switch (inst->getOp())
+ {
+ case kIROp_GetCurrentStage:
+ case kIROp_Call:
+ instsToModify.add(inst);
+ break;
+ }
+ }
+ }
+
+ IRInst* stageVal = nullptr;
+ IRBuilder builder(func);
+ for (auto inst : instsToModify)
+ {
+ builder.setInsertBefore(inst);
+
+ switch (inst->getOp())
+ {
+ case kIROp_GetCurrentStage:
+ {
+ // Replace `GetCurrentStage` with the stage it is specialized to.
+ if (!stageVal)
+ {
+ stageVal = builder.getIntValue((IRIntegerValue)stage);
+ }
+ inst->replaceUsesWith(stageVal);
+ inst->removeAndDeallocate();
+ break;
+ }
+ case kIROp_Call:
+ {
+ // Replace calls to stage specific functions with the specialized function for the
+ // given stage.
+ auto callInst = static_cast<IRCall*>(inst);
+ auto callee = callInst->getCallee();
+ auto specializedFuncs = mapFuncToStageSpecializedFunc.tryGetValue(callee);
+ if (specializedFuncs)
+ {
+ auto specializedFunc = specializedFuncs->tryGetValue(stage);
+ if (specializedFunc)
+ {
+ builder.replaceOperand(callInst->getCalleeUse(), *specializedFunc);
+ }
+ }
+ break;
+ }
+ }
+ }
+}
+
+void specializeStageSwitch(IRModule* module)
+{
+ Dictionary<IRInst*, HashSet<IRFunc*>> mapInstToReferencingEntryPoints;
+ buildEntryPointReferenceGraph(mapInstToReferencingEntryPoints, module);
+
+ HashSet<IRInst*> stageSpecificFunctions;
+ discoverStageSpecificFunctions(stageSpecificFunctions, module);
+
+ // Clone all stage specific functions for each stage they are used in.
+ Dictionary<IRInst*, Dictionary<Stage, IRInst*>> mapFuncToStageSpecializedFunc;
+ for (auto func : stageSpecificFunctions)
+ {
+ auto referencingEntryPoints = mapInstToReferencingEntryPoints.tryGetValue(func);
+ if (!referencingEntryPoints)
+ continue;
+ if (func->findDecoration<IREntryPointDecoration>())
+ continue;
+ Dictionary<Stage, IRInst*> specializedFuncs;
+ for (auto entryPoint : *referencingEntryPoints)
+ {
+ auto entryPointDecor = entryPoint->findDecoration<IREntryPointDecoration>();
+ if (!entryPointDecor)
+ continue;
+ auto stage = entryPointDecor->getProfile().getStage();
+ auto stageSpecializedFunc = specializedFuncs.tryGetValue(stage);
+ if (stageSpecializedFunc)
+ continue;
+ IRCloneEnv cloneEnv;
+ IRBuilder builder(func);
+ builder.setInsertBefore(func);
+ auto clonedFunc = cloneInst(&cloneEnv, &builder, func);
+ specializedFuncs[stage] = clonedFunc;
+ }
+ mapFuncToStageSpecializedFunc.add(func, _Move(specializedFuncs));
+ }
+
+ // Rewrite entrypoint and cloned functions to replace `GetCurrentStage` with the stage they are
+ // specialized to.
+ for (auto func : stageSpecificFunctions)
+ {
+ // Is this an entrypoint?
+ if (auto entryPointDecor = func->findDecoration<IREntryPointDecoration>())
+ {
+ auto stage = entryPointDecor->getProfile().getStage();
+ specializeFuncToStage(
+ stage,
+ as<IRGlobalValueWithCode>(func),
+ mapFuncToStageSpecializedFunc);
+ }
+ else
+ {
+ // Is this a cloned function?
+ auto specializedFuncs = mapFuncToStageSpecializedFunc.tryGetValue(func);
+ if (!specializedFuncs)
+ continue;
+ for (auto pair : *specializedFuncs)
+ {
+ auto stage = pair.first;
+ auto specializedFunc = pair.second;
+ specializeFuncToStage(
+ stage,
+ as<IRGlobalValueWithCode>(specializedFunc),
+ mapFuncToStageSpecializedFunc);
+ }
+ }
+ }
+
+ // Remove all original stage specific functions.
+ for (auto f : mapFuncToStageSpecializedFunc)
+ {
+ f.first->removeAndDeallocate();
+ }
+}
+
+} // namespace Slang
diff --git a/source/slang/slang-ir-specialize-stage-switch.h b/source/slang/slang-ir-specialize-stage-switch.h
new file mode 100644
index 000000000..54d967e5d
--- /dev/null
+++ b/source/slang/slang-ir-specialize-stage-switch.h
@@ -0,0 +1,14 @@
+#ifndef SLANG_IR_SPECIALIZE_STAGE_SWITCH_H
+#define SLANG_IR_SPECIALIZE_STAGE_SWITCH_H
+
+namespace Slang
+{
+struct IRModule;
+
+// Repalce all stage_switch insts with the case that matches current calling entrypoint.
+//
+void specializeStageSwitch(IRModule* module);
+
+} // namespace Slang
+
+#endif
diff --git a/source/slang/slang-ir.cpp b/source/slang/slang-ir.cpp
index 3314567f1..6cf7a1786 100644
--- a/source/slang/slang-ir.cpp
+++ b/source/slang/slang-ir.cpp
@@ -3209,6 +3209,11 @@ void IRBuilder::setDataType(IRInst* inst, IRType* dataType)
}
}
+IRInst* IRBuilder::emitGetCurrentStage()
+{
+ return emitIntrinsicInst(getIntType(), kIROp_GetCurrentStage, 0, nullptr);
+}
+
IRInst* IRBuilder::emitGetValueFromBoundInterface(IRType* type, IRInst* boundInterfaceValue)
{
auto inst =
@@ -8301,6 +8306,7 @@ bool IRInst::mightHaveSideEffects(SideEffectAnalysisOptions options)
case kIROp_ResolveVaryingInputRef:
case kIROp_GetPerVertexInputArray:
case kIROp_MetalCastToDepthTexture:
+ case kIROp_GetCurrentStage:
return false;
case kIROp_ForwardDifferentiate:
diff --git a/source/slang/slang-lower-to-ir.cpp b/source/slang/slang-lower-to-ir.cpp
index 36003dcb2..06cdf430b 100644
--- a/source/slang/slang-lower-to-ir.cpp
+++ b/source/slang/slang-lower-to-ir.cpp
@@ -6719,6 +6719,102 @@ struct StmtLoweringVisitor : StmtVisitor<StmtLoweringVisitor>
}
}
+ void visitStageSwitchStmt(StageSwitchStmt* stmt)
+ {
+ if (!stmt->targetCases.getCount())
+ return;
+
+ // We will lower stage switch as a normal switch statement, so they can participate in all
+ // optimizations.
+ auto builder = getBuilder();
+ startBlockIfNeeded(stmt);
+
+ // First emit code to get the current stage to switch on:
+ auto conditionVal = builder->emitGetCurrentStage();
+
+ // Remember the initial block so that we can add to it
+ // after we've collected all the `case`s
+ auto initialBlock = builder->getBlock();
+
+ // Next, create a block to use as the target for any `break` statements
+ auto breakLabel = createBlock();
+
+ // Register the `break` label so
+ // that we can find it for nested statements.
+ context->shared->breakLabels.add(stmt, breakLabel);
+
+ builder->setInsertInto(initialBlock->getParent());
+
+ // Iterate over the body of the statement, looking
+ // for `case` or `default` statements:
+ SwitchStmtInfo info;
+ info.initialBlock = initialBlock;
+ info.defaultLabel = nullptr;
+
+ Dictionary<Stmt*, IRBlock*> mapCaseStmtToBlock;
+ for (auto targetCase : stmt->targetCases)
+ {
+ IRBlock* caseBlock = nullptr;
+ if (!mapCaseStmtToBlock.tryGetValue(targetCase->body, caseBlock))
+ {
+ caseBlock = builder->emitBlock();
+ lowerStmt(context, targetCase->body);
+ mapCaseStmtToBlock.add(targetCase->body, caseBlock);
+ if (!builder->getBlock()->getTerminator())
+ builder->emitBranch(breakLabel);
+ }
+ if (targetCase->capability == 0)
+ {
+ info.defaultLabel = caseBlock;
+ }
+ else
+ {
+ auto stage = getStageFromAtom((CapabilityAtom)targetCase->capability);
+ info.cases.add(builder->getIntValue(builder->getIntType(), (IRIntegerValue)stage));
+ info.cases.add(caseBlock);
+ }
+ }
+
+ // If the current block (the end of the last
+ // `case`) is not terminated, then terminate with a
+ // `break` operation.
+ //
+ // Double check that we aren't in the initial
+ // block, so we don't get tripped up on an
+ // empty `switch`.
+ auto curBlock = builder->getBlock();
+ if (curBlock != initialBlock)
+ {
+ // Is the block already terminated?
+ if (!curBlock->getTerminator())
+ {
+ // Not terminated, so add one.
+ builder->emitBreak(breakLabel);
+ }
+ }
+
+ // If there was no `default` statement, then the
+ // default case will just branch directly to the end.
+ auto defaultLabel = info.defaultLabel ? info.defaultLabel : breakLabel;
+
+ // Now that we've collected the cases, we are
+ // prepared to emit the `switch` instruction
+ // itself.
+ builder->setInsertInto(initialBlock);
+ builder->emitSwitch(
+ conditionVal,
+ breakLabel,
+ defaultLabel,
+ info.cases.getCount(),
+ info.cases.getBuffer());
+
+ // Finally we insert the label that a `break` will jump to
+ // (and that control flow will fall through to otherwise).
+ // This is the block that subsequent code will go into.
+ insertBlock(breakLabel);
+ context->shared->breakLabels.remove(stmt);
+ }
+
void visitTargetSwitchStmt(TargetSwitchStmt* stmt)
{
if (!stmt->targetCases.getCount())
diff --git a/source/slang/slang-parser.cpp b/source/slang/slang-parser.cpp
index 95eb971ee..38285c41f 100644
--- a/source/slang/slang-parser.cpp
+++ b/source/slang/slang-parser.cpp
@@ -5380,9 +5380,8 @@ static Stmt* ParseDefaultStmt(Parser* parser)
return stmt;
}
-static Stmt* parseTargetSwitchStmt(Parser* parser)
+static Stmt* parseTargetSwitchStmtImpl(Parser* parser, TargetSwitchStmt* stmt)
{
- TargetSwitchStmt* stmt = parser->astBuilder->create<TargetSwitchStmt>();
parser->FillPosition(stmt);
parser->ReadToken();
if (!beginMatch(parser, MatchedTokenType::CurlyBraces))
@@ -5479,6 +5478,18 @@ static Stmt* parseTargetSwitchStmt(Parser* parser)
return stmt;
}
+static Stmt* parseTargetSwitchStmt(Parser* parser)
+{
+ auto stmt = parser->astBuilder->create<TargetSwitchStmt>();
+ return parseTargetSwitchStmtImpl(parser, stmt);
+}
+
+static Stmt* parseStageSwitchStmt(Parser* parser)
+{
+ auto stmt = parser->astBuilder->create<StageSwitchStmt>();
+ return parseTargetSwitchStmtImpl(parser, stmt);
+}
+
static Stmt* parseIntrinsicAsmStmt(Parser* parser)
{
IntrinsicAsmStmt* stmt = parser->astBuilder->create<IntrinsicAsmStmt>();
@@ -5725,6 +5736,8 @@ Stmt* Parser::ParseStatement(Stmt* parentStmt)
statement = ParseSwitchStmt(this);
else if (LookAheadToken("__target_switch"))
statement = parseTargetSwitchStmt(this);
+ else if (LookAheadToken("__stage_switch"))
+ statement = parseStageSwitchStmt(this);
else if (LookAheadToken("__intrinsic_asm"))
statement = parseIntrinsicAsmStmt(this);
else if (LookAheadToken("case"))