diff options
| author | Yong He <yonghe@outlook.com> | 2024-03-26 17:35:24 -0700 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2024-03-26 17:35:24 -0700 |
| commit | dfdf243f07c977fa59b1a5968ce053bf590f8120 (patch) | |
| tree | 6121218f9e4d664722ed6192ca08f7c0e3c1d45b /source | |
| parent | 0877d1a3e9d69fdbf4087581df96954e56e4dd97 (diff) | |
Support mutable existential parameters. (#3836)
* Support mutable existential parameters.
* Update test.
Diffstat (limited to 'source')
| -rw-r--r-- | source/slang/slang-check-conversion.cpp | 10 | ||||
| -rw-r--r-- | source/slang/slang-check-overload.cpp | 10 | ||||
| -rw-r--r-- | source/slang/slang-ir-lower-existential.cpp | 44 | ||||
| -rw-r--r-- | source/slang/slang-ir-lower-witness-lookup.cpp | 9 | ||||
| -rw-r--r-- | source/slang/slang-lower-to-ir.cpp | 11 |
5 files changed, 74 insertions, 10 deletions
diff --git a/source/slang/slang-check-conversion.cpp b/source/slang/slang-check-conversion.cpp index a2381c7f7..aeb964cc9 100644 --- a/source/slang/slang-check-conversion.cpp +++ b/source/slang/slang-check-conversion.cpp @@ -932,11 +932,15 @@ namespace Slang // to pass a value of a derived `struct` type into methods that // expect a value of its base type. // - // TODO: vet this logic for correctness. - // if (fromExpr && fromExpr->type.isLeftValue) { - (*outToExpr)->type.isLeftValue = true; + // If the original type is a concrete type and toType is an interface type, + // we need to wrap the original expression into a MakeExistential, and the + // result of MakeExistential is not an l-value. + bool toTypeIsInterface = isInterfaceType(toType); + bool fromTypeIsInterface = isInterfaceType(fromType); + if (!toTypeIsInterface || toTypeIsInterface == fromTypeIsInterface) + (*outToExpr)->type.isLeftValue = true; } } if (outCost) diff --git a/source/slang/slang-check-overload.cpp b/source/slang/slang-check-overload.cpp index 1cb6681b3..3831fed84 100644 --- a/source/slang/slang-check-overload.cpp +++ b/source/slang/slang-check-overload.cpp @@ -1914,8 +1914,16 @@ namespace Slang auto& arg = expr->arguments[i]; if (funcType && i < funcType->getParamCount()) { - if (funcType->getParamDirection(i) == kParameterDirection_Out) + switch (funcType->getParamDirection(i)) + { + case kParameterDirection_Out: + case kParameterDirection_InOut: + case kParameterDirection_Ref: + case kParameterDirection_ConstRef: continue; + default: + break; + } } arg = maybeOpenExistential(arg); } diff --git a/source/slang/slang-ir-lower-existential.cpp b/source/slang/slang-ir-lower-existential.cpp index 7bae856c6..13f46e914 100644 --- a/source/slang/slang-ir-lower-existential.cpp +++ b/source/slang/slang-ir-lower-existential.cpp @@ -133,9 +133,49 @@ namespace Slang processExtractExistentialElement(inst, 1); } - void processExtractExistentialType(IRExtractExistentialType* inst) + void processExtractExistentialType(IRExtractExistentialType* extractInst) { - processExtractExistentialElement(inst, 0); + IRBuilder builderStorage(sharedContext->module); + auto builder = &builderStorage; + builder->setInsertBefore(extractInst); + + IRInst* element = nullptr; + IRInst* anyValueType = nullptr; + if (isComInterfaceType(extractInst->getOperand(0)->getDataType())) + { + // If this is an COM interface, the elements (witness table/rtti) are just the interface value itself. + element = extractInst->getOperand(0); + } + else + { + element = extractTupleElement(builder, extractInst->getOperand(0), 0); + if (auto tupleType = as<IRTupleType>(extractInst->getOperand(0)->getDataType())) + { + anyValueType = tupleType->getOperand(2); + } + } + + // If this instruction is used as a type, we need to replace it with the lowered type, + // which should be an AnyValueType. + // If it is used as a value, then we can replace it with the extracted element. + auto isTypeUse = [](IRUse* use) -> bool + { + auto user = use->getUser(); + if (as<IRType>(user)) + return true; + if (use == &use->getUser()->typeUse) + return true; + return false; + }; + traverseUses(extractInst, [&](IRUse* use) + { + if (anyValueType && isTypeUse(use)) + { + builder->replaceOperand(use, anyValueType); + return; + } + builder->replaceOperand(use, element); + }); } void processGetValueFromBoundInterface(IRGetValueFromBoundInterface* inst) diff --git a/source/slang/slang-ir-lower-witness-lookup.cpp b/source/slang/slang-ir-lower-witness-lookup.cpp index aca41fd69..9e3c5d251 100644 --- a/source/slang/slang-ir-lower-witness-lookup.cpp +++ b/source/slang/slang-ir-lower-witness-lookup.cpp @@ -325,11 +325,16 @@ struct WitnessLookupLoweringContext return resultValue; } - void rewriteCallSite(IRCall* call, IRInst* dispatchFunc, IRInst* existentialObject) + void rewriteCallSite(IRCall* call, IRInst* dispatchFunc, IRInst* initialExistentialObject) { SLANG_RELEASE_ASSERT(call->getArgCount() != 0); call->setOperand(0, dispatchFunc); - call->setOperand(1, existentialObject); + IRBuilder builder(call); + builder.setInsertBefore(call); + auto witnessTable = builder.emitExtractExistentialWitnessTable(initialExistentialObject); + auto newExistentialObject = builder.emitMakeExistential( + initialExistentialObject->getDataType(), call->getOperand(1), witnessTable); + call->setOperand(1, newExistentialObject); } bool processWitnessLookup(IRLookupWitnessMethod* lookupInst) diff --git a/source/slang/slang-lower-to-ir.cpp b/source/slang/slang-lower-to-ir.cpp index 41638dc83..799faec88 100644 --- a/source/slang/slang-lower-to-ir.cpp +++ b/source/slang/slang-lower-to-ir.cpp @@ -2503,6 +2503,7 @@ void addArg( LoweredValInfo argVal, //< The lowered value of the argument to add IRType* paramType, //< The type of the corresponding parameter ParameterDirection paramDirection, //< The direction of the parameter (`in`, `out`, etc.) + Type* argType, //< The AST-level type of the argument SourceLoc loc) //< A location to use if we need to report an error { switch(paramDirection) @@ -2543,6 +2544,12 @@ void addArg( // If the value is not one that could yield a simple l-value // then we need to convert it into a temporary // + if (as<IRThisType>(paramType)) + { + // When paramType is ThisType, we need to get the actual argument type + // from the arg. + paramType = lowerType(context, argType); + } if (auto refType = as<IRConstRefType>(paramType)) { paramType = refType->getValueType(); @@ -2616,7 +2623,7 @@ void addCallArgsForParam( case kParameterDirection_InOut: { LoweredValInfo loweredArg = lowerLValueExpr(context, argExpr); - addArg(context, ioArgs, ioFixups, loweredArg, paramType, paramDirection, argExpr->loc); + addArg(context, ioArgs, ioFixups, loweredArg, paramType, paramDirection, argExpr->type, argExpr->loc); } break; @@ -3223,7 +3230,7 @@ static LoweredValInfo _emitCallToAccessor( auto thisParam = info.parameterLists.params[0]; auto thisParamType = lowerType(context, thisParam.type); - addArg(context, &allArgs, &fixups, base, thisParamType, thisParam.direction, SourceLoc()); + addArg(context, &allArgs, &fixups, base, thisParamType, thisParam.direction, thisParam.type, SourceLoc()); } allArgs.addRange(args, argCount); |
