summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--source/slang/slang-emit.cpp3
-rw-r--r--source/slang/slang-ir-check-recursion.cpp126
-rw-r--r--source/slang/slang-ir-check-recursion.h (renamed from source/slang/slang-ir-check-recursive-type.h)4
-rw-r--r--source/slang/slang-ir-check-recursive-type.cpp65
-rw-r--r--source/slang/slang-ir-check-unsupported-inst.cpp44
-rw-r--r--source/slang/slang-lower-to-ir.cpp2
6 files changed, 133 insertions, 111 deletions
diff --git a/source/slang/slang-emit.cpp b/source/slang/slang-emit.cpp
index a9d5c5e50..04ad55c1f 100644
--- a/source/slang/slang-emit.cpp
+++ b/source/slang/slang-emit.cpp
@@ -22,7 +22,7 @@
#include "slang-ir-autodiff.h"
#include "slang-ir-bind-existentials.h"
#include "slang-ir-byte-address-legalize.h"
-#include "slang-ir-check-recursive-type.h"
+#include "slang-ir-check-recursion.h"
#include "slang-ir-check-shader-parameter-type.h"
#include "slang-ir-check-unsupported-inst.h"
#include "slang-ir-cleanup-void.h"
@@ -884,6 +884,7 @@ Result linkAndOptimizeIR(
if (targetProgram->getOptionSet().shouldRunNonEssentialValidation())
{
checkForRecursiveTypes(irModule, sink);
+ checkForRecursiveFunctions(codeGenContext->getTargetReq(), irModule, sink);
// For some targets, we are more restrictive about what types are allowed
// to be used as shader parameters in ConstantBuffer/ParameterBlock.
diff --git a/source/slang/slang-ir-check-recursion.cpp b/source/slang/slang-ir-check-recursion.cpp
new file mode 100644
index 000000000..404437a46
--- /dev/null
+++ b/source/slang/slang-ir-check-recursion.cpp
@@ -0,0 +1,126 @@
+#include "slang-ir-check-recursion.h"
+
+#include "slang-ir-util.h"
+
+namespace Slang
+{
+bool checkTypeRecursionImpl(
+ HashSet<IRInst*>& checkedTypes,
+ HashSet<IRInst*>& stack,
+ IRInst* type,
+ IRInst* field,
+ DiagnosticSink* sink)
+{
+ auto visitElementType = [&](IRInst* elementType, IRInst* field) -> bool
+ {
+ if (!stack.add(elementType))
+ {
+ sink->diagnose(field ? field : type, Diagnostics::recursiveType, type);
+ return false;
+ }
+ if (checkedTypes.add(elementType))
+ checkTypeRecursionImpl(checkedTypes, stack, elementType, field, sink);
+ stack.remove(elementType);
+ return true;
+ };
+ if (auto arrayType = as<IRArrayTypeBase>(type))
+ {
+ return visitElementType(arrayType->getElementType(), field);
+ }
+ else if (auto structType = as<IRStructType>(type))
+ {
+ for (auto sfield : structType->getFields())
+ if (!visitElementType(sfield->getFieldType(), sfield))
+ return false;
+ }
+ return true;
+}
+
+void checkTypeRecursion(HashSet<IRInst*>& checkedTypes, IRInst* type, DiagnosticSink* sink)
+{
+ HashSet<IRInst*> stack;
+ if (checkedTypes.add(type))
+ {
+ stack.add(type);
+ checkTypeRecursionImpl(checkedTypes, stack, type, nullptr, sink);
+ }
+}
+
+void checkForRecursiveTypes(IRModule* module, DiagnosticSink* sink)
+{
+ HashSet<IRInst*> checkedTypes;
+ for (auto globalInst : module->getGlobalInsts())
+ {
+ switch (globalInst->getOp())
+ {
+ case kIROp_StructType:
+ {
+ checkTypeRecursion(checkedTypes, globalInst, sink);
+ }
+ break;
+ default:
+ break;
+ }
+ }
+}
+
+bool checkFunctionRecursionImpl(
+ HashSet<IRFunc*>& checkedFuncs,
+ HashSet<IRFunc*>& callStack,
+ IRFunc* func,
+ DiagnosticSink* sink)
+{
+ for (auto block : func->getBlocks())
+ {
+ for (auto inst : block->getChildren())
+ {
+ auto callInst = as<IRCall>(inst);
+ if (!callInst)
+ continue;
+ auto callee = as<IRFunc>(callInst->getCallee());
+ if (!callee)
+ continue;
+ if (!callStack.add(callee))
+ {
+ sink->diagnose(callInst, Diagnostics::unsupportedRecursion, callee);
+ return false;
+ }
+ if (checkedFuncs.add(callee))
+ checkFunctionRecursionImpl(checkedFuncs, callStack, callee, sink);
+ callStack.remove(callee);
+ }
+ }
+ return true;
+}
+
+void checkFunctionRecursion(HashSet<IRFunc*>& checkedFuncs, IRFunc* func, DiagnosticSink* sink)
+{
+ HashSet<IRFunc*> callStack;
+ if (checkedFuncs.add(func))
+ {
+ callStack.add(func);
+ checkFunctionRecursionImpl(checkedFuncs, callStack, func, sink);
+ }
+}
+
+void checkForRecursiveFunctions(TargetRequest* target, IRModule* module, DiagnosticSink* sink)
+{
+ HashSet<IRFunc*> checkedFuncsForRecursionDetection;
+ for (auto globalInst : module->getGlobalInsts())
+ {
+ switch (globalInst->getOp())
+ {
+ case kIROp_Func:
+ if (!isCPUTarget(target))
+ checkFunctionRecursion(
+ checkedFuncsForRecursionDetection,
+ as<IRFunc>(globalInst),
+ sink);
+ break;
+ default:
+ break;
+ }
+ }
+}
+
+} // namespace Slang
diff --git a/source/slang/slang-ir-check-recursive-type.h b/source/slang/slang-ir-check-recursion.h
index dd5796c86..1bfcfbee9 100644
--- a/source/slang/slang-ir-check-recursive-type.h
+++ b/source/slang/slang-ir-check-recursion.h
@@ -4,6 +4,10 @@ namespace Slang
{
struct IRModule;
class DiagnosticSink;
+class TargetRequest;
void checkForRecursiveTypes(IRModule* module, DiagnosticSink* sink);
+
+void checkForRecursiveFunctions(TargetRequest* target, IRModule* module, DiagnosticSink* sink);
+
} // namespace Slang
diff --git a/source/slang/slang-ir-check-recursive-type.cpp b/source/slang/slang-ir-check-recursive-type.cpp
deleted file mode 100644
index ee4541735..000000000
--- a/source/slang/slang-ir-check-recursive-type.cpp
+++ /dev/null
@@ -1,65 +0,0 @@
-#include "slang-ir-check-recursive-type.h"
-
-#include "slang-ir-util.h"
-
-namespace Slang
-{
-bool checkTypeRecursionImpl(
- HashSet<IRInst*>& checkedTypes,
- HashSet<IRInst*>& stack,
- IRInst* type,
- IRInst* field,
- DiagnosticSink* sink)
-{
- auto visitElementType = [&](IRInst* elementType, IRInst* field) -> bool
- {
- if (!stack.add(elementType))
- {
- sink->diagnose(field ? field : type, Diagnostics::recursiveType, type);
- return false;
- }
- if (checkedTypes.add(elementType))
- checkTypeRecursionImpl(checkedTypes, stack, elementType, field, sink);
- stack.remove(elementType);
- return true;
- };
- if (auto arrayType = as<IRArrayTypeBase>(type))
- {
- return visitElementType(arrayType->getElementType(), field);
- }
- else if (auto structType = as<IRStructType>(type))
- {
- for (auto sfield : structType->getFields())
- if (!visitElementType(sfield->getFieldType(), sfield))
- return false;
- }
- return true;
-}
-
-void checkTypeRecursion(HashSet<IRInst*>& checkedTypes, IRInst* type, DiagnosticSink* sink)
-{
- HashSet<IRInst*> stack;
- if (checkedTypes.add(type))
- {
- stack.add(type);
- checkTypeRecursionImpl(checkedTypes, stack, type, nullptr, sink);
- }
-}
-
-void checkForRecursiveTypes(IRModule* module, DiagnosticSink* sink)
-{
- HashSet<IRInst*> checkedTypes;
- for (auto globalInst : module->getGlobalInsts())
- {
- switch (globalInst->getOp())
- {
- case kIROp_StructType:
- {
- checkTypeRecursion(checkedTypes, globalInst, sink);
- }
- break;
- }
- }
-}
-
-} // namespace Slang
diff --git a/source/slang/slang-ir-check-unsupported-inst.cpp b/source/slang/slang-ir-check-unsupported-inst.cpp
index ea9e7cc64..3bf570dc1 100644
--- a/source/slang/slang-ir-check-unsupported-inst.cpp
+++ b/source/slang/slang-ir-check-unsupported-inst.cpp
@@ -5,46 +5,6 @@
namespace Slang
{
-bool isCPUTarget(TargetRequest* targetReq);
-
-bool checkRecursionImpl(
- HashSet<IRFunc*>& checkedFuncs,
- HashSet<IRFunc*>& callStack,
- IRFunc* func,
- DiagnosticSink* sink)
-{
- for (auto block : func->getBlocks())
- {
- for (auto inst : block->getChildren())
- {
- auto callInst = as<IRCall>(inst);
- if (!callInst)
- continue;
- auto callee = as<IRFunc>(callInst->getCallee());
- if (!callee)
- continue;
- if (!callStack.add(callee))
- {
- sink->diagnose(callInst, Diagnostics::unsupportedRecursion, callee);
- return false;
- }
- if (checkedFuncs.add(callee))
- checkRecursionImpl(checkedFuncs, callStack, callee, sink);
- callStack.remove(callee);
- }
- }
- return true;
-}
-
-void checkRecursion(HashSet<IRFunc*>& checkedFuncs, IRFunc* func, DiagnosticSink* sink)
-{
- HashSet<IRFunc*> callStack;
- if (checkedFuncs.add(func))
- {
- callStack.add(func);
- checkRecursionImpl(checkedFuncs, callStack, func, sink);
- }
-}
void checkUnsupportedInst(TargetRequest* target, IRFunc* func, DiagnosticSink* sink)
{
@@ -65,8 +25,6 @@ void checkUnsupportedInst(TargetRequest* target, IRFunc* func, DiagnosticSink* s
void checkUnsupportedInst(TargetRequest* target, IRModule* module, DiagnosticSink* sink)
{
- HashSet<IRFunc*> checkedFuncsForRecursionDetection;
-
for (auto globalInst : module->getGlobalInsts())
{
switch (globalInst->getOp())
@@ -84,8 +42,6 @@ void checkUnsupportedInst(TargetRequest* target, IRModule* module, DiagnosticSin
break;
}
case kIROp_Func:
- if (!isCPUTarget(target))
- checkRecursion(checkedFuncsForRecursionDetection, as<IRFunc>(globalInst), sink);
checkUnsupportedInst(target, as<IRFunc>(globalInst), sink);
break;
case kIROp_Generic:
diff --git a/source/slang/slang-lower-to-ir.cpp b/source/slang/slang-lower-to-ir.cpp
index 92c442433..06c5f005b 100644
--- a/source/slang/slang-lower-to-ir.cpp
+++ b/source/slang/slang-lower-to-ir.cpp
@@ -9,7 +9,7 @@
#include "slang-ir-autodiff.h"
#include "slang-ir-bit-field-accessors.h"
#include "slang-ir-check-differentiability.h"
-#include "slang-ir-check-recursive-type.h"
+#include "slang-ir-check-recursion.h"
#include "slang-ir-clone.h"
#include "slang-ir-constexpr.h"
#include "slang-ir-dce.h"