summaryrefslogtreecommitdiffstats
path: root/source
diff options
context:
space:
mode:
authorYong He <yonghe@outlook.com>2023-05-25 23:39:30 -0700
committerGitHub <noreply@github.com>2023-05-26 06:39:30 +0000
commitab284ca61d0c4c29ac7331b99a98f95bb3ad44e5 (patch)
tree761c50913404ee5bf0b827db781f75173c6c123a /source
parentf88e1299b7715190ce82f3f4473f0d0eeaa2000e (diff)
Fix bug in legalizeFuncType that leads to invalid IR. (#2902)
* Fix bug in legalizeFuncType that leads to invalid IR. * Diagnose on functions that never returns when differentiate it. --------- Co-authored-by: Yong He <yhe@nvidia.com>
Diffstat (limited to 'source')
-rw-r--r--source/slang/slang-diagnostic-defs.h2
-rw-r--r--source/slang/slang-ir-autodiff-rev.cpp9
-rw-r--r--source/slang/slang-ir-legalize-types.cpp1
-rw-r--r--source/slang/slang-ir-single-return.cpp9
-rw-r--r--source/slang/slang-ir-single-return.h2
5 files changed, 20 insertions, 3 deletions
diff --git a/source/slang/slang-diagnostic-defs.h b/source/slang/slang-diagnostic-defs.h
index 722f25843..29ce30323 100644
--- a/source/slang/slang-diagnostic-defs.h
+++ b/source/slang/slang-diagnostic-defs.h
@@ -597,6 +597,8 @@ DIAGNOSTIC(40011, Error, unconstrainedGenericParameterNotAllowedInDynamicFunctio
DIAGNOSTIC(40020, Error, cannotUnrollLoop, "loop does not terminate within the limited number of iterations, unrolling is aborted.")
+DIAGNOSTIC(40030, Fatal, functionNeverReturnsFatal, "function '$0' never returns, compilation ceased.")
+
// 41000 - IR-level validation issues
DIAGNOSTIC(41000, Warning, unreachableCode, "unreachable code detected")
diff --git a/source/slang/slang-ir-autodiff-rev.cpp b/source/slang/slang-ir-autodiff-rev.cpp
index 70c43cdcb..3f0036b06 100644
--- a/source/slang/slang-ir-autodiff-rev.cpp
+++ b/source/slang/slang-ir-autodiff-rev.cpp
@@ -537,10 +537,17 @@ namespace Slang
DifferentiableTypeConformanceContext diffTypeContext(autoDiffSharedContext);
diffTypeContext.setFunc(func);
- if (!isSingleReturnFunc(func))
+ auto returnCount = getReturnCount(func);
+ if (returnCount > 1)
{
convertFuncToSingleReturnForm(func->getModule(), func);
}
+ else if (returnCount == 0)
+ {
+ // The function is ill-formed and never returns (such as having an infinite loop),
+ // we can't possibly reverse-differentiate such functions, so we will diagnose it here.
+ getSink()->diagnose(func->sourceLoc, Diagnostics::functionNeverReturnsFatal, func);
+ }
eliminateContinueBlocksInFunc(func->getModule(), func);
diff --git a/source/slang/slang-ir-legalize-types.cpp b/source/slang/slang-ir-legalize-types.cpp
index 135425676..162498bf0 100644
--- a/source/slang/slang-ir-legalize-types.cpp
+++ b/source/slang/slang-ir-legalize-types.cpp
@@ -2285,6 +2285,7 @@ struct LegalFuncBuilder
// prefers to modify an IR node in-place rather than create a distinct
// legalized copy of it.
//
+ irBuilder->setInsertBefore(oldFunc);
auto newFuncType = irBuilder->getFuncType(
m_paramTypes.getCount(),
m_paramTypes.getBuffer(),
diff --git a/source/slang/slang-ir-single-return.cpp b/source/slang/slang-ir-single-return.cpp
index 10d8972bc..0b61e5065 100644
--- a/source/slang/slang-ir-single-return.cpp
+++ b/source/slang/slang-ir-single-return.cpp
@@ -89,7 +89,7 @@ void convertFuncToSingleReturnForm(IRModule* irModule, IRGlobalValueWithCode* fu
context.processFunc(func);
}
-bool isSingleReturnFunc(IRGlobalValueWithCode* func)
+int getReturnCount(IRGlobalValueWithCode* func)
{
int returnCount = 0;
for (auto block : func->getBlocks())
@@ -102,7 +102,12 @@ bool isSingleReturnFunc(IRGlobalValueWithCode* func)
}
}
}
- return returnCount <= 1;
+ return returnCount;
+}
+
+bool isSingleReturnFunc(IRGlobalValueWithCode* func)
+{
+ return getReturnCount(func) == 1;
}
} // namespace Slang
diff --git a/source/slang/slang-ir-single-return.h b/source/slang/slang-ir-single-return.h
index bb186634d..e38a37db8 100644
--- a/source/slang/slang-ir-single-return.h
+++ b/source/slang/slang-ir-single-return.h
@@ -10,4 +10,6 @@ namespace Slang
void convertFuncToSingleReturnForm(IRModule* module, IRGlobalValueWithCode* func);
bool isSingleReturnFunc(IRGlobalValueWithCode* func);
+
+ int getReturnCount(IRGlobalValueWithCode* func);
}