summaryrefslogtreecommitdiffstats
path: root/source
diff options
context:
space:
mode:
authorTim Foley <tfoleyNV@users.noreply.github.com>2019-12-19 07:11:56 -0800
committerjsmall-nvidia <jsmall@nvidia.com>2019-12-19 10:11:55 -0500
commit60934d98fbc20d83b5e149e72a197ec4f5c61580 (patch)
tree0bdac186e47aad93f3c1f661a74c3b65dec3a8be /source
parent15b46afc2d0c10561bb8440b2eec565a5edfad32 (diff)
Fix invocation of `[mutating]` methods (#1156)
The logic for invoking methods (member functions) in `slang-lower-to-ir.cpp` was failing to take into account whether the callee was `[mutating]` or not. Instead, it would always lower the `base` expression in something like `base.f(...)` as an r-value expression, consistent with a non-`[mutating]` method. The incorrect code generation strategy somehow turned out to work in many cases, but it broke in cases where a `[mutating]` method was called on an `inout` parameter. E.g., in this code: ```hlsl struct Stuff { [mutating] void doThing() { ... } } void broken(inout Stuff s) { s.doThing(); } ``` The `broken` function would fail to write back the value mutated by `doThing` to its `s` parameter before returning. The crux of the fix here is inside `visitInvokeExpr()`. Instead of directly calling `lowerRValueExpr` on the base expression of a method/member-function call, we instead compute the "direction" of the `this` parameter in the callee, and use that to emit the argument expression appropriately. In order to enable that change, there are several refactorings included: * The existing `ParameterDirection` and `getParameterDirection()` calls were lifted out from the declaration visitor to the global scope, so that they could be shared between lowering of functions and their call sites. * The logic for determining the "direction" of a `this` parameter was factored out of `collectParameterLists()` into its own `getThisParamDirection()` subroutine (again so that functions and call sites can share matching logic). * The logic for turning an AST expression used as a call argument into IR argument(s)* was pulled out into its own `addCallArgsForParam` *and* was refactored to rely on a `ParameterDirection` instead of directly inspecting the modifiers on a `ParamDecl`. This allows the function to be used for ordinary/direct arguments and the `this` argument, and also ensures that the caller and callee will agree on the direction of parameters. Fixing the way that `[mutating]` methods are called actually broke some test cases, specifically in the cases where a `[mutating]` method was being called on a value with an interface-constrained generic type: ```hlsl interface IThing { [mutating] void doStuff(); } void myFunc<T : IThing>(inout T thing) { thing.doStuff(); } ``` Our argument passing for `inout` parameters currently requires that we make a temp copy of `thing` into a local, and then pass that local as argument for the `inout` parameter, before copying back. The issue that arose was that a simple version of the logic uses the type of the `base` expression in `base.someMethod(...)` as the type of the local variable, but for an interface method call the base expression will have been cast to the interface type (we effectively have `((IThing) thing).doStuff()`. The fix here was to query the this type through the member function we are calling, and to share that logic between the function-call and function-declaration cases, to try and make sure they match, which meant even more logic got hoisted out of the declaration-emission logic and to the top level. Note: This change does *not* clean up any other clarity or performance concerns around `out` and `inout` parameters; it is only focused on correctness.
Diffstat (limited to 'source')
-rw-r--r--source/slang/slang-lower-to-ir.cpp353
1 files changed, 215 insertions, 138 deletions
diff --git a/source/slang/slang-lower-to-ir.cpp b/source/slang/slang-lower-to-ir.cpp
index 19062f1df..7cf5215c9 100644
--- a/source/slang/slang-lower-to-ir.cpp
+++ b/source/slang/slang-lower-to-ir.cpp
@@ -1999,6 +1999,124 @@ LoweredValInfo tryGetAddress(
LoweredValInfo const& inVal,
TryGetAddressMode mode);
+ /// Represents the "direction" that a parameter is being passed (e.g., `in` or `out`
+enum ParameterDirection
+{
+ kParameterDirection_In, ///< Copy in
+ kParameterDirection_Out, ///< Copy out
+ kParameterDirection_InOut, ///< Copy in, copy out
+ kParameterDirection_Ref, ///< By-reference
+};
+
+ /// Compute the direction for a parameter based on its declaration
+ParameterDirection getParameterDirection(VarDeclBase* paramDecl)
+{
+ if( paramDecl->HasModifier<RefModifier>() )
+ {
+ // The AST specified `ref`:
+ return kParameterDirection_Ref;
+ }
+ if( paramDecl->HasModifier<InOutModifier>() )
+ {
+ // The AST specified `inout`:
+ return kParameterDirection_InOut;
+ }
+ if (paramDecl->HasModifier<OutModifier>())
+ {
+ // We saw an `out` modifier, so now we need
+ // to check if there was a paired `in`.
+ if(paramDecl->HasModifier<InModifier>())
+ return kParameterDirection_InOut;
+ else
+ return kParameterDirection_Out;
+ }
+ else
+ {
+ // No direction modifier, or just `in`:
+ return kParameterDirection_In;
+ }
+}
+
+ /// Compute the direction for a `this` parameter based on the declaration of its parent function
+ParameterDirection getThisParamDirection(Decl* parentDecl)
+{
+ // Applications can opt in to a mutable `this` parameter,
+ // by applying the `[mutating]` attribute to their
+ // declaration.
+ //
+ if( parentDecl->HasModifier<MutatingAttribute>() )
+ {
+ return kParameterDirection_InOut;
+ }
+
+ // TODO: If/when we support user-defined subscripts or properties,
+ // we should probably make the `set` accessor on those default to
+ // `[mutating]` rather than require users to specify it. There
+ // might need to be a `[nonmutating]` modifier for the rare case
+ // where a user wants to opt out.
+
+ // For now we make any `this` parameter default to `in`.
+ //
+ return kParameterDirection_In;
+}
+
+DeclRef<Decl> createDefaultSpecializedDeclRefImpl(IRGenContext* context, Decl* decl)
+{
+ DeclRef<Decl> declRef;
+ declRef.decl = decl;
+ declRef.substitutions = createDefaultSubstitutions(context->getSession(), decl);
+ return declRef;
+}
+//
+// The client should actually call the templated wrapper, to preserve type information.
+template<typename D>
+DeclRef<D> createDefaultSpecializedDeclRef(IRGenContext* context, D* decl)
+{
+ DeclRef<Decl> declRef = createDefaultSpecializedDeclRefImpl(context, decl);
+ return declRef.as<D>();
+}
+
+ /// Get the type of the `this` parameter introduced by `parentDeclRef`, or null.
+ ///
+ /// E.g., if `parentDeclRef` is a `struct` declaration, then this will
+ /// return the type of that `struct`.
+ ///
+ /// If this function is called on a declaration that does not itself directly
+ /// introduce a notion of `this`, then null will be returned. Note that this
+ /// includes things like function declarations themselves, which inherit the
+ /// definition of `this` from their parent/outer declaration.
+ ///
+RefPtr<Type> getThisParamTypeForContainer(
+ IRGenContext* context,
+ DeclRef<Decl> parentDeclRef)
+{
+ if( auto aggTypeDeclRef = parentDeclRef.as<AggTypeDecl>() )
+ {
+ return DeclRefType::Create(context->getSession(), aggTypeDeclRef);
+ }
+ else if( auto extensionDeclRef = parentDeclRef.as<ExtensionDecl>() )
+ {
+ return GetTargetType(extensionDeclRef);
+ }
+
+ return nullptr;
+}
+
+RefPtr<Type> getThisParamTypeForCallable(
+ IRGenContext* context,
+ DeclRef<Decl> callableDeclRef)
+{
+ auto parentDeclRef = callableDeclRef.GetParent();
+
+ if(auto subscriptDeclRef = parentDeclRef.as<SubscriptDecl>())
+ parentDeclRef = subscriptDeclRef.GetParent();
+
+ if(auto genericDeclRef = parentDeclRef.as<GenericDecl>())
+ parentDeclRef = genericDeclRef.GetParent();
+
+ return getThisParamTypeForContainer(context, parentDeclRef);
+}
+
//
@@ -2419,53 +2537,27 @@ struct ExprLoweringVisitorBase : ExprVisitor<Derived, LoweredValInfo>
LoweredValInfo src;
};
- void addDirectCallArgs(
- InvokeExpr* expr,
- DeclRef<CallableDecl> funcDeclRef,
- List<IRInst*>* ioArgs,
+ /// Add argument(s) corresponding to one parameter to a call
+ ///
+ /// The `argExpr` is the AST-level expression being passed as an argument to the call.
+ /// The `paramType` and `paramDirection` represent what is known about the receiving
+ /// parameter of the callee (e.g., if the parameter `in`, `inout`, etc.).
+ /// The `ioArgs` array receives the IR-level argument(s) that are added for the given
+ /// argument expression.
+ /// The `ioFixups` array receives any "fixup" code that needs to be run *after* the
+ /// call completes (e.g., to move from a scratch variable used for an `inout` argument back
+ /// into the original location).
+ ///
+ void addCallArgsForParam(
+ IRType* paramType,
+ ParameterDirection paramDirection,
+ Expr* argExpr,
+ List<IRInst*>* ioArgs,
List<OutArgumentFixup>* ioFixups)
{
- UInt argCount = expr->Arguments.getCount();
- UInt argCounter = 0;
- for (auto paramDeclRef : getMembersOfType<ParamDecl>(funcDeclRef))
+ switch(paramDirection)
{
- auto paramDecl = paramDeclRef.getDecl();
- IRType* paramType = lowerType(context, GetType(paramDeclRef));
-
- UInt argIndex = argCounter++;
- RefPtr<Expr> argExpr;
- if(argIndex < argCount)
- {
- argExpr = expr->Arguments[argIndex];
- }
- else
- {
- // We have run out of arguments supplied at the call site,
- // but there are still parameters remaining. This must mean
- // that these parameters have default argument expressions
- // associated with them.
- argExpr = getInitExpr(paramDeclRef);
-
- // Assert that such an expression must have been present.
- SLANG_ASSERT(argExpr);
-
- // TODO: The approach we are taking here to default arguments
- // is simplistic, and has consequences for the front-end as
- // well as binary serialization of modules.
- //
- // We could consider some more refined approaches where, e.g.,
- // functions with default arguments generate multiple IR-level
- // functions, that compute and provide the default values.
- //
- // Alternatively, each parameter with defaults could be generated
- // into its own callable function that provides the default value,
- // so that calling modules can call into a pre-generated function.
- //
- // Each of these options involves trade-offs, and we need to
- // make a conscious decision at some point.
- }
-
- if(paramDecl->HasModifier<RefModifier>())
+ case kParameterDirection_Ref:
{
// A `ref` qualified parameter must be implemented with by-reference
// parameter passing, so the argument value should be lowered as
@@ -2482,8 +2574,10 @@ struct ExprLoweringVisitorBase : ExprVisitor<Derived, LoweredValInfo>
IRInst* argPtr = getAddress(context, loweredArg, argExpr->loc);
(*ioArgs).add(argPtr);
}
- else if (paramDecl->HasModifier<OutModifier>()
- || paramDecl->HasModifier<InOutModifier>())
+ break;
+
+ case kParameterDirection_Out:
+ case kParameterDirection_InOut:
{
// This is a `out` or `inout` parameter, and so
// the argument must be lowered as an l-value.
@@ -2508,6 +2602,13 @@ struct ExprLoweringVisitorBase : ExprVisitor<Derived, LoweredValInfo>
// In each of these cases, the safe option is to create
// a temporary variable to use for argument-passing,
// and then do copy-in/copy-out around the call.
+ //
+ // TODO: We should consider ruling out case (2) as undefined
+ // behavior, and specify that whether `inout` and `out` are
+ // handled via copy-in-copy-out or by-reference parameter
+ // passing is an implementation detail. That would allow
+ // us to avoid introducing a copy except where it is required
+ // for the semantics of (1).
LoweredValInfo tempVar = createVar(context, paramType);
@@ -2515,8 +2616,7 @@ struct ExprLoweringVisitorBase : ExprVisitor<Derived, LoweredValInfo>
// to ensure that we pass in the original value stored
// in the argument, which we accomplish by assigning
// from the l-value to our temp.
- if (paramDecl->HasModifier<InModifier>()
- || paramDecl->HasModifier<InOutModifier>())
+ if(paramDirection == kParameterDirection_InOut)
{
assign(context, tempVar, loweredArg);
}
@@ -2536,13 +2636,67 @@ struct ExprLoweringVisitorBase : ExprVisitor<Derived, LoweredValInfo>
(*ioFixups).add(fixup);
}
- else
+ break;
+
+ default:
{
// This is a pure input parameter, and so we will
// pass it as an r-value.
LoweredValInfo loweredArg = lowerRValueExpr(context, argExpr);
addArgs(context, ioArgs, loweredArg);
}
+ break;
+ }
+ }
+
+ void addDirectCallArgs(
+ InvokeExpr* expr,
+ DeclRef<CallableDecl> funcDeclRef,
+ List<IRInst*>* ioArgs,
+ List<OutArgumentFixup>* ioFixups)
+ {
+ UInt argCount = expr->Arguments.getCount();
+ UInt argCounter = 0;
+ for (auto paramDeclRef : getMembersOfType<ParamDecl>(funcDeclRef))
+ {
+ auto paramDecl = paramDeclRef.getDecl();
+ IRType* paramType = lowerType(context, GetType(paramDeclRef));
+ auto paramDirection = getParameterDirection(paramDecl);
+
+ UInt argIndex = argCounter++;
+ RefPtr<Expr> argExpr;
+ if(argIndex < argCount)
+ {
+ argExpr = expr->Arguments[argIndex];
+ }
+ else
+ {
+ // We have run out of arguments supplied at the call site,
+ // but there are still parameters remaining. This must mean
+ // that these parameters have default argument expressions
+ // associated with them.
+ argExpr = getInitExpr(paramDeclRef);
+
+ // Assert that such an expression must have been present.
+ SLANG_ASSERT(argExpr);
+
+ // TODO: The approach we are taking here to default arguments
+ // is simplistic, and has consequences for the front-end as
+ // well as binary serialization of modules.
+ //
+ // We could consider some more refined approaches where, e.g.,
+ // functions with default arguments generate multiple IR-level
+ // functions, that compute and provide the default values.
+ //
+ // Alternatively, each parameter with defaults could be generated
+ // into its own callable function that provides the default value,
+ // so that calling modules can call into a pre-generated function.
+ //
+ // Each of these options involves trade-offs, and we need to
+ // make a conscious decision at some point.
+ }
+
+ addCallArgsForParam(paramType, paramDirection, argExpr, ioArgs, ioFixups);
}
}
@@ -2693,8 +2847,14 @@ struct ExprLoweringVisitorBase : ExprVisitor<Derived, LoweredValInfo>
// a member function:
if( baseExpr )
{
- auto loweredBaseVal = lowerRValueExpr(context, baseExpr);
- addArgs(context, &irArgs, loweredBaseVal);
+ auto thisType = getThisParamTypeForCallable(context, funcDeclRef);
+ auto irThisType = lowerType(context, thisType);
+ addCallArgsForParam(
+ irThisType,
+ getThisParamDirection(funcDeclRef.getDecl()),
+ baseExpr,
+ &irArgs,
+ &argFixups);
}
// Then we have the "direct" arguments to the call.
@@ -5240,21 +5400,6 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo>
}
- DeclRef<Decl> createDefaultSpecializedDeclRefImpl(Decl* decl)
- {
- DeclRef<Decl> declRef;
- declRef.decl = decl;
- declRef.substitutions = createDefaultSubstitutions(context->getSession(), decl);
- return declRef;
- }
- //
- // The client should actually call the templated wrapper, to preserve type information.
- template<typename D>
- DeclRef<D> createDefaultSpecializedDeclRef(D* decl)
- {
- DeclRef<Decl> declRef = createDefaultSpecializedDeclRefImpl(decl);
- return declRef.as<D>();
- }
// When lowering something callable (most commonly a function declaration),
@@ -5295,13 +5440,6 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo>
// To handle this we break out the relevant data into derived
// structures:
//
- enum ParameterDirection
- {
- kParameterDirection_In, ///< Copy in
- kParameterDirection_Out, ///< Copy out
- kParameterDirection_InOut, ///< Copy in, copy out
- kParameterDirection_Ref, ///< By-reference
- };
struct ParameterInfo
{
// This AST-level type of the parameter
@@ -5318,36 +5456,6 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo>
bool isThisParam = false;
};
//
- // We need a way to compute the appropriate `ParameterDirection` for a
- // declared parameter:
- //
- ParameterDirection getParameterDirection(VarDeclBase* paramDecl)
- {
- if( paramDecl->HasModifier<RefModifier>() )
- {
- // The AST specified `ref`:
- return kParameterDirection_Ref;
- }
- if( paramDecl->HasModifier<InOutModifier>() )
- {
- // The AST specified `inout`:
- return kParameterDirection_InOut;
- }
- if (paramDecl->HasModifier<OutModifier>())
- {
- // We saw an `out` modifier, so now we need
- // to check if there was a paired `in`.
- if(paramDecl->HasModifier<InModifier>())
- return kParameterDirection_InOut;
- else
- return kParameterDirection_Out;
- }
- else
- {
- // No direction modifier, or just `in`:
- return kParameterDirection_In;
- }
- }
// We need a way to be able to create a `ParameterInfo` given the declaration
// of a parameter:
//
@@ -5412,22 +5520,6 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo>
ioParameterLists->params.add(info);
}
- void addThisParameter(
- ParameterDirection direction,
- AggTypeDecl* typeDecl,
- ParameterLists* ioParameterLists)
- {
- // We need to construct an appopriate declaration-reference
- // for the type declaration we were given. In particular,
- // we need to specialize it for any generic parameters
- // that are in scope here.
- auto declRef = createDefaultSpecializedDeclRef(typeDecl);
- RefPtr<Type> type = DeclRefType::Create(context->getSession(), declRef);
- addThisParameter(
- direction,
- type,
- ioParameterLists);
- }
//
// And here is our function that will do the recursive walk:
void collectParameterLists(
@@ -5457,26 +5549,11 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo>
// parameter corresponding to the outer declaration.
if( innerMode != kParameterListCollectMode_Static )
{
- // For now we make any `this` parameter default to `in`.
- //
- ParameterDirection direction = kParameterDirection_In;
- //
- // Applications can opt in to a mutable `this` parameter,
- // by applying the `[mutating]` attribute to their
- // declaration.
- //
- if( decl->HasModifier<MutatingAttribute>() )
- {
- direction = kParameterDirection_InOut;
- }
-
- if( auto aggTypeDecl = as<AggTypeDecl>(parentDecl) )
- {
- addThisParameter(direction, aggTypeDecl, ioParameterLists);
- }
- else if( auto extensionDecl = as<ExtensionDecl>(parentDecl) )
+ ParameterDirection direction = getThisParamDirection(decl);
+ auto thisType = getThisParamTypeForContainer(context, createDefaultSpecializedDeclRef(context, parentDecl));
+ if(thisType)
{
- addThisParameter(direction, extensionDecl->targetType, ioParameterLists);
+ addThisParameter(direction, thisType, ioParameterLists);
}
}
}