summaryrefslogtreecommitdiffstats
path: root/source/slang
diff options
context:
space:
mode:
authorYong He <yonghe@outlook.com>2022-06-01 17:37:07 -0700
committerGitHub <noreply@github.com>2022-06-01 17:37:07 -0700
commit17e3b88b541ed7f45d575f0f9caaa808cd0a6619 (patch)
treeefacd5d4bf6381a5adf8055daa28f91ddc048a76 /source/slang
parentfa10f7dc23f8b93c0f9ef3fb5477871a20aaa974 (diff)
New language feature: basic error handling. (#2253)
* New language feature: basic error handling. * Fix. * Fix `tryCall` encoding according to code review. Co-authored-by: Yong He <yhe@nvidia.com>
Diffstat (limited to 'source/slang')
-rw-r--r--source/slang/slang-ast-decl.h3
-rw-r--r--source/slang/slang-ast-dump.cpp5
-rw-r--r--source/slang/slang-ast-expr.h22
-rw-r--r--source/slang/slang-ast-type.cpp16
-rw-r--r--source/slang/slang-ast-type.h2
-rw-r--r--source/slang/slang-check-decl.cpp21
-rw-r--r--source/slang/slang-check-expr.cpp62
-rw-r--r--source/slang/slang-check-impl.h19
-rw-r--r--source/slang/slang-check-stmt.cpp4
-rw-r--r--source/slang/slang-diagnostic-defs.h7
-rw-r--r--source/slang/slang-emit.cpp24
-rw-r--r--source/slang/slang-ir-deduplicate.cpp33
-rw-r--r--source/slang/slang-ir-generics-lowering-context.cpp10
-rw-r--r--source/slang/slang-ir-generics-lowering-context.h2
-rw-r--r--source/slang/slang-ir-inst-defs.h13
-rw-r--r--source/slang/slang-ir-insts.h102
-rw-r--r--source/slang/slang-ir-lower-error-handling.cpp242
-rw-r--r--source/slang/slang-ir-lower-error-handling.h18
-rw-r--r--source/slang/slang-ir-lower-existential.cpp2
-rw-r--r--source/slang/slang-ir-lower-generic-call.cpp4
-rw-r--r--source/slang/slang-ir-lower-generic-function.cpp2
-rw-r--r--source/slang/slang-ir-lower-generics.cpp2
-rw-r--r--source/slang/slang-ir-lower-result-type.cpp317
-rw-r--r--source/slang/slang-ir-lower-result-type.h16
-rw-r--r--source/slang/slang-ir-specialize-dynamic-associatedtype-lookup.cpp2
-rw-r--r--source/slang/slang-ir-witness-table-wrapper.cpp2
-rw-r--r--source/slang/slang-ir.cpp90
-rw-r--r--source/slang/slang-ir.h9
-rw-r--r--source/slang/slang-lower-to-ir.cpp123
-rw-r--r--source/slang/slang-parser.cpp22
-rw-r--r--source/slang/slang-serialize-type-info.h6
-rw-r--r--source/slang/slang-serialize.h2
-rw-r--r--source/slang/slang-syntax.cpp27
-rw-r--r--source/slang/slang-syntax.h12
34 files changed, 1186 insertions, 57 deletions
diff --git a/source/slang/slang-ast-decl.h b/source/slang/slang-ast-decl.h
index 5b596ded0..9c95f3520 100644
--- a/source/slang/slang-ast-decl.h
+++ b/source/slang/slang-ast-decl.h
@@ -266,6 +266,9 @@ class CallableDecl : public ContainerDecl
}
TypeExp returnType;
+
+ // If this callable throws an error code, `errorType` is the type of the error code.
+ TypeExp errorType;
// Fields related to redeclaration, so that we
// can support multiple specialized variations
diff --git a/source/slang/slang-ast-dump.cpp b/source/slang/slang-ast-dump.cpp
index 9da1a0163..d67a35174 100644
--- a/source/slang/slang-ast-dump.cpp
+++ b/source/slang/slang-ast-dump.cpp
@@ -341,7 +341,10 @@ struct ASTDumpContext
{
m_writer->emit(getGLSLNameForImageFormat(imageFormat));
}
-
+ void dump(TryClauseType clauseType)
+ {
+ m_writer->emit(getTryClauseTypeName(clauseType));
+ }
void dump(const String& string)
{
dump(string.getUnownedSlice());
diff --git a/source/slang/slang-ast-expr.h b/source/slang/slang-ast-expr.h
index a6c1d432c..f2fae7ced 100644
--- a/source/slang/slang-ast-expr.h
+++ b/source/slang/slang-ast-expr.h
@@ -144,6 +144,28 @@ class InvokeExpr: public AppExprBase
SLANG_AST_CLASS(InvokeExpr)
};
+enum class TryClauseType
+{
+ None,
+ Standard, // Normal `try` clause
+ Optional, // (Not implemented) `try?` clause that returns an optional value.
+ Assert, // (Not implemented) `try!` clause that should always succeed and triggers runtime error if failed.
+};
+
+char const* getTryClauseTypeName(TryClauseType value);
+
+class TryExpr : public Expr
+{
+ SLANG_AST_CLASS(TryExpr)
+
+ Expr* base;
+
+ TryClauseType tryClauseType = TryClauseType::Standard;
+
+ // The scope of this expr.
+ Scope* scope = nullptr;
+};
+
class OperatorExpr: public InvokeExpr
{
SLANG_AST_CLASS(OperatorExpr)
diff --git a/source/slang/slang-ast-type.cpp b/source/slang/slang-ast-type.cpp
index b6bc1f170..bcaf8f5a9 100644
--- a/source/slang/slang-ast-type.cpp
+++ b/source/slang/slang-ast-type.cpp
@@ -549,6 +549,11 @@ void FuncType::_toTextOverride(StringBuilder& out)
out << getParamType(pp);
}
out << toSlice(") -> ") << getResultType();
+
+ if (!getErrorType()->equals(getASTBuilder()->getVoidType()))
+ {
+ out << " throws " << getErrorType();
+ }
}
bool FuncType::_equalsImplOverride(Type * type)
@@ -571,6 +576,9 @@ bool FuncType::_equalsImplOverride(Type * type)
if (!resultType->equals(funcType->resultType))
return false;
+ if (!errorType->equals(funcType->errorType))
+ return false;
+
// TODO: if we ever introduce other kinds
// of qualification on function types, we'd
// want to consider it here.
@@ -586,6 +594,9 @@ Val* FuncType::_substituteImplOverride(ASTBuilder* astBuilder, SubstitutionSet s
// result type
Type* substResultType = as<Type>(resultType->substituteImpl(astBuilder, subst, &diff));
+ // error type
+ Type* substErrorType = as<Type>(errorType->substituteImpl(astBuilder, subst, &diff));
+
// parameter types
List<Type*> substParamTypes;
for (auto pp : paramTypes)
@@ -601,6 +612,7 @@ Val* FuncType::_substituteImplOverride(ASTBuilder* astBuilder, SubstitutionSet s
FuncType* substType = astBuilder->create<FuncType>();
substType->resultType = substResultType;
substType->paramTypes = substParamTypes;
+ substType->errorType = substErrorType;
return substType;
}
@@ -608,6 +620,7 @@ Type* FuncType::_createCanonicalTypeOverride()
{
// result type
Type* canResultType = resultType->getCanonicalType();
+ Type* canErrorType = errorType->getCanonicalType();
// parameter types
List<Type*> canParamTypes;
@@ -619,7 +632,7 @@ Type* FuncType::_createCanonicalTypeOverride()
FuncType* canType = getASTBuilder()->create<FuncType>();
canType->resultType = canResultType;
canType->paramTypes = canParamTypes;
-
+ canType->errorType = canErrorType;
return canType;
}
@@ -634,6 +647,7 @@ HashCode FuncType::_getHashCodeOverride()
hashCode,
getParamType(pp)->getHashCode());
}
+ combineHash(hashCode, getErrorType()->getHashCode());
return hashCode;
}
diff --git a/source/slang/slang-ast-type.h b/source/slang/slang-ast-type.h
index 7aa1a36ab..8de047af5 100644
--- a/source/slang/slang-ast-type.h
+++ b/source/slang/slang-ast-type.h
@@ -575,10 +575,12 @@ class FuncType : public Type
List<Type*> paramTypes;
Type* resultType = nullptr;
+ Type* errorType = nullptr;
UInt getParamCount() { return paramTypes.getCount(); }
Type* getParamType(UInt index) { return paramTypes[index]; }
Type* getResultType() { return resultType; }
+ Type* getErrorType() { return errorType; }
// Overrides should be public so base classes can access
void _toTextOverride(StringBuilder& out);
diff --git a/source/slang/slang-check-decl.cpp b/source/slang/slang-check-decl.cpp
index 6d579e1fb..af19f484d 100644
--- a/source/slang/slang-check-decl.cpp
+++ b/source/slang/slang-check-decl.cpp
@@ -673,7 +673,7 @@ namespace Slang
// Make sure a declaration has been checked, so we can refer to it.
// Note that this may lead to us recursively invoking checking,
// so this may not be the best way to handle things.
- void SemanticsVisitor::ensureDecl(Decl* decl, DeclCheckState state)
+ void SemanticsVisitor::ensureDecl(Decl* decl, DeclCheckState state, SemanticsContext* baseContext)
{
// If the `decl` has already been checked up to or beyond `state`
// then there is nothing for us to do.
@@ -733,7 +733,7 @@ namespace Slang
// context, so that the state at the point where a declaration is *referenced*
// cannot affect the state in which the declaration is *checked*.
//
- SemanticsContext subContext(getShared());
+ SemanticsContext subContext = baseContext ? SemanticsContext(*baseContext) : SemanticsContext(getShared());
_dispatchDeclCheckingVisitor(decl, nextState, subContext);
// In the common case, the visitor will have done the necessary
@@ -3609,17 +3609,17 @@ namespace Slang
}
}
- void SemanticsVisitor::ensureDeclBase(DeclBase* declBase, DeclCheckState state)
+ void SemanticsVisitor::ensureDeclBase(DeclBase* declBase, DeclCheckState state, SemanticsContext* baseContext)
{
if(auto decl = as<Decl>(declBase))
{
- ensureDecl(decl, state);
+ ensureDecl(decl, state, baseContext);
}
else if(auto declGroup = as<DeclGroup>(declBase))
{
for(auto dd : declGroup->decls)
{
- ensureDecl(dd, state);
+ ensureDecl(dd, state, baseContext);
}
}
else
@@ -4398,6 +4398,17 @@ namespace Slang
{
ensureDecl(paramDecl, DeclCheckState::ReadyForReference);
}
+
+ auto errorType = decl->errorType;
+ if (errorType.exp)
+ {
+ errorType = CheckProperType(errorType);
+ }
+ else
+ {
+ errorType = TypeExp(m_astBuilder->getVoidType());
+ }
+ decl->errorType = errorType;
}
void SemanticsDeclHeaderVisitor::visitFuncDecl(FuncDecl* funcDecl)
diff --git a/source/slang/slang-check-expr.cpp b/source/slang/slang-check-expr.cpp
index 317ab6a1a..4b29de1bf 100644
--- a/source/slang/slang-check-expr.cpp
+++ b/source/slang/slang-check-expr.cpp
@@ -1337,6 +1337,15 @@ namespace Slang
// if this is still an invoke expression, test arguments passed to inout/out parameter are LValues
if(auto funcType = as<FuncType>(invoke->functionExpr->type))
{
+ if (!funcType->errorType->equals(m_astBuilder->getVoidType()))
+ {
+ // If the callee throws, make sure we are inside a try clause.
+ if (m_enclosingTryClauseType == TryClauseType::None)
+ {
+ getSink()->diagnose(invoke, Diagnostics::mustUseTryClauseToCallAThrowFunc);
+ }
+ }
+
Index paramCount = funcType->getParamCount();
for (Index pp = 0; pp < paramCount; ++pp)
{
@@ -1529,6 +1538,59 @@ namespace Slang
return CheckInvokeExprWithCheckedOperands(expr);
}
+ Expr* SemanticsExprVisitor::visitTryExpr(TryExpr* expr)
+ {
+ auto prevTryClauseType = expr->tryClauseType;
+ m_enclosingTryClauseType = expr->tryClauseType;
+ expr->base = CheckTerm(expr->base);
+ m_enclosingTryClauseType = prevTryClauseType;
+ expr->type = expr->base->type;
+ if (as<ErrorType>(expr->type))
+ return expr;
+
+ auto parentFunc = this->m_parentFunc;
+ // TODO: check if the try clause is caught.
+ // For now we assume all `try`s are not caught (because we don't have catch yet).
+ if (!parentFunc)
+ {
+ getSink()->diagnose(expr, Diagnostics::uncaughtTryCallInNonThrowFunc);
+ return expr;
+ }
+ if (parentFunc->errorType->equals(m_astBuilder->getVoidType()))
+ {
+ getSink()->diagnose(expr, Diagnostics::uncaughtTryCallInNonThrowFunc);
+ return expr;
+ }
+ if (!as<InvokeExpr>(expr->base))
+ {
+ getSink()->diagnose(expr, Diagnostics::tryClauseMustApplyToInvokeExpr);
+ return expr;
+ }
+ auto base = as<InvokeExpr>(expr->base);
+ if (auto callee = as<DeclRefExpr>(base->functionExpr))
+ {
+ if (auto funcCallee = as<FuncDecl>(callee->declRef.getDecl()))
+ {
+ if (funcCallee->errorType->equals(m_astBuilder->getVoidType()))
+ {
+ getSink()->diagnose(expr, Diagnostics::tryInvokeCalleeShouldThrow, callee->declRef);
+ }
+ if (!parentFunc->errorType->equals(funcCallee->errorType))
+ {
+ getSink()->diagnose(
+ expr,
+ Diagnostics::errorTypeOfCalleeIncompatibleWithCaller,
+ callee->declRef,
+ funcCallee->errorType,
+ parentFunc->errorType);
+ }
+ return expr;
+ }
+ }
+ getSink()->diagnose(expr, Diagnostics::calleeOfTryCallMustBeFunc);
+ return expr;
+ }
+
Expr* SemanticsVisitor::MaybeDereference(Expr* inExpr)
{
Expr* expr = inExpr;
diff --git a/source/slang/slang-check-impl.h b/source/slang/slang-check-impl.h
index e6088ccca..1e549c7da 100644
--- a/source/slang/slang-check-impl.h
+++ b/source/slang/slang-check-impl.h
@@ -345,6 +345,15 @@ namespace Slang
return result;
}
+ TryClauseType getEnclosingTryClauseType() { return m_enclosingTryClauseType; }
+
+ SemanticsContext withEnclosingTryClauseType(TryClauseType tryClauseType)
+ {
+ SemanticsContext result(*this);
+ result.m_enclosingTryClauseType = tryClauseType;
+ return result;
+ }
+
/// A scope that is local to a particular expression, and
/// that can be used to allocate temporary bindings that
/// might be needed by that expression or its sub-expressions.
@@ -380,6 +389,7 @@ namespace Slang
ExprLocalScope* m_exprLocalScope = nullptr;
+
protected:
// TODO: consider making more of this state `private`...
@@ -389,6 +399,9 @@ namespace Slang
/// The linked list of lexically surrounding statements.
OuterStmtInfo* m_outerStmts = nullptr;
+ /// The type of a try clause (if any) enclosing current expr.
+ TryClauseType m_enclosingTryClauseType = TryClauseType::None;
+
ASTBuilder* m_astBuilder = nullptr;
};
@@ -558,7 +571,7 @@ namespace Slang
/// on this function to avoid blowing out the stack or (even worse
/// creating a circular dependency).
///
- void ensureDecl(Decl* decl, DeclCheckState state);
+ void ensureDecl(Decl* decl, DeclCheckState state, SemanticsContext* baseContext = nullptr);
/// Helper routine allowing `ensureDecl` to be called on a `DeclRef`
void ensureDecl(DeclRefBase const& declRef, DeclCheckState state)
@@ -572,7 +585,7 @@ namespace Slang
/// called on a `DeclGroup` this function just calls `ensureDecl()`
/// on each declaration in the group.
///
- void ensureDeclBase(DeclBase* decl, DeclCheckState state);
+ void ensureDeclBase(DeclBase* decl, DeclCheckState state, SemanticsContext* baseContext);
// A "proper" type is one that can be used as the type of an expression.
// Put simply, it can be a concrete type like `int`, or a generic
@@ -1660,6 +1673,8 @@ namespace Slang
Expr* visitTypeCastExpr(TypeCastExpr * expr);
+ Expr* visitTryExpr(TryExpr* expr);
+
//
// Some syntax nodes should not occur in the concrete input syntax,
// and will only appear *after* checking is complete. We need to
diff --git a/source/slang/slang-check-stmt.cpp b/source/slang/slang-check-stmt.cpp
index 45be1d662..57bb3b85a 100644
--- a/source/slang/slang-check-stmt.cpp
+++ b/source/slang/slang-check-stmt.cpp
@@ -51,7 +51,7 @@ namespace Slang
// local `struct` declaration, where it would have members
// that need to be recursively checked.
//
- ensureDeclBase(stmt->decl, DeclCheckState::Checked);
+ ensureDeclBase(stmt->decl, DeclCheckState::Checked, this);
}
void SemanticsStmtVisitor::visitBlockStmt(BlockStmt* stmt)
@@ -289,7 +289,7 @@ namespace Slang
{
stmt->device = CheckExpr(stmt->device);
stmt->gridDims = CheckExpr(stmt->gridDims);
- ensureDeclBase(stmt->dispatchThreadID, DeclCheckState::Checked);
+ ensureDeclBase(stmt->dispatchThreadID, DeclCheckState::Checked, this);
WithOuterStmt subContext(this, stmt);
stmt->kernelCall = subContext.CheckExpr(stmt->kernelCall);
return;
diff --git a/source/slang/slang-diagnostic-defs.h b/source/slang/slang-diagnostic-defs.h
index 57035e84e..9ab34128a 100644
--- a/source/slang/slang-diagnostic-defs.h
+++ b/source/slang/slang-diagnostic-defs.h
@@ -290,6 +290,13 @@ DIAGNOSTIC(33071, Error, expectedAStringLiteral, "expected a string literal")
DIAGNOSTIC( -1, Note, noteExplicitConversionPossible, "explicit conversion from '$0' to '$1' is possible")
DIAGNOSTIC(30080, Error, ambiguousConversion, "more than one implicit conversion exists from '$0' to '$1'")
+DIAGNOSTIC(30090, Error, tryClauseMustApplyToInvokeExpr, "expression in a 'try' clause must be a call to a function or operator overload.")
+DIAGNOSTIC(30091, Error, tryInvokeCalleeShouldThrow, "'$0' called from a 'try' clause does not throw an error, make sure the callee is marked as 'throws'")
+DIAGNOSTIC(30092, Error, calleeOfTryCallMustBeFunc, "callee in a 'try' clause must be a function")
+DIAGNOSTIC(30093, Error, uncaughtTryCallInNonThrowFunc, "the current function or environment is not declared to throw any errors, but the 'try' clause is not caught")
+DIAGNOSTIC(30094, Error, mustUseTryClauseToCallAThrowFunc, "the callee may throw an error, and therefore must be called within a 'try' clause")
+DIAGNOSTIC(30095, Error, errorTypeOfCalleeIncompatibleWithCaller, "the error type `$1` of callee `$0` is not compatible with the caller's error type `$2`.")
+
// Attributes
DIAGNOSTIC(31000, Error, unknownAttributeName, "unknown attribute '$0'")
DIAGNOSTIC(31001, Error, attributeArgumentCountMismatch, "attribute '$0' expects $1 arguments ($2 provided)")
diff --git a/source/slang/slang-emit.cpp b/source/slang/slang-emit.cpp
index 0ed17ad7c..72ad80873 100644
--- a/source/slang/slang-emit.cpp
+++ b/source/slang/slang-emit.cpp
@@ -23,6 +23,7 @@
#include "slang-ir-com-interface.h"
#include "slang-ir-lower-generics.h"
#include "slang-ir-lower-tuple-types.h"
+#include "slang-ir-lower-result-type.h"
#include "slang-ir-lower-bit-cast.h"
#include "slang-ir-lower-reinterpret.h"
#include "slang-ir-metadata.h"
@@ -200,6 +201,19 @@ Result linkAndOptimizeIR(
// un-specialized IR.
dumpIRIfEnabled(codeGenContext, irModule);
+ switch (target)
+ {
+ default:
+ break;
+ case CodeGenTarget::HostCPPSource:
+ lowerComInterfaces(irModule, sink);
+ generateDllImportFuncs(irModule, sink);
+ break;
+ }
+
+ // Lower `Result<T,E>` types into ordinary struct types.
+ lowerResultType(irModule, sink);
+
// Replace any global constants with their values.
//
replaceGlobalConstants(irModule);
@@ -689,16 +703,6 @@ Result linkAndOptimizeIR(
break;
}
- switch (target)
- {
- default:
- break;
- case CodeGenTarget::HostCPPSource:
- lowerComInterfaces(irModule, sink);
- generateDllImportFuncs(irModule, sink);
- break;
- }
-
// TODO: our current dynamic dispatch pass will remove all uses of witness tables.
// If we are going to support function-pointer based, "real" modular dynamic dispatch,
// we will need to disable this pass.
diff --git a/source/slang/slang-ir-deduplicate.cpp b/source/slang/slang-ir-deduplicate.cpp
index f2c199af6..8aef7736c 100644
--- a/source/slang/slang-ir-deduplicate.cpp
+++ b/source/slang/slang-ir-deduplicate.cpp
@@ -76,4 +76,37 @@ namespace Slang
for (auto inst : instToRemove)
inst->removeAndDeallocate();
}
+
+ void SharedIRBuilder::replaceGlobalInst(IRInst* oldInst, IRInst* newInst)
+ {
+ List<IRUse*> uses;
+ for (auto use = oldInst->firstUse; use; use = use->nextUse)
+ {
+ uses.add(use);
+ }
+
+ bool shouldUpdateGlobalNumberedCache = false;
+ for (auto use : uses)
+ {
+ use->set(newInst);
+ // depending on the type of the user inst, we may need to rebuild and update the global
+ // numbering cache.
+ if (isGloballyNumberedInst(use->getUser()))
+ {
+ shouldUpdateGlobalNumberedCache = true;
+ }
+ }
+ oldInst->removeAndDeallocate();
+ if (shouldUpdateGlobalNumberedCache)
+ {
+ deduplicateAndRebuildGlobalNumberingMap();
+ }
+ }
+
+ bool SharedIRBuilder::isGloballyNumberedInst(IRInst* inst)
+ {
+ if (!inst->getParent() || inst->getParent()->getOp() != kIROp_Module)
+ return false;
+ return m_globalValueNumberingMap.ContainsKey(IRInstKey{inst});
+ }
}
diff --git a/source/slang/slang-ir-generics-lowering-context.cpp b/source/slang/slang-ir-generics-lowering-context.cpp
index 284e1fa11..bf3b8d855 100644
--- a/source/slang/slang-ir-generics-lowering-context.cpp
+++ b/source/slang/slang-ir-generics-lowering-context.cpp
@@ -36,6 +36,12 @@ namespace Slang
return false;
}
+ bool isComInterfaceType(IRType* type)
+ {
+ return type->findDecoration<IRComInterfaceDecoration>() != nullptr ||
+ type->getOp() == kIROp_ComPtrType;
+ }
+
bool isTypeValue(IRInst* typeInst)
{
if (typeInst)
@@ -175,7 +181,7 @@ namespace Slang
if (isBuiltin(interfaceType))
return (IRType*)paramType;
- if (interfaceType->findDecoration<IRComInterfaceDecoration>())
+ if (isComInterfaceType((IRType*)interfaceType))
return (IRType*)interfaceType;
auto anyValueSize = getInterfaceAnyValueSize(
@@ -192,7 +198,7 @@ namespace Slang
if (isBuiltin(paramType))
return (IRType*)paramType;
- if (paramType->findDecoration<IRComInterfaceDecoration>())
+ if (isComInterfaceType((IRType*)paramType))
return (IRType*)paramType;
// In the dynamic-dispatch case, a value of interface type
diff --git a/source/slang/slang-ir-generics-lowering-context.h b/source/slang/slang-ir-generics-lowering-context.h
index 23bdcc78b..475464c8f 100644
--- a/source/slang/slang-ir-generics-lowering-context.h
+++ b/source/slang/slang-ir-generics-lowering-context.h
@@ -99,6 +99,8 @@ namespace Slang
bool isPolymorphicType(IRInst* typeInst);
+ bool isComInterfaceType(IRType* type);
+
// Returns true if typeInst represents a type and should be lowered into
// Ptr(RTTIType).
bool isTypeValue(IRInst* typeInst);
diff --git a/source/slang/slang-ir-inst-defs.h b/source/slang/slang-ir-inst-defs.h
index c617a0218..b1759026b 100644
--- a/source/slang/slang-ir-inst-defs.h
+++ b/source/slang/slang-ir-inst-defs.h
@@ -56,6 +56,7 @@ INST(Nop, nop, 0, 0)
INST(ConjunctionType, Conjunction, 0, 0)
INST(AttributedType, Attributed, 0, 0)
+ INST(ResultType, Result, 2, 0)
/* BindExistentialsTypeBase */
@@ -274,6 +275,12 @@ INST(makeArray, makeArray, 0, 0)
INST(makeStruct, makeStruct, 0, 0)
INST(MakeTuple, makeTuple, 0, 0)
INST(GetTupleElement, getTupleElement, 2, 0)
+INST(MakeResultValue, makeResultValue, 1, 0)
+INST(MakeResultValueVoid, makeResultValueVoid, 0, 0)
+INST(MakeResultError, makeResultError, 1, 0)
+INST(IsResultError, isResultError, 1, 0)
+INST(GetResultError, getResultError, 1, 0)
+INST(GetResultValue, getResultValue, 1, 0)
INST(Call, call, 1, 0)
@@ -435,6 +442,9 @@ INST(SwizzledStore, swizzledStore, 2, 0)
INST(ifElse, ifElse, 4, 0)
INST_RANGE(ConditionalBranch, conditionalBranch, ifElse)
+ INST(Throw, throw, 1, 0)
+ // tryCall <successBlock> <failBlock> <callee> <args>...
+ INST(TryCall, tryCall, 3, 0)
// switch <val> <break> <default> <caseVal1> <caseBlock1> ...
INST(Switch, switch, 3, 0)
@@ -730,7 +740,8 @@ INST_RANGE(Layout, VarLayout, EntryPointLayout)
INST(TypeSizeAttr, size, 2, 0)
INST(VarOffsetAttr, offset, 2, 0)
INST_RANGE(LayoutResourceInfoAttr, TypeSizeAttr, VarOffsetAttr)
-INST_RANGE(Attr, PendingLayoutAttr, VarOffsetAttr)
+ INST(FuncThrowTypeAttr, FuncThrowType, 1, 0)
+INST_RANGE(Attr, PendingLayoutAttr, FuncThrowTypeAttr)
/* Liveness */
INST(LiveRangeStart, liveRangeStart, 2, 0)
diff --git a/source/slang/slang-ir-insts.h b/source/slang/slang-ir-insts.h
index 77b3eabc0..f7f7328a1 100644
--- a/source/slang/slang-ir-insts.h
+++ b/source/slang/slang-ir-insts.h
@@ -689,6 +689,14 @@ struct IRVarOffsetAttr : public IRLayoutResourceInfoAttr
}
};
+ /// An attribute that specifies the error type a function is throwing
+struct IRFuncThrowTypeAttr : IRAttr
+{
+ IR_LEAF_ISA(FuncThrowTypeAttr)
+
+ IRType* getErrorType() { return (IRType*)getOperand(0); }
+};
+
/// An attribute that specifies size information for a single resource kind.
struct IRTypeSizeAttr : public IRLayoutResourceInfoAttr
{
@@ -1547,6 +1555,25 @@ struct IRSwitch : IRTerminatorInst
IRBlock* getCaseLabel(UInt index) { return (IRBlock*) getOperand(3 + index*2 + 1); }
};
+struct IRThrow : IRTerminatorInst
+{
+ IR_LEAF_ISA(Throw);
+
+ IRInst* getValue() { return getOperand(0); }
+};
+
+struct IRTryCall : IRTerminatorInst
+{
+ IR_LEAF_ISA(TryCall);
+
+ IRBlock* getSuccessBlock() { return cast<IRBlock>(getOperand(0)); }
+ IRBlock* getFailureBlock() { return cast<IRBlock>(getOperand(1)); }
+ IRInst* getCallee() { return getOperand(2); }
+ UInt getArgCount() { return getOperandCount() - 3; }
+ IRUse* getArgs() { return getOperands() + 3; }
+ IRInst* getArg(UInt index) { return getOperand(index + 3); }
+};
+
struct IRSwizzle : IRInst
{
IRUse base;
@@ -1758,6 +1785,52 @@ struct IRGetTupleElement : IRInst
IRInst* getElementIndex() { return getOperand(1); }
};
+// Constructs an `Result<T,E>` value from an error code.
+struct IRMakeResultError : IRInst
+{
+ IR_LEAF_ISA(MakeResultError)
+
+ IRInst* getErrorValue() { return getOperand(0); }
+};
+
+// Constructs an `Result<T,E>` value from an valid value.
+struct IRMakeResultValue : IRInst
+{
+ IR_LEAF_ISA(MakeResultValue)
+
+ IRInst* getValue() { return getOperand(0); }
+};
+
+// Constructs an `Result<void,E>` value that represents success in a function that returns `void`.
+struct IRMakeResultValueVoid : IRInst
+{
+ IR_LEAF_ISA(MakeResultValueVoid)
+};
+
+// Determines if a `Result` value represents an error.
+struct IRIsResultError : IRInst
+{
+ IR_LEAF_ISA(IsResultError)
+
+ IRInst* getResultOperand() { return getOperand(0); }
+};
+
+// Extract the value from a `Result`.
+struct IRGetResultValue : IRInst
+{
+ IR_LEAF_ISA(GetResultValue)
+
+ IRInst* getResultOperand() { return getOperand(0); }
+};
+
+// Extract the error code from a `Result`.
+struct IRGetResultError : IRInst
+{
+ IR_LEAF_ISA(GetResultError)
+
+ IRInst* getResultOperand() { return getOperand(0); }
+};
+
/// An instruction that packs a concrete value into an existential-type "box"
struct IRMakeExistential : IRInst
{
@@ -1908,12 +1981,17 @@ public:
// keys are modified (thus its hash code is changed).
void deduplicateAndRebuildGlobalNumberingMap();
+ // Replaces all uses of oldInst with newInst, and ensures the global numbering map is valid after the replacement.
+ void replaceGlobalInst(IRInst* oldInst, IRInst* newInst);
+
typedef Dictionary<IRInstKey, IRInst*> GlobalValueNumberingMap;
typedef Dictionary<IRConstantKey, IRConstant*> ConstantMap;
GlobalValueNumberingMap& getGlobalValueNumberingMap() { return m_globalValueNumberingMap; }
ConstantMap& getConstantMap() { return m_constantMap; }
+ bool isGloballyNumberedInst(IRInst* inst);
+
private:
// The module that will own all of the IR
IRModule* m_module;
@@ -2104,6 +2182,8 @@ public:
IRTupleType* getTupleType(IRType* type0, IRType* type1, IRType* type2);
IRTupleType* getTupleType(IRType* type0, IRType* type1, IRType* type2, IRType* type3);
+ IRResultType* getResultType(IRType* valueType, IRType* errorType);
+
IRBasicBlockType* getBasicBlockType();
IRWitnessTableType* getWitnessTableType(IRType* baseType);
IRWitnessTableIDType* getWitnessTableIDType(IRType* baseType);
@@ -2152,6 +2232,9 @@ public:
IRType* resultType);
IRFuncType* getFuncType(
+ UInt paramCount, IRType* const* paramTypes, IRType* resultType, IRAttr* attribute);
+
+ IRFuncType* getFuncType(
List<IRType*> const& paramTypes,
IRType* resultType)
{
@@ -2279,6 +2362,14 @@ public:
return emitCallInst(type, func, args.getCount(), args.getBuffer());
}
+ IRInst* emitTryCallInst(
+ IRType* type,
+ IRBlock* successBlock,
+ IRBlock* failureBlock,
+ IRInst* func,
+ UInt argCount,
+ IRInst* const* args);
+
IRInst* createIntrinsicInst(
IRType* type,
IROp op,
@@ -2322,6 +2413,13 @@ public:
IRInst* emitGetTupleElement(IRType* type, IRInst* tuple, UInt element);
+ IRInst* emitMakeResultError(IRType* resultType, IRInst* errorVal);
+ IRInst* emitMakeResultValue(IRType* resultType, IRInst* val);
+ IRInst* emitMakeResultValueVoid(IRType* resultType);
+ IRInst* emitIsResultError(IRInst* result);
+ IRInst* emitGetResultError(IRInst* result);
+ IRInst* emitGetResultValue(IRInst* result);
+
IRInst* emitMakeVector(
IRType* type,
UInt argCount,
@@ -2590,6 +2688,8 @@ public:
IRInst* emitReturn();
+ IRInst* emitThrow(IRInst* val);
+
IRInst* emitDiscard();
IRInst* emitUnreachable();
@@ -2687,6 +2787,8 @@ public:
IRInst* emitAdd(IRType* type, IRInst* left, IRInst* right);
IRInst* emitMul(IRType* type, IRInst* left, IRInst* right);
+ IRInst* emitEql(IRInst* left, IRInst* right);
+ IRInst* emitNeq(IRInst* left, IRInst* right);
IRInst* emitShr(IRType* type, IRInst* op0, IRInst* op1);
IRInst* emitShl(IRType* type, IRInst* op0, IRInst* op1);
diff --git a/source/slang/slang-ir-lower-error-handling.cpp b/source/slang/slang-ir-lower-error-handling.cpp
new file mode 100644
index 000000000..5a1389e57
--- /dev/null
+++ b/source/slang/slang-ir-lower-error-handling.cpp
@@ -0,0 +1,242 @@
+// slang-ir-lower-error-handling.cpp
+
+#include "slang-ir-lower-error-handling.h"
+
+#include "slang-ir.h"
+#include "slang-ir-insts.h"
+
+namespace Slang
+{
+
+struct ErrorHandlingLoweringContext
+{
+ IRModule* module;
+ DiagnosticSink* diagnosticSink;
+
+ SharedIRBuilder sharedBuilder;
+
+ List<IRInst*> workList;
+ HashSet<IRInst*> workListSet;
+
+ void addToWorkList(IRInst* inst)
+ {
+ if (workListSet.Contains(inst))
+ return;
+
+ workList.add(inst);
+ workListSet.Add(inst);
+ }
+
+ void processFuncType(IRFuncType* funcType)
+ {
+ auto throwAttr = funcType->findAttr<IRFuncThrowTypeAttr>();
+ if (!throwAttr)
+ return;
+ IRBuilder builder(sharedBuilder);
+ builder.setInsertBefore(funcType);
+ auto resultType =
+ builder.getResultType(funcType->getResultType(), throwAttr->getErrorType());
+ List<IRType*> paramTypes;
+ for (UInt i = 0; i < funcType->getParamCount(); i++)
+ {
+ if (as<IRAttr>(funcType->getParamType(i)))
+ break;
+ paramTypes.add(funcType->getParamType(i));
+ }
+ auto newFuncType = builder.getFuncType(paramTypes, resultType);
+ sharedBuilder.replaceGlobalInst(funcType, newFuncType);
+ }
+
+ void processTryCall(IRTryCall* tryCall)
+ {
+ // If we see:
+ // ```
+ // value = tryCall(callee, successBlock, failBlock, args)
+ // successBlock:
+ // resultParam = IRParam<resultType>
+ // ... (uses resultParam) ...
+ // failBlock:
+ // errorParam = IRParam<errorType>
+ // (uses errorParam)
+ // ```
+ // We need to rewrite it as
+ // ```
+ // result = call(callee) : Result<callee.returnType, callee.errorType>
+ // isError = isResultError(result)
+ // ifElse(isError, failBlock, successBlock)
+ // successBlock:
+ // value = getResultValue(result) : returnType
+ // ... (replaces resultParam with value)
+ // failBlock:
+ // error = getResultError(result) : errorType
+ // ... (replaces errorParam with error)
+ // ```
+ IRFuncType* funcType = cast<IRFuncType>(tryCall->getCallee()->getDataType());
+ auto resultValueType = funcType->getResultType();
+ auto throwAttr = funcType->findAttr<IRFuncThrowTypeAttr>();
+ if (!throwAttr)
+ {
+ SLANG_ASSERT_FAILURE("tryCall applied to callee without a IRFuncThrowTypeAttr");
+ }
+ auto errorType = throwAttr->getErrorType();
+
+ IRBuilder builder(sharedBuilder);
+ builder.setInsertBefore(tryCall);
+
+ auto resultType = builder.getResultType(resultValueType, errorType);
+ List<IRInst*> args;
+ for (UInt i = 0; i < tryCall->getArgCount(); i++)
+ {
+ args.add(tryCall->getArg(i));
+ }
+ auto call = builder.emitCallInst(resultType, tryCall->getCallee(), args);
+ auto isFail = builder.emitIsResultError(call);
+ auto failBlock = tryCall->getFailureBlock();
+ auto successBlock = tryCall->getSuccessBlock();
+
+ builder.emitIf(isFail, failBlock, successBlock);
+
+ // Replace the params in failBlock to `getResultError(call)`.
+ builder.setInsertBefore(failBlock->getFirstOrdinaryInst());
+ auto errorParam = failBlock->getFirstParam();
+ auto errVal = builder.emitGetResultError(call);
+ errorParam->replaceUsesWith(errVal);
+ errorParam->removeAndDeallocate();
+
+ // Replace the params in successBlock to `getResultValue(call)`.
+ builder.setInsertBefore(successBlock->getFirstOrdinaryInst());
+ auto resultParam = successBlock->getFirstParam();
+ auto resultValue = builder.emitGetResultValue(call);
+ resultParam->replaceUsesWith(resultValue);
+ resultParam->removeAndDeallocate();
+
+ tryCall->removeAndDeallocate();
+ }
+
+ void processReturn(IRReturn* ret)
+ {
+ auto parentFunc = getParentFunc(ret);
+ if (!parentFunc)
+ return;
+ auto funcType = cast<IRFuncType>(parentFunc->getDataType());
+ auto throwAttr = funcType->findAttr<IRFuncThrowTypeAttr>();
+ if (!throwAttr)
+ return;
+
+ // If we are in a throwing function and sees a `return(val)` inst,
+ // replace it with a `return makeResultValue(val)`, so that it returns a `Result<T,E>` type.
+ IRBuilder builder(sharedBuilder);
+ builder.setInsertBefore(ret);
+ auto resultType =
+ builder.getResultType(funcType->getResultType(), throwAttr->getErrorType());
+ IRInst* resultVal = nullptr;
+ if (ret->getOp() == kIROp_ReturnVal)
+ {
+ auto val = cast<IRReturnVal>(ret)->getVal();
+ resultVal = builder.emitMakeResultValue(resultType, val);
+ }
+ else
+ {
+ resultVal = builder.emitMakeResultValueVoid(resultType);
+ }
+ builder.emitReturn(resultVal);
+ ret->removeAndDeallocate();
+ }
+
+ void processThrow(IRThrow* throwInst)
+ {
+ auto parentFunc = getParentFunc(throwInst);
+ SLANG_ASSERT(parentFunc);
+ auto funcType = cast<IRFuncType>(parentFunc->getDataType());
+ auto throwAttr = funcType->findAttr<IRFuncThrowTypeAttr>();
+ SLANG_ASSERT(throwAttr);
+
+ // If we are in a throwing function and sees a `throw(e)` inst,
+ // replace it with a `return makeResultError(e)`.
+ IRBuilder builder(sharedBuilder);
+ builder.setInsertBefore(throwInst);
+ auto resultType =
+ builder.getResultType(funcType->getResultType(), throwAttr->getErrorType());
+ IRInst* resultVal = builder.emitMakeResultError(resultType, throwInst->getValue());
+ builder.emitReturn(resultVal);
+ throwInst->removeAndDeallocate();
+ }
+
+ void processInst(IRInst* inst)
+ {
+ switch (inst->getOp())
+ {
+ case kIROp_TryCall:
+ processTryCall(cast<IRTryCall>(inst));
+ break;
+ case kIROp_ReturnVal:
+ case kIROp_ReturnVoid:
+ processReturn(cast<IRReturn>(inst));
+ break;
+ case kIROp_Throw:
+ processThrow(cast<IRThrow>(inst));
+ break;
+ default:
+ break;
+ }
+ }
+
+ void processInsts()
+ {
+ addToWorkList(module->getModuleInst());
+
+ while (workList.getCount() != 0)
+ {
+ IRInst* inst = workList.getLast();
+
+ workList.removeLast();
+ workListSet.Remove(inst);
+
+ processInst(inst);
+
+ for (auto child = inst->getLastChild(); child; child = child->getPrevInst())
+ {
+ addToWorkList(child);
+ }
+ }
+ }
+
+ void processModule()
+ {
+ // Deduplicate equivalent types.
+ sharedBuilder.deduplicateAndRebuildGlobalNumberingMap();
+
+ // Translate all IRTryCall, IRThrow, IRReturn, IRReturnVal.
+ processInsts();
+
+ // Lower all functypes.
+ // Function types with an IRThrowTypeAttribute will be translated into a normal function
+ // type that returns `Result<T,E>`.
+ List<IRFuncType*> oldFuncTypes;
+ for (auto child : module->getGlobalInsts())
+ {
+ switch (child->getOp())
+ {
+ case kIROp_FuncType:
+ oldFuncTypes.add(cast<IRFuncType>(child));
+ break;
+ default:
+ break;
+ }
+ }
+ for (auto funcType : oldFuncTypes)
+ {
+ processFuncType(funcType);
+ }
+ }
+};
+
+void lowerErrorHandling(IRModule* module, DiagnosticSink* sink)
+{
+ ErrorHandlingLoweringContext context;
+ context.module = module;
+ context.diagnosticSink = sink;
+ context.sharedBuilder.init(module);
+ return context.processModule();
+}
+}
diff --git a/source/slang/slang-ir-lower-error-handling.h b/source/slang/slang-ir-lower-error-handling.h
new file mode 100644
index 000000000..92a65404c
--- /dev/null
+++ b/source/slang/slang-ir-lower-error-handling.h
@@ -0,0 +1,18 @@
+// slang-ir-lower-error-handling.h
+#pragma once
+
+namespace Slang
+{
+
+struct IRModule;
+class DiagnosticSink;
+
+/// Lower error handling related opcodes and function calls to use standard control flow.
+/// A function with an error code type will be translated into a function that returns `Result<T,E>`, which can be
+/// further lowered to standard return values and `out` parameters in a separate pass.
+/// Call sites (`IRTryCall`) to error-throwing function will be rewritten to conform to the new function signature.
+/// `IRThrow` will be replaced with `IRReturnVal(IRMakeErrorResult(e))`.
+///
+void lowerErrorHandling(IRModule* module, DiagnosticSink* sink);
+
+}
diff --git a/source/slang/slang-ir-lower-existential.cpp b/source/slang/slang-ir-lower-existential.cpp
index dfa714a82..b0d9e6f2f 100644
--- a/source/slang/slang-ir-lower-existential.cpp
+++ b/source/slang/slang-ir-lower-existential.cpp
@@ -106,7 +106,7 @@ namespace Slang
builder->setInsertBefore(extractInst);
IRInst* element = nullptr;
- if (extractInst->getOperand(0)->getDataType()->findDecoration<IRComInterfaceDecoration>())
+ if (isComInterfaceType(extractInst->getOperand(0)->getDataType()))
{
// If this is an COM interface, the elements (witness table/rtti) are just the interface value itself.
element = extractInst->getOperand(0);
diff --git a/source/slang/slang-ir-lower-generic-call.cpp b/source/slang/slang-ir-lower-generic-call.cpp
index da9f54764..7dbe11f52 100644
--- a/source/slang/slang-ir-lower-generic-call.cpp
+++ b/source/slang/slang-ir-lower-generic-call.cpp
@@ -276,9 +276,7 @@ namespace Slang
// all occurences of associatedtypes.
// If `w` in `lookup_interface_method(w, ...)` is a COM interface, bail.
- if (lookupInst->getWitnessTable()
- ->getDataType()
- ->findDecoration<IRComInterfaceDecoration>())
+ if (isComInterfaceType(lookupInst->getWitnessTable()->getDataType()))
{
return;
}
diff --git a/source/slang/slang-ir-lower-generic-function.cpp b/source/slang/slang-ir-lower-generic-function.cpp
index 600f21b20..a2636e8ed 100644
--- a/source/slang/slang-ir-lower-generic-function.cpp
+++ b/source/slang/slang-ir-lower-generic-function.cpp
@@ -153,7 +153,7 @@ namespace Slang
if (isBuiltin(interfaceType))
return interfaceType;
// Do not lower COM interfaces.
- if (interfaceType->findDecoration<IRComInterfaceDecoration>())
+ if (isComInterfaceType(interfaceType))
return interfaceType;
List<IRInterfaceRequirementEntry*> newEntries;
diff --git a/source/slang/slang-ir-lower-generics.cpp b/source/slang/slang-ir-lower-generics.cpp
index 5d8e6d929..703441252 100644
--- a/source/slang/slang-ir-lower-generics.cpp
+++ b/source/slang/slang-ir-lower-generics.cpp
@@ -137,7 +137,7 @@ namespace Slang
{
auto witnessTableType = lookupWitnessMethod->getWitnessTable()->getDataType();
auto interfaceType = cast<IRWitnessTableType>(witnessTableType)->getConformanceType();
- if (interfaceType->findDecoration<IRComInterfaceDecoration>())
+ if (isComInterfaceType((IRType*)interfaceType))
return;
if (!implementedInterfaces.Contains(interfaceType))
{
diff --git a/source/slang/slang-ir-lower-result-type.cpp b/source/slang/slang-ir-lower-result-type.cpp
new file mode 100644
index 000000000..e46a0ceb5
--- /dev/null
+++ b/source/slang/slang-ir-lower-result-type.cpp
@@ -0,0 +1,317 @@
+// slang-ir-lower-result-type.cpp
+
+#include "slang-ir-lower-result-type.h"
+#include "slang-ir.h"
+#include "slang-ir-insts.h"
+
+namespace Slang
+{
+ struct ResultTypeLoweringContext
+ {
+ IRModule* module;
+ DiagnosticSink* sink;
+
+ SharedIRBuilder sharedBuilderStorage;
+
+ List<IRInst*> workList;
+ HashSet<IRInst*> workListSet;
+
+ struct LoweredResultTypeInfo : public RefObject
+ {
+ IRType* resultType = nullptr;
+ IRType* errorType = nullptr;
+ IRType* valueType = nullptr;
+ IRType* loweredType = nullptr;
+ IRStructField* valueField = nullptr;
+ IRStructField* errorField = nullptr;
+ };
+ Dictionary<IRInst*, RefPtr<LoweredResultTypeInfo>> mapLoweredTypeToResultTypeInfo;
+ Dictionary<IRInst*, RefPtr<LoweredResultTypeInfo>> loweredResultTypes;
+
+ IRType* maybeLowerResultType(IRBuilder* builder, IRType* type)
+ {
+ if (auto info = getLoweredResultType(builder, type))
+ return info->loweredType;
+ else
+ return type;
+ }
+
+ LoweredResultTypeInfo* getLoweredResultType(IRBuilder* builder, IRInst* type)
+ {
+ if (auto loweredInfo = loweredResultTypes.TryGetValue(type))
+ return loweredInfo->Ptr();
+ if (auto loweredInfo = mapLoweredTypeToResultTypeInfo.TryGetValue(type))
+ return loweredInfo->Ptr();
+
+ if (!type)
+ return nullptr;
+ if (type->getOp() != kIROp_ResultType)
+ return nullptr;
+
+ RefPtr<LoweredResultTypeInfo> info = new LoweredResultTypeInfo();
+ info->resultType = (IRType*)type;
+ info->errorType = cast<IRResultType>(type)->getErrorType();
+
+ auto resultType = cast<IRResultType>(type);
+ auto valueType = resultType->getValueType();
+ if (valueType->getOp() != kIROp_VoidType)
+ {
+ auto structType = builder->createStructType();
+ info->loweredType = structType;
+ builder->addNameHintDecoration(structType, UnownedStringSlice("ResultType"));
+
+ info->valueType = valueType;
+ auto valueKey = builder->createStructKey();
+ builder->addNameHintDecoration(valueKey, UnownedStringSlice("value"));
+ info->valueField = builder->createStructField(structType, valueKey, (IRType*)valueType);
+
+ auto errorType = resultType->getErrorType();
+ auto errorKey = builder->createStructKey();
+ builder->addNameHintDecoration(errorKey, UnownedStringSlice("error"));
+ info->errorField = builder->createStructField(structType, errorKey, (IRType*)errorType);
+ }
+ else
+ {
+ info->loweredType = resultType->getErrorType();
+ }
+ mapLoweredTypeToResultTypeInfo[info->loweredType] = info;
+ loweredResultTypes[type] = info;
+ return info.Ptr();
+ }
+
+ void addToWorkList(
+ IRInst* inst)
+ {
+ for (auto ii = inst->getParent(); ii; ii = ii->getParent())
+ {
+ if (as<IRGeneric>(ii))
+ return;
+ }
+
+ if (workListSet.Contains(inst))
+ return;
+
+ workList.add(inst);
+ workListSet.Add(inst);
+ }
+
+ IRInst* getSuccessErrorValue(IRType* type)
+ {
+ switch (type->getOp())
+ {
+ case kIROp_Int8Type:
+ case kIROp_Int16Type:
+ case kIROp_IntType:
+ case kIROp_Int64Type:
+ case kIROp_UInt8Type:
+ case kIROp_UInt16Type:
+ case kIROp_UIntType:
+ case kIROp_UInt64Type:
+ break;
+ default:
+ SLANG_ASSERT_FAILURE("error type is not lowered to an integer type.");
+ }
+ IRBuilder builderStorage(sharedBuilderStorage);
+ auto builder = &builderStorage;
+ builder->setInsertInto(module);
+ return builder->getIntValue(type, 0);
+ }
+
+ void processMakeResultValue(IRMakeResultValue* inst)
+ {
+ IRBuilder builderStorage(sharedBuilderStorage);
+ auto builder = &builderStorage;
+ builder->setInsertBefore(inst);
+
+ auto info = getLoweredResultType(builder, inst->getDataType());
+ List<IRInst*> operands;
+ operands.add(inst->getOperand(0));
+ operands.add(getSuccessErrorValue(info->errorType));
+ auto makeStruct = builder->emitMakeStruct(info->loweredType, operands);
+ inst->replaceUsesWith(makeStruct);
+ inst->removeAndDeallocate();
+ }
+
+ void processMakeResultValueVoid(IRMakeResultValueVoid* inst)
+ {
+ IRBuilder builderStorage(sharedBuilderStorage);
+ auto builder = &builderStorage;
+ builder->setInsertBefore(inst);
+
+ auto info = getLoweredResultType(builder, inst->getDataType());
+ auto errCode = getSuccessErrorValue(info->errorType);
+ inst->replaceUsesWith(errCode);
+ inst->removeAndDeallocate();
+ }
+
+ void processMakeResultError(IRMakeResultError* inst)
+ {
+ IRBuilder builderStorage(sharedBuilderStorage);
+ auto builder = &builderStorage;
+ builder->setInsertBefore(inst);
+
+ auto info = getLoweredResultType(builder, inst->getDataType());
+ if (info->valueField)
+ {
+ List<IRInst*> operands;
+ operands.add(builder->emitConstructorInst(info->valueType, 0, nullptr));
+ operands.add(inst->getErrorValue());
+ auto makeStruct = builder->emitMakeStruct(info->loweredType, operands);
+ inst->replaceUsesWith(makeStruct);
+ }
+ else
+ {
+ inst->replaceUsesWith(inst->getErrorValue());
+ }
+ inst->removeAndDeallocate();
+ }
+
+ IRInst* getResultError(IRBuilder* builder, IRInst* resultInst)
+ {
+ auto loweredResultTypeInfo = getLoweredResultType(builder, resultInst->getDataType());
+ SLANG_ASSERT(loweredResultTypeInfo);
+ if (loweredResultTypeInfo->valueField)
+ {
+ auto value = builder->emitFieldExtract(
+ loweredResultTypeInfo->errorType,
+ resultInst,
+ loweredResultTypeInfo->errorField->getKey());
+ return value;
+ }
+ else
+ {
+ return resultInst;
+ }
+ }
+
+ void processGetResultError(IRGetResultError* inst)
+ {
+ IRBuilder builderStorage(sharedBuilderStorage);
+ auto builder = &builderStorage;
+ builder->setInsertBefore(inst);
+
+ auto resultValue = inst->getResultOperand();
+ auto errValue = getResultError(builder, resultValue);
+ inst->replaceUsesWith(errValue);
+ inst->removeAndDeallocate();
+ }
+
+ void processGetResultValue(IRGetResultValue* inst)
+ {
+ IRBuilder builderStorage(sharedBuilderStorage);
+ auto builder = &builderStorage;
+ builder->setInsertBefore(inst);
+
+ auto base = inst->getResultOperand();
+ auto loweredResultTypeInfo = getLoweredResultType(builder, base->getDataType());
+ SLANG_ASSERT(loweredResultTypeInfo);
+ SLANG_ASSERT(loweredResultTypeInfo->valueField);
+ auto getElement = builder->emitFieldExtract(
+ loweredResultTypeInfo->errorType,
+ base,
+ loweredResultTypeInfo->valueField->getKey());
+ inst->replaceUsesWith(getElement);
+ inst->removeAndDeallocate();
+ }
+
+ void processIsResultError(IRIsResultError* inst)
+ {
+ IRBuilder builderStorage(sharedBuilderStorage);
+ auto builder = &builderStorage;
+ builder->setInsertBefore(inst);
+
+ auto base = inst->getResultOperand();
+ auto loweredResultTypeInfo = getLoweredResultType(builder, base->getDataType());
+ SLANG_ASSERT(loweredResultTypeInfo);
+ SLANG_ASSERT(loweredResultTypeInfo->valueField);
+
+ auto resultValue = inst->getResultOperand();
+ auto errValue = getResultError(builder, resultValue);
+ auto isSuccess =
+ builder->emitNeq(errValue, getSuccessErrorValue(loweredResultTypeInfo->errorType));
+ inst->replaceUsesWith(isSuccess);
+ inst->removeAndDeallocate();
+ }
+
+ void processResultType(IRResultType* inst)
+ {
+ IRBuilder builderStorage(sharedBuilderStorage);
+ auto builder = &builderStorage;
+ builder->setInsertBefore(inst);
+
+ auto loweredResultTypeInfo = getLoweredResultType(builder, inst);
+ SLANG_ASSERT(loweredResultTypeInfo);
+ SLANG_UNUSED(loweredResultTypeInfo);
+ }
+
+ void processInst(IRInst* inst)
+ {
+ switch (inst->getOp())
+ {
+ case kIROp_MakeResultValue:
+ processMakeResultValue((IRMakeResultValue*)inst);
+ break;
+ case kIROp_MakeResultValueVoid:
+ processMakeResultValueVoid((IRMakeResultValueVoid*)inst);
+ break;
+ case kIROp_MakeResultError:
+ processMakeResultError((IRMakeResultError*)inst);
+ break;
+ case kIROp_GetResultError:
+ processGetResultError((IRGetResultError*)inst);
+ break;
+ case kIROp_GetResultValue:
+ processGetResultValue((IRGetResultValue*)inst);
+ break;
+ case kIROp_IsResultError:
+ processIsResultError((IRIsResultError*)inst);
+ break;
+ case kIROp_ResultType:
+ processResultType((IRResultType*)inst);
+ break;
+ default:
+ break;
+ }
+ }
+
+ void processModule()
+ {
+ SharedIRBuilder* sharedBuilder = &sharedBuilderStorage;
+ sharedBuilder->init(module);
+
+ // Deduplicate equivalent types.
+ sharedBuilder->deduplicateAndRebuildGlobalNumberingMap();
+
+ addToWorkList(module->getModuleInst());
+
+ while (workList.getCount() != 0)
+ {
+ IRInst* inst = workList.getLast();
+
+ workList.removeLast();
+ workListSet.Remove(inst);
+
+ processInst(inst);
+
+ for (auto child = inst->getLastChild(); child; child = child->getPrevInst())
+ {
+ addToWorkList(child);
+ }
+ }
+
+ // Replace all result types with lowered struct types.
+ for (auto kv : loweredResultTypes)
+ {
+ kv.Key->replaceUsesWith(kv.Value->loweredType);
+ }
+ }
+ };
+
+ void lowerResultType(IRModule* module, DiagnosticSink* sink)
+ {
+ ResultTypeLoweringContext context;
+ context.module = module;
+ context.sink = sink;
+ context.processModule();
+ }
+}
diff --git a/source/slang/slang-ir-lower-result-type.h b/source/slang/slang-ir-lower-result-type.h
new file mode 100644
index 000000000..f04a1a6cb
--- /dev/null
+++ b/source/slang/slang-ir-lower-result-type.h
@@ -0,0 +1,16 @@
+// slang-ir-lower-result-type.h
+#pragma once
+
+#include "slang-ir.h"
+
+namespace Slang
+{
+ struct IRModule;
+ class DiagnosticSink;
+
+ /// Lower `IRResultType<T,E>` types to ordinary `struct`s.
+ void lowerResultType(
+ IRModule* module,
+ DiagnosticSink* sink);
+
+}
diff --git a/source/slang/slang-ir-specialize-dynamic-associatedtype-lookup.cpp b/source/slang/slang-ir-specialize-dynamic-associatedtype-lookup.cpp
index 6c6a7dec5..7464a1c35 100644
--- a/source/slang/slang-ir-specialize-dynamic-associatedtype-lookup.cpp
+++ b/source/slang/slang-ir-specialize-dynamic-associatedtype-lookup.cpp
@@ -119,7 +119,7 @@ struct AssociatedTypeLookupSpecializationContext
void processLookupInterfaceMethodInst(IRLookupWitnessMethod* inst)
{
- if (inst->getWitnessTable()->getDataType()->findDecoration<IRComInterfaceDecoration>())
+ if (isComInterfaceType(inst->getWitnessTable()->getDataType()))
{
return;
}
diff --git a/source/slang/slang-ir-witness-table-wrapper.cpp b/source/slang/slang-ir-witness-table-wrapper.cpp
index c8a1e2dbe..5d1a9360e 100644
--- a/source/slang/slang-ir-witness-table-wrapper.cpp
+++ b/source/slang/slang-ir-witness-table-wrapper.cpp
@@ -170,7 +170,7 @@ namespace Slang
auto interfaceType = cast<IRInterfaceType>(witnessTable->getConformanceType());
if (isBuiltin(interfaceType))
return;
- if (interfaceType->findDecoration<IRComInterfaceDecoration>())
+ if (isComInterfaceType(interfaceType))
return;
// We need to consider whether the concrete type that is conforming
diff --git a/source/slang/slang-ir.cpp b/source/slang/slang-ir.cpp
index d454333e6..c71954346 100644
--- a/source/slang/slang-ir.cpp
+++ b/source/slang/slang-ir.cpp
@@ -2587,6 +2587,12 @@ namespace Slang
return getTupleType(SLANG_COUNT_OF(operands), operands);
}
+ IRResultType* IRBuilder::getResultType(IRType* valueType, IRType* errorType)
+ {
+ IRInst* operands[] = {valueType, errorType};
+ return (IRResultType*)getType(kIROp_ResultType, 2, operands);
+ }
+
IRBasicBlockType* IRBuilder::getBasicBlockType()
{
return (IRBasicBlockType*)getType(kIROp_BasicBlockType);
@@ -2717,6 +2723,14 @@ namespace Slang
(IRInst* const*) paramTypes);
}
+ IRFuncType* IRBuilder::getFuncType(
+ UInt paramCount, IRType* const* paramTypes, IRType* resultType, IRAttr* attribute)
+ {
+ UInt counts[3] = {1, paramCount, 1};
+ IRInst** lists[3] = {(IRInst**)&resultType, (IRInst**)paramTypes, (IRInst**)&attribute};
+ return (IRFuncType*)findOrEmitHoistableInst(nullptr, kIROp_FuncType, 3, counts, lists);
+ }
+
IRWitnessTableType* IRBuilder::getWitnessTableType(
IRType* baseType)
{
@@ -3088,6 +3102,21 @@ namespace Slang
return inst;
}
+ IRInst* IRBuilder::emitTryCallInst(
+ IRType* type,
+ IRBlock* successBlock,
+ IRBlock* failureBlock,
+ IRInst* func,
+ UInt argCount,
+ IRInst* const* args)
+ {
+ IRInst* fixedArgs[] = {successBlock, failureBlock, func};
+ auto inst = createInstWithTrailingArgs<IRTryCall>(
+ this, kIROp_TryCall, type, 3, fixedArgs, argCount, args);
+ addInst(inst);
+ return inst;
+ }
+
IRInst* IRBuilder::createIntrinsicInst(
IRType* type,
IROp op,
@@ -3183,6 +3212,46 @@ namespace Slang
return emitIntrinsicInst(type, kIROp_GetTupleElement, 2, args);
}
+ IRInst* IRBuilder::emitMakeResultError(IRType* resultType, IRInst* errorVal)
+ {
+ return emitIntrinsicInst(resultType, kIROp_MakeResultError, 1, &errorVal);
+ }
+
+ IRInst* IRBuilder::emitMakeResultValue(IRType* resultType, IRInst* value)
+ {
+ return emitIntrinsicInst(resultType, kIROp_MakeResultValue, 1, &value);
+ }
+
+ IRInst* IRBuilder::emitMakeResultValueVoid(IRType* resultType)
+ {
+ return emitIntrinsicInst(resultType, kIROp_MakeResultValueVoid, 0, nullptr);
+ }
+
+ IRInst* IRBuilder::emitIsResultError(IRInst* result)
+ {
+ return emitIntrinsicInst(getBoolType(), kIROp_IsResultError, 1, &result);
+ }
+
+ IRInst* IRBuilder::emitGetResultError(IRInst* result)
+ {
+ SLANG_ASSERT(result->getDataType());
+ return emitIntrinsicInst(
+ cast<IRResultType>(result->getDataType())->getErrorType(),
+ kIROp_GetResultError,
+ 1,
+ &result);
+ }
+
+ IRInst* IRBuilder::emitGetResultValue(IRInst* result)
+ {
+ SLANG_ASSERT(result->getDataType());
+ return emitIntrinsicInst(
+ cast<IRResultType>(result->getDataType())->getValueType(),
+ kIROp_GetResultValue,
+ 1,
+ &result);
+ }
+
IRInst* IRBuilder::emitMakeVector(
IRType* type,
UInt argCount,
@@ -3880,6 +3949,13 @@ namespace Slang
return inst;
}
+ IRInst* IRBuilder::emitThrow(IRInst* val)
+ {
+ auto inst = createInst<IRThrow>(this, kIROp_Throw, nullptr, val);
+ addInst(inst);
+ return inst;
+ }
+
IRInst* IRBuilder::emitUnreachable()
{
auto inst = createInst<IRUnreachable>(
@@ -4207,6 +4283,20 @@ namespace Slang
return inst;
}
+ IRInst* IRBuilder::emitEql(IRInst* left, IRInst* right)
+ {
+ auto inst = createInst<IRInst>(this, kIROp_Eql, getBoolType(), left, right);
+ addInst(inst);
+ return inst;
+ }
+
+ IRInst* IRBuilder::emitNeq(IRInst* left, IRInst* right)
+ {
+ auto inst = createInst<IRInst>(this, kIROp_Neq, getBoolType(), left, right);
+ addInst(inst);
+ return inst;
+ }
+
IRInst* IRBuilder::emitMul(IRType* type, IRInst* left, IRInst* right)
{
auto inst = createInst<IRInst>(
diff --git a/source/slang/slang-ir.h b/source/slang/slang-ir.h
index 6c766542f..6593e4409 100644
--- a/source/slang/slang-ir.h
+++ b/source/slang/slang-ir.h
@@ -1458,6 +1458,15 @@ struct IRTupleType : IRType
IR_LEAF_ISA(TupleType)
};
+/// Represents an `Result<T,E>`, used by functions that throws error codes.
+struct IRResultType : IRType
+{
+ IR_LEAF_ISA(ResultType)
+
+ IRType* getValueType() { return (IRType*)getOperand(0); }
+ IRType* getErrorType() { return (IRType*)getOperand(1); }
+};
+
struct IRTypeType : IRType
{
IR_LEAF_ISA(TypeType);
diff --git a/source/slang/slang-lower-to-ir.cpp b/source/slang/slang-lower-to-ir.cpp
index 5cfb07c1c..1448139fb 100644
--- a/source/slang/slang-lower-to-ir.cpp
+++ b/source/slang/slang-lower-to-ir.cpp
@@ -16,6 +16,7 @@
#include "slang-ir-validate.h"
#include "slang-ir-string-hash.h"
#include "slang-ir-clone.h"
+#include "slang-ir-lower-error-handling.h"
#include "slang-mangle.h"
#include "slang-type-layout.h"
@@ -612,6 +613,12 @@ int32_t getIntrinsicOp(
return int32_t(irOp);
}
+struct TryClauseEnvironment
+{
+ TryClauseType clauseType = TryClauseType::None;
+ IRBlock* catchBlock = nullptr;
+};
+
// Given a `LoweredValInfo` for something callable, along with a
// bunch of arguments, emit an appropriate call to it.
LoweredValInfo emitCallToVal(
@@ -619,7 +626,8 @@ LoweredValInfo emitCallToVal(
IRType* type,
LoweredValInfo funcVal,
UInt argCount,
- IRInst* const* args)
+ IRInst* const* args,
+ const TryClauseEnvironment& tryEnv)
{
auto builder = context->irBuilder;
switch (funcVal.flavor)
@@ -627,8 +635,33 @@ LoweredValInfo emitCallToVal(
case LoweredValInfo::Flavor::None:
SLANG_UNEXPECTED("null function");
default:
- return LoweredValInfo::simple(
- builder->emitCallInst(type, getSimpleVal(context, funcVal), argCount, args));
+ switch (tryEnv.clauseType)
+ {
+ case TryClauseType::None:
+ return LoweredValInfo::simple(
+ builder->emitCallInst(type, getSimpleVal(context, funcVal), argCount, args));
+
+ case TryClauseType::Standard:
+ {
+ auto callee = getSimpleVal(context, funcVal);
+ auto succBlock = builder->createBlock();
+ auto failBlock = builder->createBlock();
+ auto funcType = as<IRFuncType>(callee->getDataType());
+ auto throwAttr = funcType->findAttr<IRFuncThrowTypeAttr>();
+ assert(throwAttr);
+ auto voidType = builder->getVoidType();
+ builder->emitTryCallInst(voidType, succBlock, failBlock, callee, argCount, args);
+ builder->insertBlock(failBlock);
+ auto errParam = builder->emitParam(throwAttr->getErrorType());
+ builder->emitThrow(errParam);
+ builder->insertBlock(succBlock);
+ auto value = builder->emitParam(type);
+ return LoweredValInfo::simple(value);
+ }
+ break;
+ default:
+ SLANG_UNIMPLEMENTED_X("emitCallToVal(tryClauseType)");
+ }
}
}
@@ -655,7 +688,8 @@ LoweredValInfo emitCallToDeclRef(
DeclRef<Decl> funcDeclRef,
IRType* funcType,
UInt argCount,
- IRInst* const* args)
+ IRInst* const* args,
+ const TryClauseEnvironment& tryEnv)
{
SLANG_ASSERT(funcType);
@@ -700,7 +734,7 @@ LoweredValInfo emitCallToDeclRef(
// Fallback case is to emit an actual call.
//
LoweredValInfo funcVal = emitDeclRef(context, funcDeclRef, funcType);
- return emitCallToVal(context, type, funcVal, argCount, args);
+ return emitCallToVal(context, type, funcVal, argCount, args, tryEnv);
}
LoweredValInfo emitCallToDeclRef(
@@ -708,9 +742,17 @@ LoweredValInfo emitCallToDeclRef(
IRType* type,
DeclRef<Decl> funcDeclRef,
IRType* funcType,
- List<IRInst*> const& args)
+ List<IRInst*> const& args,
+ const TryClauseEnvironment& tryEnv)
{
- return emitCallToDeclRef(context, type, funcDeclRef, funcType, args.getCount(), args.getBuffer());
+ return emitCallToDeclRef(
+ context,
+ type,
+ funcDeclRef,
+ funcType,
+ args.getCount(),
+ args.getBuffer(),
+ tryEnv);
}
/// Represents the "direction" that a parameter is being passed (e.g., `in` or `out`
@@ -1580,10 +1622,21 @@ struct ValLoweringVisitor : ValVisitor<ValLoweringVisitor, LoweredValInfo, Lower
{
paramTypes.add(lowerType(context, type->getParamType(pp)));
}
- return getBuilder()->getFuncType(
- paramCount,
- paramTypes.getBuffer(),
- resultType);
+ if (type->errorType->equals(context->astBuilder->getVoidType()))
+ {
+ return getBuilder()->getFuncType(
+ paramCount,
+ paramTypes.getBuffer(),
+ resultType);
+ }
+ else
+ {
+ auto errorType = lowerType(context, type->getErrorType());
+ auto irThrowFuncTypeAttribute =
+ getBuilder()->getAttr(kIROp_FuncThrowTypeAttr, 1, (IRInst**)&errorType);
+ return getBuilder()->getFuncType(
+ paramCount, paramTypes.getBuffer(), resultType, irThrowFuncTypeAttribute);
+ }
}
IRType* visitPtrType(PtrType* type)
@@ -2784,11 +2837,20 @@ void _lowerFuncDeclBaseTypeInfo(
// being accessed, rather than a simple value.
irResultType = builder->getPtrType(irResultType);
}
-
- outInfo.type = builder->getFuncType(
- paramTypes.getCount(),
- paramTypes.getBuffer(),
- irResultType);
+
+ auto errorType = lowerType(context, getErrorCodeType(context->astBuilder, declRef));
+ if (errorType->getOp() != kIROp_VoidType)
+ {
+ IRAttr* throwTypeAttr = nullptr;
+ throwTypeAttr = builder->getAttr(kIROp_FuncThrowTypeAttr, 1, (IRInst**)&errorType);
+ outInfo.type = builder->getFuncType(
+ paramTypes.getCount(), paramTypes.getBuffer(), irResultType, throwTypeAttr);
+ }
+ else
+ {
+ outInfo.type =
+ builder->getFuncType(paramTypes.getCount(), paramTypes.getBuffer(), irResultType);
+ }
}
static LoweredValInfo _emitCallToAccessor(
@@ -2824,7 +2886,8 @@ static LoweredValInfo _emitCallToAccessor(
accessorDeclRef,
info.type,
allArgs.getCount(),
- allArgs.getBuffer());
+ allArgs.getBuffer(),
+ TryClauseEnvironment());
applyOutArgumentFixups(context, fixups);
@@ -3524,6 +3587,11 @@ struct ExprLoweringVisitorBase : ExprVisitor<Derived, LoweredValInfo>
LoweredValInfo visitInvokeExpr(InvokeExpr* expr)
{
+ return visitInvokeExprImpl(expr, TryClauseEnvironment());
+ }
+
+ LoweredValInfo visitInvokeExprImpl(InvokeExpr* expr, const TryClauseEnvironment& tryEnv)
+ {
auto type = lowerType(context, expr->type);
// We are going to look at the syntactic form of
@@ -3636,7 +3704,8 @@ struct ExprLoweringVisitorBase : ExprVisitor<Derived, LoweredValInfo>
type,
funcDeclRef,
funcType,
- irArgs);
+ irArgs,
+ tryEnv);
applyOutArgumentFixups(context, argFixups);
return result;
}
@@ -3658,6 +3727,16 @@ struct ExprLoweringVisitorBase : ExprVisitor<Derived, LoweredValInfo>
UNREACHABLE_RETURN(LoweredValInfo());
}
+ /// Emit code for a `try` invoke.
+ LoweredValInfo visitTryExpr(TryExpr* expr)
+ {
+ auto invokeExpr = as<InvokeExpr>(expr->base);
+ assert(invokeExpr);
+ TryClauseEnvironment tryEnv;
+ tryEnv.clauseType = expr->tryClauseType;
+ return visitInvokeExprImpl(invokeExpr, tryEnv);
+ }
+
/// Emit code to cast `value` to a concrete `superType` (e.g., a `struct`).
///
/// The `subTypeWitness` is expected to witness the sub-type relationship
@@ -8338,7 +8417,13 @@ RefPtr<IRModule> generateIRForTranslationUnit(
// dumpIR(module);
- // First, inline calls to any functions that have been
+ // First, lower error handling logic into normal control flow.
+ // This includes lowering throwing functions into functions that
+ // returns a `Result<T,E>` value, translating `tryCall` into
+ // normal `call` + `ifElse`, etc.
+ lowerErrorHandling(module, compileRequest->getSink());
+
+ // Next, inline calls to any functions that have been
// marked for mandatory "early" inlining.
//
performMandatoryEarlyInlining(module);
diff --git a/source/slang/slang-parser.cpp b/source/slang/slang-parser.cpp
index db532a601..2604ffd9b 100644
--- a/source/slang/slang-parser.cpp
+++ b/source/slang/slang-parser.cpp
@@ -1485,6 +1485,12 @@ namespace Slang
parser->PushScope(decl);
parseParameterList(parser, decl);
+
+ if (AdvanceIf(parser, "throws"))
+ {
+ decl->errorType = parser->ParseTypeExp();
+ }
+
_parseOptSemantics(parser, decl);
decl->body = parseOptBody(parser);
@@ -3383,6 +3389,10 @@ namespace Slang
{
parser->PushScope(decl);
parseModernParamList(parser, decl);
+ if (AdvanceIf(parser, "throws"))
+ {
+ decl->errorType = parser->ParseTypeExp();
+ }
if(AdvanceIf(parser, TokenType::RightArrow))
{
decl->returnType = parser->ParseTypeExp();
@@ -4860,6 +4870,15 @@ namespace Slang
return parser->astBuilder->create<NullPtrLiteralExpr>();
}
+ static NodeBase* parseTryExpr(Parser* parser, void* /*userData*/)
+ {
+ auto tryExpr = parser->astBuilder->create<TryExpr>();
+ tryExpr->tryClauseType = TryClauseType::Standard;
+ tryExpr->base = parser->ParseLeafExpression();
+ tryExpr->scope = parser->currentScope;
+ return tryExpr;
+ }
+
static bool _isFinite(double value)
{
// Lets type pun double to uint64_t, so we can detect special double values
@@ -6263,7 +6282,7 @@ namespace Slang
// keyword (no further tokens expected/allowed),
// and which can be represented just by creating
// a new AST node of the corresponding type.
-
+
_makeParseModifier("in", InModifier::kReflectClassInfo),
_makeParseModifier("input", InputModifier::kReflectClassInfo),
_makeParseModifier("out", OutModifier::kReflectClassInfo),
@@ -6336,6 +6355,7 @@ namespace Slang
_makeParseExpr("true", parseTrueExpr),
_makeParseExpr("false", parseFalseExpr),
_makeParseExpr("nullptr", parseNullPtrExpr),
+ _makeParseExpr("try", parseTryExpr),
_makeParseExpr("__TaggedUnion", parseTaggedUnionType),
};
diff --git a/source/slang/slang-serialize-type-info.h b/source/slang/slang-serialize-type-info.h
index 7ed45bb0b..c80eb8051 100644
--- a/source/slang/slang-serialize-type-info.h
+++ b/source/slang/slang-serialize-type-info.h
@@ -142,6 +142,12 @@ struct SerialTypeInfo<bool>
}
};
+// Specialization for all enum types
+template<typename T>
+struct SerialTypeInfo<T, typename std::enable_if<std::is_enum<T>::value>::type>
+ : public SerialIdentityTypeInfo<T>
+{};
+
// Pointer
// Could handle different pointer base types with some more template magic here, but instead went with Pointer type to keep
// things simpler.
diff --git a/source/slang/slang-serialize.h b/source/slang/slang-serialize.h
index 990a36adc..fcd2daa1f 100644
--- a/source/slang/slang-serialize.h
+++ b/source/slang/slang-serialize.h
@@ -37,7 +37,7 @@ struct SerialClass;
struct SerialField;
// Type used to implement mechanisms to convert to and from serial types.
-template <typename T>
+template <typename T, typename /*enumTypeSFINAE*/ = void>
struct SerialTypeInfo;
enum class SerialTypeKind : uint8_t
diff --git a/source/slang/slang-syntax.cpp b/source/slang/slang-syntax.cpp
index b9a5b8cd5..24ccfa4a4 100644
--- a/source/slang/slang-syntax.cpp
+++ b/source/slang/slang-syntax.cpp
@@ -1057,6 +1057,7 @@ Index getFilterCountImpl(const ReflectClassInfo& clsInfo, MemberFilterStyle filt
FuncType* funcType = astBuilder->create<FuncType>();
funcType->resultType = getResultType(astBuilder, declRef);
+ funcType->errorType = getErrorCodeType(astBuilder, declRef);
for (auto paramDeclRef : getParameters(declRef))
{
auto paramDecl = paramDeclRef.getDecl();
@@ -1269,9 +1270,27 @@ char const* getGLSLNameForImageFormat(ImageFormat format)
return kImageFormatInfos[Index(format)].name.begin();
}
- const ImageFormatInfo& getImageFormatInfo(ImageFormat format)
- {
- return kImageFormatInfos[Index(format)];
- }
+
+const ImageFormatInfo& getImageFormatInfo(ImageFormat format)
+{
+ return kImageFormatInfos[Index(format)];
+}
+
+char const* getTryClauseTypeName(TryClauseType c)
+{
+ switch (c)
+ {
+ case TryClauseType::None:
+ return "None";
+ case TryClauseType::Standard:
+ return "Standard";
+ case TryClauseType::Optional:
+ return "Optional";
+ case TryClauseType::Assert:
+ return "Assert";
+ default:
+ return "Unknown";
+ }
+}
} // namespace Slang
diff --git a/source/slang/slang-syntax.h b/source/slang/slang-syntax.h
index c144ceb70..b4463a77d 100644
--- a/source/slang/slang-syntax.h
+++ b/source/slang/slang-syntax.h
@@ -162,6 +162,18 @@ namespace Slang
return declRef.substitute(astBuilder, declRef.getDecl()->returnType.type);
}
+ inline Type* getErrorCodeType(ASTBuilder* astBuilder, DeclRef<CallableDecl> const& declRef)
+ {
+ if (declRef.getDecl()->errorType.type)
+ {
+ return declRef.substitute(astBuilder, declRef.getDecl()->errorType.type);
+ }
+ else
+ {
+ return astBuilder->getVoidType();
+ }
+ }
+
inline FilteredMemberRefList<ParamDecl> getParameters(DeclRef<CallableDecl> const& declRef)
{
return getMembersOfType<ParamDecl>(declRef);