diff options
| author | Yong He <yonghe@outlook.com> | 2025-02-06 22:02:43 -0800 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2025-02-06 22:02:43 -0800 |
| commit | bae87afb20f95f9f27c64c4955bbc4464c576509 (patch) | |
| tree | 44d079bd76002d69be20efdbd03ac6ff62ef8caf /source | |
| parent | 075b10e69055acc6536d74c1cb3399e0fe75338d (diff) | |
Support stage_switch. (#6311)
* Support stage_switch.
* Update proposal status.
* Fix gl_InstanceID.
* Fix.
Diffstat (limited to 'source')
| -rw-r--r-- | source/slang/glsl.meta.slang | 29 | ||||
| -rw-r--r-- | source/slang/slang-ast-stmt.h | 5 | ||||
| -rw-r--r-- | source/slang/slang-capability.cpp | 60 | ||||
| -rw-r--r-- | source/slang/slang-capability.h | 5 | ||||
| -rw-r--r-- | source/slang/slang-check-decl.cpp | 26 | ||||
| -rw-r--r-- | source/slang/slang-check-stmt.cpp | 44 | ||||
| -rw-r--r-- | source/slang/slang-diagnostic-defs.h | 7 | ||||
| -rw-r--r-- | source/slang/slang-emit.cpp | 9 | ||||
| -rw-r--r-- | source/slang/slang-ir-autodiff-fwd.cpp | 1 | ||||
| -rw-r--r-- | source/slang/slang-ir-inst-defs.h | 3 | ||||
| -rw-r--r-- | source/slang/slang-ir-insts.h | 4 | ||||
| -rw-r--r-- | source/slang/slang-ir-specialize-stage-switch.cpp | 198 | ||||
| -rw-r--r-- | source/slang/slang-ir-specialize-stage-switch.h | 14 | ||||
| -rw-r--r-- | source/slang/slang-ir.cpp | 6 | ||||
| -rw-r--r-- | source/slang/slang-lower-to-ir.cpp | 96 | ||||
| -rw-r--r-- | source/slang/slang-parser.cpp | 17 |
16 files changed, 496 insertions, 28 deletions
diff --git a/source/slang/glsl.meta.slang b/source/slang/glsl.meta.slang index ef3bfd683..0bad3c681 100644 --- a/source/slang/glsl.meta.slang +++ b/source/slang/glsl.meta.slang @@ -4824,24 +4824,41 @@ public property uint3 gl_LaunchSizeEXT } } +internal in int __gl_PrimitiveID : SV_PrimitiveID; public property int gl_PrimitiveID { - [require(cuda_glsl_hlsl_spirv, raytracing_anyhit_closesthit_intersection)] + [require(cuda_glsl_hlsl_spirv)] get { - setupExtForRayTracingBuiltIn(); - return PrimitiveIndex(); + __stage_switch + { + case anyhit: + case closesthit: + case intersection: + setupExtForRayTracingBuiltIn(); + return PrimitiveIndex(); + default: + return __gl_PrimitiveID; + } } } public property int gl_InstanceID { - [require(cuda_glsl_hlsl_spirv, raytracing_anyhit_closesthit_intersection)] + [require(cuda_glsl_hlsl_spirv)] get { - setupExtForRayTracingBuiltIn(); - return InstanceIndex(); + __stage_switch + { + case anyhit: + case closesthit: + case intersection: + setupExtForRayTracingBuiltIn(); + return InstanceIndex(); + default: + return gl_InstanceIndex; + } } } diff --git a/source/slang/slang-ast-stmt.h b/source/slang/slang-ast-stmt.h index eba7027c2..09c491287 100644 --- a/source/slang/slang-ast-stmt.h +++ b/source/slang/slang-ast-stmt.h @@ -113,6 +113,11 @@ class TargetSwitchStmt : public Stmt List<TargetCaseStmt*> targetCases; }; +class StageSwitchStmt : public TargetSwitchStmt +{ + SLANG_AST_CLASS(StageSwitchStmt) +}; + class IntrinsicAsmStmt : public Stmt { SLANG_AST_CLASS(IntrinsicAsmStmt) diff --git a/source/slang/slang-capability.cpp b/source/slang/slang-capability.cpp index 7c4e34d32..1eb0cae31 100644 --- a/source/slang/slang-capability.cpp +++ b/source/slang/slang-capability.cpp @@ -100,6 +100,45 @@ bool isDirectChildOfAbstractAtom(CapabilityAtom name) return _getInfo(name).abstractBase != CapabilityName::Invalid; } +bool isStageAtom(CapabilityName name, CapabilityName& outCanonicalStage) +{ + auto& info = _getInfo(name); + if (info.abstractBase == CapabilityName::stage) + { + outCanonicalStage = name; + return true; + } + switch (name) + { + case CapabilityName::anyhit: + outCanonicalStage = CapabilityName::_anyhit; + return true; + case CapabilityName::closesthit: + outCanonicalStage = CapabilityName::_closesthit; + return true; + case CapabilityName::miss: + outCanonicalStage = CapabilityName::_miss; + return true; + case CapabilityName::intersection: + outCanonicalStage = CapabilityName::_intersection; + return true; + case CapabilityName::raygen: + outCanonicalStage = CapabilityName::_raygen; + return true; + case CapabilityName::callable: + outCanonicalStage = CapabilityName::_callable; + return true; + case CapabilityName::mesh: + outCanonicalStage = CapabilityName::_mesh; + return true; + case CapabilityName::amplification: + outCanonicalStage = CapabilityName::_amplification; + return true; + default: + return false; + } +} + bool isTargetVersionAtom(CapabilityAtom name) { if (name >= CapabilityAtom::_spirv_1_0 && name <= getLatestSpirvAtom()) @@ -620,7 +659,26 @@ CapabilitySet CapabilitySet::getTargetsThisHasButOtherDoesNot(const CapabilitySe if (other.m_targetSets.tryGetValue(i.first)) continue; - newSet.m_targetSets[i.first] = this->m_targetSets[i.first]; + newSet.m_targetSets[i.first] = i.second; + } + return newSet; +} + +CapabilitySet CapabilitySet::getStagesThisHasButOtherDoesNot(const CapabilitySet& other) +{ + CapabilitySet newSet{}; + for (auto& i : this->m_targetSets) + { + if (auto otherTarget = other.m_targetSets.tryGetValue(i.first)) + { + auto& thisTarget = m_targetSets[i.first]; + for (auto& stage : thisTarget.shaderStageSets) + { + if (otherTarget->shaderStageSets.containsKey(stage.first)) + continue; + newSet.m_targetSets[i.first].shaderStageSets[stage.first] = stage.second; + } + } } return newSet; } diff --git a/source/slang/slang-capability.h b/source/slang/slang-capability.h index 631cd307a..7c429d825 100644 --- a/source/slang/slang-capability.h +++ b/source/slang/slang-capability.h @@ -169,6 +169,9 @@ public: /// Return a capability set of 'target' atoms 'this' has, but 'other' does not. CapabilitySet getTargetsThisHasButOtherDoesNot(const CapabilitySet& other); + /// Return a capability set of 'stage' atoms 'this' has, but 'other' does not. + CapabilitySet getStagesThisHasButOtherDoesNot(const CapabilitySet& other); + /// Are these two capability sets equal? bool operator==(CapabilitySet const& that) const; @@ -359,7 +362,7 @@ void getCapabilityNames(List<UnownedStringSlice>& ioNames); UnownedStringSlice capabilityNameToString(CapabilityName name); bool isDirectChildOfAbstractAtom(CapabilityAtom name); - +bool isStageAtom(CapabilityName name, CapabilityName& outCanonicalStage); /// Return true if `name` represents an atom for a target version, e.g. spirv_1_5. bool isTargetVersionAtom(CapabilityAtom name); diff --git a/source/slang/slang-check-decl.cpp b/source/slang/slang-check-decl.cpp index ab3335bfb..39ab41421 100644 --- a/source/slang/slang-check-decl.cpp +++ b/source/slang/slang-check-decl.cpp @@ -12735,14 +12735,26 @@ struct CapabilityDeclReferenceVisitor std::swap(stmt->targetCases[i], stmt->targetCases[i + 1]); continue; } - - if (!maybeRequireCapability) - targetCap = (CapabilitySet(CapabilityName::any_target) - .getTargetsThisHasButOtherDoesNot(set)); + if (as<StageSwitchStmt>(stmt)) + { + if (!maybeRequireCapability) + targetCap = (CapabilitySet(CapabilityName::any_target) + .getStagesThisHasButOtherDoesNot(set)); + else + targetCap = + (maybeRequireCapability->capabilitySet.getStagesThisHasButOtherDoesNot( + set)); + } else - targetCap = - (maybeRequireCapability->capabilitySet.getTargetsThisHasButOtherDoesNot( - set)); + { + if (!maybeRequireCapability) + targetCap = (CapabilitySet(CapabilityName::any_target) + .getTargetsThisHasButOtherDoesNot(set)); + else + targetCap = + (maybeRequireCapability->capabilitySet.getTargetsThisHasButOtherDoesNot( + set)); + } } else { diff --git a/source/slang/slang-check-stmt.cpp b/source/slang/slang-check-stmt.cpp index 151c9324a..db6f00d23 100644 --- a/source/slang/slang-check-stmt.cpp +++ b/source/slang/slang-check-stmt.cpp @@ -367,6 +367,40 @@ void SemanticsStmtVisitor::visitTargetSwitchStmt(TargetSwitchStmt* stmt) HashSet<Stmt*> checkedStmt; for (auto caseStmt : stmt->targetCases) { + CapabilitySet set((CapabilityName)caseStmt->capability); + + CapabilityName canonicalStage = CapabilityName::Invalid; + bool isStage = isStageAtom((CapabilityName)caseStmt->capability, canonicalStage); + if (as<StageSwitchStmt>(stmt)) + { + if (!isStage && caseStmt->capability != 0) + { + getSink()->diagnose( + caseStmt->capabilityToken.loc, + Diagnostics::unknownStageName, + caseStmt->capabilityToken); + } + caseStmt->capability = (int)canonicalStage; + } + else + { + if (isStage) + { + getSink()->diagnose( + caseStmt->capabilityToken.loc, + Diagnostics::targetSwitchCaseCannotBeAStage); + } + else if ( + caseStmt->capabilityToken.getContentLength() != 0 && + (set.getCapabilityTargetSets().getCount() != 1 || set.isInvalid() || set.isEmpty())) + { + getSink()->diagnose( + caseStmt->capabilityToken.loc, + Diagnostics::invalidTargetSwitchCase, + capabilityNameToString((CapabilityName)caseStmt->capability)); + } + } + if (checkedStmt.contains(caseStmt->body)) continue; subContext.checkStmt(caseStmt); @@ -377,7 +411,6 @@ void SemanticsStmtVisitor::visitTargetSwitchStmt(TargetSwitchStmt* stmt) void SemanticsStmtVisitor::visitTargetCaseStmt(TargetCaseStmt* stmt) { auto switchStmt = FindOuterStmt<TargetSwitchStmt>(); - CapabilitySet set((CapabilityName)stmt->capability); if (getShared()->isInLanguageServer() && getShared()->getSession()->getCompletionRequestTokenName() == stmt->capabilityToken.getName()) @@ -385,15 +418,6 @@ void SemanticsStmtVisitor::visitTargetCaseStmt(TargetCaseStmt* stmt) getShared()->getLinkage()->contentAssistInfo.completionSuggestions.scopeKind = CompletionSuggestions::ScopeKind::Capabilities; } - - if (stmt->capabilityToken.getContentLength() != 0 && - (set.getCapabilityTargetSets().getCount() != 1 || set.isInvalid() || set.isEmpty())) - { - getSink()->diagnose( - stmt->capabilityToken.loc, - Diagnostics::invalidTargetSwitchCase, - capabilityNameToString((CapabilityName)stmt->capability)); - } if (!switchStmt) { getSink()->diagnose(stmt, Diagnostics::caseOutsideSwitch); diff --git a/source/slang/slang-diagnostic-defs.h b/source/slang/slang-diagnostic-defs.h index a1768415e..59a1bbdb6 100644 --- a/source/slang/slang-diagnostic-defs.h +++ b/source/slang/slang-diagnostic-defs.h @@ -561,6 +561,13 @@ DIAGNOSTIC( Error, spirvUndefinedId, "SPIRV id '%$0' is not defined in the current assembly block location") + +DIAGNOSTIC( + 29115, + Error, + targetSwitchCaseCannotBeAStage, + "cannot use a stage name in '__target_switch', use '__stage_switch' for stage-specific code.") + // // 3xxxx - Semantic analysis // diff --git a/source/slang/slang-emit.cpp b/source/slang/slang-emit.cpp index 58376bbc1..f093104bd 100644 --- a/source/slang/slang-emit.cpp +++ b/source/slang/slang-emit.cpp @@ -90,6 +90,7 @@ #include "slang-ir-specialize-buffer-load-arg.h" #include "slang-ir-specialize-matrix-layout.h" #include "slang-ir-specialize-resources.h" +#include "slang-ir-specialize-stage-switch.h" #include "slang-ir-specialize.h" #include "slang-ir-ssa-simplification.h" #include "slang-ir-ssa.h" @@ -322,6 +323,7 @@ struct RequiredLoweringPassSet bool dynamicResource; bool dynamicResourceHeap; bool resolveVaryingInputRef; + bool specializeStageSwitch; }; // Scan the IR module and determine which lowering/legalization passes are needed based @@ -444,6 +446,9 @@ void calcRequiredLoweringPassSet( case kIROp_ResolveVaryingInputRef: result.resolveVaryingInputRef = true; break; + case kIROp_GetCurrentStage: + result.specializeStageSwitch = true; + break; } if (!result.generics || !result.existentialTypeLayout) { @@ -1027,6 +1032,10 @@ Result linkAndOptimizeIR( cleanupGenerics(targetProgram, irModule, sink); dumpIRIfEnabled(codeGenContext, irModule, "AFTER-LOWER-GENERICS"); + // After dynamic dispatch logic is resolved into ordinary function calls, + // we can now run our stage specialization logic. + if (requiredLoweringPassSet.specializeStageSwitch) + specializeStageSwitch(irModule); if (sink->getErrorCount() != 0) return SLANG_FAIL; #if 0 diff --git a/source/slang/slang-ir-autodiff-fwd.cpp b/source/slang/slang-ir-autodiff-fwd.cpp index 19bfa5e72..6fc7d56ad 100644 --- a/source/slang/slang-ir-autodiff-fwd.cpp +++ b/source/slang/slang-ir-autodiff-fwd.cpp @@ -2177,6 +2177,7 @@ InstPair ForwardDiffTranscriber::transcribeInstImpl(IRBuilder* builder, IRInst* case kIROp_Printf: case kIROp_MakeCoopVector: case kIROp_MakeCoopVectorFromValuePack: + case kIROp_GetCurrentStage: return transcribeNonDiffInst(builder, origInst); // A call to createDynamicObject<T>(arbitraryData) cannot provide a diff value, diff --git a/source/slang/slang-ir-inst-defs.h b/source/slang/slang-ir-inst-defs.h index 59a6852b1..4de7457a3 100644 --- a/source/slang/slang-ir-inst-defs.h +++ b/source/slang/slang-ir-inst-defs.h @@ -410,6 +410,9 @@ INST(InterfaceRequirementEntry, interface_req_entry, 2, GLOBAL) // We will materialize this inst during `translateGLSLGlobalVar`. INST(GetWorkGroupSize, GetWorkGroupSize, 0, HOISTABLE) +// An inst that returns the current stage of the calling entry point. +INST(GetCurrentStage, GetCurrentStage, 0, 0) + INST(Param, param, 0, 0) INST(StructField, field, 2, 0) INST(Var, var, 0, 0) diff --git a/source/slang/slang-ir-insts.h b/source/slang/slang-ir-insts.h index fcafe4bc6..3249f34b0 100644 --- a/source/slang/slang-ir-insts.h +++ b/source/slang/slang-ir-insts.h @@ -2424,7 +2424,7 @@ struct IRCall : IRInst IR_LEAF_ISA(Call) IRInst* getCallee() { return getOperand(0); } - + IRUse* getCalleeUse() { return getOperands(); } UInt getArgCount() { return getOperandCount() - 1; } IRUse* getArgs() { return getOperands() + 1; } IROperandList<IRInst> getArgsList() @@ -3881,6 +3881,8 @@ public: // its rate, if any. void setDataType(IRInst* inst, IRType* dataType); + IRInst* emitGetCurrentStage(); + /// Extract the value wrapped inside an existential box. IRInst* emitGetValueFromBoundInterface(IRType* type, IRInst* boundInterfaceValue); diff --git a/source/slang/slang-ir-specialize-stage-switch.cpp b/source/slang/slang-ir-specialize-stage-switch.cpp new file mode 100644 index 000000000..f65aa4d4c --- /dev/null +++ b/source/slang/slang-ir-specialize-stage-switch.cpp @@ -0,0 +1,198 @@ +#include "slang-ir-specialize-stage-switch.h" + +#include "slang-capability.h" +#include "slang-compiler.h" +#include "slang-ir-call-graph.h" +#include "slang-ir-insts.h" +#include "slang-ir-util.h" +#include "slang-ir.h" + +namespace Slang +{ +bool funcHasGetCurrentStageInst(IRGlobalValueWithCode* func) +{ + for (auto block : func->getBlocks()) + { + for (auto inst : block->getChildren()) + { + if (inst->getOp() == kIROp_GetCurrentStage) + { + return true; + } + } + } + return false; +} + +void discoverStageSpecificFunctions(HashSet<IRInst*>& stageSpecificFunctions, IRModule* module) +{ + List<IRInst*> workList; + for (auto inst : module->getGlobalInsts()) + { + if (auto func = as<IRGlobalValueWithCode>(inst)) + { + if (funcHasGetCurrentStageInst(func)) + { + workList.add(inst); + stageSpecificFunctions.add(func); + } + } + } + for (Index i = 0; i < workList.getCount(); i++) + { + auto callee = workList[i]; + traverseUses( + callee, + [&](IRUse* use) + { + if (use->getUser()->getOp() == kIROp_Call) + { + auto parentFunc = getParentFunc(use->getUser()); + if (parentFunc && stageSpecificFunctions.add(parentFunc)) + { + workList.add(parentFunc); + } + } + }); + } +} + +// Given a func, replace all `GetCurrentStage` insts with the given stage, and rewrite all calls to +// stage specific functions to the specialized function for the given stage. +// +void specializeFuncToStage( + Stage stage, + IRGlobalValueWithCode* func, + Dictionary<IRInst*, Dictionary<Stage, IRInst*>>& mapFuncToStageSpecializedFunc) +{ + // Collect all insts that may need to be modified. + List<IRInst*> instsToModify; + for (auto block : func->getBlocks()) + { + for (auto inst : block->getChildren()) + { + switch (inst->getOp()) + { + case kIROp_GetCurrentStage: + case kIROp_Call: + instsToModify.add(inst); + break; + } + } + } + + IRInst* stageVal = nullptr; + IRBuilder builder(func); + for (auto inst : instsToModify) + { + builder.setInsertBefore(inst); + + switch (inst->getOp()) + { + case kIROp_GetCurrentStage: + { + // Replace `GetCurrentStage` with the stage it is specialized to. + if (!stageVal) + { + stageVal = builder.getIntValue((IRIntegerValue)stage); + } + inst->replaceUsesWith(stageVal); + inst->removeAndDeallocate(); + break; + } + case kIROp_Call: + { + // Replace calls to stage specific functions with the specialized function for the + // given stage. + auto callInst = static_cast<IRCall*>(inst); + auto callee = callInst->getCallee(); + auto specializedFuncs = mapFuncToStageSpecializedFunc.tryGetValue(callee); + if (specializedFuncs) + { + auto specializedFunc = specializedFuncs->tryGetValue(stage); + if (specializedFunc) + { + builder.replaceOperand(callInst->getCalleeUse(), *specializedFunc); + } + } + break; + } + } + } +} + +void specializeStageSwitch(IRModule* module) +{ + Dictionary<IRInst*, HashSet<IRFunc*>> mapInstToReferencingEntryPoints; + buildEntryPointReferenceGraph(mapInstToReferencingEntryPoints, module); + + HashSet<IRInst*> stageSpecificFunctions; + discoverStageSpecificFunctions(stageSpecificFunctions, module); + + // Clone all stage specific functions for each stage they are used in. + Dictionary<IRInst*, Dictionary<Stage, IRInst*>> mapFuncToStageSpecializedFunc; + for (auto func : stageSpecificFunctions) + { + auto referencingEntryPoints = mapInstToReferencingEntryPoints.tryGetValue(func); + if (!referencingEntryPoints) + continue; + if (func->findDecoration<IREntryPointDecoration>()) + continue; + Dictionary<Stage, IRInst*> specializedFuncs; + for (auto entryPoint : *referencingEntryPoints) + { + auto entryPointDecor = entryPoint->findDecoration<IREntryPointDecoration>(); + if (!entryPointDecor) + continue; + auto stage = entryPointDecor->getProfile().getStage(); + auto stageSpecializedFunc = specializedFuncs.tryGetValue(stage); + if (stageSpecializedFunc) + continue; + IRCloneEnv cloneEnv; + IRBuilder builder(func); + builder.setInsertBefore(func); + auto clonedFunc = cloneInst(&cloneEnv, &builder, func); + specializedFuncs[stage] = clonedFunc; + } + mapFuncToStageSpecializedFunc.add(func, _Move(specializedFuncs)); + } + + // Rewrite entrypoint and cloned functions to replace `GetCurrentStage` with the stage they are + // specialized to. + for (auto func : stageSpecificFunctions) + { + // Is this an entrypoint? + if (auto entryPointDecor = func->findDecoration<IREntryPointDecoration>()) + { + auto stage = entryPointDecor->getProfile().getStage(); + specializeFuncToStage( + stage, + as<IRGlobalValueWithCode>(func), + mapFuncToStageSpecializedFunc); + } + else + { + // Is this a cloned function? + auto specializedFuncs = mapFuncToStageSpecializedFunc.tryGetValue(func); + if (!specializedFuncs) + continue; + for (auto pair : *specializedFuncs) + { + auto stage = pair.first; + auto specializedFunc = pair.second; + specializeFuncToStage( + stage, + as<IRGlobalValueWithCode>(specializedFunc), + mapFuncToStageSpecializedFunc); + } + } + } + + // Remove all original stage specific functions. + for (auto f : mapFuncToStageSpecializedFunc) + { + f.first->removeAndDeallocate(); + } +} + +} // namespace Slang diff --git a/source/slang/slang-ir-specialize-stage-switch.h b/source/slang/slang-ir-specialize-stage-switch.h new file mode 100644 index 000000000..54d967e5d --- /dev/null +++ b/source/slang/slang-ir-specialize-stage-switch.h @@ -0,0 +1,14 @@ +#ifndef SLANG_IR_SPECIALIZE_STAGE_SWITCH_H +#define SLANG_IR_SPECIALIZE_STAGE_SWITCH_H + +namespace Slang +{ +struct IRModule; + +// Repalce all stage_switch insts with the case that matches current calling entrypoint. +// +void specializeStageSwitch(IRModule* module); + +} // namespace Slang + +#endif diff --git a/source/slang/slang-ir.cpp b/source/slang/slang-ir.cpp index 3314567f1..6cf7a1786 100644 --- a/source/slang/slang-ir.cpp +++ b/source/slang/slang-ir.cpp @@ -3209,6 +3209,11 @@ void IRBuilder::setDataType(IRInst* inst, IRType* dataType) } } +IRInst* IRBuilder::emitGetCurrentStage() +{ + return emitIntrinsicInst(getIntType(), kIROp_GetCurrentStage, 0, nullptr); +} + IRInst* IRBuilder::emitGetValueFromBoundInterface(IRType* type, IRInst* boundInterfaceValue) { auto inst = @@ -8301,6 +8306,7 @@ bool IRInst::mightHaveSideEffects(SideEffectAnalysisOptions options) case kIROp_ResolveVaryingInputRef: case kIROp_GetPerVertexInputArray: case kIROp_MetalCastToDepthTexture: + case kIROp_GetCurrentStage: return false; case kIROp_ForwardDifferentiate: diff --git a/source/slang/slang-lower-to-ir.cpp b/source/slang/slang-lower-to-ir.cpp index 36003dcb2..06cdf430b 100644 --- a/source/slang/slang-lower-to-ir.cpp +++ b/source/slang/slang-lower-to-ir.cpp @@ -6719,6 +6719,102 @@ struct StmtLoweringVisitor : StmtVisitor<StmtLoweringVisitor> } } + void visitStageSwitchStmt(StageSwitchStmt* stmt) + { + if (!stmt->targetCases.getCount()) + return; + + // We will lower stage switch as a normal switch statement, so they can participate in all + // optimizations. + auto builder = getBuilder(); + startBlockIfNeeded(stmt); + + // First emit code to get the current stage to switch on: + auto conditionVal = builder->emitGetCurrentStage(); + + // Remember the initial block so that we can add to it + // after we've collected all the `case`s + auto initialBlock = builder->getBlock(); + + // Next, create a block to use as the target for any `break` statements + auto breakLabel = createBlock(); + + // Register the `break` label so + // that we can find it for nested statements. + context->shared->breakLabels.add(stmt, breakLabel); + + builder->setInsertInto(initialBlock->getParent()); + + // Iterate over the body of the statement, looking + // for `case` or `default` statements: + SwitchStmtInfo info; + info.initialBlock = initialBlock; + info.defaultLabel = nullptr; + + Dictionary<Stmt*, IRBlock*> mapCaseStmtToBlock; + for (auto targetCase : stmt->targetCases) + { + IRBlock* caseBlock = nullptr; + if (!mapCaseStmtToBlock.tryGetValue(targetCase->body, caseBlock)) + { + caseBlock = builder->emitBlock(); + lowerStmt(context, targetCase->body); + mapCaseStmtToBlock.add(targetCase->body, caseBlock); + if (!builder->getBlock()->getTerminator()) + builder->emitBranch(breakLabel); + } + if (targetCase->capability == 0) + { + info.defaultLabel = caseBlock; + } + else + { + auto stage = getStageFromAtom((CapabilityAtom)targetCase->capability); + info.cases.add(builder->getIntValue(builder->getIntType(), (IRIntegerValue)stage)); + info.cases.add(caseBlock); + } + } + + // If the current block (the end of the last + // `case`) is not terminated, then terminate with a + // `break` operation. + // + // Double check that we aren't in the initial + // block, so we don't get tripped up on an + // empty `switch`. + auto curBlock = builder->getBlock(); + if (curBlock != initialBlock) + { + // Is the block already terminated? + if (!curBlock->getTerminator()) + { + // Not terminated, so add one. + builder->emitBreak(breakLabel); + } + } + + // If there was no `default` statement, then the + // default case will just branch directly to the end. + auto defaultLabel = info.defaultLabel ? info.defaultLabel : breakLabel; + + // Now that we've collected the cases, we are + // prepared to emit the `switch` instruction + // itself. + builder->setInsertInto(initialBlock); + builder->emitSwitch( + conditionVal, + breakLabel, + defaultLabel, + info.cases.getCount(), + info.cases.getBuffer()); + + // Finally we insert the label that a `break` will jump to + // (and that control flow will fall through to otherwise). + // This is the block that subsequent code will go into. + insertBlock(breakLabel); + context->shared->breakLabels.remove(stmt); + } + void visitTargetSwitchStmt(TargetSwitchStmt* stmt) { if (!stmt->targetCases.getCount()) diff --git a/source/slang/slang-parser.cpp b/source/slang/slang-parser.cpp index 95eb971ee..38285c41f 100644 --- a/source/slang/slang-parser.cpp +++ b/source/slang/slang-parser.cpp @@ -5380,9 +5380,8 @@ static Stmt* ParseDefaultStmt(Parser* parser) return stmt; } -static Stmt* parseTargetSwitchStmt(Parser* parser) +static Stmt* parseTargetSwitchStmtImpl(Parser* parser, TargetSwitchStmt* stmt) { - TargetSwitchStmt* stmt = parser->astBuilder->create<TargetSwitchStmt>(); parser->FillPosition(stmt); parser->ReadToken(); if (!beginMatch(parser, MatchedTokenType::CurlyBraces)) @@ -5479,6 +5478,18 @@ static Stmt* parseTargetSwitchStmt(Parser* parser) return stmt; } +static Stmt* parseTargetSwitchStmt(Parser* parser) +{ + auto stmt = parser->astBuilder->create<TargetSwitchStmt>(); + return parseTargetSwitchStmtImpl(parser, stmt); +} + +static Stmt* parseStageSwitchStmt(Parser* parser) +{ + auto stmt = parser->astBuilder->create<StageSwitchStmt>(); + return parseTargetSwitchStmtImpl(parser, stmt); +} + static Stmt* parseIntrinsicAsmStmt(Parser* parser) { IntrinsicAsmStmt* stmt = parser->astBuilder->create<IntrinsicAsmStmt>(); @@ -5725,6 +5736,8 @@ Stmt* Parser::ParseStatement(Stmt* parentStmt) statement = ParseSwitchStmt(this); else if (LookAheadToken("__target_switch")) statement = parseTargetSwitchStmt(this); + else if (LookAheadToken("__stage_switch")) + statement = parseStageSwitchStmt(this); else if (LookAheadToken("__intrinsic_asm")) statement = parseIntrinsicAsmStmt(this); else if (LookAheadToken("case")) |
