summaryrefslogtreecommitdiffstats
path: root/source
diff options
context:
space:
mode:
authorYong He <yonghe@outlook.com>2024-03-26 17:35:24 -0700
committerGitHub <noreply@github.com>2024-03-26 17:35:24 -0700
commitdfdf243f07c977fa59b1a5968ce053bf590f8120 (patch)
tree6121218f9e4d664722ed6192ca08f7c0e3c1d45b /source
parent0877d1a3e9d69fdbf4087581df96954e56e4dd97 (diff)
Support mutable existential parameters. (#3836)
* Support mutable existential parameters. * Update test.
Diffstat (limited to 'source')
-rw-r--r--source/slang/slang-check-conversion.cpp10
-rw-r--r--source/slang/slang-check-overload.cpp10
-rw-r--r--source/slang/slang-ir-lower-existential.cpp44
-rw-r--r--source/slang/slang-ir-lower-witness-lookup.cpp9
-rw-r--r--source/slang/slang-lower-to-ir.cpp11
5 files changed, 74 insertions, 10 deletions
diff --git a/source/slang/slang-check-conversion.cpp b/source/slang/slang-check-conversion.cpp
index a2381c7f7..aeb964cc9 100644
--- a/source/slang/slang-check-conversion.cpp
+++ b/source/slang/slang-check-conversion.cpp
@@ -932,11 +932,15 @@ namespace Slang
// to pass a value of a derived `struct` type into methods that
// expect a value of its base type.
//
- // TODO: vet this logic for correctness.
- //
if (fromExpr && fromExpr->type.isLeftValue)
{
- (*outToExpr)->type.isLeftValue = true;
+ // If the original type is a concrete type and toType is an interface type,
+ // we need to wrap the original expression into a MakeExistential, and the
+ // result of MakeExistential is not an l-value.
+ bool toTypeIsInterface = isInterfaceType(toType);
+ bool fromTypeIsInterface = isInterfaceType(fromType);
+ if (!toTypeIsInterface || toTypeIsInterface == fromTypeIsInterface)
+ (*outToExpr)->type.isLeftValue = true;
}
}
if (outCost)
diff --git a/source/slang/slang-check-overload.cpp b/source/slang/slang-check-overload.cpp
index 1cb6681b3..3831fed84 100644
--- a/source/slang/slang-check-overload.cpp
+++ b/source/slang/slang-check-overload.cpp
@@ -1914,8 +1914,16 @@ namespace Slang
auto& arg = expr->arguments[i];
if (funcType && i < funcType->getParamCount())
{
- if (funcType->getParamDirection(i) == kParameterDirection_Out)
+ switch (funcType->getParamDirection(i))
+ {
+ case kParameterDirection_Out:
+ case kParameterDirection_InOut:
+ case kParameterDirection_Ref:
+ case kParameterDirection_ConstRef:
continue;
+ default:
+ break;
+ }
}
arg = maybeOpenExistential(arg);
}
diff --git a/source/slang/slang-ir-lower-existential.cpp b/source/slang/slang-ir-lower-existential.cpp
index 7bae856c6..13f46e914 100644
--- a/source/slang/slang-ir-lower-existential.cpp
+++ b/source/slang/slang-ir-lower-existential.cpp
@@ -133,9 +133,49 @@ namespace Slang
processExtractExistentialElement(inst, 1);
}
- void processExtractExistentialType(IRExtractExistentialType* inst)
+ void processExtractExistentialType(IRExtractExistentialType* extractInst)
{
- processExtractExistentialElement(inst, 0);
+ IRBuilder builderStorage(sharedContext->module);
+ auto builder = &builderStorage;
+ builder->setInsertBefore(extractInst);
+
+ IRInst* element = nullptr;
+ IRInst* anyValueType = nullptr;
+ if (isComInterfaceType(extractInst->getOperand(0)->getDataType()))
+ {
+ // If this is an COM interface, the elements (witness table/rtti) are just the interface value itself.
+ element = extractInst->getOperand(0);
+ }
+ else
+ {
+ element = extractTupleElement(builder, extractInst->getOperand(0), 0);
+ if (auto tupleType = as<IRTupleType>(extractInst->getOperand(0)->getDataType()))
+ {
+ anyValueType = tupleType->getOperand(2);
+ }
+ }
+
+ // If this instruction is used as a type, we need to replace it with the lowered type,
+ // which should be an AnyValueType.
+ // If it is used as a value, then we can replace it with the extracted element.
+ auto isTypeUse = [](IRUse* use) -> bool
+ {
+ auto user = use->getUser();
+ if (as<IRType>(user))
+ return true;
+ if (use == &use->getUser()->typeUse)
+ return true;
+ return false;
+ };
+ traverseUses(extractInst, [&](IRUse* use)
+ {
+ if (anyValueType && isTypeUse(use))
+ {
+ builder->replaceOperand(use, anyValueType);
+ return;
+ }
+ builder->replaceOperand(use, element);
+ });
}
void processGetValueFromBoundInterface(IRGetValueFromBoundInterface* inst)
diff --git a/source/slang/slang-ir-lower-witness-lookup.cpp b/source/slang/slang-ir-lower-witness-lookup.cpp
index aca41fd69..9e3c5d251 100644
--- a/source/slang/slang-ir-lower-witness-lookup.cpp
+++ b/source/slang/slang-ir-lower-witness-lookup.cpp
@@ -325,11 +325,16 @@ struct WitnessLookupLoweringContext
return resultValue;
}
- void rewriteCallSite(IRCall* call, IRInst* dispatchFunc, IRInst* existentialObject)
+ void rewriteCallSite(IRCall* call, IRInst* dispatchFunc, IRInst* initialExistentialObject)
{
SLANG_RELEASE_ASSERT(call->getArgCount() != 0);
call->setOperand(0, dispatchFunc);
- call->setOperand(1, existentialObject);
+ IRBuilder builder(call);
+ builder.setInsertBefore(call);
+ auto witnessTable = builder.emitExtractExistentialWitnessTable(initialExistentialObject);
+ auto newExistentialObject = builder.emitMakeExistential(
+ initialExistentialObject->getDataType(), call->getOperand(1), witnessTable);
+ call->setOperand(1, newExistentialObject);
}
bool processWitnessLookup(IRLookupWitnessMethod* lookupInst)
diff --git a/source/slang/slang-lower-to-ir.cpp b/source/slang/slang-lower-to-ir.cpp
index 41638dc83..799faec88 100644
--- a/source/slang/slang-lower-to-ir.cpp
+++ b/source/slang/slang-lower-to-ir.cpp
@@ -2503,6 +2503,7 @@ void addArg(
LoweredValInfo argVal, //< The lowered value of the argument to add
IRType* paramType, //< The type of the corresponding parameter
ParameterDirection paramDirection, //< The direction of the parameter (`in`, `out`, etc.)
+ Type* argType, //< The AST-level type of the argument
SourceLoc loc) //< A location to use if we need to report an error
{
switch(paramDirection)
@@ -2543,6 +2544,12 @@ void addArg(
// If the value is not one that could yield a simple l-value
// then we need to convert it into a temporary
//
+ if (as<IRThisType>(paramType))
+ {
+ // When paramType is ThisType, we need to get the actual argument type
+ // from the arg.
+ paramType = lowerType(context, argType);
+ }
if (auto refType = as<IRConstRefType>(paramType))
{
paramType = refType->getValueType();
@@ -2616,7 +2623,7 @@ void addCallArgsForParam(
case kParameterDirection_InOut:
{
LoweredValInfo loweredArg = lowerLValueExpr(context, argExpr);
- addArg(context, ioArgs, ioFixups, loweredArg, paramType, paramDirection, argExpr->loc);
+ addArg(context, ioArgs, ioFixups, loweredArg, paramType, paramDirection, argExpr->type, argExpr->loc);
}
break;
@@ -3223,7 +3230,7 @@ static LoweredValInfo _emitCallToAccessor(
auto thisParam = info.parameterLists.params[0];
auto thisParamType = lowerType(context, thisParam.type);
- addArg(context, &allArgs, &fixups, base, thisParamType, thisParam.direction, SourceLoc());
+ addArg(context, &allArgs, &fixups, base, thisParamType, thisParam.direction, thisParam.type, SourceLoc());
}
allArgs.addRange(args, argCount);