diff options
Diffstat (limited to 'source/slang')
| -rw-r--r-- | source/slang/slang-emit.cpp | 7 | ||||
| -rw-r--r-- | source/slang/slang-ir-autodiff-cfg-norm.cpp | 7 | ||||
| -rw-r--r-- | source/slang/slang-ir-autodiff-fwd.cpp | 3 | ||||
| -rw-r--r-- | source/slang/slang-ir-autodiff-unzip.h | 3 | ||||
| -rw-r--r-- | source/slang/slang-ir-validate.cpp | 13 | ||||
| -rw-r--r-- | source/slang/slang-ir-validate.h | 34 |
6 files changed, 51 insertions, 16 deletions
diff --git a/source/slang/slang-emit.cpp b/source/slang/slang-emit.cpp index c0c1ba75a..f4d535466 100644 --- a/source/slang/slang-emit.cpp +++ b/source/slang/slang-emit.cpp @@ -950,9 +950,10 @@ Result linkAndOptimizeIR( if (requiredLoweringPassSet.autodiff) { dumpIRIfEnabled(codeGenContext, irModule, "BEFORE-AUTODIFF"); - enableIRValidationAtInsert(); - changed |= processAutodiffCalls(targetProgram, irModule, sink); - disableIRValidationAtInsert(); + { + auto validationScope = enableIRValidationScope(); + changed |= processAutodiffCalls(targetProgram, irModule, sink); + } dumpIRIfEnabled(codeGenContext, irModule, "AFTER-AUTODIFF"); } diff --git a/source/slang/slang-ir-autodiff-cfg-norm.cpp b/source/slang/slang-ir-autodiff-cfg-norm.cpp index 30e832719..720fdece2 100644 --- a/source/slang/slang-ir-autodiff-cfg-norm.cpp +++ b/source/slang/slang-ir-autodiff-cfg-norm.cpp @@ -755,9 +755,10 @@ void normalizeCFG( sortBlocksInFunc(func); legalizeDefUse(func); - disableIRValidationAtInsert(); - constructSSA(module, func); - enableIRValidationAtInsert(); + { + auto validationScope = disableIRValidationScope(); + constructSSA(module, func); + } module->invalidateAnalysisForInst(func); #if _DEBUG diff --git a/source/slang/slang-ir-autodiff-fwd.cpp b/source/slang/slang-ir-autodiff-fwd.cpp index 003790793..a2ee2bdf9 100644 --- a/source/slang/slang-ir-autodiff-fwd.cpp +++ b/source/slang/slang-ir-autodiff-fwd.cpp @@ -1900,12 +1900,11 @@ SlangResult ForwardDiffTranscriber::prepareFuncForForwardDiff(IRFunc* func) if (SLANG_SUCCEEDED(result)) { - disableIRValidationAtInsert(); + auto validationScope = disableIRValidationScope(); auto simplifyOptions = IRSimplificationOptions::getDefault(nullptr); simplifyOptions.removeRedundancy = true; simplifyOptions.hoistLoopInvariantInsts = true; simplifyFunc(autoDiffSharedContext->targetProgram, func, simplifyOptions); - enableIRValidationAtInsert(); } return result; } diff --git a/source/slang/slang-ir-autodiff-unzip.h b/source/slang/slang-ir-autodiff-unzip.h index ec435ee87..80b2038aa 100644 --- a/source/slang/slang-ir-autodiff-unzip.h +++ b/source/slang/slang-ir-autodiff-unzip.h @@ -436,9 +436,8 @@ struct DiffUnzipPass if (intermediateVar) { - disableIRValidationAtInsert(); + auto validationScope = disableIRValidationScope(); diffBuilder->addBackwardDerivativePrimalContextDecoration(callInst, intermediateVar); - enableIRValidationAtInsert(); } IRInst* diffVal = nullptr; diff --git a/source/slang/slang-ir-validate.cpp b/source/slang/slang-ir-validate.cpp index 565ae97d8..b3d6504ab 100644 --- a/source/slang/slang-ir-validate.cpp +++ b/source/slang/slang-ir-validate.cpp @@ -273,14 +273,19 @@ void validateIRInstOperands(IRValidateContext* context, IRInst* inst) } static thread_local bool _enableIRValidationAtInsert = false; -void disableIRValidationAtInsert() + +// RAII class implementation for exception-safe IR validation state management +IRValidationScope::IRValidationScope(bool enableValidation) + : m_previousState(_enableIRValidationAtInsert) { - _enableIRValidationAtInsert = false; + _enableIRValidationAtInsert = enableValidation; } -void enableIRValidationAtInsert() + +IRValidationScope::~IRValidationScope() { - _enableIRValidationAtInsert = true; + _enableIRValidationAtInsert = m_previousState; } + void validateIRInstOperands(IRInst* inst) { if (!_enableIRValidationAtInsert) diff --git a/source/slang/slang-ir-validate.h b/source/slang/slang-ir-validate.h index 722359452..7fc882f37 100644 --- a/source/slang/slang-ir-validate.h +++ b/source/slang/slang-ir-validate.h @@ -36,8 +36,38 @@ void validateIRModuleIfEnabled(CompileRequestBase* compileRequest, IRModule* mod void validateIRModuleIfEnabled(CodeGenContext* codeGenContext, IRModule* module); -void disableIRValidationAtInsert(); -void enableIRValidationAtInsert(); +// RAII class to manage IR validation state in an exception-safe manner +class [[nodiscard]] IRValidationScope +{ +public: + // Constructor saves current state and sets new state + explicit IRValidationScope(bool enableValidation); + + // Destructor automatically restores previous state + ~IRValidationScope(); + + // Non-copyable to prevent accidental copies + IRValidationScope(const IRValidationScope&) = delete; + IRValidationScope& operator=(const IRValidationScope&) = delete; + + // Non-movable to keep it simple + IRValidationScope(IRValidationScope&&) = delete; + IRValidationScope& operator=(IRValidationScope&&) = delete; + +private: + bool m_previousState; +}; + +// Convenience functions to create scoped guards +[[nodiscard]] inline IRValidationScope enableIRValidationScope() +{ + return IRValidationScope(true); +} + +[[nodiscard]] inline IRValidationScope disableIRValidationScope() +{ + return IRValidationScope(false); +} // Validate that the destination of an atomic operation is appropriate, meaning it's // either 'groupshared' or in a device buffer. |
