summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorJulius Ikkala <julius.ikkala@gmail.com>2025-05-23 22:27:37 +0300
committerGitHub <noreply@github.com>2025-05-23 12:27:37 -0700
commit57c3f938221c427b78da7087f8a832ba4a271a7c (patch)
treee9a6d26278dc1ad75b222ac4fc9b7a1d8449e576
parentd108bfa677c70808b32bd77e93637ed34c19c75d (diff)
Implement throw & catch statements (#6916)
* Implement throw statement It already existed in the IR, so only parsing, checking and lowering was missing. * Initial catch implementation Likely very broken. * Error out when catch() isn't last in scope * Prevent accessing variables from scope preceding catch As those may actually not be available at that point. * Add IError and use it in Result type lowering * Add diagnostic tests * Allow caught throws in non-throw functions * Fix catch propagating between functions & SPIR-V merge issue * Add test for non-trivial error types * Fix MSVC build * Fix invalid value type from Result lowering * Also lower error handling in templates * Lower result types only after specialization * Attempt to disambiguate error enums by witness table * Revert matching by witness, types should be distinct too * Don't assert valueField when getting Result's error value It may not exist if the function returns void, but getting the error value is still legitimate. * Update tests for new error numbers & get rid of expected.txt * Change catch lowering to resemble breaking a loop ... To make SPIR-V happy. * Fix dead catch blocks and invalid cached dominator tree * More SPIR-V adjustment * Lower catch as two nested loops * Add defer interaction test and revert broken defer changes * Fix enum type when throwing literals * Cleanup and bikeshedding * Document error handling mechanism * Fix table of contents * Use boolean tag in Result<T, E> * Use anyValue storage for Result<T,E> * Remove IError * Fix formatting * Eradicate success values from docs and tests * Use parseModernParamDecl for catch parameter * Implement do-catch syntax * Implement catch-all * Fix formatting * Fix marshalling native calls that throw --------- Co-authored-by: Yong He <yonghe@outlook.com>
-rw-r--r--docs/user-guide/03-convenience-features.md65
-rw-r--r--docs/user-guide/toc.html1
-rw-r--r--source/slang/slang-ast-iterator.h14
-rw-r--r--source/slang/slang-ast-stmt.h16
-rw-r--r--source/slang/slang-check-decl.cpp9
-rw-r--r--source/slang/slang-check-expr.cpp69
-rw-r--r--source/slang/slang-check-impl.h23
-rw-r--r--source/slang/slang-check-stmt.cpp76
-rw-r--r--source/slang/slang-diagnostic-defs.h22
-rw-r--r--source/slang/slang-emit.cpp10
-rw-r--r--source/slang/slang-ir-lower-error-handling.cpp48
-rw-r--r--source/slang/slang-ir-lower-result-type.cpp155
-rw-r--r--source/slang/slang-ir-marshal-native-call.cpp39
-rw-r--r--source/slang/slang-ir.cpp6
-rw-r--r--source/slang/slang-language-server-ast-lookup.cpp11
-rw-r--r--source/slang/slang-lower-to-ir.cpp168
-rw-r--r--source/slang/slang-parser.cpp82
-rw-r--r--tests/language-feature/error-handling/basic.slang91
-rw-r--r--tests/language-feature/error-handling/catch-all.slang57
-rw-r--r--tests/language-feature/error-handling/defer-interaction.slang58
-rw-r--r--tests/language-feature/error-handling/generics.slang47
-rw-r--r--tests/language-feature/error-handling/non-trivial-error-type.slang40
-rw-r--r--tests/language-feature/error-handling/throw-in-defer.slang17
-rw-r--r--tests/language-feature/error-handling/throw-type-mismatch.slang17
-rw-r--r--tests/language-feature/error-handling/throw-without-throws.slang19
-rw-r--r--tests/language-feature/error-handling/try-in-defer.slang19
26 files changed, 1003 insertions, 176 deletions
diff --git a/docs/user-guide/03-convenience-features.md b/docs/user-guide/03-convenience-features.md
index 5646292aa..d522daf81 100644
--- a/docs/user-guide/03-convenience-features.md
+++ b/docs/user-guide/03-convenience-features.md
@@ -793,6 +793,71 @@ by using the `[ForceInline]` decoration:
int f(int x) { return x + 1; }
```
+Error handling
+-----------------
+
+Slang supports an error handling mechanism that is superficially similar to
+exceptions in many other languages, but has some unique characteristics.
+
+In contrast to C++ exceptions, this mechanism makes the control flow of errors
+more explicit, and the performance charasteristics are similar to adding an
+if-statement after every potentially throwing function call to check and handle
+the error.
+
+In order to be able to throw an error, a function must declare the type of that
+error with `throws`:
+```
+enum MyError
+{
+ Failure,
+ CatastrophicFailure
+}
+
+int f() throws MyError
+{
+ if (computerIsBroken())
+ throw MyError.CatastrophicFailure;
+ return 42;
+}
+```
+Currently, functions may only throw a single type of error.
+
+To call a function that may throw, you must prepend it with `try`:
+
+```
+let result = try f();
+```
+
+If you don't catch the `try`, related errors are re-thrown and the calling
+function must declare that it `throws` that error type:
+
+```
+void g() throws MyError
+{
+ // This would not compile if `g()` wasn't declared to throw MyError as well.
+ let result = try f();
+ printf("Success: %d\n", result);
+}
+```
+
+To catch an error, you can use a `do-catch` statement:
+
+```
+void g()
+{
+ do
+ {
+ let result = try f();
+ printf("Success: %d\n", result);
+ }
+ catch(err: MyError)
+ {
+ printf("Not good!\n");
+ }
+}
+```
+
+You can chain multiple catch statements for different types of errors.
Special Scoping Syntax
-------------------
diff --git a/docs/user-guide/toc.html b/docs/user-guide/toc.html
index 47e0c05ad..c83f36282 100644
--- a/docs/user-guide/toc.html
+++ b/docs/user-guide/toc.html
@@ -49,6 +49,7 @@
<li data-link="convenience-features#extensions"><span>Extensions</span></li>
<li data-link="convenience-features#multi-level-break"><span>Multi-level break</span></li>
<li data-link="convenience-features#force-inlining"><span>Force inlining</span></li>
+<li data-link="convenience-features#error-handling"><span>Error handling</span></li>
<li data-link="convenience-features#special-scoping-syntax"><span>Special Scoping Syntax</span></li>
<li data-link="convenience-features#user-defined-attributes-experimental"><span>User Defined Attributes (Experimental)</span></li>
</ul>
diff --git a/source/slang/slang-ast-iterator.h b/source/slang/slang-ast-iterator.h
index 0a4b17cc0..047c33d1a 100644
--- a/source/slang/slang-ast-iterator.h
+++ b/source/slang/slang-ast-iterator.h
@@ -438,6 +438,20 @@ struct ASTIterator
dispatchIfNotNull(stmt->statement);
}
+ void visitThrowStmt(ThrowStmt* stmt)
+ {
+ iterator->maybeDispatchCallback(stmt);
+ iterator->visitExpr(stmt->expression);
+ }
+
+ void visitCatchStmt(CatchStmt* stmt)
+ {
+ if (stmt->errorVar)
+ iterator->visitDecl(stmt->errorVar);
+ dispatchIfNotNull(stmt->tryBody);
+ dispatchIfNotNull(stmt->handleBody);
+ }
+
void visitWhileStmt(WhileStmt* stmt)
{
iterator->maybeDispatchCallback(stmt);
diff --git a/source/slang/slang-ast-stmt.h b/source/slang/slang-ast-stmt.h
index a1b7c274e..1fe97adf1 100644
--- a/source/slang/slang-ast-stmt.h
+++ b/source/slang/slang-ast-stmt.h
@@ -295,6 +295,22 @@ class DeferStmt : public Stmt
};
FIDDLE()
+class ThrowStmt : public Stmt
+{
+ FIDDLE(...)
+ FIDDLE() Expr* expression = nullptr;
+};
+
+FIDDLE()
+class CatchStmt : public Stmt
+{
+ FIDDLE(...)
+ FIDDLE() ParamDecl* errorVar = nullptr; // null => catch-all
+ FIDDLE() Stmt* tryBody = nullptr;
+ FIDDLE() Stmt* handleBody = nullptr;
+};
+
+FIDDLE()
class ExpressionStmt : public Stmt
{
FIDDLE(...)
diff --git a/source/slang/slang-check-decl.cpp b/source/slang/slang-check-decl.cpp
index 1331839a4..7a0dcb06f 100644
--- a/source/slang/slang-check-decl.cpp
+++ b/source/slang/slang-check-decl.cpp
@@ -856,6 +856,15 @@ struct SemanticsDeclReferenceVisitor : public SemanticsDeclVisitorBase,
void visitDeferStmt(DeferStmt* stmt) { dispatchIfNotNull(stmt->statement); }
+ void visitThrowStmt(ThrowStmt* stmt) { dispatchIfNotNull(stmt->expression); }
+
+ void visitCatchStmt(CatchStmt* stmt)
+ {
+ dispatchIfNotNull(stmt->errorVar);
+ dispatchIfNotNull(stmt->tryBody);
+ dispatchIfNotNull(stmt->handleBody);
+ }
+
void visitWhileStmt(WhileStmt* stmt)
{
dispatchIfNotNull(stmt->predicate);
diff --git a/source/slang/slang-check-expr.cpp b/source/slang/slang-check-expr.cpp
index db507c060..75b1b7024 100644
--- a/source/slang/slang-check-expr.cpp
+++ b/source/slang/slang-check-expr.cpp
@@ -3951,45 +3951,64 @@ Expr* SemanticsExprVisitor::visitTryExpr(TryExpr* expr)
return expr;
auto parentFunc = this->m_parentFunc;
- // TODO: check if the try clause is caught.
- // For now we assume all `try`s are not caught (because we don't have catch yet).
- if (!parentFunc)
+ auto base = as<InvokeExpr>(expr->base);
+ if (!base)
{
- getSink()->diagnose(expr, Diagnostics::uncaughtTryCallInNonThrowFunc);
+ getSink()->diagnose(expr, Diagnostics::tryClauseMustApplyToInvokeExpr);
return expr;
}
- if (parentFunc->errorType->equals(m_astBuilder->getBottomType()))
+
+ auto callee = as<DeclRefExpr>(base->functionExpr);
+ if (!callee)
{
- getSink()->diagnose(expr, Diagnostics::uncaughtTryCallInNonThrowFunc);
+ getSink()->diagnose(expr, Diagnostics::calleeOfTryCallMustBeFunc);
return expr;
}
- if (!as<InvokeExpr>(expr->base))
+
+ auto funcCallee = as<FuncDecl>(callee->declRef.getDecl());
+ Stmt* catchStmt = nullptr;
+ if (funcCallee)
{
- getSink()->diagnose(expr, Diagnostics::tryClauseMustApplyToInvokeExpr);
+ if (funcCallee->errorType->equals(m_astBuilder->getBottomType()))
+ {
+ getSink()->diagnose(expr, Diagnostics::tryInvokeCalleeShouldThrow, callee->declRef);
+ return expr;
+ }
+ catchStmt = findMatchingCatchStmt(funcCallee->errorType);
+ }
+
+ if (FindOuterStmt<DeferStmt>(catchStmt))
+ {
+ // 'try' may jump outside a defer statement, which isn't allowed for
+ // now.
+ getSink()->diagnose(expr, Diagnostics::uncaughtTryInsideDefer);
return expr;
}
- auto base = as<InvokeExpr>(expr->base);
- if (auto callee = as<DeclRefExpr>(base->functionExpr))
+
+ if (!catchStmt)
{
- if (auto funcCallee = as<FuncDecl>(callee->declRef.getDecl()))
+ // Uncaught try.
+ if (!parentFunc)
{
- if (funcCallee->errorType->equals(m_astBuilder->getBottomType()))
- {
- getSink()->diagnose(expr, Diagnostics::tryInvokeCalleeShouldThrow, callee->declRef);
- }
- if (!parentFunc->errorType->equals(funcCallee->errorType))
- {
- getSink()->diagnose(
- expr,
- Diagnostics::errorTypeOfCalleeIncompatibleWithCaller,
- callee->declRef,
- funcCallee->errorType,
- parentFunc->errorType);
- }
+ getSink()->diagnose(expr, Diagnostics::uncaughtTryCallInNonThrowFunc);
+ return expr;
+ }
+ if (parentFunc->errorType->equals(m_astBuilder->getBottomType()))
+ {
+ getSink()->diagnose(expr, Diagnostics::uncaughtTryCallInNonThrowFunc);
+ return expr;
+ }
+ if (funcCallee && !parentFunc->errorType->equals(funcCallee->errorType))
+ {
+ getSink()->diagnose(
+ expr,
+ Diagnostics::errorTypeOfCalleeIncompatibleWithCaller,
+ callee->declRef,
+ funcCallee->errorType,
+ parentFunc->errorType);
return expr;
}
}
- getSink()->diagnose(expr, Diagnostics::calleeOfTryCallMustBeFunc);
return expr;
}
diff --git a/source/slang/slang-check-impl.h b/source/slang/slang-check-impl.h
index 87f2df53d..950a150c4 100644
--- a/source/slang/slang-check-impl.h
+++ b/source/slang/slang-check-impl.h
@@ -1028,6 +1028,20 @@ public:
return result;
}
+ template<typename T>
+ T* FindOuterStmt(Stmt* searchUntil = nullptr)
+ {
+ for (auto outerStmtInfo = m_outerStmts; outerStmtInfo && outerStmtInfo->stmt != searchUntil;
+ outerStmtInfo = outerStmtInfo->next)
+ {
+ auto outerStmt = outerStmtInfo->stmt;
+ auto found = as<T>(outerStmt);
+ if (found)
+ return found;
+ }
+ return nullptr;
+ }
+
// Setup the flag to indicate disabling the short-circuiting evaluation
// for the logical expressions associted with the subcontext
SemanticsContext disableShortCircuitLogicalExpr()
@@ -2867,6 +2881,8 @@ public:
void addVisibilityModifier(Decl* decl, DeclVisibility vis);
void checkRayPayloadStructFields(StructDecl* structDecl);
+
+ CatchStmt* findMatchingCatchStmt(Type* errorType);
};
@@ -3011,9 +3027,6 @@ struct SemanticsStmtVisitor : public SemanticsVisitor, StmtVisitor<SemanticsStmt
void checkStmt(Stmt* stmt);
- template<typename T>
- T* FindOuterStmt(Stmt* searchUntil = nullptr);
-
Stmt* findOuterStmtWithLabel(Name* label);
void visitDeclStmt(DeclStmt* stmt);
@@ -3058,6 +3071,10 @@ struct SemanticsStmtVisitor : public SemanticsVisitor, StmtVisitor<SemanticsStmt
void visitDeferStmt(DeferStmt* stmt);
+ void visitThrowStmt(ThrowStmt* stmt);
+
+ void visitCatchStmt(CatchStmt* stmt);
+
void visitWhileStmt(WhileStmt* stmt);
void visitGpuForeachStmt(GpuForeachStmt* stmt);
diff --git a/source/slang/slang-check-stmt.cpp b/source/slang/slang-check-stmt.cpp
index 2dc8c2685..9d1899462 100644
--- a/source/slang/slang-check-stmt.cpp
+++ b/source/slang/slang-check-stmt.cpp
@@ -41,6 +41,19 @@ void SemanticsVisitor::checkStmt(Stmt* stmt, SemanticsContext const& context)
checkModifiers(stmt);
}
+CatchStmt* SemanticsVisitor::findMatchingCatchStmt(Type* errorType)
+{
+ for (auto outerStmtInfo = m_outerStmts; outerStmtInfo; outerStmtInfo = outerStmtInfo->next)
+ {
+ if (auto catchStmt = as<CatchStmt>(outerStmtInfo->stmt))
+ {
+ if (!catchStmt->errorVar || catchStmt->errorVar->getType()->equals(errorType))
+ return catchStmt;
+ }
+ }
+ return nullptr;
+}
+
void SemanticsStmtVisitor::visitDeclStmt(DeclStmt* stmt)
{
// When we encounter a declaration during statement checking,
@@ -118,20 +131,6 @@ void SemanticsStmtVisitor::checkStmt(Stmt* stmt)
SemanticsVisitor::checkStmt(stmt, *this);
}
-template<typename T>
-T* SemanticsStmtVisitor::FindOuterStmt(Stmt* searchUntil)
-{
- for (auto outerStmtInfo = m_outerStmts; outerStmtInfo && outerStmtInfo->stmt != searchUntil;
- outerStmtInfo = outerStmtInfo->next)
- {
- auto outerStmt = outerStmtInfo->stmt;
- auto found = as<T>(outerStmt);
- if (found)
- return found;
- }
- return nullptr;
-}
-
Stmt* SemanticsStmtVisitor::findOuterStmtWithLabel(Name* label)
{
for (auto outerStmtInfo = m_outerStmts; outerStmtInfo; outerStmtInfo = outerStmtInfo->next)
@@ -616,6 +615,55 @@ void SemanticsStmtVisitor::visitDeferStmt(DeferStmt* stmt)
subContext.checkStmt(stmt->statement);
}
+void SemanticsStmtVisitor::visitThrowStmt(ThrowStmt* stmt)
+{
+ stmt->expression = CheckTerm(stmt->expression);
+ Stmt* catchStmt = findMatchingCatchStmt(stmt->expression->type);
+
+ auto parentFunc = getParentFunc();
+ if (!catchStmt && (!parentFunc || parentFunc->errorType->equals(m_astBuilder->getBottomType())))
+ {
+ getSink()->diagnose(stmt, Diagnostics::uncaughtThrowInNonThrowFunc);
+ return;
+ }
+
+ if (!catchStmt && !stmt->expression->type->equals(m_astBuilder->getErrorType()))
+ {
+ if (!parentFunc->errorType->equals(stmt->expression->type))
+ {
+ getSink()->diagnose(
+ stmt->expression,
+ Diagnostics::throwTypeIncompatibleWithErrorType,
+ stmt->expression->type,
+ parentFunc->errorType);
+ }
+ }
+
+ if (FindOuterStmt<DeferStmt>(catchStmt))
+ {
+ // Allowing 'throw' to escape a defer statement gets quite complex, for
+ // similar reasons as 'return' - if you have two (or more) defers,
+ // both of which exit the outer scope, it's unclear which one gets
+ // called and when. Both can't fully run. That kind of goes against the
+ // point of 'defer', which is to _always_ run some code when exiting
+ // scopes.
+ getSink()->diagnose(stmt, Diagnostics::uncaughtThrowInsideDefer);
+ }
+}
+
+void SemanticsStmtVisitor::visitCatchStmt(CatchStmt* stmt)
+{
+ if (stmt->errorVar)
+ {
+ ensureDeclBase(stmt->errorVar, DeclCheckState::DefinitionChecked, this);
+ stmt->errorVar->hiddenFromLookup = false;
+ }
+
+ WithOuterStmt subContext(this, stmt);
+ subContext.checkStmt(stmt->tryBody);
+ subContext.checkStmt(stmt->handleBody);
+}
+
void SemanticsStmtVisitor::visitExpressionStmt(ExpressionStmt* stmt)
{
stmt->expression = CheckExpr(stmt->expression);
diff --git a/source/slang/slang-diagnostic-defs.h b/source/slang/slang-diagnostic-defs.h
index bf90c6608..3babc9a56 100644
--- a/source/slang/slang-diagnostic-defs.h
+++ b/source/slang/slang-diagnostic-defs.h
@@ -1001,6 +1001,28 @@ DIAGNOSTIC(
nonCopyableTypeCapturedInLambda,
"cannot capture non-copyable type '$0' in a lambda expression.")
+DIAGNOSTIC(
+ 30113,
+ Error,
+ uncaughtThrowInsideDefer,
+ "'throw' expressions require a matching 'catch' inside a defer statement.")
+DIAGNOSTIC(
+ 30114,
+ Error,
+ uncaughtTryInsideDefer,
+ "'try' expressions require a matching 'catch' inside a defer statement.")
+DIAGNOSTIC(
+ 30115,
+ Error,
+ uncaughtThrowInNonThrowFunc,
+ "the current function or environment is not declared to throw any errors, but contains an "
+ "uncaught 'throw' statement.")
+DIAGNOSTIC(
+ 30116,
+ Error,
+ throwTypeIncompatibleWithErrorType,
+ "the type `$0` of `throw` is not compatible with function's error type `$1`.")
+
// Include
DIAGNOSTIC(
30500,
diff --git a/source/slang/slang-emit.cpp b/source/slang/slang-emit.cpp
index 260bee0ff..a4362b912 100644
--- a/source/slang/slang-emit.cpp
+++ b/source/slang/slang-emit.cpp
@@ -806,10 +806,6 @@ Result linkAndOptimizeIR(
break;
}
- // Lower `Result<T,E>` types into ordinary struct types.
- if (requiredLoweringPassSet.resultType)
- lowerResultType(irModule, sink);
-
#if 0
dumpIRIfEnabled(codeGenContext, irModule, "UNIONS DESUGARED");
#endif
@@ -951,6 +947,12 @@ Result linkAndOptimizeIR(
break;
}
+ // Lower `Result<T,E>` types into ordinary struct types. This must happen
+ // after specialization, since otherwise incompatible copies of the lowered
+ // result structure are generated.
+ if (requiredLoweringPassSet.resultType)
+ lowerResultType(irModule, sink);
+
// Report checkpointing information
if (codeGenContext->shouldReportCheckpointIntermediates())
{
diff --git a/source/slang/slang-ir-lower-error-handling.cpp b/source/slang/slang-ir-lower-error-handling.cpp
index 917658ac4..a885fc2d2 100644
--- a/source/slang/slang-ir-lower-error-handling.cpp
+++ b/source/slang/slang-ir-lower-error-handling.cpp
@@ -15,6 +15,7 @@ struct ErrorHandlingLoweringContext
InstWorkList workList;
InstHashSet workListSet;
+ List<IRFuncType*> oldFuncTypes;
ErrorHandlingLoweringContext(IRModule* inModule)
: module(inModule), workList(inModule), workListSet(inModule)
@@ -37,8 +38,10 @@ struct ErrorHandlingLoweringContext
return;
IRBuilder builder(module);
builder.setInsertBefore(funcType);
+
auto resultType =
builder.getResultType(funcType->getResultType(), throwAttr->getErrorType());
+
List<IRType*> paramTypes;
for (UInt i = 0; i < funcType->getParamCount(); i++)
{
@@ -48,6 +51,7 @@ struct ErrorHandlingLoweringContext
}
auto newFuncType = builder.getFuncType(paramTypes, resultType);
funcType->replaceUsesWith(newFuncType);
+ funcType->removeAndDeallocate();
}
void processTryCall(IRTryCall* tryCall)
@@ -99,14 +103,27 @@ struct ErrorHandlingLoweringContext
auto failBlock = tryCall->getFailureBlock();
auto successBlock = tryCall->getSuccessBlock();
- builder.emitIf(isFail, failBlock, successBlock);
-
- // Replace the params in failBlock to `getResultError(call)`.
- builder.setInsertBefore(failBlock->getFirstOrdinaryInst());
- auto errorParam = failBlock->getFirstParam();
- auto errVal = builder.emitGetResultError(call);
- errorParam->replaceUsesWith(errVal);
- errorParam->removeAndDeallocate();
+ if (failBlock->getFirstParam())
+ {
+ // The isFail branch could otherwise just jump to the handler, but
+ // there's unfortunately the error parameter that needs to be passed as
+ // well, and it can't be done in IfElse. So there's an extra block in
+ // between to do that.
+ auto handlerJumpBlock = builder.createBlock();
+ auto branch = builder.emitIf(isFail, handlerJumpBlock, successBlock);
+
+ builder.setInsertAfter(branch->getParent());
+ builder.addInst(handlerJumpBlock);
+ builder.setInsertInto(handlerJumpBlock);
+
+ auto errVal = builder.emitGetResultError(call);
+ builder.emitBranch(failBlock, 1, &errVal);
+ }
+ else
+ {
+ // Catch-all with no parameter, so we can just jump to it directly.
+ builder.emitIf(isFail, failBlock, successBlock);
+ }
// Replace the params in successBlock to `getResultValue(call)`.
builder.setInsertBefore(successBlock->getFirstOrdinaryInst());
@@ -173,6 +190,9 @@ struct ErrorHandlingLoweringContext
case kIROp_Throw:
processThrow(cast<IRThrow>(inst));
break;
+ case kIROp_FuncType:
+ oldFuncTypes.add(cast<IRFuncType>(inst));
+ break;
default:
break;
}
@@ -206,18 +226,6 @@ struct ErrorHandlingLoweringContext
// Lower all functypes.
// Function types with an IRThrowTypeAttribute will be translated into a normal function
// type that returns `Result<T,E>`.
- List<IRFuncType*> oldFuncTypes;
- for (auto child : module->getGlobalInsts())
- {
- switch (child->getOp())
- {
- case kIROp_FuncType:
- oldFuncTypes.add(cast<IRFuncType>(child));
- break;
- default:
- break;
- }
- }
for (auto funcType : oldFuncTypes)
{
processFuncType(funcType);
diff --git a/source/slang/slang-ir-lower-result-type.cpp b/source/slang/slang-ir-lower-result-type.cpp
index 4cf684f33..3e7a2f523 100644
--- a/source/slang/slang-ir-lower-result-type.cpp
+++ b/source/slang/slang-ir-lower-result-type.cpp
@@ -2,6 +2,7 @@
#include "slang-ir-lower-result-type.h"
+#include "slang-ir-any-value-marshalling.h"
#include "slang-ir-insts.h"
#include "slang-ir.h"
@@ -23,11 +24,13 @@ struct ResultTypeLoweringContext
struct LoweredResultTypeInfo : public RefObject
{
IRType* resultType = nullptr;
- IRType* errorType = nullptr;
- IRType* valueType = nullptr;
IRType* loweredType = nullptr;
- IRStructField* valueField = nullptr;
- IRStructField* errorField = nullptr;
+ IRType* tagType = nullptr;
+ IRType* valueType = nullptr;
+ IRType* errorType = nullptr;
+ IRType* anyValueType = nullptr;
+ IRStructField* tagField = nullptr;
+ IRStructField* anyValueField = nullptr;
};
Dictionary<IRInst*, RefPtr<LoweredResultTypeInfo>> mapLoweredTypeToResultTypeInfo;
Dictionary<IRInst*, RefPtr<LoweredResultTypeInfo>> loweredResultTypes;
@@ -53,31 +56,40 @@ struct ResultTypeLoweringContext
return nullptr;
RefPtr<LoweredResultTypeInfo> info = new LoweredResultTypeInfo();
+ auto resultType = cast<IRResultType>(type);
info->resultType = (IRType*)type;
- info->errorType = cast<IRResultType>(type)->getErrorType();
- auto resultType = cast<IRResultType>(type);
+ auto structType = builder->createStructType();
+ info->loweredType = structType;
+ builder->addNameHintDecoration(structType, UnownedStringSlice("ResultType"));
+
+ info->tagType = builder->getBoolType();
+ auto tagKey = builder->createStructKey();
+ builder->addNameHintDecoration(tagKey, UnownedStringSlice("tag"));
+ info->tagField = builder->createStructField(structType, tagKey, info->tagType);
+
+ SlangInt anyValueSize = 0;
auto valueType = resultType->getValueType();
if (valueType->getOp() != kIROp_VoidType)
{
- auto structType = builder->createStructType();
- info->loweredType = structType;
- builder->addNameHintDecoration(structType, UnownedStringSlice("ResultType"));
-
+ anyValueSize = getAnyValueSize(valueType);
info->valueType = valueType;
- auto valueKey = builder->createStructKey();
- builder->addNameHintDecoration(valueKey, UnownedStringSlice("value"));
- info->valueField = builder->createStructField(structType, valueKey, (IRType*)valueType);
-
- auto errorType = resultType->getErrorType();
- auto errorKey = builder->createStructKey();
- builder->addNameHintDecoration(errorKey, UnownedStringSlice("error"));
- info->errorField = builder->createStructField(structType, errorKey, (IRType*)errorType);
- }
- else
- {
- info->loweredType = resultType->getErrorType();
}
+
+ auto errorType = resultType->getErrorType();
+ info->errorType = errorType;
+
+ auto errSize = getAnyValueSize(errorType);
+ if (errSize > anyValueSize)
+ anyValueSize = errSize;
+
+ info->anyValueType =
+ builder->getAnyValueType(builder->getIntValue(builder->getUIntType(), anyValueSize));
+ auto anyValueKey = builder->createStructKey();
+ builder->addNameHintDecoration(anyValueKey, UnownedStringSlice("anyValue"));
+ info->anyValueField =
+ builder->createStructField(structType, anyValueKey, info->anyValueType);
+
mapLoweredTypeToResultTypeInfo[info->loweredType] = info;
loweredResultTypes[type] = info;
return info.Ptr();
@@ -98,30 +110,6 @@ struct ResultTypeLoweringContext
workListSet.add(inst);
}
- IRInst* getSuccessErrorValue(IRType* type)
- {
- switch (type->getOp())
- {
- case kIROp_Int8Type:
- case kIROp_Int16Type:
- case kIROp_IntType:
- case kIROp_Int64Type:
- case kIROp_IntPtrType:
- case kIROp_UInt8Type:
- case kIROp_UInt16Type:
- case kIROp_UIntType:
- case kIROp_UInt64Type:
- case kIROp_UIntPtrType:
- break;
- default:
- SLANG_ASSERT_FAILURE("error type is not lowered to an integer type.");
- }
- IRBuilder builderStorage(module);
- auto builder = &builderStorage;
- builder->setInsertInto(module);
- return builder->getIntValue(type, 0);
- }
-
void processMakeResultValue(IRMakeResultValue* inst)
{
IRBuilder builderStorage(module);
@@ -129,19 +117,16 @@ struct ResultTypeLoweringContext
builder->setInsertBefore(inst);
auto info = getLoweredResultType(builder, inst->getDataType());
- if (info->loweredType->getOp() == kIROp_StructType)
- {
- List<IRInst*> operands;
- operands.add(inst->getOperand(0));
- operands.add(getSuccessErrorValue(info->errorType));
- auto makeStruct = builder->emitMakeStruct(info->loweredType, operands);
- inst->replaceUsesWith(makeStruct);
- }
- else
- {
- auto errCode = getSuccessErrorValue(info->errorType);
- inst->replaceUsesWith(errCode);
- }
+
+ List<IRInst*> operands;
+ operands.add(builder->getBoolValue(false));
+ auto packInst = builder->emitPackAnyValue(
+ info->anyValueType,
+ info->valueType ? inst->getOperand(0) : builder->emitDefaultConstruct(info->errorType));
+ operands.add(packInst);
+
+ auto makeStruct = builder->emitMakeStruct(info->loweredType, operands);
+ inst->replaceUsesWith(makeStruct);
inst->removeAndDeallocate();
}
@@ -152,18 +137,15 @@ struct ResultTypeLoweringContext
builder->setInsertBefore(inst);
auto info = getLoweredResultType(builder, inst->getDataType());
- if (info->valueField)
- {
- List<IRInst*> operands;
- operands.add(builder->emitDefaultConstruct(info->valueType));
- operands.add(inst->getErrorValue());
- auto makeStruct = builder->emitMakeStruct(info->loweredType, operands);
- inst->replaceUsesWith(makeStruct);
- }
- else
- {
- inst->replaceUsesWith(inst->getErrorValue());
- }
+
+ auto packInst = builder->emitPackAnyValue(info->anyValueType, inst->getErrorValue());
+
+ List<IRInst*> operands;
+ operands.add(builder->getBoolValue(true));
+ operands.add(packInst);
+
+ auto makeStruct = builder->emitMakeStruct(info->loweredType, operands);
+ inst->replaceUsesWith(makeStruct);
inst->removeAndDeallocate();
}
@@ -171,13 +153,14 @@ struct ResultTypeLoweringContext
{
auto loweredResultTypeInfo = getLoweredResultType(builder, resultInst->getDataType());
SLANG_ASSERT(loweredResultTypeInfo);
- if (loweredResultTypeInfo->valueField)
+ if (loweredResultTypeInfo->valueType)
{
auto value = builder->emitFieldExtract(
- loweredResultTypeInfo->errorType,
+ loweredResultTypeInfo->anyValueType,
resultInst,
- loweredResultTypeInfo->errorField->getKey());
- return value;
+ loweredResultTypeInfo->anyValueField->getKey());
+ auto unpackInst = builder->emitUnpackAnyValue(loweredResultTypeInfo->errorType, value);
+ return unpackInst;
}
else
{
@@ -206,12 +189,14 @@ struct ResultTypeLoweringContext
auto base = inst->getResultOperand();
auto loweredResultTypeInfo = getLoweredResultType(builder, base->getDataType());
SLANG_ASSERT(loweredResultTypeInfo);
- SLANG_ASSERT(loweredResultTypeInfo->valueField);
+ SLANG_ASSERT(loweredResultTypeInfo->valueType);
+
auto getElement = builder->emitFieldExtract(
- loweredResultTypeInfo->errorType,
+ loweredResultTypeInfo->anyValueType,
base,
- loweredResultTypeInfo->valueField->getKey());
- inst->replaceUsesWith(getElement);
+ loweredResultTypeInfo->anyValueField->getKey());
+ auto unpackInst = builder->emitUnpackAnyValue(loweredResultTypeInfo->valueType, getElement);
+ inst->replaceUsesWith(unpackInst);
inst->removeAndDeallocate();
}
@@ -224,13 +209,13 @@ struct ResultTypeLoweringContext
auto base = inst->getResultOperand();
auto loweredResultTypeInfo = getLoweredResultType(builder, base->getDataType());
SLANG_ASSERT(loweredResultTypeInfo);
- SLANG_ASSERT(loweredResultTypeInfo->valueField);
- auto resultValue = inst->getResultOperand();
- auto errValue = getResultError(builder, resultValue);
- auto isSuccess =
- builder->emitNeq(errValue, getSuccessErrorValue(loweredResultTypeInfo->errorType));
- inst->replaceUsesWith(isSuccess);
+ auto isFailure = builder->emitFieldExtract(
+ loweredResultTypeInfo->tagType,
+ base,
+ loweredResultTypeInfo->tagField->getKey());
+
+ inst->replaceUsesWith(isFailure);
inst->removeAndDeallocate();
}
diff --git a/source/slang/slang-ir-marshal-native-call.cpp b/source/slang/slang-ir-marshal-native-call.cpp
index 2257ef098..1d3f04318 100644
--- a/source/slang/slang-ir-marshal-native-call.cpp
+++ b/source/slang/slang-ir-marshal-native-call.cpp
@@ -42,8 +42,10 @@ IRFuncType* NativeCallMarshallingContext::getNativeFuncType(
if (auto resultType = as<IRResultType>(declaredFuncType->getResultType()))
{
auto nativeResultType = getNativeType(builder, resultType->getValueType());
+ auto nativeErrorType = getNativeType(builder, resultType->getErrorType());
nativeParamTypes.add(builder.getPtrType(nativeResultType));
- returnType = resultType->getErrorType();
+ nativeParamTypes.add(builder.getPtrType(nativeErrorType));
+ returnType = builder.getIntType();
}
else
{
@@ -243,10 +245,8 @@ IRFunc* NativeCallMarshallingContext::generateDLLExportWrapperFunc(
IRBlock* trueBlock = nullptr;
IRBlock* falseBlock = nullptr;
IRBlock* afterBlock = nullptr;
- builder.emitIfElseWithBlocks(isResultError, trueBlock, falseBlock, afterBlock);
- builder.setInsertInto(trueBlock);
- builder.emitReturn(builder.emitGetResultError(callInst));
+ builder.emitIfElseWithBlocks(isResultError, trueBlock, falseBlock, afterBlock);
builder.setInsertInto(falseBlock);
auto resultVal = builder.emitGetResultValue(callInst);
@@ -258,8 +258,22 @@ IRFunc* NativeCallMarshallingContext::generateDLLExportWrapperFunc(
builder.emitStore(params[nativeParamConsumeIndex], nativeVals[i]);
nativeParamConsumeIndex++;
}
+ // S_OK
builder.emitReturn(builder.getIntValue(builder.getIntType(), 0));
+ builder.setInsertInto(trueBlock);
+ nativeVals.clear();
+ auto errorVal = builder.emitGetResultError(callInst);
+ marshalManagedValueToNativeResultValue(builder, errorVal, nativeVals);
+ for (Index i = 0; i < nativeVals.getCount(); i++)
+ {
+ SLANG_RELEASE_ASSERT(nativeParamConsumeIndex < params.getCount());
+ builder.emitStore(params[nativeParamConsumeIndex], nativeVals[i]);
+ nativeParamConsumeIndex++;
+ }
+ // E_FAIL
+ builder.emitReturn(builder.getIntValue(builder.getIntType(), 0x80004005));
+
builder.setInsertInto(afterBlock);
builder.emitUnreachable();
}
@@ -296,10 +310,13 @@ IRInst* NativeCallMarshallingContext::marshalNativeCall(
IRType* originalReturnType = originalFuncType->getResultType();
IRVar* resultVar = nullptr;
+ IRVar* errorVar = nullptr;
if (auto resultType = as<IRResultType>(originalReturnType))
{
// Declare a local variable to receive result.
resultVar = builder.emitVar(getNativeType(builder, resultType->getValueType()));
+ errorVar = builder.emitVar(getNativeType(builder, resultType->getErrorType()));
+ args.add(resultVar);
args.add(resultVar);
}
@@ -314,16 +331,18 @@ IRInst* NativeCallMarshallingContext::marshalNativeCall(
if (auto resultType = as<IRResultType>(originalReturnType))
{
auto val = builder.emitLoad(resultVar);
- auto err = call;
+ auto err = builder.emitLoad(errorVar);
+ auto tag = call;
val = marshalNativeValueToManagedValue(builder, val);
- auto intErr = err;
- if (err->getDataType()->getOp() != kIROp_IntType)
+ err = marshalNativeValueToManagedValue(builder, err);
+ auto intTag = tag;
+ if (tag->getDataType()->getOp() != kIROp_IntType)
{
- intErr = builder.emitCast(builder.getIntType(), err);
+ intTag = builder.emitCast(builder.getIntType(), tag);
}
- auto errIsError = builder.emitLess(intErr, builder.getIntValue(builder.getIntType(), 0));
+ auto tagIsError = builder.emitLess(intTag, builder.getIntValue(builder.getIntType(), 0));
IRBlock *trueBlock, *falseBlock, *afterBlock;
- builder.emitIfElseWithBlocks(errIsError, trueBlock, falseBlock, afterBlock);
+ builder.emitIfElseWithBlocks(tagIsError, trueBlock, falseBlock, afterBlock);
builder.setInsertInto(trueBlock);
returnValue = builder.emitMakeResultError(resultType, err);
builder.emitBranch(afterBlock, 1, &returnValue);
diff --git a/source/slang/slang-ir.cpp b/source/slang/slang-ir.cpp
index b09d9e6e2..6d5bd9a25 100644
--- a/source/slang/slang-ir.cpp
+++ b/source/slang/slang-ir.cpp
@@ -620,6 +620,12 @@ static IRBlock::SuccessorList getSuccessors(IRInst* terminator)
end = begin + 1;
break;
+ case kIROp_TryCall:
+ // tryCall <successBlock> <failBlock> <callee> <args>...
+ begin = operands + 0;
+ end = begin + 2;
+ break;
+
default:
SLANG_UNEXPECTED("unhandled terminator instruction");
UNREACHABLE_RETURN(IRBlock::SuccessorList(nullptr, nullptr));
diff --git a/source/slang/slang-language-server-ast-lookup.cpp b/source/slang/slang-language-server-ast-lookup.cpp
index b08388aee..d0ed9cdac 100644
--- a/source/slang/slang-language-server-ast-lookup.cpp
+++ b/source/slang/slang-language-server-ast-lookup.cpp
@@ -645,6 +645,17 @@ struct ASTLookupStmtVisitor : public StmtVisitor<ASTLookupStmtVisitor, bool>
bool visitDeferStmt(DeferStmt* stmt) { return dispatchIfNotNull(stmt->statement); }
+ bool visitThrowStmt(ThrowStmt* stmt) { return checkExpr(stmt->expression); }
+
+ bool visitCatchStmt(CatchStmt* stmt)
+ {
+ if (stmt->errorVar && _findAstNodeImpl(*context, stmt->errorVar))
+ return true;
+ if (dispatchIfNotNull(stmt->tryBody))
+ return true;
+ return dispatchIfNotNull(stmt->handleBody);
+ }
+
bool visitWhileStmt(WhileStmt* stmt)
{
if (checkExpr(stmt->predicate))
diff --git a/source/slang/slang-lower-to-ir.cpp b/source/slang/slang-lower-to-ir.cpp
index 9920075fe..ebcdaf1e1 100644
--- a/source/slang/slang-lower-to-ir.cpp
+++ b/source/slang/slang-lower-to-ir.cpp
@@ -555,6 +555,17 @@ struct AstOrIRType
explicit operator bool() { return astType || irType; }
};
+struct CatchHandler
+{
+ // 'nullptr' implies catch-all.
+ IRType* errorType = nullptr;
+
+ // Block of the handler statement. Takes a value of errorType as parameter.
+ IRBlock* errorHandler = nullptr;
+
+ CatchHandler* prev = nullptr;
+};
+
struct IRGenContext
{
ASTBuilder* astBuilder;
@@ -599,6 +610,9 @@ struct IRGenContext
// The current scope end for use with `defer`.
IRBlock* scopeEndBlock = nullptr;
+ // A chain of nested `catch` handlers for `try` and `throw.
+ CatchHandler* catchHandler = nullptr;
+
// Callback function to call when after lowering a type.
std::function<IRType*(IRGenContext* context, Type* type, IRType* irType)> lowerTypeCallback =
nullptr;
@@ -754,6 +768,19 @@ int32_t getIntrinsicOp(Decl* decl, IntrinsicOpModifier* intrinsicOpMod)
return int32_t(irOp);
}
+static CatchHandler findErrorHandler(IRGenContext* context, IRType* type)
+{
+ for (auto handler = context->catchHandler; handler != nullptr;
+ handler = context->catchHandler->prev)
+ {
+ if (!handler->errorType || handler->errorType == type)
+ {
+ return *handler;
+ }
+ }
+ return CatchHandler();
+}
+
struct TryClauseEnvironment
{
TryClauseType clauseType = TryClauseType::None;
@@ -809,18 +836,28 @@ LoweredValInfo emitCallToVal(
case TryClauseType::Standard:
{
auto callee = getSimpleVal(context, funcVal);
- auto succBlock = builder->createBlock();
- auto failBlock = builder->createBlock();
auto funcType = as<IRFuncType>(callee->getDataType());
auto throwAttr = funcType->findAttr<IRFuncThrowTypeAttr>();
assert(throwAttr);
+
+ auto handler = findErrorHandler(context, throwAttr->getErrorType());
+ auto succBlock = builder->createBlock();
+ auto failBlock =
+ handler.errorHandler ? handler.errorHandler : builder->createBlock();
+
auto voidType = builder->getVoidType();
builder->emitTryCallInst(voidType, succBlock, failBlock, callee, argCount, args);
- builder->insertBlock(failBlock);
- auto errParam = builder->emitParam(throwAttr->getErrorType());
- builder->emitThrow(errParam);
builder->insertBlock(succBlock);
auto value = builder->emitParam(type);
+
+ if (!handler.errorHandler)
+ {
+ // We have to create a default fail block, which just re-throws.
+ builder->insertBlock(failBlock);
+ auto errParam = builder->emitParam(throwAttr->getErrorType());
+ builder->emitThrow(errParam);
+ builder->setInsertInto(succBlock);
+ }
return LoweredValInfo::simple(value);
}
break;
@@ -1982,8 +2019,9 @@ struct ValLoweringVisitor : ValVisitor<ValLoweringVisitor, LoweredValInfo, Lower
else
{
auto errorType = lowerType(context, type->getErrorType());
+ IRInst* operands[] = {errorType};
auto irThrowFuncTypeAttribute =
- getBuilder()->getAttr(kIROp_FuncThrowTypeAttr, 1, (IRInst**)&errorType);
+ getBuilder()->getAttr(kIROp_FuncThrowTypeAttr, 1, operands);
return getBuilder()->getFuncType(
paramCount,
paramTypes.getBuffer(),
@@ -3449,9 +3487,9 @@ void _lowerFuncDeclBaseTypeInfo(
if (!getErrorCodeType(context->astBuilder, declRef)
->equals(context->astBuilder->getBottomType()))
{
- auto errorType = lowerType(context, getErrorCodeType(context->astBuilder, declRef));
- IRAttr* throwTypeAttr = nullptr;
- throwTypeAttr = builder->getAttr(kIROp_FuncThrowTypeAttr, 1, (IRInst**)&errorType);
+ auto irErrorType = lowerType(context, getErrorCodeType(context->astBuilder, declRef));
+ IRInst* operands[] = {irErrorType};
+ IRAttr* throwTypeAttr = builder->getAttr(kIROp_FuncThrowTypeAttr, 1, operands);
outInfo.type = builder->getFuncType(
paramTypes.getCount(),
paramTypes.getBuffer(),
@@ -6671,6 +6709,117 @@ struct StmtLoweringVisitor : StmtVisitor<StmtLoweringVisitor>
builder->setInsertInto(mergeBlock);
}
+ void visitThrowStmt(ThrowStmt* stmt)
+ {
+ auto builder = getBuilder();
+ startBlockIfNeeded(stmt);
+
+ auto loweredExpr = lowerRValueExpr(context, stmt->expression);
+ auto loweredVal = getSimpleVal(context, loweredExpr);
+ auto throwType = lowerType(context, stmt->expression->type);
+
+ CatchHandler handler;
+ if (loweredVal && throwType)
+ {
+ handler = findErrorHandler(context, throwType);
+ }
+
+ if (handler.errorHandler)
+ {
+ builder->emitBranch(handler.errorHandler, 1, &loweredVal);
+ }
+ else
+ {
+ builder->emitThrow(getSimpleVal(context, loweredExpr));
+ }
+ }
+
+ void visitCatchStmt(CatchStmt* stmt)
+ {
+ auto builder = getBuilder();
+ startBlockIfNeeded(stmt);
+
+ // The mental model here is that the below Catch statement:
+ //
+ // let val = try MayThrowFunc();
+ // // Do stuff with val
+ // catch(err: Error)
+ // {
+ // catchBlock(err);
+ // }
+ //
+ // lowers similarly to:
+ //
+ // handlerLoop: for(;;)
+ // {
+ // E err; // Actually just a parameter for the catchBlock, not a real variable.
+ // bodyLoop: for(;;)
+ // {
+ // // Body goes here
+ // Result<T, E> r = mayThrowFunc();
+ // if(isResultError(r))
+ // {
+ // err = r.error;
+ // break bodyLoop;
+ // }
+ // let val = r.getSuccessValue();
+ // // Do stuff with val
+ // break handlerLoop;
+ // }
+ // catchBlock(err);
+ // break handlerLoop;
+ // }
+ //
+ // This approach allows for it to generate valid SPIR-V. Just jumping
+ // around with unstructured conditional jumps doesn't work there.
+
+ IRBlock* handlerLoopHead = createBlock();
+ IRBlock* handlerBreakLabel = createBlock();
+ IRBlock* bodyLoopHead = createBlock();
+ IRBlock* bodyBreakLabel = createBlock();
+
+ builder->emitLoop(handlerLoopHead, handlerBreakLabel, handlerLoopHead);
+ insertBlock(handlerLoopHead);
+
+ builder->emitLoop(bodyLoopHead, bodyBreakLabel, bodyLoopHead);
+ insertBlock(bodyLoopHead);
+
+ CatchHandler catchHandler;
+ catchHandler.errorType =
+ stmt->errorVar ? lowerType(context, stmt->errorVar->getType()) : nullptr;
+ catchHandler.errorHandler = bodyBreakLabel;
+ catchHandler.prev = context->catchHandler;
+ context->catchHandler = &catchHandler;
+
+ // Note that the tryBody doesn't actually have to have it's own scope or
+ // block. If there's a `defer` in the tryBody, it can run after the
+ // catch statement.
+ lowerStmt(context, stmt->tryBody);
+
+ // Put break; at the end of the body if there's nothing else there yet.
+ // This prevents the catch handler from running.
+ emitBranchIfNeeded(handlerBreakLabel);
+
+ context->catchHandler = catchHandler.prev;
+
+ insertBlock(bodyBreakLabel);
+
+ if (catchHandler.errorType)
+ {
+ auto irParam = builder->emitParam(catchHandler.errorType);
+ auto paramVal = LoweredValInfo::simple(irParam);
+ context->setGlobalValue(stmt->errorVar, paramVal);
+ }
+
+ IRBlock* prevScopeEndBlock = pushScopeBlock(handlerBreakLabel);
+ lowerStmt(context, stmt->handleBody);
+ popScopeBlock(prevScopeEndBlock, true);
+
+ emitBranchIfNeeded(handlerBreakLabel);
+
+ insertBlock(handlerBreakLabel);
+ }
+
void visitDiscardStmt(DiscardStmt* stmt)
{
startBlockIfNeeded(stmt);
@@ -8536,6 +8685,7 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo>
subContextStorage.returnDestination = LoweredValInfo();
subContextStorage.lowerTypeCallback = nullptr;
+ subContextStorage.catchHandler = nullptr;
}
IRBuilder* getBuilder() { return &subBuilderStorage; }
diff --git a/source/slang/slang-parser.cpp b/source/slang/slang-parser.cpp
index b2a006adc..f968f9fe1 100644
--- a/source/slang/slang-parser.cpp
+++ b/source/slang/slang-parser.cpp
@@ -211,11 +211,14 @@ public:
Stmt* parseIfLetStatement();
ForStmt* ParseForStatement();
WhileStmt* ParseWhileStatement();
- DoWhileStmt* ParseDoWhileStatement();
+ DoWhileStmt* ParseDoWhileStatement(Stmt* body);
+ CatchStmt* ParseDoCatchStatement(Stmt* body);
+ Stmt* ParseDoStatement();
BreakStmt* ParseBreakStatement();
ContinueStmt* ParseContinueStatement();
ReturnStmt* ParseReturnStatement();
DeferStmt* ParseDeferStatement();
+ ThrowStmt* ParseThrowStatement();
ExpressionStmt* ParseExpressionStatement();
Expr* ParseExpression(Precedence level = Precedence::Comma);
@@ -5741,7 +5744,7 @@ Stmt* Parser::ParseStatement(Stmt* parentStmt)
else if (LookAheadToken("while"))
statement = ParseWhileStatement();
else if (LookAheadToken("do"))
- statement = ParseDoWhileStatement();
+ statement = ParseDoStatement();
else if (LookAheadToken("break"))
statement = ParseBreakStatement();
else if (LookAheadToken("continue"))
@@ -5781,6 +5784,10 @@ Stmt* Parser::ParseStatement(Stmt* parentStmt)
{
statement = ParseExpressionStatement();
}
+ else if (LookAheadToken("throw"))
+ {
+ statement = ParseThrowStatement();
+ }
else if (LookAheadToken(TokenType::Identifier) || LookAheadToken(TokenType::Scope))
{
if (LookAheadToken(TokenType::Identifier) && LookAheadToken(TokenType::Colon, 1))
@@ -5941,7 +5948,6 @@ Stmt* Parser::parseBlockStatement()
Stmt* body = nullptr;
-
if (!tokenReader.isAtEnd())
{
FillPosition(blockStatement);
@@ -6259,12 +6265,11 @@ WhileStmt* Parser::ParseWhileStatement()
return whileStatement;
}
-DoWhileStmt* Parser::ParseDoWhileStatement()
+DoWhileStmt* Parser::ParseDoWhileStatement(Stmt* body)
{
DoWhileStmt* doWhileStatement = astBuilder->create<DoWhileStmt>();
FillPosition(doWhileStatement);
- ReadToken("do");
- doWhileStatement->statement = ParseStatement();
+ doWhileStatement->statement = body;
ReadToken("while");
ReadToken(TokenType::LParent);
doWhileStatement->predicate = ParseExpression();
@@ -6273,6 +6278,62 @@ DoWhileStmt* Parser::ParseDoWhileStatement()
return doWhileStatement;
}
+CatchStmt* Parser::ParseDoCatchStatement(Stmt* body)
+{
+ for (;;)
+ {
+ ScopeDecl* scopeDecl = astBuilder->create<ScopeDecl>();
+ pushScopeAndSetParent(scopeDecl);
+
+ CatchStmt* catchStatement = astBuilder->create<CatchStmt>();
+ FillPosition(catchStatement);
+ ReadToken("catch");
+
+ // Optional error parameter. If not given, the catch catches all error
+ // types.
+ if (AdvanceIf(this, TokenType::LParent))
+ {
+ ParamDecl* errorVar = parseModernParamDecl(this);
+ catchStatement->errorVar = errorVar;
+ AddMember(scopeDecl, errorVar);
+ ReadToken(TokenType::RParent);
+ }
+
+ catchStatement->tryBody = body;
+ catchStatement->handleBody = ParseStatement();
+
+ PopScope();
+
+ if (!LookAheadToken("catch"))
+ return catchStatement;
+
+ // Use this catch as the body for the next one, if multiple are chained.
+ body = catchStatement;
+ }
+}
+
+Stmt* Parser::ParseDoStatement()
+{
+ SourceLoc position = tokenReader.peekLoc();
+ ReadToken("do");
+ Stmt* statement = ParseStatement();
+ if (LookAheadToken("while"))
+ {
+ Stmt* whileStatement = ParseDoWhileStatement(statement);
+ whileStatement->loc = position;
+ return whileStatement;
+ }
+ else if (LookAheadToken("catch"))
+ {
+ return ParseDoCatchStatement(statement);
+ }
+ else
+ {
+ Unexpected(this, "while' or 'catch");
+ return statement;
+ }
+}
+
BreakStmt* Parser::ParseBreakStatement()
{
BreakStmt* breakStatement = astBuilder->create<BreakStmt>();
@@ -6315,6 +6376,15 @@ DeferStmt* Parser::ParseDeferStatement()
return deferStatement;
}
+ThrowStmt* Parser::ParseThrowStatement()
+{
+ ThrowStmt* throwStatement = astBuilder->create<ThrowStmt>();
+ FillPosition(throwStatement);
+ ReadToken("throw");
+ throwStatement->expression = ParseExpression();
+ return throwStatement;
+}
+
ExpressionStmt* Parser::ParseExpressionStatement()
{
ExpressionStmt* statement = astBuilder->create<ExpressionStmt>();
diff --git a/tests/language-feature/error-handling/basic.slang b/tests/language-feature/error-handling/basic.slang
new file mode 100644
index 000000000..4cb270cb8
--- /dev/null
+++ b/tests/language-feature/error-handling/basic.slang
@@ -0,0 +1,91 @@
+//TEST(compute):COMPARE_COMPUTE(filecheck-buffer=CHECK): -shaderobj
+//TEST(compute):COMPARE_COMPUTE(filecheck-buffer=CHECK): -vk -shaderobj
+//TEST(compute):COMPARE_COMPUTE(filecheck-buffer=CHECK): -cpu -shaderobj
+
+// CHECK: 2
+// CHECK-NEXT: 0
+// CHECK-NEXT: 11
+// CHECK-NEXT: 12
+// CHECK-NEXT: 6
+// CHECK-NEXT: 1
+
+//TEST_INPUT:ubuffer(data=[0 0 0 0 0 0], stride=4):out,name=outputBuffer
+RWStructuredBuffer<int> outputBuffer;
+
+enum MyError1
+{
+ Fail
+};
+
+enum MyError2
+{
+ Fail1,
+ Fail2 = 0x12
+};
+
+void throwingFunc() throws MyError1
+{
+ throw MyError1.Fail;
+}
+
+int maybeBadFunc1(int n) throws MyError1
+{
+ if (n == 1) throw MyError1.Fail;
+ return n;
+}
+
+int maybeBadFunc2(int n) throws MyError2
+{
+ if (n == 2) throw MyError2.Fail2;
+ return n;
+}
+
+int multiCatchFunc(int n)
+{
+ do
+ {
+ let a = try maybeBadFunc1(n);
+ let b = try maybeBadFunc2(n);
+ return a+b;
+ }
+ catch(err: MyError1)
+ {
+ return 0x11;
+ }
+ catch(err: MyError2)
+ {
+ return reinterpret<int>(err);
+ }
+}
+
+int containedThrow()
+{
+ do
+ {
+ throw MyError1.Fail;
+ }
+ catch(err: MyError1)
+ {
+ return 1;
+ }
+}
+
+[numthreads(1, 1, 1)]
+void computeMain(int3 dispatchThreadID: SV_DispatchThreadID)
+{
+ do
+ {
+ try throwingFunc();
+ outputBuffer[0] = 1;
+ }
+ catch(err: MyError1)
+ {
+ outputBuffer[0] = 2;
+ }
+
+ outputBuffer[1] = multiCatchFunc(0);
+ outputBuffer[2] = multiCatchFunc(1);
+ outputBuffer[3] = multiCatchFunc(2);
+ outputBuffer[4] = multiCatchFunc(3);
+ outputBuffer[5] = containedThrow();
+}
diff --git a/tests/language-feature/error-handling/catch-all.slang b/tests/language-feature/error-handling/catch-all.slang
new file mode 100644
index 000000000..afa0693e4
--- /dev/null
+++ b/tests/language-feature/error-handling/catch-all.slang
@@ -0,0 +1,57 @@
+//TEST(compute):COMPARE_COMPUTE(filecheck-buffer=CHECK): -shaderobj
+//TEST(compute):COMPARE_COMPUTE(filecheck-buffer=CHECK): -vk -shaderobj
+//TEST(compute):COMPARE_COMPUTE(filecheck-buffer=CHECK): -cpu -shaderobj
+
+//TEST_INPUT:ubuffer(data=[0 0 0 0], stride=4):out,name=outputBuffer
+RWStructuredBuffer<int> outputBuffer;
+
+enum MyError1
+{
+ Fail = 0x10
+};
+
+enum MyError2
+{
+ Fail = 0x20
+};
+
+int f(int n) throws MyError1
+{
+ if (n == 1) throw MyError1.Fail;
+ return n;
+}
+
+int g(int n) throws MyError2
+{
+ if (n == 2) throw MyError2.Fail;
+ return n;
+}
+
+void handlerFunc(int i, int n)
+{
+ do
+ {
+ int a = try f(n);
+ int b = try g(n);
+ int c = a+b+1;
+ outputBuffer[i] = c;
+ }
+ catch(err: MyError1)
+ {
+ outputBuffer[i] = reinterpret<int>(err);
+ }
+ catch
+ {
+ outputBuffer[i] = 0x30;
+ }
+}
+
+[numthreads(1, 1, 1)]
+void computeMain(int3 dispatchThreadID: SV_DispatchThreadID)
+{
+ int i = 0;
+ handlerFunc(0, 0); // CHECK: 1
+ handlerFunc(1, 1); // CHECK-NEXT: 10
+ handlerFunc(2, 2); // CHECK-NEXT: 30
+ handlerFunc(3, 3); // CHECK-NEXT 7
+}
diff --git a/tests/language-feature/error-handling/defer-interaction.slang b/tests/language-feature/error-handling/defer-interaction.slang
new file mode 100644
index 000000000..0b3f9f829
--- /dev/null
+++ b/tests/language-feature/error-handling/defer-interaction.slang
@@ -0,0 +1,58 @@
+//TEST(compute):COMPARE_COMPUTE(filecheck-buffer=CHECK): -shaderobj
+//TEST(compute):COMPARE_COMPUTE(filecheck-buffer=CHECK): -vk -shaderobj
+//TEST(compute):COMPARE_COMPUTE(filecheck-buffer=CHECK): -cpu -shaderobj
+
+//TEST_INPUT:ubuffer(data=[0 0 0 0 0], stride=4):out,name=outputBuffer
+RWStructuredBuffer<int> outputBuffer;
+
+enum MyError
+{
+ Fail
+};
+
+int maybeThrowingFunc(int n) throws MyError
+{
+ if (n == 3)
+ throw MyError.Fail;
+ return n;
+}
+
+void testFunc(int n, inout int i)
+{
+ int value = n;
+ defer
+ {
+ outputBuffer[i++] = value;
+ }
+
+ defer
+ {
+ do
+ {
+ let m = try maybeThrowingFunc(n);
+ value += m;
+ }
+ catch(err: MyError)
+ {
+ defer
+ {
+ outputBuffer[i++] = 0x80;
+ }
+ outputBuffer[i++] = 0xFF;
+ }
+ }
+}
+
+[numthreads(1, 1, 1)]
+void computeMain(int3 dispatchThreadID: SV_DispatchThreadID)
+{
+ int i = 0;
+ // CHECK: 2
+ testFunc(1, i);
+ // CHECK-NEXT: 4
+ testFunc(2, i);
+ // CHECK-NEXT: FF
+ // CHECK-NEXT: 80
+ // CHECK-NEXT: 3
+ testFunc(3, i);
+}
diff --git a/tests/language-feature/error-handling/generics.slang b/tests/language-feature/error-handling/generics.slang
new file mode 100644
index 000000000..377ef622d
--- /dev/null
+++ b/tests/language-feature/error-handling/generics.slang
@@ -0,0 +1,47 @@
+//TEST(compute):COMPARE_COMPUTE(filecheck-buffer=CHECK): -shaderobj
+//TEST(compute):COMPARE_COMPUTE(filecheck-buffer=CHECK): -vk -shaderobj
+//TEST(compute):COMPARE_COMPUTE(filecheck-buffer=CHECK): -cpu -shaderobj
+
+// CHECK: 5
+// CHECK-NEXT: 1
+// CHECK-NEXT: 0
+
+//TEST_INPUT:ubuffer(data=[0 0 0], stride=4):out,name=outputBuffer
+RWStructuredBuffer<int> outputBuffer;
+
+enum MyError
+{
+ Fail = 1
+};
+
+T func<T: __BuiltinFloatingPointType>(T val) throws MyError
+{
+ do
+ {
+ if (val >= T(3))
+ throw MyError.Fail;
+ return val * T(2);
+ }
+ catch(err: MyError)
+ {
+ // Just rethrow to test catching inside a generic as well.
+ throw err;
+ }
+}
+
+[numthreads(1, 1, 1)]
+void computeMain(int3 dispatchThreadID: SV_DispatchThreadID)
+{
+ int i = 0;
+ do
+ {
+ outputBuffer[i] = int(try func(2.5f));
+ i+=1;
+ outputBuffer[i] = int(try func(3.5f));
+ i+=1;
+ }
+ catch(err: MyError)
+ {
+ outputBuffer[i] = reinterpret<int>(err);
+ }
+}
diff --git a/tests/language-feature/error-handling/non-trivial-error-type.slang b/tests/language-feature/error-handling/non-trivial-error-type.slang
new file mode 100644
index 000000000..9e03536d3
--- /dev/null
+++ b/tests/language-feature/error-handling/non-trivial-error-type.slang
@@ -0,0 +1,40 @@
+//TEST(compute):COMPARE_COMPUTE(filecheck-buffer=CHECK): -shaderobj
+//TEST(compute):COMPARE_COMPUTE(filecheck-buffer=CHECK): -vk -shaderobj
+//TEST(compute):COMPARE_COMPUTE(filecheck-buffer=CHECK): -cpu -shaderobj
+
+// CHECK: 2
+// CHECK-NEXT: 13
+// CHECK-NEXT: 0
+
+//TEST_INPUT:ubuffer(data=[0 0 0], stride=4):out,name=outputBuffer
+RWStructuredBuffer<int> outputBuffer;
+
+struct MyError
+{
+ int code = 0;
+ int param = 0;
+};
+
+int func(int val) throws MyError
+{
+ if (val >= 3)
+ throw MyError(1, val);
+ return val * 2;
+}
+
+[numthreads(1, 1, 1)]
+void computeMain(int3 dispatchThreadID: SV_DispatchThreadID)
+{
+ int i = 0;
+ do
+ {
+ outputBuffer[i] = try func(1);
+ i+=1;
+ outputBuffer[i] = try func(3);
+ i+=1;
+ }
+ catch(err: MyError)
+ {
+ outputBuffer[i] = err.code * 0x10 + err.param;
+ }
+}
diff --git a/tests/language-feature/error-handling/throw-in-defer.slang b/tests/language-feature/error-handling/throw-in-defer.slang
new file mode 100644
index 000000000..d19bf227e
--- /dev/null
+++ b/tests/language-feature/error-handling/throw-in-defer.slang
@@ -0,0 +1,17 @@
+//DIAGNOSTIC_TEST:SIMPLE(filecheck=CHECK):
+enum MyError
+{
+ Fail
+}
+
+void f() throws MyError
+{
+ defer {
+ // Throw isn't allowed to escape defer for the same reason as 'return',
+ // it'd prevent other defer statements from running. This is legal if
+ // you catch it, though.
+ throw MyError.Fail;
+ }
+}
+
+// CHECK: error 30113
diff --git a/tests/language-feature/error-handling/throw-type-mismatch.slang b/tests/language-feature/error-handling/throw-type-mismatch.slang
new file mode 100644
index 000000000..7e5278962
--- /dev/null
+++ b/tests/language-feature/error-handling/throw-type-mismatch.slang
@@ -0,0 +1,17 @@
+//DIAGNOSTIC_TEST:SIMPLE(filecheck=CHECK):
+enum MyError1
+{
+ Fail
+}
+
+enum MyError2
+{
+ Fail
+}
+
+int g() throws MyError1
+{
+ throw MyError2.Fail;
+}
+
+// CHECK: error 30116
diff --git a/tests/language-feature/error-handling/throw-without-throws.slang b/tests/language-feature/error-handling/throw-without-throws.slang
new file mode 100644
index 000000000..e37198b82
--- /dev/null
+++ b/tests/language-feature/error-handling/throw-without-throws.slang
@@ -0,0 +1,19 @@
+//DIAGNOSTIC_TEST:SIMPLE(filecheck=CHECK):
+enum MyError
+{
+ Fail
+}
+
+int g() throws MyError
+{
+ throw MyError.Fail;
+}
+
+void f()
+{
+ let n = try g();
+ throw MyError.Fail;
+}
+
+// CHECK: error 30093
+// CHECK: error 30115
diff --git a/tests/language-feature/error-handling/try-in-defer.slang b/tests/language-feature/error-handling/try-in-defer.slang
new file mode 100644
index 000000000..54cd18e12
--- /dev/null
+++ b/tests/language-feature/error-handling/try-in-defer.slang
@@ -0,0 +1,19 @@
+//DIAGNOSTIC_TEST:SIMPLE(filecheck=CHECK):
+enum MyError
+{
+ Fail
+}
+
+int g() throws MyError
+{
+ throw MyError.Fail;
+}
+
+void f() throws MyError
+{
+ defer {
+ let n = try g();
+ }
+}
+
+// CHECK: error 30114