summaryrefslogtreecommitdiff
path: root/source/slang/slang-ir-lower-error-handling.cpp
diff options
context:
space:
mode:
authorYong He <yonghe@outlook.com>2022-06-01 17:37:07 -0700
committerGitHub <noreply@github.com>2022-06-01 17:37:07 -0700
commit17e3b88b541ed7f45d575f0f9caaa808cd0a6619 (patch)
treeefacd5d4bf6381a5adf8055daa28f91ddc048a76 /source/slang/slang-ir-lower-error-handling.cpp
parentfa10f7dc23f8b93c0f9ef3fb5477871a20aaa974 (diff)
New language feature: basic error handling. (#2253)
* New language feature: basic error handling. * Fix. * Fix `tryCall` encoding according to code review. Co-authored-by: Yong He <yhe@nvidia.com>
Diffstat (limited to 'source/slang/slang-ir-lower-error-handling.cpp')
-rw-r--r--source/slang/slang-ir-lower-error-handling.cpp242
1 files changed, 242 insertions, 0 deletions
diff --git a/source/slang/slang-ir-lower-error-handling.cpp b/source/slang/slang-ir-lower-error-handling.cpp
new file mode 100644
index 000000000..5a1389e57
--- /dev/null
+++ b/source/slang/slang-ir-lower-error-handling.cpp
@@ -0,0 +1,242 @@
+// slang-ir-lower-error-handling.cpp
+
+#include "slang-ir-lower-error-handling.h"
+
+#include "slang-ir.h"
+#include "slang-ir-insts.h"
+
+namespace Slang
+{
+
+struct ErrorHandlingLoweringContext
+{
+ IRModule* module;
+ DiagnosticSink* diagnosticSink;
+
+ SharedIRBuilder sharedBuilder;
+
+ List<IRInst*> workList;
+ HashSet<IRInst*> workListSet;
+
+ void addToWorkList(IRInst* inst)
+ {
+ if (workListSet.Contains(inst))
+ return;
+
+ workList.add(inst);
+ workListSet.Add(inst);
+ }
+
+ void processFuncType(IRFuncType* funcType)
+ {
+ auto throwAttr = funcType->findAttr<IRFuncThrowTypeAttr>();
+ if (!throwAttr)
+ return;
+ IRBuilder builder(sharedBuilder);
+ builder.setInsertBefore(funcType);
+ auto resultType =
+ builder.getResultType(funcType->getResultType(), throwAttr->getErrorType());
+ List<IRType*> paramTypes;
+ for (UInt i = 0; i < funcType->getParamCount(); i++)
+ {
+ if (as<IRAttr>(funcType->getParamType(i)))
+ break;
+ paramTypes.add(funcType->getParamType(i));
+ }
+ auto newFuncType = builder.getFuncType(paramTypes, resultType);
+ sharedBuilder.replaceGlobalInst(funcType, newFuncType);
+ }
+
+ void processTryCall(IRTryCall* tryCall)
+ {
+ // If we see:
+ // ```
+ // value = tryCall(callee, successBlock, failBlock, args)
+ // successBlock:
+ // resultParam = IRParam<resultType>
+ // ... (uses resultParam) ...
+ // failBlock:
+ // errorParam = IRParam<errorType>
+ // (uses errorParam)
+ // ```
+ // We need to rewrite it as
+ // ```
+ // result = call(callee) : Result<callee.returnType, callee.errorType>
+ // isError = isResultError(result)
+ // ifElse(isError, failBlock, successBlock)
+ // successBlock:
+ // value = getResultValue(result) : returnType
+ // ... (replaces resultParam with value)
+ // failBlock:
+ // error = getResultError(result) : errorType
+ // ... (replaces errorParam with error)
+ // ```
+ IRFuncType* funcType = cast<IRFuncType>(tryCall->getCallee()->getDataType());
+ auto resultValueType = funcType->getResultType();
+ auto throwAttr = funcType->findAttr<IRFuncThrowTypeAttr>();
+ if (!throwAttr)
+ {
+ SLANG_ASSERT_FAILURE("tryCall applied to callee without a IRFuncThrowTypeAttr");
+ }
+ auto errorType = throwAttr->getErrorType();
+
+ IRBuilder builder(sharedBuilder);
+ builder.setInsertBefore(tryCall);
+
+ auto resultType = builder.getResultType(resultValueType, errorType);
+ List<IRInst*> args;
+ for (UInt i = 0; i < tryCall->getArgCount(); i++)
+ {
+ args.add(tryCall->getArg(i));
+ }
+ auto call = builder.emitCallInst(resultType, tryCall->getCallee(), args);
+ auto isFail = builder.emitIsResultError(call);
+ auto failBlock = tryCall->getFailureBlock();
+ auto successBlock = tryCall->getSuccessBlock();
+
+ builder.emitIf(isFail, failBlock, successBlock);
+
+ // Replace the params in failBlock to `getResultError(call)`.
+ builder.setInsertBefore(failBlock->getFirstOrdinaryInst());
+ auto errorParam = failBlock->getFirstParam();
+ auto errVal = builder.emitGetResultError(call);
+ errorParam->replaceUsesWith(errVal);
+ errorParam->removeAndDeallocate();
+
+ // Replace the params in successBlock to `getResultValue(call)`.
+ builder.setInsertBefore(successBlock->getFirstOrdinaryInst());
+ auto resultParam = successBlock->getFirstParam();
+ auto resultValue = builder.emitGetResultValue(call);
+ resultParam->replaceUsesWith(resultValue);
+ resultParam->removeAndDeallocate();
+
+ tryCall->removeAndDeallocate();
+ }
+
+ void processReturn(IRReturn* ret)
+ {
+ auto parentFunc = getParentFunc(ret);
+ if (!parentFunc)
+ return;
+ auto funcType = cast<IRFuncType>(parentFunc->getDataType());
+ auto throwAttr = funcType->findAttr<IRFuncThrowTypeAttr>();
+ if (!throwAttr)
+ return;
+
+ // If we are in a throwing function and sees a `return(val)` inst,
+ // replace it with a `return makeResultValue(val)`, so that it returns a `Result<T,E>` type.
+ IRBuilder builder(sharedBuilder);
+ builder.setInsertBefore(ret);
+ auto resultType =
+ builder.getResultType(funcType->getResultType(), throwAttr->getErrorType());
+ IRInst* resultVal = nullptr;
+ if (ret->getOp() == kIROp_ReturnVal)
+ {
+ auto val = cast<IRReturnVal>(ret)->getVal();
+ resultVal = builder.emitMakeResultValue(resultType, val);
+ }
+ else
+ {
+ resultVal = builder.emitMakeResultValueVoid(resultType);
+ }
+ builder.emitReturn(resultVal);
+ ret->removeAndDeallocate();
+ }
+
+ void processThrow(IRThrow* throwInst)
+ {
+ auto parentFunc = getParentFunc(throwInst);
+ SLANG_ASSERT(parentFunc);
+ auto funcType = cast<IRFuncType>(parentFunc->getDataType());
+ auto throwAttr = funcType->findAttr<IRFuncThrowTypeAttr>();
+ SLANG_ASSERT(throwAttr);
+
+ // If we are in a throwing function and sees a `throw(e)` inst,
+ // replace it with a `return makeResultError(e)`.
+ IRBuilder builder(sharedBuilder);
+ builder.setInsertBefore(throwInst);
+ auto resultType =
+ builder.getResultType(funcType->getResultType(), throwAttr->getErrorType());
+ IRInst* resultVal = builder.emitMakeResultError(resultType, throwInst->getValue());
+ builder.emitReturn(resultVal);
+ throwInst->removeAndDeallocate();
+ }
+
+ void processInst(IRInst* inst)
+ {
+ switch (inst->getOp())
+ {
+ case kIROp_TryCall:
+ processTryCall(cast<IRTryCall>(inst));
+ break;
+ case kIROp_ReturnVal:
+ case kIROp_ReturnVoid:
+ processReturn(cast<IRReturn>(inst));
+ break;
+ case kIROp_Throw:
+ processThrow(cast<IRThrow>(inst));
+ break;
+ default:
+ break;
+ }
+ }
+
+ void processInsts()
+ {
+ addToWorkList(module->getModuleInst());
+
+ while (workList.getCount() != 0)
+ {
+ IRInst* inst = workList.getLast();
+
+ workList.removeLast();
+ workListSet.Remove(inst);
+
+ processInst(inst);
+
+ for (auto child = inst->getLastChild(); child; child = child->getPrevInst())
+ {
+ addToWorkList(child);
+ }
+ }
+ }
+
+ void processModule()
+ {
+ // Deduplicate equivalent types.
+ sharedBuilder.deduplicateAndRebuildGlobalNumberingMap();
+
+ // Translate all IRTryCall, IRThrow, IRReturn, IRReturnVal.
+ processInsts();
+
+ // Lower all functypes.
+ // Function types with an IRThrowTypeAttribute will be translated into a normal function
+ // type that returns `Result<T,E>`.
+ List<IRFuncType*> oldFuncTypes;
+ for (auto child : module->getGlobalInsts())
+ {
+ switch (child->getOp())
+ {
+ case kIROp_FuncType:
+ oldFuncTypes.add(cast<IRFuncType>(child));
+ break;
+ default:
+ break;
+ }
+ }
+ for (auto funcType : oldFuncTypes)
+ {
+ processFuncType(funcType);
+ }
+ }
+};
+
+void lowerErrorHandling(IRModule* module, DiagnosticSink* sink)
+{
+ ErrorHandlingLoweringContext context;
+ context.module = module;
+ context.diagnosticSink = sink;
+ context.sharedBuilder.init(module);
+ return context.processModule();
+}
+}