summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorTim Foley <tfoleyNV@users.noreply.github.com>2017-11-17 19:01:58 -0800
committerGitHub <noreply@github.com>2017-11-17 19:01:58 -0800
commitba594d0d233f1bb1a345ff158571d16862a974cd (patch)
tree9a71661469de51cf43e1323b33260fd45218f3fe
parent54bf54bd0dda378f8400860b25855558f39cb52b (diff)
IR: Add support for `out` and `inout` parameters (#289)
These were already being handled a little bit, by lowering an `out T` or `inout T` function parameter in the AST to a function parameter with type `T*` in the IR, and then emiting explicit loads/stores. The HLSL emit logic, however, couldn't tell the difference between an `out` parameter, an `inout`, or a true pointer (if we ever needed to support them...). The intention (not fully implemented) was that we'd use a hierarchy of types rooted at `PtrTypeBase`: - `PtrTypeBase` - `Ptr`: "real" pointers in the C/C++ sense - `OutTypeBase`: pointers used to represent by-reference parameter passing - `OutType`: IR level type for an `out` parameter - `InOutType`: IR level type for an `inout` or `in out` parameter Actually implementing this involved: - Adding a bit more flexibility to the `Session::getPtrType` logic to allow for creating any of the concrete types above - Making the `lower-to-ir` logic create the right type for function parameters (instead of just using `PtrType`) - Making the HLSL emit logic check for the `OutType` and `InOutType` cases rather than just `PtrType` - Changing a bunch of small places in the code so that they use `PtrTypeBase` instead of `PtrType` when they should handle any of the above cases, and also make a few places check for `OutTypeBase` instead of `PtrType` or `PtrTypeBase`, when they are really trying to capture by-reference parameters - Add a test case that uses all of the different cases we care about (without these fixes, this test case generates errors from fxc because of variables being used before being initialized, becaues parameters get declared `out` that should be `inout`). A minor point here is that we are playing a bit fast and loose right now because the IR does not actually enforce any type checks. From the standpoint of the front end, `Ptr<T>`, `Out<T>`, and `InOut<T>` are all unrelated types (each is just a `struct` declared in `core.meta.slang`), but this doesn't really matter because none of these are types our current users are explicitly using. In the IR it makes perfect sense to allow `Out<T>` or `InOut<T>` as the operand of a `load` or `store` instruction (and ditto for `getFieldAddr`, etc.) - there instructions just apply to any `PtrTypeBase`. The place where this potentially gets tricky is whether an `Out<T>` can be used where a `Ptr<T>` is expected, or vice vers (e.g., can I just pass my local variable's pointer directly to an `Out<T>` function parameter? I'm going to ignore these issues for now, since the code currently works for our test case.
-rw-r--r--source/slang/compiler.h17
-rw-r--r--source/slang/core.meta.slang10
-rw-r--r--source/slang/core.meta.slang.h10
-rw-r--r--source/slang/emit.cpp20
-rw-r--r--source/slang/ir.cpp12
-rw-r--r--source/slang/lower-to-ir.cpp38
-rw-r--r--source/slang/syntax.cpp29
-rw-r--r--tests/compute/inout.slang47
-rw-r--r--tests/compute/inout.slang.expected.txt4
9 files changed, 144 insertions, 43 deletions
diff --git a/source/slang/compiler.h b/source/slang/compiler.h
index 303be6624..a48ad2287 100644
--- a/source/slang/compiler.h
+++ b/source/slang/compiler.h
@@ -442,9 +442,24 @@ namespace Slang
// Should not be used in front-end code
Type* getIRBasicBlockType();
- // Construct pointer types on-demand
+ // Construct the type `Ptr<valueType>`, where `Ptr`
+ // is looked up as a builtin type.
RefPtr<PtrType> getPtrType(RefPtr<Type> valueType);
+ // Construct the type `Out<valueType>`
+ RefPtr<OutType> getOutType(RefPtr<Type> valueType);
+
+ // Construct the type `InOut<valueType>`
+ RefPtr<InOutType> getInOutType(RefPtr<Type> valueType);
+
+ // Construct a pointer type like `Ptr<valueType>`, but where
+ // the actual type name for the pointer type is given by `ptrTypeName`
+ RefPtr<PtrTypeBase> getPtrType(RefPtr<Type> valueType, char const* ptrTypeName);
+
+ // Construct a pointer type like `Ptr<valueType>`, but where
+ // the generic declaration for the pointer type is `genericDecl`
+ RefPtr<PtrTypeBase> getPtrType(RefPtr<Type> valueType, GenericDecl* genericDecl);
+
RefPtr<ArrayExpressionType> getArrayType(
Type* elementType,
IntVal* elementCount);
diff --git a/source/slang/core.meta.slang b/source/slang/core.meta.slang
index f36b53227..85137482d 100644
--- a/source/slang/core.meta.slang
+++ b/source/slang/core.meta.slang
@@ -101,6 +101,16 @@ __magic_type(PtrType)
struct Ptr
{};
+__generic<T>
+__magic_type(OutType)
+struct Out
+{};
+
+__generic<T>
+__magic_type(InOutType)
+struct InOut
+{};
+
${{{{
diff --git a/source/slang/core.meta.slang.h b/source/slang/core.meta.slang.h
index 1c0f28b26..8644c7c90 100644
--- a/source/slang/core.meta.slang.h
+++ b/source/slang/core.meta.slang.h
@@ -102,6 +102,16 @@ sb << "__magic_type(PtrType)\n";
sb << "struct Ptr\n";
sb << "{};\n";
sb << "\n";
+sb << "__generic<T>\n";
+sb << "__magic_type(OutType)\n";
+sb << "struct Out\n";
+sb << "{};\n";
+sb << "\n";
+sb << "__generic<T>\n";
+sb << "__magic_type(InOutType)\n";
+sb << "struct InOut\n";
+sb << "{};\n";
+sb << "\n";
sb << "";
diff --git a/source/slang/emit.cpp b/source/slang/emit.cpp
index 614e8f474..383442236 100644
--- a/source/slang/emit.cpp
+++ b/source/slang/emit.cpp
@@ -5946,19 +5946,15 @@ emitDeclImpl(decl, nullptr);
// encoded as a parameter of pointer type, so
// we need to decode that here.
//
- if( auto ptrType = type->As<PtrType>() )
+ if( auto outType = type->As<OutType>() )
{
- // TODO: we need a way to distinguish `out`
- // from `inout`. The easiest way to do
- // that might be to have each be a distinct
- // sub-case of `IRPtrType` - this would also
- // ensure that they can be distinguished from
- // real pointers when the user means to use
- // them.
-
emit("out ");
-
- type = ptrType->getValueType();
+ type = outType->getValueType();
+ }
+ else if( auto inOutType = type->As<InOutType>() )
+ {
+ emit("inout ");
+ type = inOutType->getValueType();
}
emitIRType(ctx, type, name);
@@ -6595,7 +6591,7 @@ emitDeclImpl(decl, nullptr);
{
emitIRUsedType(ctx, genericType->elementType);
}
- else if( auto ptrType = type->As<PtrType>() )
+ else if( auto ptrType = type->As<PtrTypeBase>() )
{
emitIRUsedType(ctx, ptrType->getValueType());
}
diff --git a/source/slang/ir.cpp b/source/slang/ir.cpp
index bfc26643c..2d3127a61 100644
--- a/source/slang/ir.cpp
+++ b/source/slang/ir.cpp
@@ -956,7 +956,7 @@ namespace Slang
IRInst* IRBuilder::emitLoad(
IRValue* ptr)
{
- auto ptrType = ptr->getType()->As<PtrType>();
+ auto ptrType = ptr->getType()->As<PtrTypeBase>();
if( !ptrType )
{
// Bad!
@@ -2849,12 +2849,10 @@ namespace Slang
builder.curBlock = firstBlock;
builder.insertBeforeInst = firstBlock->getFirstInst();
- // TODO: We need to distinguish any true pointers in the
- // user's code from pointers that only exist for
- // parameter-passing. This `PtrType` here should actually
- // be `OutTypeBase`, but I'm not confident that all
- // the other code is handling that correctly...
- if(auto paramPtrType = paramType->As<PtrType>() )
+ // Is the parameter type a special pointer type
+ // that indicates the parameter is used for `out`
+ // or `inout` access?
+ if(auto paramPtrType = paramType->As<OutTypeBase>() )
{
// Okay, we have the more interesting case here,
// where the parameter was being passed by reference.
diff --git a/source/slang/lower-to-ir.cpp b/source/slang/lower-to-ir.cpp
index 0f3e85805..b395d2a95 100644
--- a/source/slang/lower-to-ir.cpp
+++ b/source/slang/lower-to-ir.cpp
@@ -819,14 +819,6 @@ IRType* getIntType(
return context->getSession()->getBuiltinType(BaseType::Int);
}
-// Get a pointer type to the given element type
-RefPtr<PtrType> getPtrType(
- IRGenContext* context,
- IRType* valueType)
-{
- return context->getSession()->getPtrType(valueType);
-}
-
RefPtr<IRFuncType> getFuncType(
IRGenContext* context,
UInt paramCount,
@@ -1089,7 +1081,7 @@ struct ExprLoweringVisitorBase : ExprVisitor<Derived, LoweredValInfo>
RefPtr<Type> loweredBaseType = loweredBaseVal->getType();
if (loweredBaseType->As<PointerLikeType>()
- || loweredBaseType->As<PtrType>())
+ || loweredBaseType->As<PtrTypeBase>())
{
// Note that we do *not* perform an actual `load` operation
// here, but rather just use the pointer value to construct
@@ -1461,7 +1453,7 @@ struct ExprLoweringVisitorBase : ExprVisitor<Derived, LoweredValInfo>
case LoweredValInfo::Flavor::Ptr:
return LoweredValInfo::ptr(
builder->emitElementAddress(
- getPtrType(context, getSimpleType(type)),
+ context->getSession()->getPtrType(getSimpleType(type)),
baseVal.val,
indexVal));
@@ -1498,7 +1490,7 @@ struct ExprLoweringVisitorBase : ExprVisitor<Derived, LoweredValInfo>
IRValue* irBasePtr = base.val;
return LoweredValInfo::ptr(
getBuilder()->emitFieldAddress(
- getPtrType(context, getSimpleType(fieldType)),
+ context->getSession()->getPtrType(getSimpleType(fieldType)),
irBasePtr,
getBuilder()->getDeclRefVal(field)));
}
@@ -3114,14 +3106,22 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo>
paramTypes.Add(irParamType);
break;
- default:
- // The parameter is being used for input/output purposes,
- // so it will lower to an actual parameter with a pointer type.
- //
- // TODO: Is this the best representation we can use?
+ // If the parameter is declared `out` or `inout`,
+ // then we will represent it with a pointer type in
+ // the IR, but we will use a specialized pointer
+ // type that encodes the parameter direction information.
+ case kParameterDirection_Out:
+ paramTypes.Add(
+ context->getSession()->getOutType(irParamType));
+ break;
+ case kParameterDirection_InOut:
+ paramTypes.Add(
+ context->getSession()->getInOutType(irParamType));
+ break;
- auto irPtrType = getPtrType(context, irParamType);
- paramTypes.Add(irPtrType);
+ default:
+ SLANG_UNEXPECTED("unknown parameter direction");
+ break;
}
}
@@ -3190,7 +3190,7 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo>
//
// TODO: Is this the best representation we can use?
- auto irPtrType = irParamType.As<PtrType>();
+ auto irPtrType = irParamType.As<PtrTypeBase>();
IRParam* irParamPtr = subBuilder->emitParam(irPtrType);
if(auto paramDecl = paramInfo.decl)
diff --git a/source/slang/syntax.cpp b/source/slang/syntax.cpp
index fa9c88051..e43dd9074 100644
--- a/source/slang/syntax.cpp
+++ b/source/slang/syntax.cpp
@@ -280,19 +280,40 @@ void Type::accept(IValVisitor* visitor, void* extra)
RefPtr<PtrType> Session::getPtrType(
RefPtr<Type> valueType)
{
+ return getPtrType(valueType, "PtrType").As<PtrType>();
+ }
+
+ // Construct the type `Out<valueType>`
+ RefPtr<OutType> Session::getOutType(RefPtr<Type> valueType)
+ {
+ return getPtrType(valueType, "OutType").As<OutType>();
+ }
+
+ RefPtr<InOutType> Session::getInOutType(RefPtr<Type> valueType)
+ {
+ return getPtrType(valueType, "InOutType").As<InOutType>();
+ }
+
+ RefPtr<PtrTypeBase> Session::getPtrType(RefPtr<Type> valueType, char const* ptrTypeName)
+ {
auto genericDecl = findMagicDecl(
- this, "PtrType").As<GenericDecl>();
+ this, ptrTypeName).As<GenericDecl>();
+ return getPtrType(valueType, genericDecl);
+ }
+
+ RefPtr<PtrTypeBase> Session::getPtrType(RefPtr<Type> valueType, GenericDecl* genericDecl)
+ {
auto typeDecl = genericDecl->inner;
-
+
auto substitutions = new GenericSubstitution();
- substitutions->genericDecl = genericDecl.Ptr();
+ substitutions->genericDecl = genericDecl;
substitutions->args.Add(valueType);
auto declRef = DeclRef<Decl>(typeDecl.Ptr(), substitutions);
return DeclRefType::Create(
this,
- declRef)->As<PtrType>();
+ declRef)->As<PtrTypeBase>();
}
RefPtr<ArrayExpressionType> Session::getArrayType(
diff --git a/tests/compute/inout.slang b/tests/compute/inout.slang
new file mode 100644
index 000000000..d56887cf9
--- /dev/null
+++ b/tests/compute/inout.slang
@@ -0,0 +1,47 @@
+//TEST(compute):COMPARE_COMPUTE:-xslang -use-ir
+//TEST_INPUT:ubuffer(data=[0 1 2 3], stride=4):dxbinding(0),glbinding(0),out
+
+// Test that we correctly support both `out`
+// and `inout` function parameters.
+
+void testOut(int x, out int y)
+{
+ y = x;
+}
+
+void testInOut(int x, in out int y)
+{
+ y = y + x;
+}
+
+void testInout(int x, inout int y)
+{
+ y = y + x;
+}
+
+int test(int inVal)
+{
+ int x0 = inVal;
+ int x1;
+
+ testOut(x0, x1);
+
+ int x2 = x0;
+ testInOut(x1, x2);
+
+ int x3 = x0;
+ testInout(x2, x3);
+
+ return x3;
+}
+
+RWStructuredBuffer<int> outputBuffer : register(u0);
+
+[numthreads(4, 1, 1)]
+void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID)
+{
+ uint tid = dispatchThreadID.x;
+ int inVal = outputBuffer[tid];
+ int outVal = test(inVal);
+ outputBuffer[tid] = outVal;
+} \ No newline at end of file
diff --git a/tests/compute/inout.slang.expected.txt b/tests/compute/inout.slang.expected.txt
new file mode 100644
index 000000000..f9d85ed42
--- /dev/null
+++ b/tests/compute/inout.slang.expected.txt
@@ -0,0 +1,4 @@
+0
+3
+6
+9