summaryrefslogtreecommitdiffstats
path: root/source/slang/slang-ir-check-recursion.cpp
diff options
context:
space:
mode:
authorDarren Wihandi <65404740+fairywreath@users.noreply.github.com>2024-12-05 20:09:40 -0500
committerGitHub <noreply@github.com>2024-12-05 17:09:40 -0800
commitecc5a39feecbf73feedf352214406c8752af798a (patch)
treedcb65f45cbca2dae0ffe0cf7b68b4e4f5f410f46 /source/slang/slang-ir-check-recursion.cpp
parentd4136c93448bfdd8561af331ea6eebcec14719e3 (diff)
Do recursive function checks early during IR linking (#5777)
Diffstat (limited to 'source/slang/slang-ir-check-recursion.cpp')
-rw-r--r--source/slang/slang-ir-check-recursion.cpp126
1 files changed, 126 insertions, 0 deletions
diff --git a/source/slang/slang-ir-check-recursion.cpp b/source/slang/slang-ir-check-recursion.cpp
new file mode 100644
index 000000000..404437a46
--- /dev/null
+++ b/source/slang/slang-ir-check-recursion.cpp
@@ -0,0 +1,126 @@
+#include "slang-ir-check-recursion.h"
+
+#include "slang-ir-util.h"
+
+namespace Slang
+{
+bool checkTypeRecursionImpl(
+ HashSet<IRInst*>& checkedTypes,
+ HashSet<IRInst*>& stack,
+ IRInst* type,
+ IRInst* field,
+ DiagnosticSink* sink)
+{
+ auto visitElementType = [&](IRInst* elementType, IRInst* field) -> bool
+ {
+ if (!stack.add(elementType))
+ {
+ sink->diagnose(field ? field : type, Diagnostics::recursiveType, type);
+ return false;
+ }
+ if (checkedTypes.add(elementType))
+ checkTypeRecursionImpl(checkedTypes, stack, elementType, field, sink);
+ stack.remove(elementType);
+ return true;
+ };
+ if (auto arrayType = as<IRArrayTypeBase>(type))
+ {
+ return visitElementType(arrayType->getElementType(), field);
+ }
+ else if (auto structType = as<IRStructType>(type))
+ {
+ for (auto sfield : structType->getFields())
+ if (!visitElementType(sfield->getFieldType(), sfield))
+ return false;
+ }
+ return true;
+}
+
+void checkTypeRecursion(HashSet<IRInst*>& checkedTypes, IRInst* type, DiagnosticSink* sink)
+{
+ HashSet<IRInst*> stack;
+ if (checkedTypes.add(type))
+ {
+ stack.add(type);
+ checkTypeRecursionImpl(checkedTypes, stack, type, nullptr, sink);
+ }
+}
+
+void checkForRecursiveTypes(IRModule* module, DiagnosticSink* sink)
+{
+ HashSet<IRInst*> checkedTypes;
+ for (auto globalInst : module->getGlobalInsts())
+ {
+ switch (globalInst->getOp())
+ {
+ case kIROp_StructType:
+ {
+ checkTypeRecursion(checkedTypes, globalInst, sink);
+ }
+ break;
+ default:
+ break;
+ }
+ }
+}
+
+bool checkFunctionRecursionImpl(
+ HashSet<IRFunc*>& checkedFuncs,
+ HashSet<IRFunc*>& callStack,
+ IRFunc* func,
+ DiagnosticSink* sink)
+{
+ for (auto block : func->getBlocks())
+ {
+ for (auto inst : block->getChildren())
+ {
+ auto callInst = as<IRCall>(inst);
+ if (!callInst)
+ continue;
+ auto callee = as<IRFunc>(callInst->getCallee());
+ if (!callee)
+ continue;
+ if (!callStack.add(callee))
+ {
+ sink->diagnose(callInst, Diagnostics::unsupportedRecursion, callee);
+ return false;
+ }
+ if (checkedFuncs.add(callee))
+ checkFunctionRecursionImpl(checkedFuncs, callStack, callee, sink);
+ callStack.remove(callee);
+ }
+ }
+ return true;
+}
+
+void checkFunctionRecursion(HashSet<IRFunc*>& checkedFuncs, IRFunc* func, DiagnosticSink* sink)
+{
+ HashSet<IRFunc*> callStack;
+ if (checkedFuncs.add(func))
+ {
+ callStack.add(func);
+ checkFunctionRecursionImpl(checkedFuncs, callStack, func, sink);
+ }
+}
+
+void checkForRecursiveFunctions(TargetRequest* target, IRModule* module, DiagnosticSink* sink)
+{
+ HashSet<IRFunc*> checkedFuncsForRecursionDetection;
+ for (auto globalInst : module->getGlobalInsts())
+ {
+ switch (globalInst->getOp())
+ {
+ case kIROp_Func:
+ if (!isCPUTarget(target))
+ checkFunctionRecursion(
+ checkedFuncsForRecursionDetection,
+ as<IRFunc>(globalInst),
+ sink);
+ break;
+ default:
+ break;
+ }
+ }
+}
+
+} // namespace Slang