summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--source/slang/slang-lower-to-ir.cpp353
-rw-r--r--tests/compute/mutating-and-inout.slang49
-rw-r--r--tests/compute/mutating-and-inout.slang.expected.txt4
3 files changed, 268 insertions, 138 deletions
diff --git a/source/slang/slang-lower-to-ir.cpp b/source/slang/slang-lower-to-ir.cpp
index 19062f1df..7cf5215c9 100644
--- a/source/slang/slang-lower-to-ir.cpp
+++ b/source/slang/slang-lower-to-ir.cpp
@@ -1999,6 +1999,124 @@ LoweredValInfo tryGetAddress(
LoweredValInfo const& inVal,
TryGetAddressMode mode);
+ /// Represents the "direction" that a parameter is being passed (e.g., `in` or `out`
+enum ParameterDirection
+{
+ kParameterDirection_In, ///< Copy in
+ kParameterDirection_Out, ///< Copy out
+ kParameterDirection_InOut, ///< Copy in, copy out
+ kParameterDirection_Ref, ///< By-reference
+};
+
+ /// Compute the direction for a parameter based on its declaration
+ParameterDirection getParameterDirection(VarDeclBase* paramDecl)
+{
+ if( paramDecl->HasModifier<RefModifier>() )
+ {
+ // The AST specified `ref`:
+ return kParameterDirection_Ref;
+ }
+ if( paramDecl->HasModifier<InOutModifier>() )
+ {
+ // The AST specified `inout`:
+ return kParameterDirection_InOut;
+ }
+ if (paramDecl->HasModifier<OutModifier>())
+ {
+ // We saw an `out` modifier, so now we need
+ // to check if there was a paired `in`.
+ if(paramDecl->HasModifier<InModifier>())
+ return kParameterDirection_InOut;
+ else
+ return kParameterDirection_Out;
+ }
+ else
+ {
+ // No direction modifier, or just `in`:
+ return kParameterDirection_In;
+ }
+}
+
+ /// Compute the direction for a `this` parameter based on the declaration of its parent function
+ParameterDirection getThisParamDirection(Decl* parentDecl)
+{
+ // Applications can opt in to a mutable `this` parameter,
+ // by applying the `[mutating]` attribute to their
+ // declaration.
+ //
+ if( parentDecl->HasModifier<MutatingAttribute>() )
+ {
+ return kParameterDirection_InOut;
+ }
+
+ // TODO: If/when we support user-defined subscripts or properties,
+ // we should probably make the `set` accessor on those default to
+ // `[mutating]` rather than require users to specify it. There
+ // might need to be a `[nonmutating]` modifier for the rare case
+ // where a user wants to opt out.
+
+ // For now we make any `this` parameter default to `in`.
+ //
+ return kParameterDirection_In;
+}
+
+DeclRef<Decl> createDefaultSpecializedDeclRefImpl(IRGenContext* context, Decl* decl)
+{
+ DeclRef<Decl> declRef;
+ declRef.decl = decl;
+ declRef.substitutions = createDefaultSubstitutions(context->getSession(), decl);
+ return declRef;
+}
+//
+// The client should actually call the templated wrapper, to preserve type information.
+template<typename D>
+DeclRef<D> createDefaultSpecializedDeclRef(IRGenContext* context, D* decl)
+{
+ DeclRef<Decl> declRef = createDefaultSpecializedDeclRefImpl(context, decl);
+ return declRef.as<D>();
+}
+
+ /// Get the type of the `this` parameter introduced by `parentDeclRef`, or null.
+ ///
+ /// E.g., if `parentDeclRef` is a `struct` declaration, then this will
+ /// return the type of that `struct`.
+ ///
+ /// If this function is called on a declaration that does not itself directly
+ /// introduce a notion of `this`, then null will be returned. Note that this
+ /// includes things like function declarations themselves, which inherit the
+ /// definition of `this` from their parent/outer declaration.
+ ///
+RefPtr<Type> getThisParamTypeForContainer(
+ IRGenContext* context,
+ DeclRef<Decl> parentDeclRef)
+{
+ if( auto aggTypeDeclRef = parentDeclRef.as<AggTypeDecl>() )
+ {
+ return DeclRefType::Create(context->getSession(), aggTypeDeclRef);
+ }
+ else if( auto extensionDeclRef = parentDeclRef.as<ExtensionDecl>() )
+ {
+ return GetTargetType(extensionDeclRef);
+ }
+
+ return nullptr;
+}
+
+RefPtr<Type> getThisParamTypeForCallable(
+ IRGenContext* context,
+ DeclRef<Decl> callableDeclRef)
+{
+ auto parentDeclRef = callableDeclRef.GetParent();
+
+ if(auto subscriptDeclRef = parentDeclRef.as<SubscriptDecl>())
+ parentDeclRef = subscriptDeclRef.GetParent();
+
+ if(auto genericDeclRef = parentDeclRef.as<GenericDecl>())
+ parentDeclRef = genericDeclRef.GetParent();
+
+ return getThisParamTypeForContainer(context, parentDeclRef);
+}
+
//
@@ -2419,53 +2537,27 @@ struct ExprLoweringVisitorBase : ExprVisitor<Derived, LoweredValInfo>
LoweredValInfo src;
};
- void addDirectCallArgs(
- InvokeExpr* expr,
- DeclRef<CallableDecl> funcDeclRef,
- List<IRInst*>* ioArgs,
+ /// Add argument(s) corresponding to one parameter to a call
+ ///
+ /// The `argExpr` is the AST-level expression being passed as an argument to the call.
+ /// The `paramType` and `paramDirection` represent what is known about the receiving
+ /// parameter of the callee (e.g., if the parameter `in`, `inout`, etc.).
+ /// The `ioArgs` array receives the IR-level argument(s) that are added for the given
+ /// argument expression.
+ /// The `ioFixups` array receives any "fixup" code that needs to be run *after* the
+ /// call completes (e.g., to move from a scratch variable used for an `inout` argument back
+ /// into the original location).
+ ///
+ void addCallArgsForParam(
+ IRType* paramType,
+ ParameterDirection paramDirection,
+ Expr* argExpr,
+ List<IRInst*>* ioArgs,
List<OutArgumentFixup>* ioFixups)
{
- UInt argCount = expr->Arguments.getCount();
- UInt argCounter = 0;
- for (auto paramDeclRef : getMembersOfType<ParamDecl>(funcDeclRef))
+ switch(paramDirection)
{
- auto paramDecl = paramDeclRef.getDecl();
- IRType* paramType = lowerType(context, GetType(paramDeclRef));
-
- UInt argIndex = argCounter++;
- RefPtr<Expr> argExpr;
- if(argIndex < argCount)
- {
- argExpr = expr->Arguments[argIndex];
- }
- else
- {
- // We have run out of arguments supplied at the call site,
- // but there are still parameters remaining. This must mean
- // that these parameters have default argument expressions
- // associated with them.
- argExpr = getInitExpr(paramDeclRef);
-
- // Assert that such an expression must have been present.
- SLANG_ASSERT(argExpr);
-
- // TODO: The approach we are taking here to default arguments
- // is simplistic, and has consequences for the front-end as
- // well as binary serialization of modules.
- //
- // We could consider some more refined approaches where, e.g.,
- // functions with default arguments generate multiple IR-level
- // functions, that compute and provide the default values.
- //
- // Alternatively, each parameter with defaults could be generated
- // into its own callable function that provides the default value,
- // so that calling modules can call into a pre-generated function.
- //
- // Each of these options involves trade-offs, and we need to
- // make a conscious decision at some point.
- }
-
- if(paramDecl->HasModifier<RefModifier>())
+ case kParameterDirection_Ref:
{
// A `ref` qualified parameter must be implemented with by-reference
// parameter passing, so the argument value should be lowered as
@@ -2482,8 +2574,10 @@ struct ExprLoweringVisitorBase : ExprVisitor<Derived, LoweredValInfo>
IRInst* argPtr = getAddress(context, loweredArg, argExpr->loc);
(*ioArgs).add(argPtr);
}
- else if (paramDecl->HasModifier<OutModifier>()
- || paramDecl->HasModifier<InOutModifier>())
+ break;
+
+ case kParameterDirection_Out:
+ case kParameterDirection_InOut:
{
// This is a `out` or `inout` parameter, and so
// the argument must be lowered as an l-value.
@@ -2508,6 +2602,13 @@ struct ExprLoweringVisitorBase : ExprVisitor<Derived, LoweredValInfo>
// In each of these cases, the safe option is to create
// a temporary variable to use for argument-passing,
// and then do copy-in/copy-out around the call.
+ //
+ // TODO: We should consider ruling out case (2) as undefined
+ // behavior, and specify that whether `inout` and `out` are
+ // handled via copy-in-copy-out or by-reference parameter
+ // passing is an implementation detail. That would allow
+ // us to avoid introducing a copy except where it is required
+ // for the semantics of (1).
LoweredValInfo tempVar = createVar(context, paramType);
@@ -2515,8 +2616,7 @@ struct ExprLoweringVisitorBase : ExprVisitor<Derived, LoweredValInfo>
// to ensure that we pass in the original value stored
// in the argument, which we accomplish by assigning
// from the l-value to our temp.
- if (paramDecl->HasModifier<InModifier>()
- || paramDecl->HasModifier<InOutModifier>())
+ if(paramDirection == kParameterDirection_InOut)
{
assign(context, tempVar, loweredArg);
}
@@ -2536,13 +2636,67 @@ struct ExprLoweringVisitorBase : ExprVisitor<Derived, LoweredValInfo>
(*ioFixups).add(fixup);
}
- else
+ break;
+
+ default:
{
// This is a pure input parameter, and so we will
// pass it as an r-value.
LoweredValInfo loweredArg = lowerRValueExpr(context, argExpr);
addArgs(context, ioArgs, loweredArg);
}
+ break;
+ }
+ }
+
+ void addDirectCallArgs(
+ InvokeExpr* expr,
+ DeclRef<CallableDecl> funcDeclRef,
+ List<IRInst*>* ioArgs,
+ List<OutArgumentFixup>* ioFixups)
+ {
+ UInt argCount = expr->Arguments.getCount();
+ UInt argCounter = 0;
+ for (auto paramDeclRef : getMembersOfType<ParamDecl>(funcDeclRef))
+ {
+ auto paramDecl = paramDeclRef.getDecl();
+ IRType* paramType = lowerType(context, GetType(paramDeclRef));
+ auto paramDirection = getParameterDirection(paramDecl);
+
+ UInt argIndex = argCounter++;
+ RefPtr<Expr> argExpr;
+ if(argIndex < argCount)
+ {
+ argExpr = expr->Arguments[argIndex];
+ }
+ else
+ {
+ // We have run out of arguments supplied at the call site,
+ // but there are still parameters remaining. This must mean
+ // that these parameters have default argument expressions
+ // associated with them.
+ argExpr = getInitExpr(paramDeclRef);
+
+ // Assert that such an expression must have been present.
+ SLANG_ASSERT(argExpr);
+
+ // TODO: The approach we are taking here to default arguments
+ // is simplistic, and has consequences for the front-end as
+ // well as binary serialization of modules.
+ //
+ // We could consider some more refined approaches where, e.g.,
+ // functions with default arguments generate multiple IR-level
+ // functions, that compute and provide the default values.
+ //
+ // Alternatively, each parameter with defaults could be generated
+ // into its own callable function that provides the default value,
+ // so that calling modules can call into a pre-generated function.
+ //
+ // Each of these options involves trade-offs, and we need to
+ // make a conscious decision at some point.
+ }
+
+ addCallArgsForParam(paramType, paramDirection, argExpr, ioArgs, ioFixups);
}
}
@@ -2693,8 +2847,14 @@ struct ExprLoweringVisitorBase : ExprVisitor<Derived, LoweredValInfo>
// a member function:
if( baseExpr )
{
- auto loweredBaseVal = lowerRValueExpr(context, baseExpr);
- addArgs(context, &irArgs, loweredBaseVal);
+ auto thisType = getThisParamTypeForCallable(context, funcDeclRef);
+ auto irThisType = lowerType(context, thisType);
+ addCallArgsForParam(
+ irThisType,
+ getThisParamDirection(funcDeclRef.getDecl()),
+ baseExpr,
+ &irArgs,
+ &argFixups);
}
// Then we have the "direct" arguments to the call.
@@ -5240,21 +5400,6 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo>
}
- DeclRef<Decl> createDefaultSpecializedDeclRefImpl(Decl* decl)
- {
- DeclRef<Decl> declRef;
- declRef.decl = decl;
- declRef.substitutions = createDefaultSubstitutions(context->getSession(), decl);
- return declRef;
- }
- //
- // The client should actually call the templated wrapper, to preserve type information.
- template<typename D>
- DeclRef<D> createDefaultSpecializedDeclRef(D* decl)
- {
- DeclRef<Decl> declRef = createDefaultSpecializedDeclRefImpl(decl);
- return declRef.as<D>();
- }
// When lowering something callable (most commonly a function declaration),
@@ -5295,13 +5440,6 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo>
// To handle this we break out the relevant data into derived
// structures:
//
- enum ParameterDirection
- {
- kParameterDirection_In, ///< Copy in
- kParameterDirection_Out, ///< Copy out
- kParameterDirection_InOut, ///< Copy in, copy out
- kParameterDirection_Ref, ///< By-reference
- };
struct ParameterInfo
{
// This AST-level type of the parameter
@@ -5318,36 +5456,6 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo>
bool isThisParam = false;
};
//
- // We need a way to compute the appropriate `ParameterDirection` for a
- // declared parameter:
- //
- ParameterDirection getParameterDirection(VarDeclBase* paramDecl)
- {
- if( paramDecl->HasModifier<RefModifier>() )
- {
- // The AST specified `ref`:
- return kParameterDirection_Ref;
- }
- if( paramDecl->HasModifier<InOutModifier>() )
- {
- // The AST specified `inout`:
- return kParameterDirection_InOut;
- }
- if (paramDecl->HasModifier<OutModifier>())
- {
- // We saw an `out` modifier, so now we need
- // to check if there was a paired `in`.
- if(paramDecl->HasModifier<InModifier>())
- return kParameterDirection_InOut;
- else
- return kParameterDirection_Out;
- }
- else
- {
- // No direction modifier, or just `in`:
- return kParameterDirection_In;
- }
- }
// We need a way to be able to create a `ParameterInfo` given the declaration
// of a parameter:
//
@@ -5412,22 +5520,6 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo>
ioParameterLists->params.add(info);
}
- void addThisParameter(
- ParameterDirection direction,
- AggTypeDecl* typeDecl,
- ParameterLists* ioParameterLists)
- {
- // We need to construct an appopriate declaration-reference
- // for the type declaration we were given. In particular,
- // we need to specialize it for any generic parameters
- // that are in scope here.
- auto declRef = createDefaultSpecializedDeclRef(typeDecl);
- RefPtr<Type> type = DeclRefType::Create(context->getSession(), declRef);
- addThisParameter(
- direction,
- type,
- ioParameterLists);
- }
//
// And here is our function that will do the recursive walk:
void collectParameterLists(
@@ -5457,26 +5549,11 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo>
// parameter corresponding to the outer declaration.
if( innerMode != kParameterListCollectMode_Static )
{
- // For now we make any `this` parameter default to `in`.
- //
- ParameterDirection direction = kParameterDirection_In;
- //
- // Applications can opt in to a mutable `this` parameter,
- // by applying the `[mutating]` attribute to their
- // declaration.
- //
- if( decl->HasModifier<MutatingAttribute>() )
- {
- direction = kParameterDirection_InOut;
- }
-
- if( auto aggTypeDecl = as<AggTypeDecl>(parentDecl) )
- {
- addThisParameter(direction, aggTypeDecl, ioParameterLists);
- }
- else if( auto extensionDecl = as<ExtensionDecl>(parentDecl) )
+ ParameterDirection direction = getThisParamDirection(decl);
+ auto thisType = getThisParamTypeForContainer(context, createDefaultSpecializedDeclRef(context, parentDecl));
+ if(thisType)
{
- addThisParameter(direction, extensionDecl->targetType, ioParameterLists);
+ addThisParameter(direction, thisType, ioParameterLists);
}
}
}
diff --git a/tests/compute/mutating-and-inout.slang b/tests/compute/mutating-and-inout.slang
new file mode 100644
index 000000000..d06933b77
--- /dev/null
+++ b/tests/compute/mutating-and-inout.slang
@@ -0,0 +1,49 @@
+// mutating-and-inout.slang
+
+// Test that calling a `[mutating]` method on an `inout` function parameter works.
+
+//TEST(compute):COMPARE_COMPUTE:
+
+//TEST_INPUT:ubuffer(data=[0 0 0 0], stride=4):out,name=outputBuffer
+RWStructuredBuffer<int> outputBuffer;
+
+struct A
+{
+ int x;
+
+ [mutating] void doThings(inout int y)
+ {
+ int tmp = x;
+ x = y;
+ y = tmp;
+ }
+}
+
+int doThings(inout A a, int val)
+{
+ a.doThings(val);
+ return val;
+}
+
+int test(int val)
+{
+ A a = { val };
+ int b = val ^ 3;
+
+ int c = doThings(a, b);
+
+ int result = 0;
+ result = result*16 + a.x;
+ result = result*16 + b;
+ result = result*16 + c;
+
+ return result;
+}
+
+[numthreads(4, 1, 1)]
+void computeMain(uint3 tid : SV_DispatchThreadID)
+{
+ int val = tid.x;
+ val = test(val);
+ outputBuffer[tid.x] = val;
+}
diff --git a/tests/compute/mutating-and-inout.slang.expected.txt b/tests/compute/mutating-and-inout.slang.expected.txt
new file mode 100644
index 000000000..6c842da55
--- /dev/null
+++ b/tests/compute/mutating-and-inout.slang.expected.txt
@@ -0,0 +1,4 @@
+330
+221
+112
+3