summaryrefslogtreecommitdiffstats
path: root/source/slang/slang-ir-validate.cpp
diff options
context:
space:
mode:
authorYong He <yonghe@outlook.com>2023-01-11 15:33:28 -0800
committerGitHub <noreply@github.com>2023-01-11 15:33:28 -0800
commita3ac6e71cbc922b7c941c45f23ee18a9fc274d1f (patch)
treeacf8c18601f124e9290494f8b379d2420369fc35 /source/slang/slang-ir-validate.cpp
parent20262684bcbb707d16669b2670039df870b65ca8 (diff)
Make backward differentiation work with generics. (#2586)
* Make backward differentiation work with generics. * Fix. * Another fix. * More fix. Co-authored-by: Yong He <yhe@nvidia.com>
Diffstat (limited to 'source/slang/slang-ir-validate.cpp')
-rw-r--r--source/slang/slang-ir-validate.cpp43
1 files changed, 41 insertions, 2 deletions
diff --git a/source/slang/slang-ir-validate.cpp b/source/slang/slang-ir-validate.cpp
index 46817e212..a49eda322 100644
--- a/source/slang/slang-ir-validate.cpp
+++ b/source/slang/slang-ir-validate.cpp
@@ -29,7 +29,14 @@ namespace Slang
{
if (!condition)
{
- context->getSink()->diagnose(inst, Diagnostics::irValidationFailed, message);
+ if (context)
+ {
+ context->getSink()->diagnose(inst, Diagnostics::irValidationFailed, message);
+ }
+ else
+ {
+ SLANG_ASSERT_FAILURE("IR validation failed");
+ }
}
}
@@ -143,7 +150,10 @@ namespace Slang
// If `operandValue` precedes `inst`, then we should
// have already seen it, because we scan parent instructions
// in order.
- validate(context, context->seenInsts.Contains(operandValue), inst, "def must come before use in same block");
+ if (context)
+ {
+ validate(context, context->seenInsts.Contains(operandValue), inst, "def must come before use in same block");
+ }
return;
}
@@ -196,6 +206,34 @@ namespace Slang
}
}
+ static thread_local bool _enableIRValidationAtInsert = false;
+ void disableIRValidationAtInsert()
+ {
+ _enableIRValidationAtInsert = false;
+ }
+ void enableIRValidationAtInsert()
+ {
+ _enableIRValidationAtInsert = true;
+ }
+ void validateIRInstOperands(IRInst* inst)
+ {
+ if (!_enableIRValidationAtInsert)
+ return;
+ switch (inst->getOp())
+ {
+ case kIROp_loop:
+ case kIROp_ifElse:
+ case kIROp_unconditionalBranch:
+ case kIROp_conditionalBranch:
+ case kIROp_Switch:
+ return;
+ default:
+ break;
+ }
+
+ validateIRInstOperands(nullptr, inst);
+ }
+
void validateCodeBody(IRValidateContext* context, IRGlobalValueWithCode* code)
{
HashSet<IRBlock*> blocks;
@@ -296,4 +334,5 @@ namespace Slang
auto sink = codeGenContext->getSink();
validateIRModule(module, sink);
}
+
}