diff options
| author | Yong He <yonghe@outlook.com> | 2022-06-01 17:37:07 -0700 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2022-06-01 17:37:07 -0700 |
| commit | 17e3b88b541ed7f45d575f0f9caaa808cd0a6619 (patch) | |
| tree | efacd5d4bf6381a5adf8055daa28f91ddc048a76 | |
| parent | fa10f7dc23f8b93c0f9ef3fb5477871a20aaa974 (diff) | |
New language feature: basic error handling. (#2253)
* New language feature: basic error handling.
* Fix.
* Fix `tryCall` encoding according to code review.
Co-authored-by: Yong He <yhe@nvidia.com>
37 files changed, 1205 insertions, 62 deletions
diff --git a/build/visual-studio/slang/slang.vcxproj b/build/visual-studio/slang/slang.vcxproj index 20bb20565..89e69b55c 100644 --- a/build/visual-studio/slang/slang.vcxproj +++ b/build/visual-studio/slang/slang.vcxproj @@ -369,12 +369,14 @@ IF EXIST ..\..\..\external\slang-binaries\bin\windows-aarch64\slang-glslang.dll\ <ClInclude Include="..\..\..\source\slang\slang-ir-link.h" />
<ClInclude Include="..\..\..\source\slang\slang-ir-liveness.h" />
<ClInclude Include="..\..\..\source\slang\slang-ir-lower-bit-cast.h" />
+ <ClInclude Include="..\..\..\source\slang\slang-ir-lower-error-handling.h" />
<ClInclude Include="..\..\..\source\slang\slang-ir-lower-existential.h" />
<ClInclude Include="..\..\..\source\slang\slang-ir-lower-generic-call.h" />
<ClInclude Include="..\..\..\source\slang\slang-ir-lower-generic-function.h" />
<ClInclude Include="..\..\..\source\slang\slang-ir-lower-generic-type.h" />
<ClInclude Include="..\..\..\source\slang\slang-ir-lower-generics.h" />
<ClInclude Include="..\..\..\source\slang\slang-ir-lower-reinterpret.h" />
+ <ClInclude Include="..\..\..\source\slang\slang-ir-lower-result-type.h" />
<ClInclude Include="..\..\..\source\slang\slang-ir-lower-tuple-types.h" />
<ClInclude Include="..\..\..\source\slang\slang-ir-metadata.h" />
<ClInclude Include="..\..\..\source\slang\slang-ir-missing-return.h" />
@@ -511,12 +513,14 @@ IF EXIST ..\..\..\external\slang-binaries\bin\windows-aarch64\slang-glslang.dll\ <ClCompile Include="..\..\..\source\slang\slang-ir-link.cpp" />
<ClCompile Include="..\..\..\source\slang\slang-ir-liveness.cpp" />
<ClCompile Include="..\..\..\source\slang\slang-ir-lower-bit-cast.cpp" />
+ <ClCompile Include="..\..\..\source\slang\slang-ir-lower-error-handling.cpp" />
<ClCompile Include="..\..\..\source\slang\slang-ir-lower-existential.cpp" />
<ClCompile Include="..\..\..\source\slang\slang-ir-lower-generic-call.cpp" />
<ClCompile Include="..\..\..\source\slang\slang-ir-lower-generic-function.cpp" />
<ClCompile Include="..\..\..\source\slang\slang-ir-lower-generic-type.cpp" />
<ClCompile Include="..\..\..\source\slang\slang-ir-lower-generics.cpp" />
<ClCompile Include="..\..\..\source\slang\slang-ir-lower-reinterpret.cpp" />
+ <ClCompile Include="..\..\..\source\slang\slang-ir-lower-result-type.cpp" />
<ClCompile Include="..\..\..\source\slang\slang-ir-lower-tuple-types.cpp" />
<ClCompile Include="..\..\..\source\slang\slang-ir-metadata.cpp" />
<ClCompile Include="..\..\..\source\slang\slang-ir-missing-return.cpp" />
diff --git a/build/visual-studio/slang/slang.vcxproj.filters b/build/visual-studio/slang/slang.vcxproj.filters index ce66cc2a2..a414d7e6a 100644 --- a/build/visual-studio/slang/slang.vcxproj.filters +++ b/build/visual-studio/slang/slang.vcxproj.filters @@ -204,6 +204,9 @@ <ClInclude Include="..\..\..\source\slang\slang-ir-lower-bit-cast.h">
<Filter>Header Files</Filter>
</ClInclude>
+ <ClInclude Include="..\..\..\source\slang\slang-ir-lower-error-handling.h">
+ <Filter>Header Files</Filter>
+ </ClInclude>
<ClInclude Include="..\..\..\source\slang\slang-ir-lower-existential.h">
<Filter>Header Files</Filter>
</ClInclude>
@@ -222,6 +225,9 @@ <ClInclude Include="..\..\..\source\slang\slang-ir-lower-reinterpret.h">
<Filter>Header Files</Filter>
</ClInclude>
+ <ClInclude Include="..\..\..\source\slang\slang-ir-lower-result-type.h">
+ <Filter>Header Files</Filter>
+ </ClInclude>
<ClInclude Include="..\..\..\source\slang\slang-ir-lower-tuple-types.h">
<Filter>Header Files</Filter>
</ClInclude>
@@ -626,6 +632,9 @@ <ClCompile Include="..\..\..\source\slang\slang-ir-lower-bit-cast.cpp">
<Filter>Source Files</Filter>
</ClCompile>
+ <ClCompile Include="..\..\..\source\slang\slang-ir-lower-error-handling.cpp">
+ <Filter>Source Files</Filter>
+ </ClCompile>
<ClCompile Include="..\..\..\source\slang\slang-ir-lower-existential.cpp">
<Filter>Source Files</Filter>
</ClCompile>
@@ -644,6 +653,9 @@ <ClCompile Include="..\..\..\source\slang\slang-ir-lower-reinterpret.cpp">
<Filter>Source Files</Filter>
</ClCompile>
+ <ClCompile Include="..\..\..\source\slang\slang-ir-lower-result-type.cpp">
+ <Filter>Source Files</Filter>
+ </ClCompile>
<ClCompile Include="..\..\..\source\slang\slang-ir-lower-tuple-types.cpp">
<Filter>Source Files</Filter>
</ClCompile>
diff --git a/examples/heterogeneous-hello-world/main.slang b/examples/heterogeneous-hello-world/main.slang index 2e65c7149..2a9c1eaf1 100644 --- a/examples/heterogeneous-hello-world/main.slang +++ b/examples/heterogeneous-hello-world/main.slang @@ -9,20 +9,18 @@ int MessageBoxA(Ptr<void> hwnd, String text, String caption, uint flags); [COM] interface IObject { - int getValue(int value); + int getValue(int value) throws int; } [DllImport("test-com")] IObject createObject(); -public __extern_cpp int main() +public __extern_cpp void main() throws int { //writeln("hello world"); //MessageBoxA(nullptr, "hello world!", "example", 0); IObject object = createObject(); - int rs = object.getValue(2); - - return 0; + int rs = try object.getValue(2); }
\ No newline at end of file diff --git a/source/slang/slang-ast-decl.h b/source/slang/slang-ast-decl.h index 5b596ded0..9c95f3520 100644 --- a/source/slang/slang-ast-decl.h +++ b/source/slang/slang-ast-decl.h @@ -266,6 +266,9 @@ class CallableDecl : public ContainerDecl } TypeExp returnType; + + // If this callable throws an error code, `errorType` is the type of the error code. + TypeExp errorType; // Fields related to redeclaration, so that we // can support multiple specialized variations diff --git a/source/slang/slang-ast-dump.cpp b/source/slang/slang-ast-dump.cpp index 9da1a0163..d67a35174 100644 --- a/source/slang/slang-ast-dump.cpp +++ b/source/slang/slang-ast-dump.cpp @@ -341,7 +341,10 @@ struct ASTDumpContext { m_writer->emit(getGLSLNameForImageFormat(imageFormat)); } - + void dump(TryClauseType clauseType) + { + m_writer->emit(getTryClauseTypeName(clauseType)); + } void dump(const String& string) { dump(string.getUnownedSlice()); diff --git a/source/slang/slang-ast-expr.h b/source/slang/slang-ast-expr.h index a6c1d432c..f2fae7ced 100644 --- a/source/slang/slang-ast-expr.h +++ b/source/slang/slang-ast-expr.h @@ -144,6 +144,28 @@ class InvokeExpr: public AppExprBase SLANG_AST_CLASS(InvokeExpr) }; +enum class TryClauseType +{ + None, + Standard, // Normal `try` clause + Optional, // (Not implemented) `try?` clause that returns an optional value. + Assert, // (Not implemented) `try!` clause that should always succeed and triggers runtime error if failed. +}; + +char const* getTryClauseTypeName(TryClauseType value); + +class TryExpr : public Expr +{ + SLANG_AST_CLASS(TryExpr) + + Expr* base; + + TryClauseType tryClauseType = TryClauseType::Standard; + + // The scope of this expr. + Scope* scope = nullptr; +}; + class OperatorExpr: public InvokeExpr { SLANG_AST_CLASS(OperatorExpr) diff --git a/source/slang/slang-ast-type.cpp b/source/slang/slang-ast-type.cpp index b6bc1f170..bcaf8f5a9 100644 --- a/source/slang/slang-ast-type.cpp +++ b/source/slang/slang-ast-type.cpp @@ -549,6 +549,11 @@ void FuncType::_toTextOverride(StringBuilder& out) out << getParamType(pp); } out << toSlice(") -> ") << getResultType(); + + if (!getErrorType()->equals(getASTBuilder()->getVoidType())) + { + out << " throws " << getErrorType(); + } } bool FuncType::_equalsImplOverride(Type * type) @@ -571,6 +576,9 @@ bool FuncType::_equalsImplOverride(Type * type) if (!resultType->equals(funcType->resultType)) return false; + if (!errorType->equals(funcType->errorType)) + return false; + // TODO: if we ever introduce other kinds // of qualification on function types, we'd // want to consider it here. @@ -586,6 +594,9 @@ Val* FuncType::_substituteImplOverride(ASTBuilder* astBuilder, SubstitutionSet s // result type Type* substResultType = as<Type>(resultType->substituteImpl(astBuilder, subst, &diff)); + // error type + Type* substErrorType = as<Type>(errorType->substituteImpl(astBuilder, subst, &diff)); + // parameter types List<Type*> substParamTypes; for (auto pp : paramTypes) @@ -601,6 +612,7 @@ Val* FuncType::_substituteImplOverride(ASTBuilder* astBuilder, SubstitutionSet s FuncType* substType = astBuilder->create<FuncType>(); substType->resultType = substResultType; substType->paramTypes = substParamTypes; + substType->errorType = substErrorType; return substType; } @@ -608,6 +620,7 @@ Type* FuncType::_createCanonicalTypeOverride() { // result type Type* canResultType = resultType->getCanonicalType(); + Type* canErrorType = errorType->getCanonicalType(); // parameter types List<Type*> canParamTypes; @@ -619,7 +632,7 @@ Type* FuncType::_createCanonicalTypeOverride() FuncType* canType = getASTBuilder()->create<FuncType>(); canType->resultType = canResultType; canType->paramTypes = canParamTypes; - + canType->errorType = canErrorType; return canType; } @@ -634,6 +647,7 @@ HashCode FuncType::_getHashCodeOverride() hashCode, getParamType(pp)->getHashCode()); } + combineHash(hashCode, getErrorType()->getHashCode()); return hashCode; } diff --git a/source/slang/slang-ast-type.h b/source/slang/slang-ast-type.h index 7aa1a36ab..8de047af5 100644 --- a/source/slang/slang-ast-type.h +++ b/source/slang/slang-ast-type.h @@ -575,10 +575,12 @@ class FuncType : public Type List<Type*> paramTypes; Type* resultType = nullptr; + Type* errorType = nullptr; UInt getParamCount() { return paramTypes.getCount(); } Type* getParamType(UInt index) { return paramTypes[index]; } Type* getResultType() { return resultType; } + Type* getErrorType() { return errorType; } // Overrides should be public so base classes can access void _toTextOverride(StringBuilder& out); diff --git a/source/slang/slang-check-decl.cpp b/source/slang/slang-check-decl.cpp index 6d579e1fb..af19f484d 100644 --- a/source/slang/slang-check-decl.cpp +++ b/source/slang/slang-check-decl.cpp @@ -673,7 +673,7 @@ namespace Slang // Make sure a declaration has been checked, so we can refer to it. // Note that this may lead to us recursively invoking checking, // so this may not be the best way to handle things. - void SemanticsVisitor::ensureDecl(Decl* decl, DeclCheckState state) + void SemanticsVisitor::ensureDecl(Decl* decl, DeclCheckState state, SemanticsContext* baseContext) { // If the `decl` has already been checked up to or beyond `state` // then there is nothing for us to do. @@ -733,7 +733,7 @@ namespace Slang // context, so that the state at the point where a declaration is *referenced* // cannot affect the state in which the declaration is *checked*. // - SemanticsContext subContext(getShared()); + SemanticsContext subContext = baseContext ? SemanticsContext(*baseContext) : SemanticsContext(getShared()); _dispatchDeclCheckingVisitor(decl, nextState, subContext); // In the common case, the visitor will have done the necessary @@ -3609,17 +3609,17 @@ namespace Slang } } - void SemanticsVisitor::ensureDeclBase(DeclBase* declBase, DeclCheckState state) + void SemanticsVisitor::ensureDeclBase(DeclBase* declBase, DeclCheckState state, SemanticsContext* baseContext) { if(auto decl = as<Decl>(declBase)) { - ensureDecl(decl, state); + ensureDecl(decl, state, baseContext); } else if(auto declGroup = as<DeclGroup>(declBase)) { for(auto dd : declGroup->decls) { - ensureDecl(dd, state); + ensureDecl(dd, state, baseContext); } } else @@ -4398,6 +4398,17 @@ namespace Slang { ensureDecl(paramDecl, DeclCheckState::ReadyForReference); } + + auto errorType = decl->errorType; + if (errorType.exp) + { + errorType = CheckProperType(errorType); + } + else + { + errorType = TypeExp(m_astBuilder->getVoidType()); + } + decl->errorType = errorType; } void SemanticsDeclHeaderVisitor::visitFuncDecl(FuncDecl* funcDecl) diff --git a/source/slang/slang-check-expr.cpp b/source/slang/slang-check-expr.cpp index 317ab6a1a..4b29de1bf 100644 --- a/source/slang/slang-check-expr.cpp +++ b/source/slang/slang-check-expr.cpp @@ -1337,6 +1337,15 @@ namespace Slang // if this is still an invoke expression, test arguments passed to inout/out parameter are LValues if(auto funcType = as<FuncType>(invoke->functionExpr->type)) { + if (!funcType->errorType->equals(m_astBuilder->getVoidType())) + { + // If the callee throws, make sure we are inside a try clause. + if (m_enclosingTryClauseType == TryClauseType::None) + { + getSink()->diagnose(invoke, Diagnostics::mustUseTryClauseToCallAThrowFunc); + } + } + Index paramCount = funcType->getParamCount(); for (Index pp = 0; pp < paramCount; ++pp) { @@ -1529,6 +1538,59 @@ namespace Slang return CheckInvokeExprWithCheckedOperands(expr); } + Expr* SemanticsExprVisitor::visitTryExpr(TryExpr* expr) + { + auto prevTryClauseType = expr->tryClauseType; + m_enclosingTryClauseType = expr->tryClauseType; + expr->base = CheckTerm(expr->base); + m_enclosingTryClauseType = prevTryClauseType; + expr->type = expr->base->type; + if (as<ErrorType>(expr->type)) + return expr; + + auto parentFunc = this->m_parentFunc; + // TODO: check if the try clause is caught. + // For now we assume all `try`s are not caught (because we don't have catch yet). + if (!parentFunc) + { + getSink()->diagnose(expr, Diagnostics::uncaughtTryCallInNonThrowFunc); + return expr; + } + if (parentFunc->errorType->equals(m_astBuilder->getVoidType())) + { + getSink()->diagnose(expr, Diagnostics::uncaughtTryCallInNonThrowFunc); + return expr; + } + if (!as<InvokeExpr>(expr->base)) + { + getSink()->diagnose(expr, Diagnostics::tryClauseMustApplyToInvokeExpr); + return expr; + } + auto base = as<InvokeExpr>(expr->base); + if (auto callee = as<DeclRefExpr>(base->functionExpr)) + { + if (auto funcCallee = as<FuncDecl>(callee->declRef.getDecl())) + { + if (funcCallee->errorType->equals(m_astBuilder->getVoidType())) + { + getSink()->diagnose(expr, Diagnostics::tryInvokeCalleeShouldThrow, callee->declRef); + } + if (!parentFunc->errorType->equals(funcCallee->errorType)) + { + getSink()->diagnose( + expr, + Diagnostics::errorTypeOfCalleeIncompatibleWithCaller, + callee->declRef, + funcCallee->errorType, + parentFunc->errorType); + } + return expr; + } + } + getSink()->diagnose(expr, Diagnostics::calleeOfTryCallMustBeFunc); + return expr; + } + Expr* SemanticsVisitor::MaybeDereference(Expr* inExpr) { Expr* expr = inExpr; diff --git a/source/slang/slang-check-impl.h b/source/slang/slang-check-impl.h index e6088ccca..1e549c7da 100644 --- a/source/slang/slang-check-impl.h +++ b/source/slang/slang-check-impl.h @@ -345,6 +345,15 @@ namespace Slang return result; } + TryClauseType getEnclosingTryClauseType() { return m_enclosingTryClauseType; } + + SemanticsContext withEnclosingTryClauseType(TryClauseType tryClauseType) + { + SemanticsContext result(*this); + result.m_enclosingTryClauseType = tryClauseType; + return result; + } + /// A scope that is local to a particular expression, and /// that can be used to allocate temporary bindings that /// might be needed by that expression or its sub-expressions. @@ -380,6 +389,7 @@ namespace Slang ExprLocalScope* m_exprLocalScope = nullptr; + protected: // TODO: consider making more of this state `private`... @@ -389,6 +399,9 @@ namespace Slang /// The linked list of lexically surrounding statements. OuterStmtInfo* m_outerStmts = nullptr; + /// The type of a try clause (if any) enclosing current expr. + TryClauseType m_enclosingTryClauseType = TryClauseType::None; + ASTBuilder* m_astBuilder = nullptr; }; @@ -558,7 +571,7 @@ namespace Slang /// on this function to avoid blowing out the stack or (even worse /// creating a circular dependency). /// - void ensureDecl(Decl* decl, DeclCheckState state); + void ensureDecl(Decl* decl, DeclCheckState state, SemanticsContext* baseContext = nullptr); /// Helper routine allowing `ensureDecl` to be called on a `DeclRef` void ensureDecl(DeclRefBase const& declRef, DeclCheckState state) @@ -572,7 +585,7 @@ namespace Slang /// called on a `DeclGroup` this function just calls `ensureDecl()` /// on each declaration in the group. /// - void ensureDeclBase(DeclBase* decl, DeclCheckState state); + void ensureDeclBase(DeclBase* decl, DeclCheckState state, SemanticsContext* baseContext); // A "proper" type is one that can be used as the type of an expression. // Put simply, it can be a concrete type like `int`, or a generic @@ -1660,6 +1673,8 @@ namespace Slang Expr* visitTypeCastExpr(TypeCastExpr * expr); + Expr* visitTryExpr(TryExpr* expr); + // // Some syntax nodes should not occur in the concrete input syntax, // and will only appear *after* checking is complete. We need to diff --git a/source/slang/slang-check-stmt.cpp b/source/slang/slang-check-stmt.cpp index 45be1d662..57bb3b85a 100644 --- a/source/slang/slang-check-stmt.cpp +++ b/source/slang/slang-check-stmt.cpp @@ -51,7 +51,7 @@ namespace Slang // local `struct` declaration, where it would have members // that need to be recursively checked. // - ensureDeclBase(stmt->decl, DeclCheckState::Checked); + ensureDeclBase(stmt->decl, DeclCheckState::Checked, this); } void SemanticsStmtVisitor::visitBlockStmt(BlockStmt* stmt) @@ -289,7 +289,7 @@ namespace Slang { stmt->device = CheckExpr(stmt->device); stmt->gridDims = CheckExpr(stmt->gridDims); - ensureDeclBase(stmt->dispatchThreadID, DeclCheckState::Checked); + ensureDeclBase(stmt->dispatchThreadID, DeclCheckState::Checked, this); WithOuterStmt subContext(this, stmt); stmt->kernelCall = subContext.CheckExpr(stmt->kernelCall); return; diff --git a/source/slang/slang-diagnostic-defs.h b/source/slang/slang-diagnostic-defs.h index 57035e84e..9ab34128a 100644 --- a/source/slang/slang-diagnostic-defs.h +++ b/source/slang/slang-diagnostic-defs.h @@ -290,6 +290,13 @@ DIAGNOSTIC(33071, Error, expectedAStringLiteral, "expected a string literal") DIAGNOSTIC( -1, Note, noteExplicitConversionPossible, "explicit conversion from '$0' to '$1' is possible") DIAGNOSTIC(30080, Error, ambiguousConversion, "more than one implicit conversion exists from '$0' to '$1'") +DIAGNOSTIC(30090, Error, tryClauseMustApplyToInvokeExpr, "expression in a 'try' clause must be a call to a function or operator overload.") +DIAGNOSTIC(30091, Error, tryInvokeCalleeShouldThrow, "'$0' called from a 'try' clause does not throw an error, make sure the callee is marked as 'throws'") +DIAGNOSTIC(30092, Error, calleeOfTryCallMustBeFunc, "callee in a 'try' clause must be a function") +DIAGNOSTIC(30093, Error, uncaughtTryCallInNonThrowFunc, "the current function or environment is not declared to throw any errors, but the 'try' clause is not caught") +DIAGNOSTIC(30094, Error, mustUseTryClauseToCallAThrowFunc, "the callee may throw an error, and therefore must be called within a 'try' clause") +DIAGNOSTIC(30095, Error, errorTypeOfCalleeIncompatibleWithCaller, "the error type `$1` of callee `$0` is not compatible with the caller's error type `$2`.") + // Attributes DIAGNOSTIC(31000, Error, unknownAttributeName, "unknown attribute '$0'") DIAGNOSTIC(31001, Error, attributeArgumentCountMismatch, "attribute '$0' expects $1 arguments ($2 provided)") diff --git a/source/slang/slang-emit.cpp b/source/slang/slang-emit.cpp index 0ed17ad7c..72ad80873 100644 --- a/source/slang/slang-emit.cpp +++ b/source/slang/slang-emit.cpp @@ -23,6 +23,7 @@ #include "slang-ir-com-interface.h" #include "slang-ir-lower-generics.h" #include "slang-ir-lower-tuple-types.h" +#include "slang-ir-lower-result-type.h" #include "slang-ir-lower-bit-cast.h" #include "slang-ir-lower-reinterpret.h" #include "slang-ir-metadata.h" @@ -200,6 +201,19 @@ Result linkAndOptimizeIR( // un-specialized IR. dumpIRIfEnabled(codeGenContext, irModule); + switch (target) + { + default: + break; + case CodeGenTarget::HostCPPSource: + lowerComInterfaces(irModule, sink); + generateDllImportFuncs(irModule, sink); + break; + } + + // Lower `Result<T,E>` types into ordinary struct types. + lowerResultType(irModule, sink); + // Replace any global constants with their values. // replaceGlobalConstants(irModule); @@ -689,16 +703,6 @@ Result linkAndOptimizeIR( break; } - switch (target) - { - default: - break; - case CodeGenTarget::HostCPPSource: - lowerComInterfaces(irModule, sink); - generateDllImportFuncs(irModule, sink); - break; - } - // TODO: our current dynamic dispatch pass will remove all uses of witness tables. // If we are going to support function-pointer based, "real" modular dynamic dispatch, // we will need to disable this pass. diff --git a/source/slang/slang-ir-deduplicate.cpp b/source/slang/slang-ir-deduplicate.cpp index f2c199af6..8aef7736c 100644 --- a/source/slang/slang-ir-deduplicate.cpp +++ b/source/slang/slang-ir-deduplicate.cpp @@ -76,4 +76,37 @@ namespace Slang for (auto inst : instToRemove) inst->removeAndDeallocate(); } + + void SharedIRBuilder::replaceGlobalInst(IRInst* oldInst, IRInst* newInst) + { + List<IRUse*> uses; + for (auto use = oldInst->firstUse; use; use = use->nextUse) + { + uses.add(use); + } + + bool shouldUpdateGlobalNumberedCache = false; + for (auto use : uses) + { + use->set(newInst); + // depending on the type of the user inst, we may need to rebuild and update the global + // numbering cache. + if (isGloballyNumberedInst(use->getUser())) + { + shouldUpdateGlobalNumberedCache = true; + } + } + oldInst->removeAndDeallocate(); + if (shouldUpdateGlobalNumberedCache) + { + deduplicateAndRebuildGlobalNumberingMap(); + } + } + + bool SharedIRBuilder::isGloballyNumberedInst(IRInst* inst) + { + if (!inst->getParent() || inst->getParent()->getOp() != kIROp_Module) + return false; + return m_globalValueNumberingMap.ContainsKey(IRInstKey{inst}); + } } diff --git a/source/slang/slang-ir-generics-lowering-context.cpp b/source/slang/slang-ir-generics-lowering-context.cpp index 284e1fa11..bf3b8d855 100644 --- a/source/slang/slang-ir-generics-lowering-context.cpp +++ b/source/slang/slang-ir-generics-lowering-context.cpp @@ -36,6 +36,12 @@ namespace Slang return false; } + bool isComInterfaceType(IRType* type) + { + return type->findDecoration<IRComInterfaceDecoration>() != nullptr || + type->getOp() == kIROp_ComPtrType; + } + bool isTypeValue(IRInst* typeInst) { if (typeInst) @@ -175,7 +181,7 @@ namespace Slang if (isBuiltin(interfaceType)) return (IRType*)paramType; - if (interfaceType->findDecoration<IRComInterfaceDecoration>()) + if (isComInterfaceType((IRType*)interfaceType)) return (IRType*)interfaceType; auto anyValueSize = getInterfaceAnyValueSize( @@ -192,7 +198,7 @@ namespace Slang if (isBuiltin(paramType)) return (IRType*)paramType; - if (paramType->findDecoration<IRComInterfaceDecoration>()) + if (isComInterfaceType((IRType*)paramType)) return (IRType*)paramType; // In the dynamic-dispatch case, a value of interface type diff --git a/source/slang/slang-ir-generics-lowering-context.h b/source/slang/slang-ir-generics-lowering-context.h index 23bdcc78b..475464c8f 100644 --- a/source/slang/slang-ir-generics-lowering-context.h +++ b/source/slang/slang-ir-generics-lowering-context.h @@ -99,6 +99,8 @@ namespace Slang bool isPolymorphicType(IRInst* typeInst); + bool isComInterfaceType(IRType* type); + // Returns true if typeInst represents a type and should be lowered into // Ptr(RTTIType). bool isTypeValue(IRInst* typeInst); diff --git a/source/slang/slang-ir-inst-defs.h b/source/slang/slang-ir-inst-defs.h index c617a0218..b1759026b 100644 --- a/source/slang/slang-ir-inst-defs.h +++ b/source/slang/slang-ir-inst-defs.h @@ -56,6 +56,7 @@ INST(Nop, nop, 0, 0) INST(ConjunctionType, Conjunction, 0, 0) INST(AttributedType, Attributed, 0, 0) + INST(ResultType, Result, 2, 0) /* BindExistentialsTypeBase */ @@ -274,6 +275,12 @@ INST(makeArray, makeArray, 0, 0) INST(makeStruct, makeStruct, 0, 0) INST(MakeTuple, makeTuple, 0, 0) INST(GetTupleElement, getTupleElement, 2, 0) +INST(MakeResultValue, makeResultValue, 1, 0) +INST(MakeResultValueVoid, makeResultValueVoid, 0, 0) +INST(MakeResultError, makeResultError, 1, 0) +INST(IsResultError, isResultError, 1, 0) +INST(GetResultError, getResultError, 1, 0) +INST(GetResultValue, getResultValue, 1, 0) INST(Call, call, 1, 0) @@ -435,6 +442,9 @@ INST(SwizzledStore, swizzledStore, 2, 0) INST(ifElse, ifElse, 4, 0) INST_RANGE(ConditionalBranch, conditionalBranch, ifElse) + INST(Throw, throw, 1, 0) + // tryCall <successBlock> <failBlock> <callee> <args>... + INST(TryCall, tryCall, 3, 0) // switch <val> <break> <default> <caseVal1> <caseBlock1> ... INST(Switch, switch, 3, 0) @@ -730,7 +740,8 @@ INST_RANGE(Layout, VarLayout, EntryPointLayout) INST(TypeSizeAttr, size, 2, 0) INST(VarOffsetAttr, offset, 2, 0) INST_RANGE(LayoutResourceInfoAttr, TypeSizeAttr, VarOffsetAttr) -INST_RANGE(Attr, PendingLayoutAttr, VarOffsetAttr) + INST(FuncThrowTypeAttr, FuncThrowType, 1, 0) +INST_RANGE(Attr, PendingLayoutAttr, FuncThrowTypeAttr) /* Liveness */ INST(LiveRangeStart, liveRangeStart, 2, 0) diff --git a/source/slang/slang-ir-insts.h b/source/slang/slang-ir-insts.h index 77b3eabc0..f7f7328a1 100644 --- a/source/slang/slang-ir-insts.h +++ b/source/slang/slang-ir-insts.h @@ -689,6 +689,14 @@ struct IRVarOffsetAttr : public IRLayoutResourceInfoAttr } }; + /// An attribute that specifies the error type a function is throwing +struct IRFuncThrowTypeAttr : IRAttr +{ + IR_LEAF_ISA(FuncThrowTypeAttr) + + IRType* getErrorType() { return (IRType*)getOperand(0); } +}; + /// An attribute that specifies size information for a single resource kind. struct IRTypeSizeAttr : public IRLayoutResourceInfoAttr { @@ -1547,6 +1555,25 @@ struct IRSwitch : IRTerminatorInst IRBlock* getCaseLabel(UInt index) { return (IRBlock*) getOperand(3 + index*2 + 1); } }; +struct IRThrow : IRTerminatorInst +{ + IR_LEAF_ISA(Throw); + + IRInst* getValue() { return getOperand(0); } +}; + +struct IRTryCall : IRTerminatorInst +{ + IR_LEAF_ISA(TryCall); + + IRBlock* getSuccessBlock() { return cast<IRBlock>(getOperand(0)); } + IRBlock* getFailureBlock() { return cast<IRBlock>(getOperand(1)); } + IRInst* getCallee() { return getOperand(2); } + UInt getArgCount() { return getOperandCount() - 3; } + IRUse* getArgs() { return getOperands() + 3; } + IRInst* getArg(UInt index) { return getOperand(index + 3); } +}; + struct IRSwizzle : IRInst { IRUse base; @@ -1758,6 +1785,52 @@ struct IRGetTupleElement : IRInst IRInst* getElementIndex() { return getOperand(1); } }; +// Constructs an `Result<T,E>` value from an error code. +struct IRMakeResultError : IRInst +{ + IR_LEAF_ISA(MakeResultError) + + IRInst* getErrorValue() { return getOperand(0); } +}; + +// Constructs an `Result<T,E>` value from an valid value. +struct IRMakeResultValue : IRInst +{ + IR_LEAF_ISA(MakeResultValue) + + IRInst* getValue() { return getOperand(0); } +}; + +// Constructs an `Result<void,E>` value that represents success in a function that returns `void`. +struct IRMakeResultValueVoid : IRInst +{ + IR_LEAF_ISA(MakeResultValueVoid) +}; + +// Determines if a `Result` value represents an error. +struct IRIsResultError : IRInst +{ + IR_LEAF_ISA(IsResultError) + + IRInst* getResultOperand() { return getOperand(0); } +}; + +// Extract the value from a `Result`. +struct IRGetResultValue : IRInst +{ + IR_LEAF_ISA(GetResultValue) + + IRInst* getResultOperand() { return getOperand(0); } +}; + +// Extract the error code from a `Result`. +struct IRGetResultError : IRInst +{ + IR_LEAF_ISA(GetResultError) + + IRInst* getResultOperand() { return getOperand(0); } +}; + /// An instruction that packs a concrete value into an existential-type "box" struct IRMakeExistential : IRInst { @@ -1908,12 +1981,17 @@ public: // keys are modified (thus its hash code is changed). void deduplicateAndRebuildGlobalNumberingMap(); + // Replaces all uses of oldInst with newInst, and ensures the global numbering map is valid after the replacement. + void replaceGlobalInst(IRInst* oldInst, IRInst* newInst); + typedef Dictionary<IRInstKey, IRInst*> GlobalValueNumberingMap; typedef Dictionary<IRConstantKey, IRConstant*> ConstantMap; GlobalValueNumberingMap& getGlobalValueNumberingMap() { return m_globalValueNumberingMap; } ConstantMap& getConstantMap() { return m_constantMap; } + bool isGloballyNumberedInst(IRInst* inst); + private: // The module that will own all of the IR IRModule* m_module; @@ -2104,6 +2182,8 @@ public: IRTupleType* getTupleType(IRType* type0, IRType* type1, IRType* type2); IRTupleType* getTupleType(IRType* type0, IRType* type1, IRType* type2, IRType* type3); + IRResultType* getResultType(IRType* valueType, IRType* errorType); + IRBasicBlockType* getBasicBlockType(); IRWitnessTableType* getWitnessTableType(IRType* baseType); IRWitnessTableIDType* getWitnessTableIDType(IRType* baseType); @@ -2152,6 +2232,9 @@ public: IRType* resultType); IRFuncType* getFuncType( + UInt paramCount, IRType* const* paramTypes, IRType* resultType, IRAttr* attribute); + + IRFuncType* getFuncType( List<IRType*> const& paramTypes, IRType* resultType) { @@ -2279,6 +2362,14 @@ public: return emitCallInst(type, func, args.getCount(), args.getBuffer()); } + IRInst* emitTryCallInst( + IRType* type, + IRBlock* successBlock, + IRBlock* failureBlock, + IRInst* func, + UInt argCount, + IRInst* const* args); + IRInst* createIntrinsicInst( IRType* type, IROp op, @@ -2322,6 +2413,13 @@ public: IRInst* emitGetTupleElement(IRType* type, IRInst* tuple, UInt element); + IRInst* emitMakeResultError(IRType* resultType, IRInst* errorVal); + IRInst* emitMakeResultValue(IRType* resultType, IRInst* val); + IRInst* emitMakeResultValueVoid(IRType* resultType); + IRInst* emitIsResultError(IRInst* result); + IRInst* emitGetResultError(IRInst* result); + IRInst* emitGetResultValue(IRInst* result); + IRInst* emitMakeVector( IRType* type, UInt argCount, @@ -2590,6 +2688,8 @@ public: IRInst* emitReturn(); + IRInst* emitThrow(IRInst* val); + IRInst* emitDiscard(); IRInst* emitUnreachable(); @@ -2687,6 +2787,8 @@ public: IRInst* emitAdd(IRType* type, IRInst* left, IRInst* right); IRInst* emitMul(IRType* type, IRInst* left, IRInst* right); + IRInst* emitEql(IRInst* left, IRInst* right); + IRInst* emitNeq(IRInst* left, IRInst* right); IRInst* emitShr(IRType* type, IRInst* op0, IRInst* op1); IRInst* emitShl(IRType* type, IRInst* op0, IRInst* op1); diff --git a/source/slang/slang-ir-lower-error-handling.cpp b/source/slang/slang-ir-lower-error-handling.cpp new file mode 100644 index 000000000..5a1389e57 --- /dev/null +++ b/source/slang/slang-ir-lower-error-handling.cpp @@ -0,0 +1,242 @@ +// slang-ir-lower-error-handling.cpp + +#include "slang-ir-lower-error-handling.h" + +#include "slang-ir.h" +#include "slang-ir-insts.h" + +namespace Slang +{ + +struct ErrorHandlingLoweringContext +{ + IRModule* module; + DiagnosticSink* diagnosticSink; + + SharedIRBuilder sharedBuilder; + + List<IRInst*> workList; + HashSet<IRInst*> workListSet; + + void addToWorkList(IRInst* inst) + { + if (workListSet.Contains(inst)) + return; + + workList.add(inst); + workListSet.Add(inst); + } + + void processFuncType(IRFuncType* funcType) + { + auto throwAttr = funcType->findAttr<IRFuncThrowTypeAttr>(); + if (!throwAttr) + return; + IRBuilder builder(sharedBuilder); + builder.setInsertBefore(funcType); + auto resultType = + builder.getResultType(funcType->getResultType(), throwAttr->getErrorType()); + List<IRType*> paramTypes; + for (UInt i = 0; i < funcType->getParamCount(); i++) + { + if (as<IRAttr>(funcType->getParamType(i))) + break; + paramTypes.add(funcType->getParamType(i)); + } + auto newFuncType = builder.getFuncType(paramTypes, resultType); + sharedBuilder.replaceGlobalInst(funcType, newFuncType); + } + + void processTryCall(IRTryCall* tryCall) + { + // If we see: + // ``` + // value = tryCall(callee, successBlock, failBlock, args) + // successBlock: + // resultParam = IRParam<resultType> + // ... (uses resultParam) ... + // failBlock: + // errorParam = IRParam<errorType> + // (uses errorParam) + // ``` + // We need to rewrite it as + // ``` + // result = call(callee) : Result<callee.returnType, callee.errorType> + // isError = isResultError(result) + // ifElse(isError, failBlock, successBlock) + // successBlock: + // value = getResultValue(result) : returnType + // ... (replaces resultParam with value) + // failBlock: + // error = getResultError(result) : errorType + // ... (replaces errorParam with error) + // ``` + IRFuncType* funcType = cast<IRFuncType>(tryCall->getCallee()->getDataType()); + auto resultValueType = funcType->getResultType(); + auto throwAttr = funcType->findAttr<IRFuncThrowTypeAttr>(); + if (!throwAttr) + { + SLANG_ASSERT_FAILURE("tryCall applied to callee without a IRFuncThrowTypeAttr"); + } + auto errorType = throwAttr->getErrorType(); + + IRBuilder builder(sharedBuilder); + builder.setInsertBefore(tryCall); + + auto resultType = builder.getResultType(resultValueType, errorType); + List<IRInst*> args; + for (UInt i = 0; i < tryCall->getArgCount(); i++) + { + args.add(tryCall->getArg(i)); + } + auto call = builder.emitCallInst(resultType, tryCall->getCallee(), args); + auto isFail = builder.emitIsResultError(call); + auto failBlock = tryCall->getFailureBlock(); + auto successBlock = tryCall->getSuccessBlock(); + + builder.emitIf(isFail, failBlock, successBlock); + + // Replace the params in failBlock to `getResultError(call)`. + builder.setInsertBefore(failBlock->getFirstOrdinaryInst()); + auto errorParam = failBlock->getFirstParam(); + auto errVal = builder.emitGetResultError(call); + errorParam->replaceUsesWith(errVal); + errorParam->removeAndDeallocate(); + + // Replace the params in successBlock to `getResultValue(call)`. + builder.setInsertBefore(successBlock->getFirstOrdinaryInst()); + auto resultParam = successBlock->getFirstParam(); + auto resultValue = builder.emitGetResultValue(call); + resultParam->replaceUsesWith(resultValue); + resultParam->removeAndDeallocate(); + + tryCall->removeAndDeallocate(); + } + + void processReturn(IRReturn* ret) + { + auto parentFunc = getParentFunc(ret); + if (!parentFunc) + return; + auto funcType = cast<IRFuncType>(parentFunc->getDataType()); + auto throwAttr = funcType->findAttr<IRFuncThrowTypeAttr>(); + if (!throwAttr) + return; + + // If we are in a throwing function and sees a `return(val)` inst, + // replace it with a `return makeResultValue(val)`, so that it returns a `Result<T,E>` type. + IRBuilder builder(sharedBuilder); + builder.setInsertBefore(ret); + auto resultType = + builder.getResultType(funcType->getResultType(), throwAttr->getErrorType()); + IRInst* resultVal = nullptr; + if (ret->getOp() == kIROp_ReturnVal) + { + auto val = cast<IRReturnVal>(ret)->getVal(); + resultVal = builder.emitMakeResultValue(resultType, val); + } + else + { + resultVal = builder.emitMakeResultValueVoid(resultType); + } + builder.emitReturn(resultVal); + ret->removeAndDeallocate(); + } + + void processThrow(IRThrow* throwInst) + { + auto parentFunc = getParentFunc(throwInst); + SLANG_ASSERT(parentFunc); + auto funcType = cast<IRFuncType>(parentFunc->getDataType()); + auto throwAttr = funcType->findAttr<IRFuncThrowTypeAttr>(); + SLANG_ASSERT(throwAttr); + + // If we are in a throwing function and sees a `throw(e)` inst, + // replace it with a `return makeResultError(e)`. + IRBuilder builder(sharedBuilder); + builder.setInsertBefore(throwInst); + auto resultType = + builder.getResultType(funcType->getResultType(), throwAttr->getErrorType()); + IRInst* resultVal = builder.emitMakeResultError(resultType, throwInst->getValue()); + builder.emitReturn(resultVal); + throwInst->removeAndDeallocate(); + } + + void processInst(IRInst* inst) + { + switch (inst->getOp()) + { + case kIROp_TryCall: + processTryCall(cast<IRTryCall>(inst)); + break; + case kIROp_ReturnVal: + case kIROp_ReturnVoid: + processReturn(cast<IRReturn>(inst)); + break; + case kIROp_Throw: + processThrow(cast<IRThrow>(inst)); + break; + default: + break; + } + } + + void processInsts() + { + addToWorkList(module->getModuleInst()); + + while (workList.getCount() != 0) + { + IRInst* inst = workList.getLast(); + + workList.removeLast(); + workListSet.Remove(inst); + + processInst(inst); + + for (auto child = inst->getLastChild(); child; child = child->getPrevInst()) + { + addToWorkList(child); + } + } + } + + void processModule() + { + // Deduplicate equivalent types. + sharedBuilder.deduplicateAndRebuildGlobalNumberingMap(); + + // Translate all IRTryCall, IRThrow, IRReturn, IRReturnVal. + processInsts(); + + // Lower all functypes. + // Function types with an IRThrowTypeAttribute will be translated into a normal function + // type that returns `Result<T,E>`. + List<IRFuncType*> oldFuncTypes; + for (auto child : module->getGlobalInsts()) + { + switch (child->getOp()) + { + case kIROp_FuncType: + oldFuncTypes.add(cast<IRFuncType>(child)); + break; + default: + break; + } + } + for (auto funcType : oldFuncTypes) + { + processFuncType(funcType); + } + } +}; + +void lowerErrorHandling(IRModule* module, DiagnosticSink* sink) +{ + ErrorHandlingLoweringContext context; + context.module = module; + context.diagnosticSink = sink; + context.sharedBuilder.init(module); + return context.processModule(); +} +} diff --git a/source/slang/slang-ir-lower-error-handling.h b/source/slang/slang-ir-lower-error-handling.h new file mode 100644 index 000000000..92a65404c --- /dev/null +++ b/source/slang/slang-ir-lower-error-handling.h @@ -0,0 +1,18 @@ +// slang-ir-lower-error-handling.h +#pragma once + +namespace Slang +{ + +struct IRModule; +class DiagnosticSink; + +/// Lower error handling related opcodes and function calls to use standard control flow. +/// A function with an error code type will be translated into a function that returns `Result<T,E>`, which can be +/// further lowered to standard return values and `out` parameters in a separate pass. +/// Call sites (`IRTryCall`) to error-throwing function will be rewritten to conform to the new function signature. +/// `IRThrow` will be replaced with `IRReturnVal(IRMakeErrorResult(e))`. +/// +void lowerErrorHandling(IRModule* module, DiagnosticSink* sink); + +} diff --git a/source/slang/slang-ir-lower-existential.cpp b/source/slang/slang-ir-lower-existential.cpp index dfa714a82..b0d9e6f2f 100644 --- a/source/slang/slang-ir-lower-existential.cpp +++ b/source/slang/slang-ir-lower-existential.cpp @@ -106,7 +106,7 @@ namespace Slang builder->setInsertBefore(extractInst); IRInst* element = nullptr; - if (extractInst->getOperand(0)->getDataType()->findDecoration<IRComInterfaceDecoration>()) + if (isComInterfaceType(extractInst->getOperand(0)->getDataType())) { // If this is an COM interface, the elements (witness table/rtti) are just the interface value itself. element = extractInst->getOperand(0); diff --git a/source/slang/slang-ir-lower-generic-call.cpp b/source/slang/slang-ir-lower-generic-call.cpp index da9f54764..7dbe11f52 100644 --- a/source/slang/slang-ir-lower-generic-call.cpp +++ b/source/slang/slang-ir-lower-generic-call.cpp @@ -276,9 +276,7 @@ namespace Slang // all occurences of associatedtypes. // If `w` in `lookup_interface_method(w, ...)` is a COM interface, bail. - if (lookupInst->getWitnessTable() - ->getDataType() - ->findDecoration<IRComInterfaceDecoration>()) + if (isComInterfaceType(lookupInst->getWitnessTable()->getDataType())) { return; } diff --git a/source/slang/slang-ir-lower-generic-function.cpp b/source/slang/slang-ir-lower-generic-function.cpp index 600f21b20..a2636e8ed 100644 --- a/source/slang/slang-ir-lower-generic-function.cpp +++ b/source/slang/slang-ir-lower-generic-function.cpp @@ -153,7 +153,7 @@ namespace Slang if (isBuiltin(interfaceType)) return interfaceType; // Do not lower COM interfaces. - if (interfaceType->findDecoration<IRComInterfaceDecoration>()) + if (isComInterfaceType(interfaceType)) return interfaceType; List<IRInterfaceRequirementEntry*> newEntries; diff --git a/source/slang/slang-ir-lower-generics.cpp b/source/slang/slang-ir-lower-generics.cpp index 5d8e6d929..703441252 100644 --- a/source/slang/slang-ir-lower-generics.cpp +++ b/source/slang/slang-ir-lower-generics.cpp @@ -137,7 +137,7 @@ namespace Slang { auto witnessTableType = lookupWitnessMethod->getWitnessTable()->getDataType(); auto interfaceType = cast<IRWitnessTableType>(witnessTableType)->getConformanceType(); - if (interfaceType->findDecoration<IRComInterfaceDecoration>()) + if (isComInterfaceType((IRType*)interfaceType)) return; if (!implementedInterfaces.Contains(interfaceType)) { diff --git a/source/slang/slang-ir-lower-result-type.cpp b/source/slang/slang-ir-lower-result-type.cpp new file mode 100644 index 000000000..e46a0ceb5 --- /dev/null +++ b/source/slang/slang-ir-lower-result-type.cpp @@ -0,0 +1,317 @@ +// slang-ir-lower-result-type.cpp + +#include "slang-ir-lower-result-type.h" +#include "slang-ir.h" +#include "slang-ir-insts.h" + +namespace Slang +{ + struct ResultTypeLoweringContext + { + IRModule* module; + DiagnosticSink* sink; + + SharedIRBuilder sharedBuilderStorage; + + List<IRInst*> workList; + HashSet<IRInst*> workListSet; + + struct LoweredResultTypeInfo : public RefObject + { + IRType* resultType = nullptr; + IRType* errorType = nullptr; + IRType* valueType = nullptr; + IRType* loweredType = nullptr; + IRStructField* valueField = nullptr; + IRStructField* errorField = nullptr; + }; + Dictionary<IRInst*, RefPtr<LoweredResultTypeInfo>> mapLoweredTypeToResultTypeInfo; + Dictionary<IRInst*, RefPtr<LoweredResultTypeInfo>> loweredResultTypes; + + IRType* maybeLowerResultType(IRBuilder* builder, IRType* type) + { + if (auto info = getLoweredResultType(builder, type)) + return info->loweredType; + else + return type; + } + + LoweredResultTypeInfo* getLoweredResultType(IRBuilder* builder, IRInst* type) + { + if (auto loweredInfo = loweredResultTypes.TryGetValue(type)) + return loweredInfo->Ptr(); + if (auto loweredInfo = mapLoweredTypeToResultTypeInfo.TryGetValue(type)) + return loweredInfo->Ptr(); + + if (!type) + return nullptr; + if (type->getOp() != kIROp_ResultType) + return nullptr; + + RefPtr<LoweredResultTypeInfo> info = new LoweredResultTypeInfo(); + info->resultType = (IRType*)type; + info->errorType = cast<IRResultType>(type)->getErrorType(); + + auto resultType = cast<IRResultType>(type); + auto valueType = resultType->getValueType(); + if (valueType->getOp() != kIROp_VoidType) + { + auto structType = builder->createStructType(); + info->loweredType = structType; + builder->addNameHintDecoration(structType, UnownedStringSlice("ResultType")); + + info->valueType = valueType; + auto valueKey = builder->createStructKey(); + builder->addNameHintDecoration(valueKey, UnownedStringSlice("value")); + info->valueField = builder->createStructField(structType, valueKey, (IRType*)valueType); + + auto errorType = resultType->getErrorType(); + auto errorKey = builder->createStructKey(); + builder->addNameHintDecoration(errorKey, UnownedStringSlice("error")); + info->errorField = builder->createStructField(structType, errorKey, (IRType*)errorType); + } + else + { + info->loweredType = resultType->getErrorType(); + } + mapLoweredTypeToResultTypeInfo[info->loweredType] = info; + loweredResultTypes[type] = info; + return info.Ptr(); + } + + void addToWorkList( + IRInst* inst) + { + for (auto ii = inst->getParent(); ii; ii = ii->getParent()) + { + if (as<IRGeneric>(ii)) + return; + } + + if (workListSet.Contains(inst)) + return; + + workList.add(inst); + workListSet.Add(inst); + } + + IRInst* getSuccessErrorValue(IRType* type) + { + switch (type->getOp()) + { + case kIROp_Int8Type: + case kIROp_Int16Type: + case kIROp_IntType: + case kIROp_Int64Type: + case kIROp_UInt8Type: + case kIROp_UInt16Type: + case kIROp_UIntType: + case kIROp_UInt64Type: + break; + default: + SLANG_ASSERT_FAILURE("error type is not lowered to an integer type."); + } + IRBuilder builderStorage(sharedBuilderStorage); + auto builder = &builderStorage; + builder->setInsertInto(module); + return builder->getIntValue(type, 0); + } + + void processMakeResultValue(IRMakeResultValue* inst) + { + IRBuilder builderStorage(sharedBuilderStorage); + auto builder = &builderStorage; + builder->setInsertBefore(inst); + + auto info = getLoweredResultType(builder, inst->getDataType()); + List<IRInst*> operands; + operands.add(inst->getOperand(0)); + operands.add(getSuccessErrorValue(info->errorType)); + auto makeStruct = builder->emitMakeStruct(info->loweredType, operands); + inst->replaceUsesWith(makeStruct); + inst->removeAndDeallocate(); + } + + void processMakeResultValueVoid(IRMakeResultValueVoid* inst) + { + IRBuilder builderStorage(sharedBuilderStorage); + auto builder = &builderStorage; + builder->setInsertBefore(inst); + + auto info = getLoweredResultType(builder, inst->getDataType()); + auto errCode = getSuccessErrorValue(info->errorType); + inst->replaceUsesWith(errCode); + inst->removeAndDeallocate(); + } + + void processMakeResultError(IRMakeResultError* inst) + { + IRBuilder builderStorage(sharedBuilderStorage); + auto builder = &builderStorage; + builder->setInsertBefore(inst); + + auto info = getLoweredResultType(builder, inst->getDataType()); + if (info->valueField) + { + List<IRInst*> operands; + operands.add(builder->emitConstructorInst(info->valueType, 0, nullptr)); + operands.add(inst->getErrorValue()); + auto makeStruct = builder->emitMakeStruct(info->loweredType, operands); + inst->replaceUsesWith(makeStruct); + } + else + { + inst->replaceUsesWith(inst->getErrorValue()); + } + inst->removeAndDeallocate(); + } + + IRInst* getResultError(IRBuilder* builder, IRInst* resultInst) + { + auto loweredResultTypeInfo = getLoweredResultType(builder, resultInst->getDataType()); + SLANG_ASSERT(loweredResultTypeInfo); + if (loweredResultTypeInfo->valueField) + { + auto value = builder->emitFieldExtract( + loweredResultTypeInfo->errorType, + resultInst, + loweredResultTypeInfo->errorField->getKey()); + return value; + } + else + { + return resultInst; + } + } + + void processGetResultError(IRGetResultError* inst) + { + IRBuilder builderStorage(sharedBuilderStorage); + auto builder = &builderStorage; + builder->setInsertBefore(inst); + + auto resultValue = inst->getResultOperand(); + auto errValue = getResultError(builder, resultValue); + inst->replaceUsesWith(errValue); + inst->removeAndDeallocate(); + } + + void processGetResultValue(IRGetResultValue* inst) + { + IRBuilder builderStorage(sharedBuilderStorage); + auto builder = &builderStorage; + builder->setInsertBefore(inst); + + auto base = inst->getResultOperand(); + auto loweredResultTypeInfo = getLoweredResultType(builder, base->getDataType()); + SLANG_ASSERT(loweredResultTypeInfo); + SLANG_ASSERT(loweredResultTypeInfo->valueField); + auto getElement = builder->emitFieldExtract( + loweredResultTypeInfo->errorType, + base, + loweredResultTypeInfo->valueField->getKey()); + inst->replaceUsesWith(getElement); + inst->removeAndDeallocate(); + } + + void processIsResultError(IRIsResultError* inst) + { + IRBuilder builderStorage(sharedBuilderStorage); + auto builder = &builderStorage; + builder->setInsertBefore(inst); + + auto base = inst->getResultOperand(); + auto loweredResultTypeInfo = getLoweredResultType(builder, base->getDataType()); + SLANG_ASSERT(loweredResultTypeInfo); + SLANG_ASSERT(loweredResultTypeInfo->valueField); + + auto resultValue = inst->getResultOperand(); + auto errValue = getResultError(builder, resultValue); + auto isSuccess = + builder->emitNeq(errValue, getSuccessErrorValue(loweredResultTypeInfo->errorType)); + inst->replaceUsesWith(isSuccess); + inst->removeAndDeallocate(); + } + + void processResultType(IRResultType* inst) + { + IRBuilder builderStorage(sharedBuilderStorage); + auto builder = &builderStorage; + builder->setInsertBefore(inst); + + auto loweredResultTypeInfo = getLoweredResultType(builder, inst); + SLANG_ASSERT(loweredResultTypeInfo); + SLANG_UNUSED(loweredResultTypeInfo); + } + + void processInst(IRInst* inst) + { + switch (inst->getOp()) + { + case kIROp_MakeResultValue: + processMakeResultValue((IRMakeResultValue*)inst); + break; + case kIROp_MakeResultValueVoid: + processMakeResultValueVoid((IRMakeResultValueVoid*)inst); + break; + case kIROp_MakeResultError: + processMakeResultError((IRMakeResultError*)inst); + break; + case kIROp_GetResultError: + processGetResultError((IRGetResultError*)inst); + break; + case kIROp_GetResultValue: + processGetResultValue((IRGetResultValue*)inst); + break; + case kIROp_IsResultError: + processIsResultError((IRIsResultError*)inst); + break; + case kIROp_ResultType: + processResultType((IRResultType*)inst); + break; + default: + break; + } + } + + void processModule() + { + SharedIRBuilder* sharedBuilder = &sharedBuilderStorage; + sharedBuilder->init(module); + + // Deduplicate equivalent types. + sharedBuilder->deduplicateAndRebuildGlobalNumberingMap(); + + addToWorkList(module->getModuleInst()); + + while (workList.getCount() != 0) + { + IRInst* inst = workList.getLast(); + + workList.removeLast(); + workListSet.Remove(inst); + + processInst(inst); + + for (auto child = inst->getLastChild(); child; child = child->getPrevInst()) + { + addToWorkList(child); + } + } + + // Replace all result types with lowered struct types. + for (auto kv : loweredResultTypes) + { + kv.Key->replaceUsesWith(kv.Value->loweredType); + } + } + }; + + void lowerResultType(IRModule* module, DiagnosticSink* sink) + { + ResultTypeLoweringContext context; + context.module = module; + context.sink = sink; + context.processModule(); + } +} diff --git a/source/slang/slang-ir-lower-result-type.h b/source/slang/slang-ir-lower-result-type.h new file mode 100644 index 000000000..f04a1a6cb --- /dev/null +++ b/source/slang/slang-ir-lower-result-type.h @@ -0,0 +1,16 @@ +// slang-ir-lower-result-type.h +#pragma once + +#include "slang-ir.h" + +namespace Slang +{ + struct IRModule; + class DiagnosticSink; + + /// Lower `IRResultType<T,E>` types to ordinary `struct`s. + void lowerResultType( + IRModule* module, + DiagnosticSink* sink); + +} diff --git a/source/slang/slang-ir-specialize-dynamic-associatedtype-lookup.cpp b/source/slang/slang-ir-specialize-dynamic-associatedtype-lookup.cpp index 6c6a7dec5..7464a1c35 100644 --- a/source/slang/slang-ir-specialize-dynamic-associatedtype-lookup.cpp +++ b/source/slang/slang-ir-specialize-dynamic-associatedtype-lookup.cpp @@ -119,7 +119,7 @@ struct AssociatedTypeLookupSpecializationContext void processLookupInterfaceMethodInst(IRLookupWitnessMethod* inst) { - if (inst->getWitnessTable()->getDataType()->findDecoration<IRComInterfaceDecoration>()) + if (isComInterfaceType(inst->getWitnessTable()->getDataType())) { return; } diff --git a/source/slang/slang-ir-witness-table-wrapper.cpp b/source/slang/slang-ir-witness-table-wrapper.cpp index c8a1e2dbe..5d1a9360e 100644 --- a/source/slang/slang-ir-witness-table-wrapper.cpp +++ b/source/slang/slang-ir-witness-table-wrapper.cpp @@ -170,7 +170,7 @@ namespace Slang auto interfaceType = cast<IRInterfaceType>(witnessTable->getConformanceType()); if (isBuiltin(interfaceType)) return; - if (interfaceType->findDecoration<IRComInterfaceDecoration>()) + if (isComInterfaceType(interfaceType)) return; // We need to consider whether the concrete type that is conforming diff --git a/source/slang/slang-ir.cpp b/source/slang/slang-ir.cpp index d454333e6..c71954346 100644 --- a/source/slang/slang-ir.cpp +++ b/source/slang/slang-ir.cpp @@ -2587,6 +2587,12 @@ namespace Slang return getTupleType(SLANG_COUNT_OF(operands), operands); } + IRResultType* IRBuilder::getResultType(IRType* valueType, IRType* errorType) + { + IRInst* operands[] = {valueType, errorType}; + return (IRResultType*)getType(kIROp_ResultType, 2, operands); + } + IRBasicBlockType* IRBuilder::getBasicBlockType() { return (IRBasicBlockType*)getType(kIROp_BasicBlockType); @@ -2717,6 +2723,14 @@ namespace Slang (IRInst* const*) paramTypes); } + IRFuncType* IRBuilder::getFuncType( + UInt paramCount, IRType* const* paramTypes, IRType* resultType, IRAttr* attribute) + { + UInt counts[3] = {1, paramCount, 1}; + IRInst** lists[3] = {(IRInst**)&resultType, (IRInst**)paramTypes, (IRInst**)&attribute}; + return (IRFuncType*)findOrEmitHoistableInst(nullptr, kIROp_FuncType, 3, counts, lists); + } + IRWitnessTableType* IRBuilder::getWitnessTableType( IRType* baseType) { @@ -3088,6 +3102,21 @@ namespace Slang return inst; } + IRInst* IRBuilder::emitTryCallInst( + IRType* type, + IRBlock* successBlock, + IRBlock* failureBlock, + IRInst* func, + UInt argCount, + IRInst* const* args) + { + IRInst* fixedArgs[] = {successBlock, failureBlock, func}; + auto inst = createInstWithTrailingArgs<IRTryCall>( + this, kIROp_TryCall, type, 3, fixedArgs, argCount, args); + addInst(inst); + return inst; + } + IRInst* IRBuilder::createIntrinsicInst( IRType* type, IROp op, @@ -3183,6 +3212,46 @@ namespace Slang return emitIntrinsicInst(type, kIROp_GetTupleElement, 2, args); } + IRInst* IRBuilder::emitMakeResultError(IRType* resultType, IRInst* errorVal) + { + return emitIntrinsicInst(resultType, kIROp_MakeResultError, 1, &errorVal); + } + + IRInst* IRBuilder::emitMakeResultValue(IRType* resultType, IRInst* value) + { + return emitIntrinsicInst(resultType, kIROp_MakeResultValue, 1, &value); + } + + IRInst* IRBuilder::emitMakeResultValueVoid(IRType* resultType) + { + return emitIntrinsicInst(resultType, kIROp_MakeResultValueVoid, 0, nullptr); + } + + IRInst* IRBuilder::emitIsResultError(IRInst* result) + { + return emitIntrinsicInst(getBoolType(), kIROp_IsResultError, 1, &result); + } + + IRInst* IRBuilder::emitGetResultError(IRInst* result) + { + SLANG_ASSERT(result->getDataType()); + return emitIntrinsicInst( + cast<IRResultType>(result->getDataType())->getErrorType(), + kIROp_GetResultError, + 1, + &result); + } + + IRInst* IRBuilder::emitGetResultValue(IRInst* result) + { + SLANG_ASSERT(result->getDataType()); + return emitIntrinsicInst( + cast<IRResultType>(result->getDataType())->getValueType(), + kIROp_GetResultValue, + 1, + &result); + } + IRInst* IRBuilder::emitMakeVector( IRType* type, UInt argCount, @@ -3880,6 +3949,13 @@ namespace Slang return inst; } + IRInst* IRBuilder::emitThrow(IRInst* val) + { + auto inst = createInst<IRThrow>(this, kIROp_Throw, nullptr, val); + addInst(inst); + return inst; + } + IRInst* IRBuilder::emitUnreachable() { auto inst = createInst<IRUnreachable>( @@ -4207,6 +4283,20 @@ namespace Slang return inst; } + IRInst* IRBuilder::emitEql(IRInst* left, IRInst* right) + { + auto inst = createInst<IRInst>(this, kIROp_Eql, getBoolType(), left, right); + addInst(inst); + return inst; + } + + IRInst* IRBuilder::emitNeq(IRInst* left, IRInst* right) + { + auto inst = createInst<IRInst>(this, kIROp_Neq, getBoolType(), left, right); + addInst(inst); + return inst; + } + IRInst* IRBuilder::emitMul(IRType* type, IRInst* left, IRInst* right) { auto inst = createInst<IRInst>( diff --git a/source/slang/slang-ir.h b/source/slang/slang-ir.h index 6c766542f..6593e4409 100644 --- a/source/slang/slang-ir.h +++ b/source/slang/slang-ir.h @@ -1458,6 +1458,15 @@ struct IRTupleType : IRType IR_LEAF_ISA(TupleType) }; +/// Represents an `Result<T,E>`, used by functions that throws error codes. +struct IRResultType : IRType +{ + IR_LEAF_ISA(ResultType) + + IRType* getValueType() { return (IRType*)getOperand(0); } + IRType* getErrorType() { return (IRType*)getOperand(1); } +}; + struct IRTypeType : IRType { IR_LEAF_ISA(TypeType); diff --git a/source/slang/slang-lower-to-ir.cpp b/source/slang/slang-lower-to-ir.cpp index 5cfb07c1c..1448139fb 100644 --- a/source/slang/slang-lower-to-ir.cpp +++ b/source/slang/slang-lower-to-ir.cpp @@ -16,6 +16,7 @@ #include "slang-ir-validate.h" #include "slang-ir-string-hash.h" #include "slang-ir-clone.h" +#include "slang-ir-lower-error-handling.h" #include "slang-mangle.h" #include "slang-type-layout.h" @@ -612,6 +613,12 @@ int32_t getIntrinsicOp( return int32_t(irOp); } +struct TryClauseEnvironment +{ + TryClauseType clauseType = TryClauseType::None; + IRBlock* catchBlock = nullptr; +}; + // Given a `LoweredValInfo` for something callable, along with a // bunch of arguments, emit an appropriate call to it. LoweredValInfo emitCallToVal( @@ -619,7 +626,8 @@ LoweredValInfo emitCallToVal( IRType* type, LoweredValInfo funcVal, UInt argCount, - IRInst* const* args) + IRInst* const* args, + const TryClauseEnvironment& tryEnv) { auto builder = context->irBuilder; switch (funcVal.flavor) @@ -627,8 +635,33 @@ LoweredValInfo emitCallToVal( case LoweredValInfo::Flavor::None: SLANG_UNEXPECTED("null function"); default: - return LoweredValInfo::simple( - builder->emitCallInst(type, getSimpleVal(context, funcVal), argCount, args)); + switch (tryEnv.clauseType) + { + case TryClauseType::None: + return LoweredValInfo::simple( + builder->emitCallInst(type, getSimpleVal(context, funcVal), argCount, args)); + + case TryClauseType::Standard: + { + auto callee = getSimpleVal(context, funcVal); + auto succBlock = builder->createBlock(); + auto failBlock = builder->createBlock(); + auto funcType = as<IRFuncType>(callee->getDataType()); + auto throwAttr = funcType->findAttr<IRFuncThrowTypeAttr>(); + assert(throwAttr); + auto voidType = builder->getVoidType(); + builder->emitTryCallInst(voidType, succBlock, failBlock, callee, argCount, args); + builder->insertBlock(failBlock); + auto errParam = builder->emitParam(throwAttr->getErrorType()); + builder->emitThrow(errParam); + builder->insertBlock(succBlock); + auto value = builder->emitParam(type); + return LoweredValInfo::simple(value); + } + break; + default: + SLANG_UNIMPLEMENTED_X("emitCallToVal(tryClauseType)"); + } } } @@ -655,7 +688,8 @@ LoweredValInfo emitCallToDeclRef( DeclRef<Decl> funcDeclRef, IRType* funcType, UInt argCount, - IRInst* const* args) + IRInst* const* args, + const TryClauseEnvironment& tryEnv) { SLANG_ASSERT(funcType); @@ -700,7 +734,7 @@ LoweredValInfo emitCallToDeclRef( // Fallback case is to emit an actual call. // LoweredValInfo funcVal = emitDeclRef(context, funcDeclRef, funcType); - return emitCallToVal(context, type, funcVal, argCount, args); + return emitCallToVal(context, type, funcVal, argCount, args, tryEnv); } LoweredValInfo emitCallToDeclRef( @@ -708,9 +742,17 @@ LoweredValInfo emitCallToDeclRef( IRType* type, DeclRef<Decl> funcDeclRef, IRType* funcType, - List<IRInst*> const& args) + List<IRInst*> const& args, + const TryClauseEnvironment& tryEnv) { - return emitCallToDeclRef(context, type, funcDeclRef, funcType, args.getCount(), args.getBuffer()); + return emitCallToDeclRef( + context, + type, + funcDeclRef, + funcType, + args.getCount(), + args.getBuffer(), + tryEnv); } /// Represents the "direction" that a parameter is being passed (e.g., `in` or `out` @@ -1580,10 +1622,21 @@ struct ValLoweringVisitor : ValVisitor<ValLoweringVisitor, LoweredValInfo, Lower { paramTypes.add(lowerType(context, type->getParamType(pp))); } - return getBuilder()->getFuncType( - paramCount, - paramTypes.getBuffer(), - resultType); + if (type->errorType->equals(context->astBuilder->getVoidType())) + { + return getBuilder()->getFuncType( + paramCount, + paramTypes.getBuffer(), + resultType); + } + else + { + auto errorType = lowerType(context, type->getErrorType()); + auto irThrowFuncTypeAttribute = + getBuilder()->getAttr(kIROp_FuncThrowTypeAttr, 1, (IRInst**)&errorType); + return getBuilder()->getFuncType( + paramCount, paramTypes.getBuffer(), resultType, irThrowFuncTypeAttribute); + } } IRType* visitPtrType(PtrType* type) @@ -2784,11 +2837,20 @@ void _lowerFuncDeclBaseTypeInfo( // being accessed, rather than a simple value. irResultType = builder->getPtrType(irResultType); } - - outInfo.type = builder->getFuncType( - paramTypes.getCount(), - paramTypes.getBuffer(), - irResultType); + + auto errorType = lowerType(context, getErrorCodeType(context->astBuilder, declRef)); + if (errorType->getOp() != kIROp_VoidType) + { + IRAttr* throwTypeAttr = nullptr; + throwTypeAttr = builder->getAttr(kIROp_FuncThrowTypeAttr, 1, (IRInst**)&errorType); + outInfo.type = builder->getFuncType( + paramTypes.getCount(), paramTypes.getBuffer(), irResultType, throwTypeAttr); + } + else + { + outInfo.type = + builder->getFuncType(paramTypes.getCount(), paramTypes.getBuffer(), irResultType); + } } static LoweredValInfo _emitCallToAccessor( @@ -2824,7 +2886,8 @@ static LoweredValInfo _emitCallToAccessor( accessorDeclRef, info.type, allArgs.getCount(), - allArgs.getBuffer()); + allArgs.getBuffer(), + TryClauseEnvironment()); applyOutArgumentFixups(context, fixups); @@ -3524,6 +3587,11 @@ struct ExprLoweringVisitorBase : ExprVisitor<Derived, LoweredValInfo> LoweredValInfo visitInvokeExpr(InvokeExpr* expr) { + return visitInvokeExprImpl(expr, TryClauseEnvironment()); + } + + LoweredValInfo visitInvokeExprImpl(InvokeExpr* expr, const TryClauseEnvironment& tryEnv) + { auto type = lowerType(context, expr->type); // We are going to look at the syntactic form of @@ -3636,7 +3704,8 @@ struct ExprLoweringVisitorBase : ExprVisitor<Derived, LoweredValInfo> type, funcDeclRef, funcType, - irArgs); + irArgs, + tryEnv); applyOutArgumentFixups(context, argFixups); return result; } @@ -3658,6 +3727,16 @@ struct ExprLoweringVisitorBase : ExprVisitor<Derived, LoweredValInfo> UNREACHABLE_RETURN(LoweredValInfo()); } + /// Emit code for a `try` invoke. + LoweredValInfo visitTryExpr(TryExpr* expr) + { + auto invokeExpr = as<InvokeExpr>(expr->base); + assert(invokeExpr); + TryClauseEnvironment tryEnv; + tryEnv.clauseType = expr->tryClauseType; + return visitInvokeExprImpl(invokeExpr, tryEnv); + } + /// Emit code to cast `value` to a concrete `superType` (e.g., a `struct`). /// /// The `subTypeWitness` is expected to witness the sub-type relationship @@ -8338,7 +8417,13 @@ RefPtr<IRModule> generateIRForTranslationUnit( // dumpIR(module); - // First, inline calls to any functions that have been + // First, lower error handling logic into normal control flow. + // This includes lowering throwing functions into functions that + // returns a `Result<T,E>` value, translating `tryCall` into + // normal `call` + `ifElse`, etc. + lowerErrorHandling(module, compileRequest->getSink()); + + // Next, inline calls to any functions that have been // marked for mandatory "early" inlining. // performMandatoryEarlyInlining(module); diff --git a/source/slang/slang-parser.cpp b/source/slang/slang-parser.cpp index db532a601..2604ffd9b 100644 --- a/source/slang/slang-parser.cpp +++ b/source/slang/slang-parser.cpp @@ -1485,6 +1485,12 @@ namespace Slang parser->PushScope(decl); parseParameterList(parser, decl); + + if (AdvanceIf(parser, "throws")) + { + decl->errorType = parser->ParseTypeExp(); + } + _parseOptSemantics(parser, decl); decl->body = parseOptBody(parser); @@ -3383,6 +3389,10 @@ namespace Slang { parser->PushScope(decl); parseModernParamList(parser, decl); + if (AdvanceIf(parser, "throws")) + { + decl->errorType = parser->ParseTypeExp(); + } if(AdvanceIf(parser, TokenType::RightArrow)) { decl->returnType = parser->ParseTypeExp(); @@ -4860,6 +4870,15 @@ namespace Slang return parser->astBuilder->create<NullPtrLiteralExpr>(); } + static NodeBase* parseTryExpr(Parser* parser, void* /*userData*/) + { + auto tryExpr = parser->astBuilder->create<TryExpr>(); + tryExpr->tryClauseType = TryClauseType::Standard; + tryExpr->base = parser->ParseLeafExpression(); + tryExpr->scope = parser->currentScope; + return tryExpr; + } + static bool _isFinite(double value) { // Lets type pun double to uint64_t, so we can detect special double values @@ -6263,7 +6282,7 @@ namespace Slang // keyword (no further tokens expected/allowed), // and which can be represented just by creating // a new AST node of the corresponding type. - + _makeParseModifier("in", InModifier::kReflectClassInfo), _makeParseModifier("input", InputModifier::kReflectClassInfo), _makeParseModifier("out", OutModifier::kReflectClassInfo), @@ -6336,6 +6355,7 @@ namespace Slang _makeParseExpr("true", parseTrueExpr), _makeParseExpr("false", parseFalseExpr), _makeParseExpr("nullptr", parseNullPtrExpr), + _makeParseExpr("try", parseTryExpr), _makeParseExpr("__TaggedUnion", parseTaggedUnionType), }; diff --git a/source/slang/slang-serialize-type-info.h b/source/slang/slang-serialize-type-info.h index 7ed45bb0b..c80eb8051 100644 --- a/source/slang/slang-serialize-type-info.h +++ b/source/slang/slang-serialize-type-info.h @@ -142,6 +142,12 @@ struct SerialTypeInfo<bool> } }; +// Specialization for all enum types +template<typename T> +struct SerialTypeInfo<T, typename std::enable_if<std::is_enum<T>::value>::type> + : public SerialIdentityTypeInfo<T> +{}; + // Pointer // Could handle different pointer base types with some more template magic here, but instead went with Pointer type to keep // things simpler. diff --git a/source/slang/slang-serialize.h b/source/slang/slang-serialize.h index 990a36adc..fcd2daa1f 100644 --- a/source/slang/slang-serialize.h +++ b/source/slang/slang-serialize.h @@ -37,7 +37,7 @@ struct SerialClass; struct SerialField; // Type used to implement mechanisms to convert to and from serial types. -template <typename T> +template <typename T, typename /*enumTypeSFINAE*/ = void> struct SerialTypeInfo; enum class SerialTypeKind : uint8_t diff --git a/source/slang/slang-syntax.cpp b/source/slang/slang-syntax.cpp index b9a5b8cd5..24ccfa4a4 100644 --- a/source/slang/slang-syntax.cpp +++ b/source/slang/slang-syntax.cpp @@ -1057,6 +1057,7 @@ Index getFilterCountImpl(const ReflectClassInfo& clsInfo, MemberFilterStyle filt FuncType* funcType = astBuilder->create<FuncType>(); funcType->resultType = getResultType(astBuilder, declRef); + funcType->errorType = getErrorCodeType(astBuilder, declRef); for (auto paramDeclRef : getParameters(declRef)) { auto paramDecl = paramDeclRef.getDecl(); @@ -1269,9 +1270,27 @@ char const* getGLSLNameForImageFormat(ImageFormat format) return kImageFormatInfos[Index(format)].name.begin(); } - const ImageFormatInfo& getImageFormatInfo(ImageFormat format) - { - return kImageFormatInfos[Index(format)]; - } + +const ImageFormatInfo& getImageFormatInfo(ImageFormat format) +{ + return kImageFormatInfos[Index(format)]; +} + +char const* getTryClauseTypeName(TryClauseType c) +{ + switch (c) + { + case TryClauseType::None: + return "None"; + case TryClauseType::Standard: + return "Standard"; + case TryClauseType::Optional: + return "Optional"; + case TryClauseType::Assert: + return "Assert"; + default: + return "Unknown"; + } +} } // namespace Slang diff --git a/source/slang/slang-syntax.h b/source/slang/slang-syntax.h index c144ceb70..b4463a77d 100644 --- a/source/slang/slang-syntax.h +++ b/source/slang/slang-syntax.h @@ -162,6 +162,18 @@ namespace Slang return declRef.substitute(astBuilder, declRef.getDecl()->returnType.type); } + inline Type* getErrorCodeType(ASTBuilder* astBuilder, DeclRef<CallableDecl> const& declRef) + { + if (declRef.getDecl()->errorType.type) + { + return declRef.substitute(astBuilder, declRef.getDecl()->errorType.type); + } + else + { + return astBuilder->getVoidType(); + } + } + inline FilteredMemberRefList<ParamDecl> getParameters(DeclRef<CallableDecl> const& declRef) { return getMembersOfType<ParamDecl>(declRef); |
