diff options
| author | Yong He <yonghe@outlook.com> | 2022-06-01 17:37:07 -0700 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2022-06-01 17:37:07 -0700 |
| commit | 17e3b88b541ed7f45d575f0f9caaa808cd0a6619 (patch) | |
| tree | efacd5d4bf6381a5adf8055daa28f91ddc048a76 /source/slang/slang-ir-lower-error-handling.cpp | |
| parent | fa10f7dc23f8b93c0f9ef3fb5477871a20aaa974 (diff) | |
New language feature: basic error handling. (#2253)
* New language feature: basic error handling.
* Fix.
* Fix `tryCall` encoding according to code review.
Co-authored-by: Yong He <yhe@nvidia.com>
Diffstat (limited to 'source/slang/slang-ir-lower-error-handling.cpp')
| -rw-r--r-- | source/slang/slang-ir-lower-error-handling.cpp | 242 |
1 files changed, 242 insertions, 0 deletions
diff --git a/source/slang/slang-ir-lower-error-handling.cpp b/source/slang/slang-ir-lower-error-handling.cpp new file mode 100644 index 000000000..5a1389e57 --- /dev/null +++ b/source/slang/slang-ir-lower-error-handling.cpp @@ -0,0 +1,242 @@ +// slang-ir-lower-error-handling.cpp + +#include "slang-ir-lower-error-handling.h" + +#include "slang-ir.h" +#include "slang-ir-insts.h" + +namespace Slang +{ + +struct ErrorHandlingLoweringContext +{ + IRModule* module; + DiagnosticSink* diagnosticSink; + + SharedIRBuilder sharedBuilder; + + List<IRInst*> workList; + HashSet<IRInst*> workListSet; + + void addToWorkList(IRInst* inst) + { + if (workListSet.Contains(inst)) + return; + + workList.add(inst); + workListSet.Add(inst); + } + + void processFuncType(IRFuncType* funcType) + { + auto throwAttr = funcType->findAttr<IRFuncThrowTypeAttr>(); + if (!throwAttr) + return; + IRBuilder builder(sharedBuilder); + builder.setInsertBefore(funcType); + auto resultType = + builder.getResultType(funcType->getResultType(), throwAttr->getErrorType()); + List<IRType*> paramTypes; + for (UInt i = 0; i < funcType->getParamCount(); i++) + { + if (as<IRAttr>(funcType->getParamType(i))) + break; + paramTypes.add(funcType->getParamType(i)); + } + auto newFuncType = builder.getFuncType(paramTypes, resultType); + sharedBuilder.replaceGlobalInst(funcType, newFuncType); + } + + void processTryCall(IRTryCall* tryCall) + { + // If we see: + // ``` + // value = tryCall(callee, successBlock, failBlock, args) + // successBlock: + // resultParam = IRParam<resultType> + // ... (uses resultParam) ... + // failBlock: + // errorParam = IRParam<errorType> + // (uses errorParam) + // ``` + // We need to rewrite it as + // ``` + // result = call(callee) : Result<callee.returnType, callee.errorType> + // isError = isResultError(result) + // ifElse(isError, failBlock, successBlock) + // successBlock: + // value = getResultValue(result) : returnType + // ... (replaces resultParam with value) + // failBlock: + // error = getResultError(result) : errorType + // ... (replaces errorParam with error) + // ``` + IRFuncType* funcType = cast<IRFuncType>(tryCall->getCallee()->getDataType()); + auto resultValueType = funcType->getResultType(); + auto throwAttr = funcType->findAttr<IRFuncThrowTypeAttr>(); + if (!throwAttr) + { + SLANG_ASSERT_FAILURE("tryCall applied to callee without a IRFuncThrowTypeAttr"); + } + auto errorType = throwAttr->getErrorType(); + + IRBuilder builder(sharedBuilder); + builder.setInsertBefore(tryCall); + + auto resultType = builder.getResultType(resultValueType, errorType); + List<IRInst*> args; + for (UInt i = 0; i < tryCall->getArgCount(); i++) + { + args.add(tryCall->getArg(i)); + } + auto call = builder.emitCallInst(resultType, tryCall->getCallee(), args); + auto isFail = builder.emitIsResultError(call); + auto failBlock = tryCall->getFailureBlock(); + auto successBlock = tryCall->getSuccessBlock(); + + builder.emitIf(isFail, failBlock, successBlock); + + // Replace the params in failBlock to `getResultError(call)`. + builder.setInsertBefore(failBlock->getFirstOrdinaryInst()); + auto errorParam = failBlock->getFirstParam(); + auto errVal = builder.emitGetResultError(call); + errorParam->replaceUsesWith(errVal); + errorParam->removeAndDeallocate(); + + // Replace the params in successBlock to `getResultValue(call)`. + builder.setInsertBefore(successBlock->getFirstOrdinaryInst()); + auto resultParam = successBlock->getFirstParam(); + auto resultValue = builder.emitGetResultValue(call); + resultParam->replaceUsesWith(resultValue); + resultParam->removeAndDeallocate(); + + tryCall->removeAndDeallocate(); + } + + void processReturn(IRReturn* ret) + { + auto parentFunc = getParentFunc(ret); + if (!parentFunc) + return; + auto funcType = cast<IRFuncType>(parentFunc->getDataType()); + auto throwAttr = funcType->findAttr<IRFuncThrowTypeAttr>(); + if (!throwAttr) + return; + + // If we are in a throwing function and sees a `return(val)` inst, + // replace it with a `return makeResultValue(val)`, so that it returns a `Result<T,E>` type. + IRBuilder builder(sharedBuilder); + builder.setInsertBefore(ret); + auto resultType = + builder.getResultType(funcType->getResultType(), throwAttr->getErrorType()); + IRInst* resultVal = nullptr; + if (ret->getOp() == kIROp_ReturnVal) + { + auto val = cast<IRReturnVal>(ret)->getVal(); + resultVal = builder.emitMakeResultValue(resultType, val); + } + else + { + resultVal = builder.emitMakeResultValueVoid(resultType); + } + builder.emitReturn(resultVal); + ret->removeAndDeallocate(); + } + + void processThrow(IRThrow* throwInst) + { + auto parentFunc = getParentFunc(throwInst); + SLANG_ASSERT(parentFunc); + auto funcType = cast<IRFuncType>(parentFunc->getDataType()); + auto throwAttr = funcType->findAttr<IRFuncThrowTypeAttr>(); + SLANG_ASSERT(throwAttr); + + // If we are in a throwing function and sees a `throw(e)` inst, + // replace it with a `return makeResultError(e)`. + IRBuilder builder(sharedBuilder); + builder.setInsertBefore(throwInst); + auto resultType = + builder.getResultType(funcType->getResultType(), throwAttr->getErrorType()); + IRInst* resultVal = builder.emitMakeResultError(resultType, throwInst->getValue()); + builder.emitReturn(resultVal); + throwInst->removeAndDeallocate(); + } + + void processInst(IRInst* inst) + { + switch (inst->getOp()) + { + case kIROp_TryCall: + processTryCall(cast<IRTryCall>(inst)); + break; + case kIROp_ReturnVal: + case kIROp_ReturnVoid: + processReturn(cast<IRReturn>(inst)); + break; + case kIROp_Throw: + processThrow(cast<IRThrow>(inst)); + break; + default: + break; + } + } + + void processInsts() + { + addToWorkList(module->getModuleInst()); + + while (workList.getCount() != 0) + { + IRInst* inst = workList.getLast(); + + workList.removeLast(); + workListSet.Remove(inst); + + processInst(inst); + + for (auto child = inst->getLastChild(); child; child = child->getPrevInst()) + { + addToWorkList(child); + } + } + } + + void processModule() + { + // Deduplicate equivalent types. + sharedBuilder.deduplicateAndRebuildGlobalNumberingMap(); + + // Translate all IRTryCall, IRThrow, IRReturn, IRReturnVal. + processInsts(); + + // Lower all functypes. + // Function types with an IRThrowTypeAttribute will be translated into a normal function + // type that returns `Result<T,E>`. + List<IRFuncType*> oldFuncTypes; + for (auto child : module->getGlobalInsts()) + { + switch (child->getOp()) + { + case kIROp_FuncType: + oldFuncTypes.add(cast<IRFuncType>(child)); + break; + default: + break; + } + } + for (auto funcType : oldFuncTypes) + { + processFuncType(funcType); + } + } +}; + +void lowerErrorHandling(IRModule* module, DiagnosticSink* sink) +{ + ErrorHandlingLoweringContext context; + context.module = module; + context.diagnosticSink = sink; + context.sharedBuilder.init(module); + return context.processModule(); +} +} |
