summaryrefslogtreecommitdiff
path: root/source/slang
diff options
context:
space:
mode:
Diffstat (limited to 'source/slang')
-rw-r--r--source/slang/slang-emit.cpp7
-rw-r--r--source/slang/slang-ir-autodiff-cfg-norm.cpp7
-rw-r--r--source/slang/slang-ir-autodiff-fwd.cpp3
-rw-r--r--source/slang/slang-ir-autodiff-unzip.h3
-rw-r--r--source/slang/slang-ir-validate.cpp13
-rw-r--r--source/slang/slang-ir-validate.h34
6 files changed, 51 insertions, 16 deletions
diff --git a/source/slang/slang-emit.cpp b/source/slang/slang-emit.cpp
index c0c1ba75a..f4d535466 100644
--- a/source/slang/slang-emit.cpp
+++ b/source/slang/slang-emit.cpp
@@ -950,9 +950,10 @@ Result linkAndOptimizeIR(
if (requiredLoweringPassSet.autodiff)
{
dumpIRIfEnabled(codeGenContext, irModule, "BEFORE-AUTODIFF");
- enableIRValidationAtInsert();
- changed |= processAutodiffCalls(targetProgram, irModule, sink);
- disableIRValidationAtInsert();
+ {
+ auto validationScope = enableIRValidationScope();
+ changed |= processAutodiffCalls(targetProgram, irModule, sink);
+ }
dumpIRIfEnabled(codeGenContext, irModule, "AFTER-AUTODIFF");
}
diff --git a/source/slang/slang-ir-autodiff-cfg-norm.cpp b/source/slang/slang-ir-autodiff-cfg-norm.cpp
index 30e832719..720fdece2 100644
--- a/source/slang/slang-ir-autodiff-cfg-norm.cpp
+++ b/source/slang/slang-ir-autodiff-cfg-norm.cpp
@@ -755,9 +755,10 @@ void normalizeCFG(
sortBlocksInFunc(func);
legalizeDefUse(func);
- disableIRValidationAtInsert();
- constructSSA(module, func);
- enableIRValidationAtInsert();
+ {
+ auto validationScope = disableIRValidationScope();
+ constructSSA(module, func);
+ }
module->invalidateAnalysisForInst(func);
#if _DEBUG
diff --git a/source/slang/slang-ir-autodiff-fwd.cpp b/source/slang/slang-ir-autodiff-fwd.cpp
index 003790793..a2ee2bdf9 100644
--- a/source/slang/slang-ir-autodiff-fwd.cpp
+++ b/source/slang/slang-ir-autodiff-fwd.cpp
@@ -1900,12 +1900,11 @@ SlangResult ForwardDiffTranscriber::prepareFuncForForwardDiff(IRFunc* func)
if (SLANG_SUCCEEDED(result))
{
- disableIRValidationAtInsert();
+ auto validationScope = disableIRValidationScope();
auto simplifyOptions = IRSimplificationOptions::getDefault(nullptr);
simplifyOptions.removeRedundancy = true;
simplifyOptions.hoistLoopInvariantInsts = true;
simplifyFunc(autoDiffSharedContext->targetProgram, func, simplifyOptions);
- enableIRValidationAtInsert();
}
return result;
}
diff --git a/source/slang/slang-ir-autodiff-unzip.h b/source/slang/slang-ir-autodiff-unzip.h
index ec435ee87..80b2038aa 100644
--- a/source/slang/slang-ir-autodiff-unzip.h
+++ b/source/slang/slang-ir-autodiff-unzip.h
@@ -436,9 +436,8 @@ struct DiffUnzipPass
if (intermediateVar)
{
- disableIRValidationAtInsert();
+ auto validationScope = disableIRValidationScope();
diffBuilder->addBackwardDerivativePrimalContextDecoration(callInst, intermediateVar);
- enableIRValidationAtInsert();
}
IRInst* diffVal = nullptr;
diff --git a/source/slang/slang-ir-validate.cpp b/source/slang/slang-ir-validate.cpp
index 565ae97d8..b3d6504ab 100644
--- a/source/slang/slang-ir-validate.cpp
+++ b/source/slang/slang-ir-validate.cpp
@@ -273,14 +273,19 @@ void validateIRInstOperands(IRValidateContext* context, IRInst* inst)
}
static thread_local bool _enableIRValidationAtInsert = false;
-void disableIRValidationAtInsert()
+
+// RAII class implementation for exception-safe IR validation state management
+IRValidationScope::IRValidationScope(bool enableValidation)
+ : m_previousState(_enableIRValidationAtInsert)
{
- _enableIRValidationAtInsert = false;
+ _enableIRValidationAtInsert = enableValidation;
}
-void enableIRValidationAtInsert()
+
+IRValidationScope::~IRValidationScope()
{
- _enableIRValidationAtInsert = true;
+ _enableIRValidationAtInsert = m_previousState;
}
+
void validateIRInstOperands(IRInst* inst)
{
if (!_enableIRValidationAtInsert)
diff --git a/source/slang/slang-ir-validate.h b/source/slang/slang-ir-validate.h
index 722359452..7fc882f37 100644
--- a/source/slang/slang-ir-validate.h
+++ b/source/slang/slang-ir-validate.h
@@ -36,8 +36,38 @@ void validateIRModuleIfEnabled(CompileRequestBase* compileRequest, IRModule* mod
void validateIRModuleIfEnabled(CodeGenContext* codeGenContext, IRModule* module);
-void disableIRValidationAtInsert();
-void enableIRValidationAtInsert();
+// RAII class to manage IR validation state in an exception-safe manner
+class [[nodiscard]] IRValidationScope
+{
+public:
+ // Constructor saves current state and sets new state
+ explicit IRValidationScope(bool enableValidation);
+
+ // Destructor automatically restores previous state
+ ~IRValidationScope();
+
+ // Non-copyable to prevent accidental copies
+ IRValidationScope(const IRValidationScope&) = delete;
+ IRValidationScope& operator=(const IRValidationScope&) = delete;
+
+ // Non-movable to keep it simple
+ IRValidationScope(IRValidationScope&&) = delete;
+ IRValidationScope& operator=(IRValidationScope&&) = delete;
+
+private:
+ bool m_previousState;
+};
+
+// Convenience functions to create scoped guards
+[[nodiscard]] inline IRValidationScope enableIRValidationScope()
+{
+ return IRValidationScope(true);
+}
+
+[[nodiscard]] inline IRValidationScope disableIRValidationScope()
+{
+ return IRValidationScope(false);
+}
// Validate that the destination of an atomic operation is appropriate, meaning it's
// either 'groupshared' or in a device buffer.