diff options
| author | Tim Foley <tfoleyNV@users.noreply.github.com> | 2019-12-19 07:11:56 -0800 |
|---|---|---|
| committer | jsmall-nvidia <jsmall@nvidia.com> | 2019-12-19 10:11:55 -0500 |
| commit | 60934d98fbc20d83b5e149e72a197ec4f5c61580 (patch) | |
| tree | 0bdac186e47aad93f3c1f661a74c3b65dec3a8be /source | |
| parent | 15b46afc2d0c10561bb8440b2eec565a5edfad32 (diff) | |
Fix invocation of `[mutating]` methods (#1156)
The logic for invoking methods (member functions) in `slang-lower-to-ir.cpp` was failing to take into account whether the callee was `[mutating]` or not. Instead, it would always lower the `base` expression in something like `base.f(...)` as an r-value expression, consistent with a non-`[mutating]` method.
The incorrect code generation strategy somehow turned out to work in many cases, but it broke in cases where a `[mutating]` method was called on an `inout` parameter. E.g., in this code:
```hlsl
struct Stuff { [mutating] void doThing() { ... } }
void broken(inout Stuff s)
{
s.doThing();
}
```
The `broken` function would fail to write back the value mutated by `doThing` to its `s` parameter before returning.
The crux of the fix here is inside `visitInvokeExpr()`. Instead of directly calling `lowerRValueExpr` on the base expression of a method/member-function call, we instead compute the "direction" of the `this` parameter in the callee, and use that to emit the argument expression appropriately.
In order to enable that change, there are several refactorings included:
* The existing `ParameterDirection` and `getParameterDirection()` calls were lifted out from the declaration visitor to the global scope, so that they could be shared between lowering of functions and their call sites.
* The logic for determining the "direction" of a `this` parameter was factored out of `collectParameterLists()` into its own `getThisParamDirection()` subroutine (again so that functions and call sites can share matching logic).
* The logic for turning an AST expression used as a call argument into IR argument(s)* was pulled out into its own `addCallArgsForParam` *and* was refactored to rely on a `ParameterDirection` instead of directly inspecting the modifiers on a `ParamDecl`. This allows the function to be used for ordinary/direct arguments and the `this` argument, and also ensures that the caller and callee will agree on the direction of parameters.
Fixing the way that `[mutating]` methods are called actually broke some test cases, specifically in the cases where a `[mutating]` method was being called on a value with an interface-constrained generic type:
```hlsl
interface IThing { [mutating] void doStuff(); }
void myFunc<T : IThing>(inout T thing)
{
thing.doStuff();
}
```
Our argument passing for `inout` parameters currently requires that we make a temp copy of `thing` into a local, and then pass that local as argument for the `inout` parameter, before copying back. The issue that arose was that a simple version of the logic uses the type of the `base` expression in `base.someMethod(...)` as the type of the local variable, but for an interface method call the base expression will have been cast to the interface type (we effectively have `((IThing) thing).doStuff()`.
The fix here was to query the this type through the member function we are calling, and to share that logic between the function-call and function-declaration cases, to try and make sure they match, which meant even more logic got hoisted out of the declaration-emission logic and to the top level.
Note: This change does *not* clean up any other clarity or performance concerns around `out` and `inout` parameters; it is only focused on correctness.
Diffstat (limited to 'source')
| -rw-r--r-- | source/slang/slang-lower-to-ir.cpp | 353 |
1 files changed, 215 insertions, 138 deletions
diff --git a/source/slang/slang-lower-to-ir.cpp b/source/slang/slang-lower-to-ir.cpp index 19062f1df..7cf5215c9 100644 --- a/source/slang/slang-lower-to-ir.cpp +++ b/source/slang/slang-lower-to-ir.cpp @@ -1999,6 +1999,124 @@ LoweredValInfo tryGetAddress( LoweredValInfo const& inVal, TryGetAddressMode mode); + /// Represents the "direction" that a parameter is being passed (e.g., `in` or `out` +enum ParameterDirection +{ + kParameterDirection_In, ///< Copy in + kParameterDirection_Out, ///< Copy out + kParameterDirection_InOut, ///< Copy in, copy out + kParameterDirection_Ref, ///< By-reference +}; + + /// Compute the direction for a parameter based on its declaration +ParameterDirection getParameterDirection(VarDeclBase* paramDecl) +{ + if( paramDecl->HasModifier<RefModifier>() ) + { + // The AST specified `ref`: + return kParameterDirection_Ref; + } + if( paramDecl->HasModifier<InOutModifier>() ) + { + // The AST specified `inout`: + return kParameterDirection_InOut; + } + if (paramDecl->HasModifier<OutModifier>()) + { + // We saw an `out` modifier, so now we need + // to check if there was a paired `in`. + if(paramDecl->HasModifier<InModifier>()) + return kParameterDirection_InOut; + else + return kParameterDirection_Out; + } + else + { + // No direction modifier, or just `in`: + return kParameterDirection_In; + } +} + + /// Compute the direction for a `this` parameter based on the declaration of its parent function +ParameterDirection getThisParamDirection(Decl* parentDecl) +{ + // Applications can opt in to a mutable `this` parameter, + // by applying the `[mutating]` attribute to their + // declaration. + // + if( parentDecl->HasModifier<MutatingAttribute>() ) + { + return kParameterDirection_InOut; + } + + // TODO: If/when we support user-defined subscripts or properties, + // we should probably make the `set` accessor on those default to + // `[mutating]` rather than require users to specify it. There + // might need to be a `[nonmutating]` modifier for the rare case + // where a user wants to opt out. + + // For now we make any `this` parameter default to `in`. + // + return kParameterDirection_In; +} + +DeclRef<Decl> createDefaultSpecializedDeclRefImpl(IRGenContext* context, Decl* decl) +{ + DeclRef<Decl> declRef; + declRef.decl = decl; + declRef.substitutions = createDefaultSubstitutions(context->getSession(), decl); + return declRef; +} +// +// The client should actually call the templated wrapper, to preserve type information. +template<typename D> +DeclRef<D> createDefaultSpecializedDeclRef(IRGenContext* context, D* decl) +{ + DeclRef<Decl> declRef = createDefaultSpecializedDeclRefImpl(context, decl); + return declRef.as<D>(); +} + + /// Get the type of the `this` parameter introduced by `parentDeclRef`, or null. + /// + /// E.g., if `parentDeclRef` is a `struct` declaration, then this will + /// return the type of that `struct`. + /// + /// If this function is called on a declaration that does not itself directly + /// introduce a notion of `this`, then null will be returned. Note that this + /// includes things like function declarations themselves, which inherit the + /// definition of `this` from their parent/outer declaration. + /// +RefPtr<Type> getThisParamTypeForContainer( + IRGenContext* context, + DeclRef<Decl> parentDeclRef) +{ + if( auto aggTypeDeclRef = parentDeclRef.as<AggTypeDecl>() ) + { + return DeclRefType::Create(context->getSession(), aggTypeDeclRef); + } + else if( auto extensionDeclRef = parentDeclRef.as<ExtensionDecl>() ) + { + return GetTargetType(extensionDeclRef); + } + + return nullptr; +} + +RefPtr<Type> getThisParamTypeForCallable( + IRGenContext* context, + DeclRef<Decl> callableDeclRef) +{ + auto parentDeclRef = callableDeclRef.GetParent(); + + if(auto subscriptDeclRef = parentDeclRef.as<SubscriptDecl>()) + parentDeclRef = subscriptDeclRef.GetParent(); + + if(auto genericDeclRef = parentDeclRef.as<GenericDecl>()) + parentDeclRef = genericDeclRef.GetParent(); + + return getThisParamTypeForContainer(context, parentDeclRef); +} + // @@ -2419,53 +2537,27 @@ struct ExprLoweringVisitorBase : ExprVisitor<Derived, LoweredValInfo> LoweredValInfo src; }; - void addDirectCallArgs( - InvokeExpr* expr, - DeclRef<CallableDecl> funcDeclRef, - List<IRInst*>* ioArgs, + /// Add argument(s) corresponding to one parameter to a call + /// + /// The `argExpr` is the AST-level expression being passed as an argument to the call. + /// The `paramType` and `paramDirection` represent what is known about the receiving + /// parameter of the callee (e.g., if the parameter `in`, `inout`, etc.). + /// The `ioArgs` array receives the IR-level argument(s) that are added for the given + /// argument expression. + /// The `ioFixups` array receives any "fixup" code that needs to be run *after* the + /// call completes (e.g., to move from a scratch variable used for an `inout` argument back + /// into the original location). + /// + void addCallArgsForParam( + IRType* paramType, + ParameterDirection paramDirection, + Expr* argExpr, + List<IRInst*>* ioArgs, List<OutArgumentFixup>* ioFixups) { - UInt argCount = expr->Arguments.getCount(); - UInt argCounter = 0; - for (auto paramDeclRef : getMembersOfType<ParamDecl>(funcDeclRef)) + switch(paramDirection) { - auto paramDecl = paramDeclRef.getDecl(); - IRType* paramType = lowerType(context, GetType(paramDeclRef)); - - UInt argIndex = argCounter++; - RefPtr<Expr> argExpr; - if(argIndex < argCount) - { - argExpr = expr->Arguments[argIndex]; - } - else - { - // We have run out of arguments supplied at the call site, - // but there are still parameters remaining. This must mean - // that these parameters have default argument expressions - // associated with them. - argExpr = getInitExpr(paramDeclRef); - - // Assert that such an expression must have been present. - SLANG_ASSERT(argExpr); - - // TODO: The approach we are taking here to default arguments - // is simplistic, and has consequences for the front-end as - // well as binary serialization of modules. - // - // We could consider some more refined approaches where, e.g., - // functions with default arguments generate multiple IR-level - // functions, that compute and provide the default values. - // - // Alternatively, each parameter with defaults could be generated - // into its own callable function that provides the default value, - // so that calling modules can call into a pre-generated function. - // - // Each of these options involves trade-offs, and we need to - // make a conscious decision at some point. - } - - if(paramDecl->HasModifier<RefModifier>()) + case kParameterDirection_Ref: { // A `ref` qualified parameter must be implemented with by-reference // parameter passing, so the argument value should be lowered as @@ -2482,8 +2574,10 @@ struct ExprLoweringVisitorBase : ExprVisitor<Derived, LoweredValInfo> IRInst* argPtr = getAddress(context, loweredArg, argExpr->loc); (*ioArgs).add(argPtr); } - else if (paramDecl->HasModifier<OutModifier>() - || paramDecl->HasModifier<InOutModifier>()) + break; + + case kParameterDirection_Out: + case kParameterDirection_InOut: { // This is a `out` or `inout` parameter, and so // the argument must be lowered as an l-value. @@ -2508,6 +2602,13 @@ struct ExprLoweringVisitorBase : ExprVisitor<Derived, LoweredValInfo> // In each of these cases, the safe option is to create // a temporary variable to use for argument-passing, // and then do copy-in/copy-out around the call. + // + // TODO: We should consider ruling out case (2) as undefined + // behavior, and specify that whether `inout` and `out` are + // handled via copy-in-copy-out or by-reference parameter + // passing is an implementation detail. That would allow + // us to avoid introducing a copy except where it is required + // for the semantics of (1). LoweredValInfo tempVar = createVar(context, paramType); @@ -2515,8 +2616,7 @@ struct ExprLoweringVisitorBase : ExprVisitor<Derived, LoweredValInfo> // to ensure that we pass in the original value stored // in the argument, which we accomplish by assigning // from the l-value to our temp. - if (paramDecl->HasModifier<InModifier>() - || paramDecl->HasModifier<InOutModifier>()) + if(paramDirection == kParameterDirection_InOut) { assign(context, tempVar, loweredArg); } @@ -2536,13 +2636,67 @@ struct ExprLoweringVisitorBase : ExprVisitor<Derived, LoweredValInfo> (*ioFixups).add(fixup); } - else + break; + + default: { // This is a pure input parameter, and so we will // pass it as an r-value. LoweredValInfo loweredArg = lowerRValueExpr(context, argExpr); addArgs(context, ioArgs, loweredArg); } + break; + } + } + + void addDirectCallArgs( + InvokeExpr* expr, + DeclRef<CallableDecl> funcDeclRef, + List<IRInst*>* ioArgs, + List<OutArgumentFixup>* ioFixups) + { + UInt argCount = expr->Arguments.getCount(); + UInt argCounter = 0; + for (auto paramDeclRef : getMembersOfType<ParamDecl>(funcDeclRef)) + { + auto paramDecl = paramDeclRef.getDecl(); + IRType* paramType = lowerType(context, GetType(paramDeclRef)); + auto paramDirection = getParameterDirection(paramDecl); + + UInt argIndex = argCounter++; + RefPtr<Expr> argExpr; + if(argIndex < argCount) + { + argExpr = expr->Arguments[argIndex]; + } + else + { + // We have run out of arguments supplied at the call site, + // but there are still parameters remaining. This must mean + // that these parameters have default argument expressions + // associated with them. + argExpr = getInitExpr(paramDeclRef); + + // Assert that such an expression must have been present. + SLANG_ASSERT(argExpr); + + // TODO: The approach we are taking here to default arguments + // is simplistic, and has consequences for the front-end as + // well as binary serialization of modules. + // + // We could consider some more refined approaches where, e.g., + // functions with default arguments generate multiple IR-level + // functions, that compute and provide the default values. + // + // Alternatively, each parameter with defaults could be generated + // into its own callable function that provides the default value, + // so that calling modules can call into a pre-generated function. + // + // Each of these options involves trade-offs, and we need to + // make a conscious decision at some point. + } + + addCallArgsForParam(paramType, paramDirection, argExpr, ioArgs, ioFixups); } } @@ -2693,8 +2847,14 @@ struct ExprLoweringVisitorBase : ExprVisitor<Derived, LoweredValInfo> // a member function: if( baseExpr ) { - auto loweredBaseVal = lowerRValueExpr(context, baseExpr); - addArgs(context, &irArgs, loweredBaseVal); + auto thisType = getThisParamTypeForCallable(context, funcDeclRef); + auto irThisType = lowerType(context, thisType); + addCallArgsForParam( + irThisType, + getThisParamDirection(funcDeclRef.getDecl()), + baseExpr, + &irArgs, + &argFixups); } // Then we have the "direct" arguments to the call. @@ -5240,21 +5400,6 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo> } - DeclRef<Decl> createDefaultSpecializedDeclRefImpl(Decl* decl) - { - DeclRef<Decl> declRef; - declRef.decl = decl; - declRef.substitutions = createDefaultSubstitutions(context->getSession(), decl); - return declRef; - } - // - // The client should actually call the templated wrapper, to preserve type information. - template<typename D> - DeclRef<D> createDefaultSpecializedDeclRef(D* decl) - { - DeclRef<Decl> declRef = createDefaultSpecializedDeclRefImpl(decl); - return declRef.as<D>(); - } // When lowering something callable (most commonly a function declaration), @@ -5295,13 +5440,6 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo> // To handle this we break out the relevant data into derived // structures: // - enum ParameterDirection - { - kParameterDirection_In, ///< Copy in - kParameterDirection_Out, ///< Copy out - kParameterDirection_InOut, ///< Copy in, copy out - kParameterDirection_Ref, ///< By-reference - }; struct ParameterInfo { // This AST-level type of the parameter @@ -5318,36 +5456,6 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo> bool isThisParam = false; }; // - // We need a way to compute the appropriate `ParameterDirection` for a - // declared parameter: - // - ParameterDirection getParameterDirection(VarDeclBase* paramDecl) - { - if( paramDecl->HasModifier<RefModifier>() ) - { - // The AST specified `ref`: - return kParameterDirection_Ref; - } - if( paramDecl->HasModifier<InOutModifier>() ) - { - // The AST specified `inout`: - return kParameterDirection_InOut; - } - if (paramDecl->HasModifier<OutModifier>()) - { - // We saw an `out` modifier, so now we need - // to check if there was a paired `in`. - if(paramDecl->HasModifier<InModifier>()) - return kParameterDirection_InOut; - else - return kParameterDirection_Out; - } - else - { - // No direction modifier, or just `in`: - return kParameterDirection_In; - } - } // We need a way to be able to create a `ParameterInfo` given the declaration // of a parameter: // @@ -5412,22 +5520,6 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo> ioParameterLists->params.add(info); } - void addThisParameter( - ParameterDirection direction, - AggTypeDecl* typeDecl, - ParameterLists* ioParameterLists) - { - // We need to construct an appopriate declaration-reference - // for the type declaration we were given. In particular, - // we need to specialize it for any generic parameters - // that are in scope here. - auto declRef = createDefaultSpecializedDeclRef(typeDecl); - RefPtr<Type> type = DeclRefType::Create(context->getSession(), declRef); - addThisParameter( - direction, - type, - ioParameterLists); - } // // And here is our function that will do the recursive walk: void collectParameterLists( @@ -5457,26 +5549,11 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo> // parameter corresponding to the outer declaration. if( innerMode != kParameterListCollectMode_Static ) { - // For now we make any `this` parameter default to `in`. - // - ParameterDirection direction = kParameterDirection_In; - // - // Applications can opt in to a mutable `this` parameter, - // by applying the `[mutating]` attribute to their - // declaration. - // - if( decl->HasModifier<MutatingAttribute>() ) - { - direction = kParameterDirection_InOut; - } - - if( auto aggTypeDecl = as<AggTypeDecl>(parentDecl) ) - { - addThisParameter(direction, aggTypeDecl, ioParameterLists); - } - else if( auto extensionDecl = as<ExtensionDecl>(parentDecl) ) + ParameterDirection direction = getThisParamDirection(decl); + auto thisType = getThisParamTypeForContainer(context, createDefaultSpecializedDeclRef(context, parentDecl)); + if(thisType) { - addThisParameter(direction, extensionDecl->targetType, ioParameterLists); + addThisParameter(direction, thisType, ioParameterLists); } } } |
