From 57c3f938221c427b78da7087f8a832ba4a271a7c Mon Sep 17 00:00:00 2001 From: Julius Ikkala Date: Fri, 23 May 2025 22:27:37 +0300 Subject: Implement throw & catch statements (#6916) * Implement throw statement It already existed in the IR, so only parsing, checking and lowering was missing. * Initial catch implementation Likely very broken. * Error out when catch() isn't last in scope * Prevent accessing variables from scope preceding catch As those may actually not be available at that point. * Add IError and use it in Result type lowering * Add diagnostic tests * Allow caught throws in non-throw functions * Fix catch propagating between functions & SPIR-V merge issue * Add test for non-trivial error types * Fix MSVC build * Fix invalid value type from Result lowering * Also lower error handling in templates * Lower result types only after specialization * Attempt to disambiguate error enums by witness table * Revert matching by witness, types should be distinct too * Don't assert valueField when getting Result's error value It may not exist if the function returns void, but getting the error value is still legitimate. * Update tests for new error numbers & get rid of expected.txt * Change catch lowering to resemble breaking a loop ... To make SPIR-V happy. * Fix dead catch blocks and invalid cached dominator tree * More SPIR-V adjustment * Lower catch as two nested loops * Add defer interaction test and revert broken defer changes * Fix enum type when throwing literals * Cleanup and bikeshedding * Document error handling mechanism * Fix table of contents * Use boolean tag in Result * Use anyValue storage for Result * Remove IError * Fix formatting * Eradicate success values from docs and tests * Use parseModernParamDecl for catch parameter * Implement do-catch syntax * Implement catch-all * Fix formatting * Fix marshalling native calls that throw --------- Co-authored-by: Yong He --- source/slang/slang-ast-iterator.h | 14 ++ source/slang/slang-ast-stmt.h | 16 +++ source/slang/slang-check-decl.cpp | 9 ++ source/slang/slang-check-expr.cpp | 69 +++++---- source/slang/slang-check-impl.h | 23 ++- source/slang/slang-check-stmt.cpp | 76 ++++++++-- source/slang/slang-diagnostic-defs.h | 22 +++ source/slang/slang-emit.cpp | 10 +- source/slang/slang-ir-lower-error-handling.cpp | 48 ++++--- source/slang/slang-ir-lower-result-type.cpp | 155 +++++++++----------- source/slang/slang-ir-marshal-native-call.cpp | 39 +++-- source/slang/slang-ir.cpp | 6 + source/slang/slang-language-server-ast-lookup.cpp | 11 ++ source/slang/slang-lower-to-ir.cpp | 168 ++++++++++++++++++++-- source/slang/slang-parser.cpp | 82 ++++++++++- 15 files changed, 572 insertions(+), 176 deletions(-) (limited to 'source') diff --git a/source/slang/slang-ast-iterator.h b/source/slang/slang-ast-iterator.h index 0a4b17cc0..047c33d1a 100644 --- a/source/slang/slang-ast-iterator.h +++ b/source/slang/slang-ast-iterator.h @@ -438,6 +438,20 @@ struct ASTIterator dispatchIfNotNull(stmt->statement); } + void visitThrowStmt(ThrowStmt* stmt) + { + iterator->maybeDispatchCallback(stmt); + iterator->visitExpr(stmt->expression); + } + + void visitCatchStmt(CatchStmt* stmt) + { + if (stmt->errorVar) + iterator->visitDecl(stmt->errorVar); + dispatchIfNotNull(stmt->tryBody); + dispatchIfNotNull(stmt->handleBody); + } + void visitWhileStmt(WhileStmt* stmt) { iterator->maybeDispatchCallback(stmt); diff --git a/source/slang/slang-ast-stmt.h b/source/slang/slang-ast-stmt.h index a1b7c274e..1fe97adf1 100644 --- a/source/slang/slang-ast-stmt.h +++ b/source/slang/slang-ast-stmt.h @@ -294,6 +294,22 @@ class DeferStmt : public Stmt FIDDLE() Stmt* statement = nullptr; }; +FIDDLE() +class ThrowStmt : public Stmt +{ + FIDDLE(...) + FIDDLE() Expr* expression = nullptr; +}; + +FIDDLE() +class CatchStmt : public Stmt +{ + FIDDLE(...) + FIDDLE() ParamDecl* errorVar = nullptr; // null => catch-all + FIDDLE() Stmt* tryBody = nullptr; + FIDDLE() Stmt* handleBody = nullptr; +}; + FIDDLE() class ExpressionStmt : public Stmt { diff --git a/source/slang/slang-check-decl.cpp b/source/slang/slang-check-decl.cpp index 1331839a4..7a0dcb06f 100644 --- a/source/slang/slang-check-decl.cpp +++ b/source/slang/slang-check-decl.cpp @@ -856,6 +856,15 @@ struct SemanticsDeclReferenceVisitor : public SemanticsDeclVisitorBase, void visitDeferStmt(DeferStmt* stmt) { dispatchIfNotNull(stmt->statement); } + void visitThrowStmt(ThrowStmt* stmt) { dispatchIfNotNull(stmt->expression); } + + void visitCatchStmt(CatchStmt* stmt) + { + dispatchIfNotNull(stmt->errorVar); + dispatchIfNotNull(stmt->tryBody); + dispatchIfNotNull(stmt->handleBody); + } + void visitWhileStmt(WhileStmt* stmt) { dispatchIfNotNull(stmt->predicate); diff --git a/source/slang/slang-check-expr.cpp b/source/slang/slang-check-expr.cpp index db507c060..75b1b7024 100644 --- a/source/slang/slang-check-expr.cpp +++ b/source/slang/slang-check-expr.cpp @@ -3951,45 +3951,64 @@ Expr* SemanticsExprVisitor::visitTryExpr(TryExpr* expr) return expr; auto parentFunc = this->m_parentFunc; - // TODO: check if the try clause is caught. - // For now we assume all `try`s are not caught (because we don't have catch yet). - if (!parentFunc) + auto base = as(expr->base); + if (!base) { - getSink()->diagnose(expr, Diagnostics::uncaughtTryCallInNonThrowFunc); + getSink()->diagnose(expr, Diagnostics::tryClauseMustApplyToInvokeExpr); return expr; } - if (parentFunc->errorType->equals(m_astBuilder->getBottomType())) + + auto callee = as(base->functionExpr); + if (!callee) { - getSink()->diagnose(expr, Diagnostics::uncaughtTryCallInNonThrowFunc); + getSink()->diagnose(expr, Diagnostics::calleeOfTryCallMustBeFunc); return expr; } - if (!as(expr->base)) + + auto funcCallee = as(callee->declRef.getDecl()); + Stmt* catchStmt = nullptr; + if (funcCallee) { - getSink()->diagnose(expr, Diagnostics::tryClauseMustApplyToInvokeExpr); + if (funcCallee->errorType->equals(m_astBuilder->getBottomType())) + { + getSink()->diagnose(expr, Diagnostics::tryInvokeCalleeShouldThrow, callee->declRef); + return expr; + } + catchStmt = findMatchingCatchStmt(funcCallee->errorType); + } + + if (FindOuterStmt(catchStmt)) + { + // 'try' may jump outside a defer statement, which isn't allowed for + // now. + getSink()->diagnose(expr, Diagnostics::uncaughtTryInsideDefer); return expr; } - auto base = as(expr->base); - if (auto callee = as(base->functionExpr)) + + if (!catchStmt) { - if (auto funcCallee = as(callee->declRef.getDecl())) + // Uncaught try. + if (!parentFunc) { - if (funcCallee->errorType->equals(m_astBuilder->getBottomType())) - { - getSink()->diagnose(expr, Diagnostics::tryInvokeCalleeShouldThrow, callee->declRef); - } - if (!parentFunc->errorType->equals(funcCallee->errorType)) - { - getSink()->diagnose( - expr, - Diagnostics::errorTypeOfCalleeIncompatibleWithCaller, - callee->declRef, - funcCallee->errorType, - parentFunc->errorType); - } + getSink()->diagnose(expr, Diagnostics::uncaughtTryCallInNonThrowFunc); + return expr; + } + if (parentFunc->errorType->equals(m_astBuilder->getBottomType())) + { + getSink()->diagnose(expr, Diagnostics::uncaughtTryCallInNonThrowFunc); + return expr; + } + if (funcCallee && !parentFunc->errorType->equals(funcCallee->errorType)) + { + getSink()->diagnose( + expr, + Diagnostics::errorTypeOfCalleeIncompatibleWithCaller, + callee->declRef, + funcCallee->errorType, + parentFunc->errorType); return expr; } } - getSink()->diagnose(expr, Diagnostics::calleeOfTryCallMustBeFunc); return expr; } diff --git a/source/slang/slang-check-impl.h b/source/slang/slang-check-impl.h index 87f2df53d..950a150c4 100644 --- a/source/slang/slang-check-impl.h +++ b/source/slang/slang-check-impl.h @@ -1028,6 +1028,20 @@ public: return result; } + template + T* FindOuterStmt(Stmt* searchUntil = nullptr) + { + for (auto outerStmtInfo = m_outerStmts; outerStmtInfo && outerStmtInfo->stmt != searchUntil; + outerStmtInfo = outerStmtInfo->next) + { + auto outerStmt = outerStmtInfo->stmt; + auto found = as(outerStmt); + if (found) + return found; + } + return nullptr; + } + // Setup the flag to indicate disabling the short-circuiting evaluation // for the logical expressions associted with the subcontext SemanticsContext disableShortCircuitLogicalExpr() @@ -2867,6 +2881,8 @@ public: void addVisibilityModifier(Decl* decl, DeclVisibility vis); void checkRayPayloadStructFields(StructDecl* structDecl); + + CatchStmt* findMatchingCatchStmt(Type* errorType); }; @@ -3011,9 +3027,6 @@ struct SemanticsStmtVisitor : public SemanticsVisitor, StmtVisitor - T* FindOuterStmt(Stmt* searchUntil = nullptr); - Stmt* findOuterStmtWithLabel(Name* label); void visitDeclStmt(DeclStmt* stmt); @@ -3058,6 +3071,10 @@ struct SemanticsStmtVisitor : public SemanticsVisitor, StmtVisitornext) + { + if (auto catchStmt = as(outerStmtInfo->stmt)) + { + if (!catchStmt->errorVar || catchStmt->errorVar->getType()->equals(errorType)) + return catchStmt; + } + } + return nullptr; +} + void SemanticsStmtVisitor::visitDeclStmt(DeclStmt* stmt) { // When we encounter a declaration during statement checking, @@ -118,20 +131,6 @@ void SemanticsStmtVisitor::checkStmt(Stmt* stmt) SemanticsVisitor::checkStmt(stmt, *this); } -template -T* SemanticsStmtVisitor::FindOuterStmt(Stmt* searchUntil) -{ - for (auto outerStmtInfo = m_outerStmts; outerStmtInfo && outerStmtInfo->stmt != searchUntil; - outerStmtInfo = outerStmtInfo->next) - { - auto outerStmt = outerStmtInfo->stmt; - auto found = as(outerStmt); - if (found) - return found; - } - return nullptr; -} - Stmt* SemanticsStmtVisitor::findOuterStmtWithLabel(Name* label) { for (auto outerStmtInfo = m_outerStmts; outerStmtInfo; outerStmtInfo = outerStmtInfo->next) @@ -616,6 +615,55 @@ void SemanticsStmtVisitor::visitDeferStmt(DeferStmt* stmt) subContext.checkStmt(stmt->statement); } +void SemanticsStmtVisitor::visitThrowStmt(ThrowStmt* stmt) +{ + stmt->expression = CheckTerm(stmt->expression); + Stmt* catchStmt = findMatchingCatchStmt(stmt->expression->type); + + auto parentFunc = getParentFunc(); + if (!catchStmt && (!parentFunc || parentFunc->errorType->equals(m_astBuilder->getBottomType()))) + { + getSink()->diagnose(stmt, Diagnostics::uncaughtThrowInNonThrowFunc); + return; + } + + if (!catchStmt && !stmt->expression->type->equals(m_astBuilder->getErrorType())) + { + if (!parentFunc->errorType->equals(stmt->expression->type)) + { + getSink()->diagnose( + stmt->expression, + Diagnostics::throwTypeIncompatibleWithErrorType, + stmt->expression->type, + parentFunc->errorType); + } + } + + if (FindOuterStmt(catchStmt)) + { + // Allowing 'throw' to escape a defer statement gets quite complex, for + // similar reasons as 'return' - if you have two (or more) defers, + // both of which exit the outer scope, it's unclear which one gets + // called and when. Both can't fully run. That kind of goes against the + // point of 'defer', which is to _always_ run some code when exiting + // scopes. + getSink()->diagnose(stmt, Diagnostics::uncaughtThrowInsideDefer); + } +} + +void SemanticsStmtVisitor::visitCatchStmt(CatchStmt* stmt) +{ + if (stmt->errorVar) + { + ensureDeclBase(stmt->errorVar, DeclCheckState::DefinitionChecked, this); + stmt->errorVar->hiddenFromLookup = false; + } + + WithOuterStmt subContext(this, stmt); + subContext.checkStmt(stmt->tryBody); + subContext.checkStmt(stmt->handleBody); +} + void SemanticsStmtVisitor::visitExpressionStmt(ExpressionStmt* stmt) { stmt->expression = CheckExpr(stmt->expression); diff --git a/source/slang/slang-diagnostic-defs.h b/source/slang/slang-diagnostic-defs.h index bf90c6608..3babc9a56 100644 --- a/source/slang/slang-diagnostic-defs.h +++ b/source/slang/slang-diagnostic-defs.h @@ -1001,6 +1001,28 @@ DIAGNOSTIC( nonCopyableTypeCapturedInLambda, "cannot capture non-copyable type '$0' in a lambda expression.") +DIAGNOSTIC( + 30113, + Error, + uncaughtThrowInsideDefer, + "'throw' expressions require a matching 'catch' inside a defer statement.") +DIAGNOSTIC( + 30114, + Error, + uncaughtTryInsideDefer, + "'try' expressions require a matching 'catch' inside a defer statement.") +DIAGNOSTIC( + 30115, + Error, + uncaughtThrowInNonThrowFunc, + "the current function or environment is not declared to throw any errors, but contains an " + "uncaught 'throw' statement.") +DIAGNOSTIC( + 30116, + Error, + throwTypeIncompatibleWithErrorType, + "the type `$0` of `throw` is not compatible with function's error type `$1`.") + // Include DIAGNOSTIC( 30500, diff --git a/source/slang/slang-emit.cpp b/source/slang/slang-emit.cpp index 260bee0ff..a4362b912 100644 --- a/source/slang/slang-emit.cpp +++ b/source/slang/slang-emit.cpp @@ -806,10 +806,6 @@ Result linkAndOptimizeIR( break; } - // Lower `Result` types into ordinary struct types. - if (requiredLoweringPassSet.resultType) - lowerResultType(irModule, sink); - #if 0 dumpIRIfEnabled(codeGenContext, irModule, "UNIONS DESUGARED"); #endif @@ -951,6 +947,12 @@ Result linkAndOptimizeIR( break; } + // Lower `Result` types into ordinary struct types. This must happen + // after specialization, since otherwise incompatible copies of the lowered + // result structure are generated. + if (requiredLoweringPassSet.resultType) + lowerResultType(irModule, sink); + // Report checkpointing information if (codeGenContext->shouldReportCheckpointIntermediates()) { diff --git a/source/slang/slang-ir-lower-error-handling.cpp b/source/slang/slang-ir-lower-error-handling.cpp index 917658ac4..a885fc2d2 100644 --- a/source/slang/slang-ir-lower-error-handling.cpp +++ b/source/slang/slang-ir-lower-error-handling.cpp @@ -15,6 +15,7 @@ struct ErrorHandlingLoweringContext InstWorkList workList; InstHashSet workListSet; + List oldFuncTypes; ErrorHandlingLoweringContext(IRModule* inModule) : module(inModule), workList(inModule), workListSet(inModule) @@ -37,8 +38,10 @@ struct ErrorHandlingLoweringContext return; IRBuilder builder(module); builder.setInsertBefore(funcType); + auto resultType = builder.getResultType(funcType->getResultType(), throwAttr->getErrorType()); + List paramTypes; for (UInt i = 0; i < funcType->getParamCount(); i++) { @@ -48,6 +51,7 @@ struct ErrorHandlingLoweringContext } auto newFuncType = builder.getFuncType(paramTypes, resultType); funcType->replaceUsesWith(newFuncType); + funcType->removeAndDeallocate(); } void processTryCall(IRTryCall* tryCall) @@ -99,14 +103,27 @@ struct ErrorHandlingLoweringContext auto failBlock = tryCall->getFailureBlock(); auto successBlock = tryCall->getSuccessBlock(); - builder.emitIf(isFail, failBlock, successBlock); - - // Replace the params in failBlock to `getResultError(call)`. - builder.setInsertBefore(failBlock->getFirstOrdinaryInst()); - auto errorParam = failBlock->getFirstParam(); - auto errVal = builder.emitGetResultError(call); - errorParam->replaceUsesWith(errVal); - errorParam->removeAndDeallocate(); + if (failBlock->getFirstParam()) + { + // The isFail branch could otherwise just jump to the handler, but + // there's unfortunately the error parameter that needs to be passed as + // well, and it can't be done in IfElse. So there's an extra block in + // between to do that. + auto handlerJumpBlock = builder.createBlock(); + auto branch = builder.emitIf(isFail, handlerJumpBlock, successBlock); + + builder.setInsertAfter(branch->getParent()); + builder.addInst(handlerJumpBlock); + builder.setInsertInto(handlerJumpBlock); + + auto errVal = builder.emitGetResultError(call); + builder.emitBranch(failBlock, 1, &errVal); + } + else + { + // Catch-all with no parameter, so we can just jump to it directly. + builder.emitIf(isFail, failBlock, successBlock); + } // Replace the params in successBlock to `getResultValue(call)`. builder.setInsertBefore(successBlock->getFirstOrdinaryInst()); @@ -173,6 +190,9 @@ struct ErrorHandlingLoweringContext case kIROp_Throw: processThrow(cast(inst)); break; + case kIROp_FuncType: + oldFuncTypes.add(cast(inst)); + break; default: break; } @@ -206,18 +226,6 @@ struct ErrorHandlingLoweringContext // Lower all functypes. // Function types with an IRThrowTypeAttribute will be translated into a normal function // type that returns `Result`. - List oldFuncTypes; - for (auto child : module->getGlobalInsts()) - { - switch (child->getOp()) - { - case kIROp_FuncType: - oldFuncTypes.add(cast(child)); - break; - default: - break; - } - } for (auto funcType : oldFuncTypes) { processFuncType(funcType); diff --git a/source/slang/slang-ir-lower-result-type.cpp b/source/slang/slang-ir-lower-result-type.cpp index 4cf684f33..3e7a2f523 100644 --- a/source/slang/slang-ir-lower-result-type.cpp +++ b/source/slang/slang-ir-lower-result-type.cpp @@ -2,6 +2,7 @@ #include "slang-ir-lower-result-type.h" +#include "slang-ir-any-value-marshalling.h" #include "slang-ir-insts.h" #include "slang-ir.h" @@ -23,11 +24,13 @@ struct ResultTypeLoweringContext struct LoweredResultTypeInfo : public RefObject { IRType* resultType = nullptr; - IRType* errorType = nullptr; - IRType* valueType = nullptr; IRType* loweredType = nullptr; - IRStructField* valueField = nullptr; - IRStructField* errorField = nullptr; + IRType* tagType = nullptr; + IRType* valueType = nullptr; + IRType* errorType = nullptr; + IRType* anyValueType = nullptr; + IRStructField* tagField = nullptr; + IRStructField* anyValueField = nullptr; }; Dictionary> mapLoweredTypeToResultTypeInfo; Dictionary> loweredResultTypes; @@ -53,31 +56,40 @@ struct ResultTypeLoweringContext return nullptr; RefPtr info = new LoweredResultTypeInfo(); + auto resultType = cast(type); info->resultType = (IRType*)type; - info->errorType = cast(type)->getErrorType(); - auto resultType = cast(type); + auto structType = builder->createStructType(); + info->loweredType = structType; + builder->addNameHintDecoration(structType, UnownedStringSlice("ResultType")); + + info->tagType = builder->getBoolType(); + auto tagKey = builder->createStructKey(); + builder->addNameHintDecoration(tagKey, UnownedStringSlice("tag")); + info->tagField = builder->createStructField(structType, tagKey, info->tagType); + + SlangInt anyValueSize = 0; auto valueType = resultType->getValueType(); if (valueType->getOp() != kIROp_VoidType) { - auto structType = builder->createStructType(); - info->loweredType = structType; - builder->addNameHintDecoration(structType, UnownedStringSlice("ResultType")); - + anyValueSize = getAnyValueSize(valueType); info->valueType = valueType; - auto valueKey = builder->createStructKey(); - builder->addNameHintDecoration(valueKey, UnownedStringSlice("value")); - info->valueField = builder->createStructField(structType, valueKey, (IRType*)valueType); - - auto errorType = resultType->getErrorType(); - auto errorKey = builder->createStructKey(); - builder->addNameHintDecoration(errorKey, UnownedStringSlice("error")); - info->errorField = builder->createStructField(structType, errorKey, (IRType*)errorType); - } - else - { - info->loweredType = resultType->getErrorType(); } + + auto errorType = resultType->getErrorType(); + info->errorType = errorType; + + auto errSize = getAnyValueSize(errorType); + if (errSize > anyValueSize) + anyValueSize = errSize; + + info->anyValueType = + builder->getAnyValueType(builder->getIntValue(builder->getUIntType(), anyValueSize)); + auto anyValueKey = builder->createStructKey(); + builder->addNameHintDecoration(anyValueKey, UnownedStringSlice("anyValue")); + info->anyValueField = + builder->createStructField(structType, anyValueKey, info->anyValueType); + mapLoweredTypeToResultTypeInfo[info->loweredType] = info; loweredResultTypes[type] = info; return info.Ptr(); @@ -98,30 +110,6 @@ struct ResultTypeLoweringContext workListSet.add(inst); } - IRInst* getSuccessErrorValue(IRType* type) - { - switch (type->getOp()) - { - case kIROp_Int8Type: - case kIROp_Int16Type: - case kIROp_IntType: - case kIROp_Int64Type: - case kIROp_IntPtrType: - case kIROp_UInt8Type: - case kIROp_UInt16Type: - case kIROp_UIntType: - case kIROp_UInt64Type: - case kIROp_UIntPtrType: - break; - default: - SLANG_ASSERT_FAILURE("error type is not lowered to an integer type."); - } - IRBuilder builderStorage(module); - auto builder = &builderStorage; - builder->setInsertInto(module); - return builder->getIntValue(type, 0); - } - void processMakeResultValue(IRMakeResultValue* inst) { IRBuilder builderStorage(module); @@ -129,19 +117,16 @@ struct ResultTypeLoweringContext builder->setInsertBefore(inst); auto info = getLoweredResultType(builder, inst->getDataType()); - if (info->loweredType->getOp() == kIROp_StructType) - { - List operands; - operands.add(inst->getOperand(0)); - operands.add(getSuccessErrorValue(info->errorType)); - auto makeStruct = builder->emitMakeStruct(info->loweredType, operands); - inst->replaceUsesWith(makeStruct); - } - else - { - auto errCode = getSuccessErrorValue(info->errorType); - inst->replaceUsesWith(errCode); - } + + List operands; + operands.add(builder->getBoolValue(false)); + auto packInst = builder->emitPackAnyValue( + info->anyValueType, + info->valueType ? inst->getOperand(0) : builder->emitDefaultConstruct(info->errorType)); + operands.add(packInst); + + auto makeStruct = builder->emitMakeStruct(info->loweredType, operands); + inst->replaceUsesWith(makeStruct); inst->removeAndDeallocate(); } @@ -152,18 +137,15 @@ struct ResultTypeLoweringContext builder->setInsertBefore(inst); auto info = getLoweredResultType(builder, inst->getDataType()); - if (info->valueField) - { - List operands; - operands.add(builder->emitDefaultConstruct(info->valueType)); - operands.add(inst->getErrorValue()); - auto makeStruct = builder->emitMakeStruct(info->loweredType, operands); - inst->replaceUsesWith(makeStruct); - } - else - { - inst->replaceUsesWith(inst->getErrorValue()); - } + + auto packInst = builder->emitPackAnyValue(info->anyValueType, inst->getErrorValue()); + + List operands; + operands.add(builder->getBoolValue(true)); + operands.add(packInst); + + auto makeStruct = builder->emitMakeStruct(info->loweredType, operands); + inst->replaceUsesWith(makeStruct); inst->removeAndDeallocate(); } @@ -171,13 +153,14 @@ struct ResultTypeLoweringContext { auto loweredResultTypeInfo = getLoweredResultType(builder, resultInst->getDataType()); SLANG_ASSERT(loweredResultTypeInfo); - if (loweredResultTypeInfo->valueField) + if (loweredResultTypeInfo->valueType) { auto value = builder->emitFieldExtract( - loweredResultTypeInfo->errorType, + loweredResultTypeInfo->anyValueType, resultInst, - loweredResultTypeInfo->errorField->getKey()); - return value; + loweredResultTypeInfo->anyValueField->getKey()); + auto unpackInst = builder->emitUnpackAnyValue(loweredResultTypeInfo->errorType, value); + return unpackInst; } else { @@ -206,12 +189,14 @@ struct ResultTypeLoweringContext auto base = inst->getResultOperand(); auto loweredResultTypeInfo = getLoweredResultType(builder, base->getDataType()); SLANG_ASSERT(loweredResultTypeInfo); - SLANG_ASSERT(loweredResultTypeInfo->valueField); + SLANG_ASSERT(loweredResultTypeInfo->valueType); + auto getElement = builder->emitFieldExtract( - loweredResultTypeInfo->errorType, + loweredResultTypeInfo->anyValueType, base, - loweredResultTypeInfo->valueField->getKey()); - inst->replaceUsesWith(getElement); + loweredResultTypeInfo->anyValueField->getKey()); + auto unpackInst = builder->emitUnpackAnyValue(loweredResultTypeInfo->valueType, getElement); + inst->replaceUsesWith(unpackInst); inst->removeAndDeallocate(); } @@ -224,13 +209,13 @@ struct ResultTypeLoweringContext auto base = inst->getResultOperand(); auto loweredResultTypeInfo = getLoweredResultType(builder, base->getDataType()); SLANG_ASSERT(loweredResultTypeInfo); - SLANG_ASSERT(loweredResultTypeInfo->valueField); - auto resultValue = inst->getResultOperand(); - auto errValue = getResultError(builder, resultValue); - auto isSuccess = - builder->emitNeq(errValue, getSuccessErrorValue(loweredResultTypeInfo->errorType)); - inst->replaceUsesWith(isSuccess); + auto isFailure = builder->emitFieldExtract( + loweredResultTypeInfo->tagType, + base, + loweredResultTypeInfo->tagField->getKey()); + + inst->replaceUsesWith(isFailure); inst->removeAndDeallocate(); } diff --git a/source/slang/slang-ir-marshal-native-call.cpp b/source/slang/slang-ir-marshal-native-call.cpp index 2257ef098..1d3f04318 100644 --- a/source/slang/slang-ir-marshal-native-call.cpp +++ b/source/slang/slang-ir-marshal-native-call.cpp @@ -42,8 +42,10 @@ IRFuncType* NativeCallMarshallingContext::getNativeFuncType( if (auto resultType = as(declaredFuncType->getResultType())) { auto nativeResultType = getNativeType(builder, resultType->getValueType()); + auto nativeErrorType = getNativeType(builder, resultType->getErrorType()); nativeParamTypes.add(builder.getPtrType(nativeResultType)); - returnType = resultType->getErrorType(); + nativeParamTypes.add(builder.getPtrType(nativeErrorType)); + returnType = builder.getIntType(); } else { @@ -243,10 +245,8 @@ IRFunc* NativeCallMarshallingContext::generateDLLExportWrapperFunc( IRBlock* trueBlock = nullptr; IRBlock* falseBlock = nullptr; IRBlock* afterBlock = nullptr; - builder.emitIfElseWithBlocks(isResultError, trueBlock, falseBlock, afterBlock); - builder.setInsertInto(trueBlock); - builder.emitReturn(builder.emitGetResultError(callInst)); + builder.emitIfElseWithBlocks(isResultError, trueBlock, falseBlock, afterBlock); builder.setInsertInto(falseBlock); auto resultVal = builder.emitGetResultValue(callInst); @@ -258,8 +258,22 @@ IRFunc* NativeCallMarshallingContext::generateDLLExportWrapperFunc( builder.emitStore(params[nativeParamConsumeIndex], nativeVals[i]); nativeParamConsumeIndex++; } + // S_OK builder.emitReturn(builder.getIntValue(builder.getIntType(), 0)); + builder.setInsertInto(trueBlock); + nativeVals.clear(); + auto errorVal = builder.emitGetResultError(callInst); + marshalManagedValueToNativeResultValue(builder, errorVal, nativeVals); + for (Index i = 0; i < nativeVals.getCount(); i++) + { + SLANG_RELEASE_ASSERT(nativeParamConsumeIndex < params.getCount()); + builder.emitStore(params[nativeParamConsumeIndex], nativeVals[i]); + nativeParamConsumeIndex++; + } + // E_FAIL + builder.emitReturn(builder.getIntValue(builder.getIntType(), 0x80004005)); + builder.setInsertInto(afterBlock); builder.emitUnreachable(); } @@ -296,10 +310,13 @@ IRInst* NativeCallMarshallingContext::marshalNativeCall( IRType* originalReturnType = originalFuncType->getResultType(); IRVar* resultVar = nullptr; + IRVar* errorVar = nullptr; if (auto resultType = as(originalReturnType)) { // Declare a local variable to receive result. resultVar = builder.emitVar(getNativeType(builder, resultType->getValueType())); + errorVar = builder.emitVar(getNativeType(builder, resultType->getErrorType())); + args.add(resultVar); args.add(resultVar); } @@ -314,16 +331,18 @@ IRInst* NativeCallMarshallingContext::marshalNativeCall( if (auto resultType = as(originalReturnType)) { auto val = builder.emitLoad(resultVar); - auto err = call; + auto err = builder.emitLoad(errorVar); + auto tag = call; val = marshalNativeValueToManagedValue(builder, val); - auto intErr = err; - if (err->getDataType()->getOp() != kIROp_IntType) + err = marshalNativeValueToManagedValue(builder, err); + auto intTag = tag; + if (tag->getDataType()->getOp() != kIROp_IntType) { - intErr = builder.emitCast(builder.getIntType(), err); + intTag = builder.emitCast(builder.getIntType(), tag); } - auto errIsError = builder.emitLess(intErr, builder.getIntValue(builder.getIntType(), 0)); + auto tagIsError = builder.emitLess(intTag, builder.getIntValue(builder.getIntType(), 0)); IRBlock *trueBlock, *falseBlock, *afterBlock; - builder.emitIfElseWithBlocks(errIsError, trueBlock, falseBlock, afterBlock); + builder.emitIfElseWithBlocks(tagIsError, trueBlock, falseBlock, afterBlock); builder.setInsertInto(trueBlock); returnValue = builder.emitMakeResultError(resultType, err); builder.emitBranch(afterBlock, 1, &returnValue); diff --git a/source/slang/slang-ir.cpp b/source/slang/slang-ir.cpp index b09d9e6e2..6d5bd9a25 100644 --- a/source/slang/slang-ir.cpp +++ b/source/slang/slang-ir.cpp @@ -620,6 +620,12 @@ static IRBlock::SuccessorList getSuccessors(IRInst* terminator) end = begin + 1; break; + case kIROp_TryCall: + // tryCall ... + begin = operands + 0; + end = begin + 2; + break; + default: SLANG_UNEXPECTED("unhandled terminator instruction"); UNREACHABLE_RETURN(IRBlock::SuccessorList(nullptr, nullptr)); diff --git a/source/slang/slang-language-server-ast-lookup.cpp b/source/slang/slang-language-server-ast-lookup.cpp index b08388aee..d0ed9cdac 100644 --- a/source/slang/slang-language-server-ast-lookup.cpp +++ b/source/slang/slang-language-server-ast-lookup.cpp @@ -645,6 +645,17 @@ struct ASTLookupStmtVisitor : public StmtVisitor bool visitDeferStmt(DeferStmt* stmt) { return dispatchIfNotNull(stmt->statement); } + bool visitThrowStmt(ThrowStmt* stmt) { return checkExpr(stmt->expression); } + + bool visitCatchStmt(CatchStmt* stmt) + { + if (stmt->errorVar && _findAstNodeImpl(*context, stmt->errorVar)) + return true; + if (dispatchIfNotNull(stmt->tryBody)) + return true; + return dispatchIfNotNull(stmt->handleBody); + } + bool visitWhileStmt(WhileStmt* stmt) { if (checkExpr(stmt->predicate)) diff --git a/source/slang/slang-lower-to-ir.cpp b/source/slang/slang-lower-to-ir.cpp index 9920075fe..ebcdaf1e1 100644 --- a/source/slang/slang-lower-to-ir.cpp +++ b/source/slang/slang-lower-to-ir.cpp @@ -555,6 +555,17 @@ struct AstOrIRType explicit operator bool() { return astType || irType; } }; +struct CatchHandler +{ + // 'nullptr' implies catch-all. + IRType* errorType = nullptr; + + // Block of the handler statement. Takes a value of errorType as parameter. + IRBlock* errorHandler = nullptr; + + CatchHandler* prev = nullptr; +}; + struct IRGenContext { ASTBuilder* astBuilder; @@ -599,6 +610,9 @@ struct IRGenContext // The current scope end for use with `defer`. IRBlock* scopeEndBlock = nullptr; + // A chain of nested `catch` handlers for `try` and `throw. + CatchHandler* catchHandler = nullptr; + // Callback function to call when after lowering a type. std::function lowerTypeCallback = nullptr; @@ -754,6 +768,19 @@ int32_t getIntrinsicOp(Decl* decl, IntrinsicOpModifier* intrinsicOpMod) return int32_t(irOp); } +static CatchHandler findErrorHandler(IRGenContext* context, IRType* type) +{ + for (auto handler = context->catchHandler; handler != nullptr; + handler = context->catchHandler->prev) + { + if (!handler->errorType || handler->errorType == type) + { + return *handler; + } + } + return CatchHandler(); +} + struct TryClauseEnvironment { TryClauseType clauseType = TryClauseType::None; @@ -809,18 +836,28 @@ LoweredValInfo emitCallToVal( case TryClauseType::Standard: { auto callee = getSimpleVal(context, funcVal); - auto succBlock = builder->createBlock(); - auto failBlock = builder->createBlock(); auto funcType = as(callee->getDataType()); auto throwAttr = funcType->findAttr(); assert(throwAttr); + + auto handler = findErrorHandler(context, throwAttr->getErrorType()); + auto succBlock = builder->createBlock(); + auto failBlock = + handler.errorHandler ? handler.errorHandler : builder->createBlock(); + auto voidType = builder->getVoidType(); builder->emitTryCallInst(voidType, succBlock, failBlock, callee, argCount, args); - builder->insertBlock(failBlock); - auto errParam = builder->emitParam(throwAttr->getErrorType()); - builder->emitThrow(errParam); builder->insertBlock(succBlock); auto value = builder->emitParam(type); + + if (!handler.errorHandler) + { + // We have to create a default fail block, which just re-throws. + builder->insertBlock(failBlock); + auto errParam = builder->emitParam(throwAttr->getErrorType()); + builder->emitThrow(errParam); + builder->setInsertInto(succBlock); + } return LoweredValInfo::simple(value); } break; @@ -1982,8 +2019,9 @@ struct ValLoweringVisitor : ValVisitorgetErrorType()); + IRInst* operands[] = {errorType}; auto irThrowFuncTypeAttribute = - getBuilder()->getAttr(kIROp_FuncThrowTypeAttr, 1, (IRInst**)&errorType); + getBuilder()->getAttr(kIROp_FuncThrowTypeAttr, 1, operands); return getBuilder()->getFuncType( paramCount, paramTypes.getBuffer(), @@ -3449,9 +3487,9 @@ void _lowerFuncDeclBaseTypeInfo( if (!getErrorCodeType(context->astBuilder, declRef) ->equals(context->astBuilder->getBottomType())) { - auto errorType = lowerType(context, getErrorCodeType(context->astBuilder, declRef)); - IRAttr* throwTypeAttr = nullptr; - throwTypeAttr = builder->getAttr(kIROp_FuncThrowTypeAttr, 1, (IRInst**)&errorType); + auto irErrorType = lowerType(context, getErrorCodeType(context->astBuilder, declRef)); + IRInst* operands[] = {irErrorType}; + IRAttr* throwTypeAttr = builder->getAttr(kIROp_FuncThrowTypeAttr, 1, operands); outInfo.type = builder->getFuncType( paramTypes.getCount(), paramTypes.getBuffer(), @@ -6671,6 +6709,117 @@ struct StmtLoweringVisitor : StmtVisitor builder->setInsertInto(mergeBlock); } + void visitThrowStmt(ThrowStmt* stmt) + { + auto builder = getBuilder(); + startBlockIfNeeded(stmt); + + auto loweredExpr = lowerRValueExpr(context, stmt->expression); + auto loweredVal = getSimpleVal(context, loweredExpr); + auto throwType = lowerType(context, stmt->expression->type); + + CatchHandler handler; + if (loweredVal && throwType) + { + handler = findErrorHandler(context, throwType); + } + + if (handler.errorHandler) + { + builder->emitBranch(handler.errorHandler, 1, &loweredVal); + } + else + { + builder->emitThrow(getSimpleVal(context, loweredExpr)); + } + } + + void visitCatchStmt(CatchStmt* stmt) + { + auto builder = getBuilder(); + startBlockIfNeeded(stmt); + + // The mental model here is that the below Catch statement: + // + // let val = try MayThrowFunc(); + // // Do stuff with val + // catch(err: Error) + // { + // catchBlock(err); + // } + // + // lowers similarly to: + // + // handlerLoop: for(;;) + // { + // E err; // Actually just a parameter for the catchBlock, not a real variable. + // bodyLoop: for(;;) + // { + // // Body goes here + // Result r = mayThrowFunc(); + // if(isResultError(r)) + // { + // err = r.error; + // break bodyLoop; + // } + // let val = r.getSuccessValue(); + // // Do stuff with val + // break handlerLoop; + // } + // catchBlock(err); + // break handlerLoop; + // } + // + // This approach allows for it to generate valid SPIR-V. Just jumping + // around with unstructured conditional jumps doesn't work there. + + IRBlock* handlerLoopHead = createBlock(); + IRBlock* handlerBreakLabel = createBlock(); + IRBlock* bodyLoopHead = createBlock(); + IRBlock* bodyBreakLabel = createBlock(); + + builder->emitLoop(handlerLoopHead, handlerBreakLabel, handlerLoopHead); + insertBlock(handlerLoopHead); + + builder->emitLoop(bodyLoopHead, bodyBreakLabel, bodyLoopHead); + insertBlock(bodyLoopHead); + + CatchHandler catchHandler; + catchHandler.errorType = + stmt->errorVar ? lowerType(context, stmt->errorVar->getType()) : nullptr; + catchHandler.errorHandler = bodyBreakLabel; + catchHandler.prev = context->catchHandler; + context->catchHandler = &catchHandler; + + // Note that the tryBody doesn't actually have to have it's own scope or + // block. If there's a `defer` in the tryBody, it can run after the + // catch statement. + lowerStmt(context, stmt->tryBody); + + // Put break; at the end of the body if there's nothing else there yet. + // This prevents the catch handler from running. + emitBranchIfNeeded(handlerBreakLabel); + + context->catchHandler = catchHandler.prev; + + insertBlock(bodyBreakLabel); + + if (catchHandler.errorType) + { + auto irParam = builder->emitParam(catchHandler.errorType); + auto paramVal = LoweredValInfo::simple(irParam); + context->setGlobalValue(stmt->errorVar, paramVal); + } + + IRBlock* prevScopeEndBlock = pushScopeBlock(handlerBreakLabel); + lowerStmt(context, stmt->handleBody); + popScopeBlock(prevScopeEndBlock, true); + + emitBranchIfNeeded(handlerBreakLabel); + + insertBlock(handlerBreakLabel); + } + void visitDiscardStmt(DiscardStmt* stmt) { startBlockIfNeeded(stmt); @@ -8536,6 +8685,7 @@ struct DeclLoweringVisitor : DeclVisitor subContextStorage.returnDestination = LoweredValInfo(); subContextStorage.lowerTypeCallback = nullptr; + subContextStorage.catchHandler = nullptr; } IRBuilder* getBuilder() { return &subBuilderStorage; } diff --git a/source/slang/slang-parser.cpp b/source/slang/slang-parser.cpp index b2a006adc..f968f9fe1 100644 --- a/source/slang/slang-parser.cpp +++ b/source/slang/slang-parser.cpp @@ -211,11 +211,14 @@ public: Stmt* parseIfLetStatement(); ForStmt* ParseForStatement(); WhileStmt* ParseWhileStatement(); - DoWhileStmt* ParseDoWhileStatement(); + DoWhileStmt* ParseDoWhileStatement(Stmt* body); + CatchStmt* ParseDoCatchStatement(Stmt* body); + Stmt* ParseDoStatement(); BreakStmt* ParseBreakStatement(); ContinueStmt* ParseContinueStatement(); ReturnStmt* ParseReturnStatement(); DeferStmt* ParseDeferStatement(); + ThrowStmt* ParseThrowStatement(); ExpressionStmt* ParseExpressionStatement(); Expr* ParseExpression(Precedence level = Precedence::Comma); @@ -5741,7 +5744,7 @@ Stmt* Parser::ParseStatement(Stmt* parentStmt) else if (LookAheadToken("while")) statement = ParseWhileStatement(); else if (LookAheadToken("do")) - statement = ParseDoWhileStatement(); + statement = ParseDoStatement(); else if (LookAheadToken("break")) statement = ParseBreakStatement(); else if (LookAheadToken("continue")) @@ -5781,6 +5784,10 @@ Stmt* Parser::ParseStatement(Stmt* parentStmt) { statement = ParseExpressionStatement(); } + else if (LookAheadToken("throw")) + { + statement = ParseThrowStatement(); + } else if (LookAheadToken(TokenType::Identifier) || LookAheadToken(TokenType::Scope)) { if (LookAheadToken(TokenType::Identifier) && LookAheadToken(TokenType::Colon, 1)) @@ -5941,7 +5948,6 @@ Stmt* Parser::parseBlockStatement() Stmt* body = nullptr; - if (!tokenReader.isAtEnd()) { FillPosition(blockStatement); @@ -6259,12 +6265,11 @@ WhileStmt* Parser::ParseWhileStatement() return whileStatement; } -DoWhileStmt* Parser::ParseDoWhileStatement() +DoWhileStmt* Parser::ParseDoWhileStatement(Stmt* body) { DoWhileStmt* doWhileStatement = astBuilder->create(); FillPosition(doWhileStatement); - ReadToken("do"); - doWhileStatement->statement = ParseStatement(); + doWhileStatement->statement = body; ReadToken("while"); ReadToken(TokenType::LParent); doWhileStatement->predicate = ParseExpression(); @@ -6273,6 +6278,62 @@ DoWhileStmt* Parser::ParseDoWhileStatement() return doWhileStatement; } +CatchStmt* Parser::ParseDoCatchStatement(Stmt* body) +{ + for (;;) + { + ScopeDecl* scopeDecl = astBuilder->create(); + pushScopeAndSetParent(scopeDecl); + + CatchStmt* catchStatement = astBuilder->create(); + FillPosition(catchStatement); + ReadToken("catch"); + + // Optional error parameter. If not given, the catch catches all error + // types. + if (AdvanceIf(this, TokenType::LParent)) + { + ParamDecl* errorVar = parseModernParamDecl(this); + catchStatement->errorVar = errorVar; + AddMember(scopeDecl, errorVar); + ReadToken(TokenType::RParent); + } + + catchStatement->tryBody = body; + catchStatement->handleBody = ParseStatement(); + + PopScope(); + + if (!LookAheadToken("catch")) + return catchStatement; + + // Use this catch as the body for the next one, if multiple are chained. + body = catchStatement; + } +} + +Stmt* Parser::ParseDoStatement() +{ + SourceLoc position = tokenReader.peekLoc(); + ReadToken("do"); + Stmt* statement = ParseStatement(); + if (LookAheadToken("while")) + { + Stmt* whileStatement = ParseDoWhileStatement(statement); + whileStatement->loc = position; + return whileStatement; + } + else if (LookAheadToken("catch")) + { + return ParseDoCatchStatement(statement); + } + else + { + Unexpected(this, "while' or 'catch"); + return statement; + } +} + BreakStmt* Parser::ParseBreakStatement() { BreakStmt* breakStatement = astBuilder->create(); @@ -6315,6 +6376,15 @@ DeferStmt* Parser::ParseDeferStatement() return deferStatement; } +ThrowStmt* Parser::ParseThrowStatement() +{ + ThrowStmt* throwStatement = astBuilder->create(); + FillPosition(throwStatement); + ReadToken("throw"); + throwStatement->expression = ParseExpression(); + return throwStatement; +} + ExpressionStmt* Parser::ParseExpressionStatement() { ExpressionStmt* statement = astBuilder->create(); -- cgit v1.2.3