From ecc5a39feecbf73feedf352214406c8752af798a Mon Sep 17 00:00:00 2001 From: Darren Wihandi <65404740+fairywreath@users.noreply.github.com> Date: Thu, 5 Dec 2024 20:09:40 -0500 Subject: Do recursive function checks early during IR linking (#5777) --- source/slang/slang-ir-check-recursion.cpp | 126 ++++++++++++++++++++++++++++++ 1 file changed, 126 insertions(+) create mode 100644 source/slang/slang-ir-check-recursion.cpp (limited to 'source/slang/slang-ir-check-recursion.cpp') diff --git a/source/slang/slang-ir-check-recursion.cpp b/source/slang/slang-ir-check-recursion.cpp new file mode 100644 index 000000000..404437a46 --- /dev/null +++ b/source/slang/slang-ir-check-recursion.cpp @@ -0,0 +1,126 @@ +#include "slang-ir-check-recursion.h" + +#include "slang-ir-util.h" + +namespace Slang +{ +bool checkTypeRecursionImpl( + HashSet& checkedTypes, + HashSet& stack, + IRInst* type, + IRInst* field, + DiagnosticSink* sink) +{ + auto visitElementType = [&](IRInst* elementType, IRInst* field) -> bool + { + if (!stack.add(elementType)) + { + sink->diagnose(field ? field : type, Diagnostics::recursiveType, type); + return false; + } + if (checkedTypes.add(elementType)) + checkTypeRecursionImpl(checkedTypes, stack, elementType, field, sink); + stack.remove(elementType); + return true; + }; + if (auto arrayType = as(type)) + { + return visitElementType(arrayType->getElementType(), field); + } + else if (auto structType = as(type)) + { + for (auto sfield : structType->getFields()) + if (!visitElementType(sfield->getFieldType(), sfield)) + return false; + } + return true; +} + +void checkTypeRecursion(HashSet& checkedTypes, IRInst* type, DiagnosticSink* sink) +{ + HashSet stack; + if (checkedTypes.add(type)) + { + stack.add(type); + checkTypeRecursionImpl(checkedTypes, stack, type, nullptr, sink); + } +} + +void checkForRecursiveTypes(IRModule* module, DiagnosticSink* sink) +{ + HashSet checkedTypes; + for (auto globalInst : module->getGlobalInsts()) + { + switch (globalInst->getOp()) + { + case kIROp_StructType: + { + checkTypeRecursion(checkedTypes, globalInst, sink); + } + break; + default: + break; + } + } +} + +bool checkFunctionRecursionImpl( + HashSet& checkedFuncs, + HashSet& callStack, + IRFunc* func, + DiagnosticSink* sink) +{ + for (auto block : func->getBlocks()) + { + for (auto inst : block->getChildren()) + { + auto callInst = as(inst); + if (!callInst) + continue; + auto callee = as(callInst->getCallee()); + if (!callee) + continue; + if (!callStack.add(callee)) + { + sink->diagnose(callInst, Diagnostics::unsupportedRecursion, callee); + return false; + } + if (checkedFuncs.add(callee)) + checkFunctionRecursionImpl(checkedFuncs, callStack, callee, sink); + callStack.remove(callee); + } + } + return true; +} + +void checkFunctionRecursion(HashSet& checkedFuncs, IRFunc* func, DiagnosticSink* sink) +{ + HashSet callStack; + if (checkedFuncs.add(func)) + { + callStack.add(func); + checkFunctionRecursionImpl(checkedFuncs, callStack, func, sink); + } +} + +void checkForRecursiveFunctions(TargetRequest* target, IRModule* module, DiagnosticSink* sink) +{ + HashSet checkedFuncsForRecursionDetection; + for (auto globalInst : module->getGlobalInsts()) + { + switch (globalInst->getOp()) + { + case kIROp_Func: + if (!isCPUTarget(target)) + checkFunctionRecursion( + checkedFuncsForRecursionDetection, + as(globalInst), + sink); + break; + default: + break; + } + } +} + +} // namespace Slang -- cgit v1.2.3